ai/SpeechRecognition/asr_service.py

303 lines
9.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
FunASR 语音识别服务
支持句级时间戳、说话人分离FunASR CAM++)、抗噪
"""
import os
import sys
MODEL_CACHE_DIR = os.path.dirname(os.path.abspath(__file__))
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
os.environ["MODELSCOPE_CACHE"] = MODEL_CACHE_DIR
os.environ["FUNASR_MODELS_DIR"] = MODEL_CACHE_DIR
if sys.platform == "win32":
os.environ["PYTHONLEGACYWINDOWSFSENCODING"] = "1"
import json
from pathlib import Path
from typing import List, Dict, Union, Optional
from dataclasses import dataclass
import warnings
warnings.filterwarnings('ignore')
@dataclass
class Sentence:
"""识别结果句子"""
speaker: str
text: str
begin_time: float
end_time: float
def to_dict(self) -> Dict:
return {
"speaker": self.speaker,
"text": self.text,
"begin_time": round(self.begin_time, 2),
"end_time": round(self.end_time, 2),
"duration": round(self.end_time - self.begin_time, 2)
}
def __str__(self) -> str:
return f"[{self.speaker}] {self.text} ({self.begin_time:.2f}s - {self.end_time:.2f}s)"
class ASRService:
"""
语音识别服务
功能:
1. 语音识别ASR
2. 句级时间戳
3. 说话人分离FunASR 内置 CAM++
4. 语音活动检测VAD- 抗噪
"""
def __init__(
self,
model_name: str = "paraformer-zh",
device: str = "auto",
cache_dir: Optional[str] = None,
merge_segments: bool = True,
min_segment_duration: float = 0.3,
merge_gap: float = 0.5
):
"""
初始化 ASR 服务
Args:
model_name: 模型名称
- "paraformer-zh": 达摩院 Paraformer 模型(推荐中文)
- "SenseVoice": SenseVoice 多语言模型
device: 运行设备 ("cpu", "cuda", "auto")
cache_dir: 模型缓存目录
merge_segments: 是否合并相邻的同一说话人片段
min_segment_duration: 最小片段时长阈值(过滤噪音)
merge_gap: 合并片段的时间间隔阈值
"""
self.model_name = model_name
self.device = device
self.cache_dir = cache_dir or MODEL_CACHE_DIR
self.merge_segments = merge_segments
self.min_segment_duration = min_segment_duration
self.merge_gap = merge_gap
os.makedirs(self.cache_dir, exist_ok=True)
self.device = self._get_device(device)
self._model = None
def _get_device(self, device: str) -> str:
import torch
if device == "auto":
if torch.cuda.is_available():
device = "cuda"
print(f"检测到 GPU: {torch.cuda.get_device_name(0)}")
else:
device = "cpu"
print("未检测到 GPU使用 CPU 运行")
elif device not in ["cpu", "cuda"]:
raise ValueError(f"不支持的设备: {device},请选择 'cpu', 'cuda''auto'")
return device
def _load_model(self):
"""懒加载模型"""
if self._model is not None:
return
try:
from funasr import AutoModel
except ImportError:
raise ImportError("请安装 FunASR: pip install funasr")
print(f"正在加载模型: {self.model_name}")
print(f"设备: {self.device}")
print(f"模型缓存目录: {self.cache_dir}")
if self.model_name == "paraformer-zh":
self._model = AutoModel(
model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
punc_model="iic/punc_ct-transformer_cn-en-common-vocab471067-large",
spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
device=self.device,
ncpu=4,
disable_pbar=True,
disable_log=True,
)
elif self.model_name == "SenseVoice":
self._model = AutoModel(
model="iic/SenseVoiceSmall",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_kwargs={"max_single_segment_time": 30000},
device=self.device,
disable_pbar=True,
disable_log=True,
)
else:
raise ValueError(f"不支持的模型: {self.model_name}")
print(f"模型加载完成!")
def recognize(
self,
audio_path: Union[str, Path],
batch_size_s: int = 300,
return_raw: bool = False,
) -> Union[List[Sentence], Dict]:
"""
识别音频文件
Args:
audio_path: 音频文件路径
batch_size_s: 批处理时长(秒)
return_raw: 是否返回原始结果
Returns:
List[Sentence]: 识别结果列表(默认)
Dict: 原始结果(如果 return_raw=True
"""
self._load_model()
audio_path = Path(audio_path)
if not audio_path.exists():
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
print(f"正在识别: {audio_path}")
if self._model is None:
raise RuntimeError("模型加载失败,无法执行识别")
result = self._model.generate(
input=str(audio_path),
batch_size_s=batch_size_s,
return_raw_text=True,
return_spk_res=True,
)
if return_raw:
return result
sentences = self._parse_result(result)
return sentences
def _parse_result(self, result: List[Dict]) -> List[Sentence]:
"""解析识别结果为 Sentence 列表"""
sentences = []
if not result:
return sentences
res = result[0] if isinstance(result, list) else result
if "sentence_info" in res:
for sent_info in res["sentence_info"]:
sentence = Sentence(
speaker="speaker_0", # 统一使用 speaker_0
text=sent_info.get("text", "").strip(),
begin_time=sent_info.get("start", 0) / 1000.0,
end_time=sent_info.get("end", 0) / 1000.0
)
if sentence.text:
sentences.append(sentence)
elif "text" in res:
sentences.append(Sentence(
speaker="speaker_0", # 统一使用 speaker_0
text=res["text"].strip(),
begin_time=0.0,
end_time=0.0
))
return sentences
def recognize_batch(
self,
audio_paths: List[Union[str, Path]],
batch_size_s: int = 300,
use_3d_speaker: bool = False
) -> List[List[Sentence]]:
"""批量识别多个音频文件"""
results = []
for audio_path in audio_paths:
try:
result = self.recognize(audio_path, batch_size_s)
results.append(result)
except Exception as e:
print(f"识别失败 [{audio_path}]: {e}")
results.append([])
return results
def export_to_json(
self,
sentences: List[Sentence],
output_path: Union[str, Path]
):
"""导出识别结果为 JSON 文件"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
data = {
"total_sentences": len(sentences),
"sentences": [s.to_dict() for s in sentences]
}
with open(output_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
print(f"结果已保存: {output_path}")
def export_to_srt(
self,
sentences: List[Sentence],
output_path: Union[str, Path]
):
"""导出识别结果为 SRT 字幕文件"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
def format_time(seconds: float) -> str:
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
millis = int((seconds % 1) * 1000)
return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}"
with open(output_path, "w", encoding="utf-8") as f:
for i, sentence in enumerate(sentences, 1):
f.write(f"{i}\n")
f.write(f"{format_time(sentence.begin_time)} --> {format_time(sentence.end_time)}\n")
f.write(f"[{sentence.speaker}] {sentence.text}\n\n")
print(f"字幕已保存: {output_path}")
def recognize_audio(
audio_path: Union[str, Path],
model_name: str = "paraformer-zh",
device: str = "auto",
use_3d_speaker: bool = False
) -> List[Sentence]:
"""快速识别音频文件"""
service = ASRService(model_name=model_name, device=device)
result = service.recognize(audio_path, return_raw=False)
assert isinstance(result, list)
return result
if __name__ == "__main__":
print("=" * 60)
print("FunASR 语音识别服务")
print("=" * 60)
print("\n使用方法:")
print(' from asr_service import ASRService')
print(' service = ASRService()')
print(' results = service.recognize("your_audio.wav")')
print(' for sent in results:')
print(' print(sent)')
print("\n使用 3D-Speaker 替换说话人:")
print(' results = service.recognize("your_audio.wav", use_3d_speaker=True)')