Author
Vikas Reddy - University of Maryland
Abstract
In the rapidly evolving landscape of communication technology, recent breakthroughs, notably the OpenAI Whisper model, have significantly enhanced the accuracy and accessibility of multilingual speech-to-text capabilities. However, despite these developments, there remains room for enhancement, particularly in terms of accuracy. This research is dedicated to enhancing the capabilities of automatic speech recognition (ASR) models, with a specific focus on Vietnamese and Japanese languages.
Our evaluation employs standard measures to gauge performance. For Vietnamese, we use the Word Error Rate (WER) metric, which assesses how often the recognized words deviate from the actual spoken words. For Japanese, we employ the Character Error Rate (CER) metric, which scrutinizes the accuracy of individual characters.
Key Results:
- Vietnamese (FOSD + Common Voice + Google Fleurs + Vivos): WER 9.46%
- Japanese (ReazonSpeech + Common Voice + Google Fleurs): CER 8.15%
Table of Contents
- Background Information
- Environment Setup
- Load Datasets
- Data Preprocessing
- Training
- Parameter Efficient Fine-tuning
- Results
- Evaluation
- Azure Speech Studio
- Conclusion
In today’s society, communication and technology have become indispensable, yet several challenges persist that impact accessibility, inclusivity, and efficient knowledge dissemination. This is where advancements like automatic speech recognition (ASR) step in, streamlining interactions between humans and computers, particularly evident in online meeting calls.
ASR is the process of converting speech signals into its corresponding text. Recently, this task has gained attraction with various corporations, thanks to the accessibility of large speech datasets, along with their corresponding transcripts.
OpenAI Whisper is a Transformer-based encoder-decoder model, specifically designed as a sequence-to-sequence architecture. It takes audio spectrogram features as the input and converts them into a sequence of text tokens. This involves:
- A feature extractor that transforms raw audio into a log-Mel spectrogram
- A Transformer encoder generating a sequence of encoder hidden states
- A decoder that predicts text tokens using cross-attention mechanisms
2. Environment Setup
There are two distinct approaches to fine-tune Whisper: utilizing Google Colab and running the code on a local PC.
Required Packages
python -m pip install -U pip
pip install evaluate pandas numpy huggingface_hub pydub tqdm spacy ginza audiomentations
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install datasets>=2.6.1
pip install git+https://github.com/huggingface/transformers
pip install librosa
pip install evaluate>=0.30
pip install jiwer
pip install gradio
pip install -q bitsandbytes datasets accelerate loralib
pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git@main
The computer configuration used for this article’s fine-tuning tasks involves a Windows 11 Pro PC with an AMD Ryzen 7 3700X 8-Core Processor with 80GB of RAM and GeForce RTX 3090 NVIDIA Graphics Card.
3. Load Datasets
Method 1: Using Hugging Face
from datasets import load_dataset, DatasetDict
common_voice = DatasetDict()
common_voice["train"] = load_dataset(
"mozilla-foundation/common_voice_11_0", "ja",
split="train+validation", use_auth_token=True
)
common_voice["test"] = load_dataset(
"mozilla-foundation/common_voice_11_0", "ja",
split="test", use_auth_token=True
)
common_voice = common_voice.remove_columns([
"accent", "age", "client_id", "down_votes",
"gender", "locale", "path", "segment", "up_votes"
])
Method 2: Manual Dataset Preparation
import os, csv, codecs
def text_change_csv(input_path, output_path):
file_csv = os.path.splitext(output_path)[0] + ".csv"
output_dir = os.path.dirname(input_path)
output_file = os.path.join(output_dir, file_csv)
encodings = ["utf-8", "latin-1"]
for encoding in encodings:
try:
with open(input_path, 'r', encoding=encoding) as rf:
with codecs.open(output_file, 'w', encoding=encoding, errors='replace') as wf:
readfile = rf.readlines()
for read_text in readfile:
read_text = read_text.split('|')
writer = csv.writer(wf, delimiter=',')
writer.writerow(read_text)
print(f"CSV has been created using encoding: {encoding}")
return True
except UnicodeDecodeError:
continue
Datasets Used
| Dataset | Language | Usage | Speech Audio (Hours) |
|---|
| Common Voice 13.0 | Vietnamese, Japanese | Hugging Face | 19h (VN), 10h (JP) |
| Google Fleurs | Vietnamese, Japanese | Hugging Face | 11h (VN), 8h (JP) |
| Vivos | Vietnamese | Hugging Face | 15h |
| FPT Open Speech Dataset | Vietnamese | Download & extract | 30h |
| VLSP2020 | Vietnamese | Download & extract | 100h |
| ReazonSpeech | Japanese | Hugging Face | 5h |
| JSUT | Japanese | Download & extract | 10h |
| JVS | Japanese | Download & extract | 30h |
4. Data Preprocessing
Data Augmentation
from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))
augment_waveform = Compose([
AddGaussianNoise(min_amplitude=0.005, max_amplitude=0.015, p=0.2),
TimeStretch(min_rate=0.8, max_rate=1.25, p=0.2, leave_length_unchanged=False),
PitchShift(min_semitones=-4, max_semitones=4, p=0.2)
])
def augment_dataset(batch):
audio = batch["audio"]["array"]
augmented_audio = augment_waveform(samples=audio, sample_rate=16000)
batch["audio"]["array"] = augmented_audio
return batch
common_voice['train'] = common_voice['train'].map(augment_dataset, keep_in_memory=True)
Transcript Normalization
import string
def remove_punctuation(sentence):
translator = str.maketrans('', '', string.punctuation)
modified_sentence = sentence.translate(translator)
return modified_sentence
def fix_sentence(sentence):
transcription = sentence
if transcription.startswith('"') and transcription.endswith('"'):
transcription = transcription[1:-1]
transcription = remove_punctuation(transcription)
transcription = transcription.lower()
return transcription
Prepare Dataset for Whisper
def prepare_dataset(batch):
audio = batch["audio"]
batch["input_features"] = processor.feature_extractor(
audio["array"], sampling_rate=audio["sampling_rate"]
).input_features[0]
batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]
transcription = fix_sentence(batch["transcription"])
batch["labels"] = processor.tokenizer(
transcription, max_length=225, truncation=True
).input_ids
return batch
common_voice = common_voice.map(
prepare_dataset,
remove_columns=common_voice.column_names['train'],
num_proc=1,
keep_in_memory=True
)
5. Training
Data Collator
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
label_features = [{"input_ids": feature["labels"]} for feature in features]
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
Evaluation Metrics (Vietnamese - WER)
import evaluate
metric = evaluate.load("wer")
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
label_ids[label_ids == -100] = tokenizer.pad_token_id
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
Evaluation Metrics (Japanese - CER)
import spacy, ginza
nlp = spacy.load("ja_ginza")
ginza.set_split_mode(nlp, "C")
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
# Tokenize Japanese text for proper evaluation
pred_str = [" ".join([str(i) for i in nlp(j)]) for j in pred_str]
label_str = [" ".join([str(i) for i in nlp(j)]) for j in label_str]
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
Training Arguments
from transformers import Seq2SeqTrainingArguments
model.config.dropout = 0.05
training_args = Seq2SeqTrainingArguments(
output_dir="./whisper-fine-tuned",
per_device_train_batch_size=16,
gradient_accumulation_steps=1,
learning_rate=1e-6,
lr_scheduler_type='linear',
optim="adamw_bnb_8bit",
warmup_steps=200,
num_train_epochs=5,
gradient_checkpointing=True,
evaluation_strategy="steps",
fp16=True,
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=255,
eval_steps=500,
logging_steps=500,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=False,
save_total_limit=1
)
Key training parameters:
- learning_rate: 1e-5 or 1e-6 works best
- warmup_steps: Use 10% of overall steps
- per_device_train_batch_size: Set based on GPU capacity (16 for RTX 3090)
- dropout: 0.05 or 0.10 to combat overfitting
6. Parameter Efficient Fine-tuning (PEFT)
PEFT achieves competitive performance while using only 1% of trainable parameters.
| Fine-tuning | Parameter Efficient Fine-tuning |
|---|
| Faster training time | Longer training time |
| Requires larger computational resources | Uses fewer computational resources |
| Re-trains the entire model | Modifies only a small subset of parameters |
| More prone to overfitting | Less prone to overfitting |
LoRA Setup
from transformers import WhisperForConditionalGeneration, prepare_model_for_int8_training
from peft import LoraConfig, get_peft_model
model = WhisperForConditionalGeneration.from_pretrained(
model_name_or_path, load_in_8bit=True, device_map="auto"
)
model = prepare_model_for_int8_training(model)
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.model.encoder.conv1.register_forward_hook(make_inputs_require_grad)
config = LoraConfig(
r=32,
lora_alpha=64,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none"
)
model = get_peft_model(model, config)
model.print_trainable_parameters()
# Output: trainable params: 15728640 || all params: 1559033600 || trainable%: 1.01%
7. Results
Vietnamese Results
The model fine-tuned on FOSD + Google Fleurs + Vivos + CV datasets achieved the lowest WER of 9.46%.
Japanese Results
The model fine-tuned on JSUT + ReazonSpeech + Google Xtreme + CV datasets achieved the lowest CER of 8.15%.
Optimization Loss Curve
8. Evaluation
Vietnamese Evaluation
Among the evaluated datasets, the Google Fleurs + Common Voice + Vivos dataset achieved the lowest CER of 7.84%, indicating highly accurate transcriptions.
Japanese Evaluation
The combined ReazonSpeech + Google Xtreme + CV dataset achieved the lowest CER of 7.44%.
Faster-Whisper Conversion
from ctranslate2.converters import TransformersConverter
model_id = "./whisper-fine-tuned/checkpoint-5000"
output_dir = "whisper-ct2"
converter = TransformersConverter(model_id, load_as_float16=True)
converter.convert(output_dir, quantization="float16")
model = WhisperModel(output_dir, device="cuda", compute_type="float16")
Faster Whisper offers approximately 40% faster inference compared to standard fine-tuned Whisper while maintaining comparable accuracy.
9. Azure Speech Studio
Azure Speech Studio provides an alternative approach to fine-tuning ASR models.
Transcription with Azure
import os, evaluate
from azure.cognitiveservices.speech import SpeechConfig, SpeechRecognizer, AudioConfig
subscription_key = "your_subscription_key"
location = "japaneast"
endpoint = "your_endpoint"
config = SpeechConfig(subscription=subscription_key, region=location)
config.endpoint_id = endpoint
speech_config = SpeechConfig(
subscription=subscription_key,
region=location,
speech_recognition_language="ja-JP"
)
predictions = []
for root, _, files in os.walk(wav_base_path):
for file_name in files:
if file_name.endswith(".wav"):
audio_file_path = os.path.join(root, file_name)
audio_config = AudioConfig(filename=audio_file_path)
speech_recognizer = SpeechRecognizer(
speech_config=speech_config,
audio_config=audio_config
)
result = speech_recognizer.recognize_once()
if result.text:
predictions.append(result.text)
Azure Results
Vietnamese: The model trained on Common Voice 14.0 achieves WER of 7.33%
Japanese: The model trained on JSUT achieves CER of 6.97%
While Azure Speech Studio may yield lower WER during training, Whisper tends to achieve better evaluation results on unseen data, particularly on diverse and complex audio.
10. Conclusion
The process of fine-tuning the Whisper ASR model emerges as a robust technique for enhancing its performance. Key findings:
- DeepL is the most proficient for Chinese to English translation
- Fine-tuning consistently yields notable performance enhancements (WER 7.33-12.15% for Vietnamese, CER 8.15-17.93% for Japanese)
- Data augmentation via audiomentations library introduces valuable diversity
- Dataset quality matters: Amount of data, audio clarity, and topic variety all impact performance
- Whisper demonstrates superior performance in real-world scenarios compared to Azure on unseen data
References
- Radford, A., et al. (2022). Robust speech recognition via large-scale weak supervision. arXiv:2212.04356
- Ardila, R., et al. (2020). Common Voice: A Massively-Multilingual Speech Corpus. arXiv:1912.06670
- Conneau, A., et al. (2022). FLEURS: Few-Shot Learning Evaluation of Universal Representations of Speech. arXiv:2205.12446
- Gandhi, S. (2022). Fine-Tune Whisper for Multilingual ASR with Transformers. Hugging Face Blog
- Mangrulkar, S. & Paul, S. Parameter-Efficient Fine-Tuning Using PEFT. Hugging Face Blog