split pkgs only for eval usage address #97; clean-up

This commit is contained in:
SWivid
2024-10-15 21:14:44 +08:00
parent 423fe4a0a5
commit bc6331529a
5 changed files with 18 additions and 17 deletions

View File

@@ -148,6 +148,12 @@ bash scripts/eval_infer_batch.sh
### Objective Evaluation
Install packages for evaluation:
```bash
pip install -r requirements_eval.txt
```
**Some Notes**
For faster-whisper with CUDA 11:

View File

@@ -1,4 +1,3 @@
import os
import re
import torch
import torchaudio
@@ -17,7 +16,6 @@ from model.utils import (
save_spectrogram,
)
from transformers import pipeline
import librosa
import click
import soundfile as sf

View File

@@ -22,12 +22,6 @@ from einops import rearrange, reduce
import jieba
from pypinyin import lazy_pinyin, Style
import zhconv
from zhon.hanzi import punctuation
from jiwer import compute_measures
from funasr import AutoModel
from faster_whisper import WhisperModel
from model.ecapa_tdnn import ECAPA_TDNN_SMALL
from model.modules import MelSpec
@@ -432,6 +426,7 @@ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path
def load_asr_model(lang, ckpt_dir = ""):
if lang == "zh":
from funasr import AutoModel
model = AutoModel(
model = os.path.join(ckpt_dir, "paraformer-zh"),
# vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
@@ -440,6 +435,7 @@ def load_asr_model(lang, ckpt_dir = ""):
disable_update=True,
) # following seed-tts setting
elif lang == "en":
from faster_whisper import WhisperModel
model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
model = WhisperModel(model_size, device="cuda", compute_type="float16")
return model
@@ -451,6 +447,7 @@ def run_asr_wer(args):
rank, lang, test_set, ckpt_dir = args
if lang == "zh":
import zhconv
torch.cuda.set_device(rank)
elif lang == "en":
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
@@ -458,10 +455,12 @@ def run_asr_wer(args):
raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
from zhon.hanzi import punctuation
punctuation_all = punctuation + string.punctuation
wers = []
from jiwer import compute_measures
for gen_wav, prompt_wav, truth in tqdm(test_set):
if lang == "zh":
res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)

View File

@@ -5,11 +5,8 @@ datasets
einops>=0.8.0
einx>=0.3.0
ema_pytorch>=0.5.2
faster_whisper
funasr
gradio
jieba
jiwer
librosa
matplotlib
numpy<=1.26.4
@@ -17,14 +14,10 @@ pydub
pypinyin
safetensors
soundfile
# torch>=2.0
# torchaudio>=2.3.0
tomli
torchdiffeq
tqdm>=4.65.0
transformers
vocos
wandb
x_transformers>=1.31.14
zhconv
zhon
tomli

5
requirements_eval.txt Normal file
View File

@@ -0,0 +1,5 @@
faster_whisper
funasr
jiwer
zhconv
zhon