Kaggle Bengali AI コンペで学ぶ音声認識 (Speech Recognition)

Bengali.AI Speech Recognition

ベンガル語の音声データから文字起こしをする精度を競うkaggleコンペ “Bengali.AI Speech Recognition“が2023年10月18日まで行われていました。このコンペで使用されていた深層学習による音声認識の手法の概要、および上位入賞者の解法を紹介する中で、音声認識についての知見を共有することがこの記事の目的です。

コンペ概要

このコンペの目標は、音声記録からベンガル語の文章を認識することです。ベンガル語は3億人以上に話される世界で最も話されている言語の一つです。しかし、さまざまな地域や場面で話されることから言語としての複雑性が高く、現在でもコンピューターによる正確な音声認識は難しいというのが現状です。(たとえば、GoogleのGoogle Speech API for Bengaliでもベンガル語の宗教説法の音声認識の誤り率は74%) そのため、このコンペによって新たにベンガル語の強固な汎用音声認識モデルが誕生する可能性もあり、その点においてとても有意義なコンペであったと言えるでしょう。

学習用データセットとしては、インドとバングラデシュの24,000人から収集された1,200時間のベンガル語音声データセットが与えられました。一方テストデータセットには、学習用には存在しない17の異なるドメイン (Audio book, ドラマ、CMなど)の音声が含まれており、ドメインごとに収録環境やノイズ、発話のスピードなどもさまざまでした。そのため、コンペ参加者たちはこれらに対してロバストなモデルの構築を求められていました。

なお提出ファイルは,Word Error Rate (WER) によって評価されました。WERは音声認識や機械翻訳で用いられる一般的な評価指標です。

Wav2Vec2.0を用いた音声認識

このコンペでは、多くの参加者がWav2Vec2.0と呼ばれるモデルを採用していました。ここでは、このWav2Vec2.0を使ってベンガル語の音声認識を行う手順について説明をします。Wav2Vec2.0とは、2020年6月にFacebook AI (現Meta) によって発表された音声認識フレームワークです。音声を入力とするEnd-to-endの自己教師あり対照学習で事前学習を行った後、ラベル付きデータによりfine-tuneされています。

Wav2Vec2.0を用いてベンガル語の音声から文章を文字起こしする手順は、以下のように表されます。

  • 音声データの読み込み
  • モデルの読み込み
  • 出力の計算

それぞれの手順について順番に説明していきます。また、実装はこちらを参考に行いました。

音声データの読み込み

まずは、モデルに入力する音声データを読み込みます。なお、このコンペで使用されていたベンガル語の音声データはこちらからダウンロードすることができます。

import librosa

#パスは自分の環境に合わせて設定してください
DATA_PATH = "/kaggle/input/bengaliai-speech/train_mp3s/4b8e4d7a06fc.wav"
sr = 32000
target_sr = 16000
audio, _ = librosa.load(DATA_PATH, sr=sr, mono=False) #音声データ
audio_array = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
target_sentence = "হুমম, ওহ হেই, দেখো।" #正解ラベル

モデルの読み込み

次にモデルを読み込みます。ここで読み込むモデルは、Wav2Vec2.0モデルとLMモデルの2つであり、どちらも公開モデルです。Wav2Vec2.0モデルはこちら、LMモデルはこちらからダウンロードできます。

Wav2Vec2.0モデルは音声データを入力としてコネクショニスト時間分類 (Connectionist Temporal Classification/ CTC) を出力します。コネクショニスト時間分類とは、時系列データである音声の各時間において、どの語彙 (文字) が発話されているかの確率のようなもの (ロジット) であると考えてください。

一方、LMモデルはコネクショニスト時間分類を入力として復号化された文を予測結果として出力します。

from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
import pyctcdecode

#パスは自分の環境に合わせて設定してください
MODEL_PATH = "/kaggle/input/indicwav2vec_v1_bengali/"
LM_PATH = "/kaggle/input/wav2vec2-xls-r-300m-bengali/language_model"

#Wav2vec2.0モデルの読み込み
model = Wav2Vec2ForCTC.from_pretrained(MODEL_PATH)
processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH)

#語彙の取り出し
vocab_dict = processor.tokenizer.get_vocab()
sorted_vocab_dict = {k: v for k, v in sorted(vocab_dict.items(), key=lambda item: item[1])}

#LMモデルの読み込み
decoder = pyctcdecode.build_ctcdecoder(
    list(sorted_vocab_dict.keys()),
    str(LM_PATH + "/5gram.bin"),
)

processor_with_lm = Wav2Vec2ProcessorWithLM(
    feature_extractor=processor.feature_extractor,
    tokenizer=processor.tokenizer,
    decoder=decoder
)

出力の計算

音声データとモデルを読み込めたので、実際に出力を計算していきます。上で説明したように、音声データはまずWav2.0Vec2モデルに入力され、ロジットに変換されます。そして、ロジットがLMモデルに入力され、予測結果である文として出力されます。ここでは、予測結果および、予測結果と正解ラベルの誤差 (WER) を求めてみました。

import torch
import jiwer

input_data = dict()
waveform = processor(audio.reshape(-1), sampling_rate=target_sr).input_values[0] 
input_data["input_values"] = torch.from_numpy(waveform[None, :])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

model.eval()
input_data["input_values"] = input_data["input_values"].to(device)
y = model(**input_data).logits
y = y.detach().cpu().numpy()[0]

pred_sentence = processor_with_lm.decode(y, beam_width=768).text
print(f"pred:{pred_sentence}")
print(f"target:{target_sentence}")

score = jiwer.wer(target_sentence, pred_sentence)
print(f"score: {score}")

WERは0に近いほうが誤りが少ないことを表し、1.0が最大値なので、この場合は残念ながら全然予測ができていないということになりますね。

pred:খুনে
target:হুমম, ওহ হেই, দেখো।
score: 1.0

モデルの学習

公開モデルだけだと全然予測ができない (1サンプルだけではなんとも言えませんが、、、) という結果になりましたが、与えられた学習データを用いてWav2Vec2.0モデルをfine-tuneすることで性能を向上させることができます。このページでは割愛しましたが、ぜひチャレンジしてみてください。

上位チームの解法まとめ

上位チームの解法の振り返りを通じて、音声認識モデルの性能向上のノウハウについて紹介していこうと思います。まず、上位チームの解法をざっとまとめると以下のようになります。

  • 多くのチームがWav2Vec2.0を用いたCTCモデル + LMモデルを採用していたが,一部のチームはSTTモデルを採用してスコアを伸ばしていた
  • データ拡張や外部データを利用して学習し,スコアを伸ばしていたチームも多く存在した
  • 句読点を予測するpunctuationモデル,音源から音声だけを分離するDenoise, アンサンブルを採用してスコアを伸ばしたチームもあった

それぞれの内容について説明していきます。

ASRモデル

多くのチームはWav2Vec2.0を用いたCTCモデル + LMモデルを採用していましたが (引用: 2nd, 3rd, 4th, 5th, 8th, 12th, 14th, etc…)、一部のチームはSTT (Speech-to-Text)モデルのWhisperを使用することでベーススコアを伸ばしていました (引用: 1st, 11th, etc…)。1st Solutionによると、WhisperはOOD (分布外データ) に強く、テストデータの分布が学習データと大きく異なっていたこのコンペにおいて有効であったとのことです。

また、CTC+LMモデルを採用したチームの中には、テストデータにのみ存在する語彙が多数あったため、LMモデルを自作したチームもありました (引用: 2nd, 3rd, 4th, etc…)。

データ拡張

データ拡張によりスコアを伸ばしたチームも多数存在しました。1st Solutionでは、16kHz→8kHz→16kHzと音声をリサンプルしたり、スピードやピッチを変更したりしてデータを拡張していました。4th Solution12th Solutionでは、音声データの分割、結合やノイズの追加を行なっていました。

データセット

外部データの追加によりスコアを伸ばしたチームも同じく多数存在しました。OpenSLR、MADASR、Shrutilipi、Kathbath、ULCAや、YouTubeの動画などが利用されていたようでした。

Punctuationモデル

句読点を予測するPunctuationモデルも上位入賞のためには重要であったようでした (引用: 1st, 2nd, 3rd, 5th, 12th, 14th, etc…) 。句読点を予測することによって、以下のように大きくスコアを上げることができる可能性があります。

(label) hello. how are you?
(predict) hello how are you
→ wer: 0.5
(restore) hello. how are you.
→ wer: 0.25

Denoise

3rd Solutionでは、Demucsを用いた音源分離によりスコアを向上させていました。また、Demucsを適用することで音声の質が悪くなってしまう可能性があるため、音声が悪くなるかどうかを評価し、Demucsを使用するかどうかを切り替えるという工夫がなされていました。

アンサンブル

ASRモデルが出力する予測結果はモデルによってその長さが異なるため、ロジットを単純にアンサンブルすることはできません。しかし、5th Solution8th Solutionはそれぞれ異なる方法でアンサンブルを成功させていました。

5th Solutionでは、複数の学習済のASRモデルの凍結されたembeddingを連結してTransformer encoderに入力することでアンサンブルに成功していました。

8th Solutionでは、複数のCTCモデルに予測に対する自信をスコアとして出力させ、最高スコアの予測結果を採用することでアンサンブルに成功していました。

まとめ

今回のコンペではベースモデル、外部データ、punctuationモデルやDenoiseなど、工夫できるポイントが豊富に用意されていました。そのため、スコアを伸ばすための手段としてこれらの手法の情報をしっかりと集めること、そして各手法を地道に検証し取捨選択することが上位入賞のためには必要であったように思います。

また、このコンペを通じて音声認識の最先端の手法についてある程度知ることができため、そういった点でもよいコンペであったと感じています。

参照