SpeechRecognition/asr_service.py

356 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
FunASR 语音识别服务
支持:句级时间戳、说话人分离、抗噪
"""
import os
import sys
# 解决 Windows 路径长度限制问题
# 设置模型缓存目录为短路径
MODEL_CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
os.environ["MODELSCOPE_CACHE"] = MODEL_CACHE_DIR
os.environ["FUNASR_MODELS_DIR"] = MODEL_CACHE_DIR
# Windows 长路径支持Windows 10 1607+
if sys.platform == "win32":
os.environ["PYTHONLEGACYWINDOWSFSENCODING"] = "1"
import json
from pathlib import Path
from typing import List, Dict, Union, Optional
from dataclasses import dataclass
import warnings
warnings.filterwarnings('ignore')
@dataclass
class Sentence:
"""识别结果句子"""
speaker: str
text: str
begin_time: float
end_time: float
def to_dict(self) -> Dict:
return {
"speaker": self.speaker,
"text": self.text,
"begin_time": round(self.begin_time, 2),
"end_time": round(self.end_time, 2),
"duration": round(self.end_time - self.begin_time, 2)
}
def __str__(self) -> str:
return f"[{self.speaker}] {self.text} ({self.begin_time:.2f}s - {self.end_time:.2f}s)"
class ASRService:
"""
语音识别服务
功能:
1. 语音识别ASR
2. 句级时间戳
3. 说话人分离Speaker Diarization
4. 语音活动检测VAD- 抗噪
"""
def __init__(
self,
model_name: str = "paraformer-zh", # paraformer-zh 或 SenseVoice
device: str = "auto",
cache_dir: Optional[str] = None
):
"""
初始化 ASR 服务
Args:
model_name: 模型名称
- "paraformer-zh": 达摩院 Paraformer 模型(推荐中文)
- "SenseVoice": SenseVoice 多语言模型
device: 运行设备 ("cpu", "cuda", "auto")
cache_dir: 模型缓存目录
"""
self.model_name = model_name
self.device = device
self.cache_dir = cache_dir or MODEL_CACHE_DIR
# 确保缓存目录存在
os.makedirs(self.cache_dir, exist_ok=True)
# 处理设备参数
self.device = self._get_device(device)
# 延迟加载模型
self._model = None
def _get_device(self, device: str) -> str:
"""
处理设备参数
Args:
device: 用户指定的设备 ("cpu", "cuda", "auto")
Returns:
str: 实际的设备 ("cpu""cuda")
"""
import torch
if device == "auto":
# 自动检测 CUDA 是否可用
if torch.cuda.is_available():
device = "cuda"
print(f"检测到 GPU: {torch.cuda.get_device_name(0)}")
else:
device = "cpu"
print("未检测到 GPU使用 CPU 运行")
elif device not in ["cpu", "cuda"]:
raise ValueError(f"不支持的设备: {device},请选择 'cpu', 'cuda''auto'")
return device
def _load_model(self):
"""懒加载模型"""
if self._model is not None:
return
try:
from funasr import AutoModel
except ImportError:
raise ImportError("请安装 FunASR: pip install funasr")
print(f"正在加载模型: {self.model_name}")
print(f"设备: {self.device}")
print(f"模型缓存目录: {self.cache_dir}")
# 模型配置
if self.model_name == "paraformer-zh":
# Paraformer 中文模型配置(支持时间戳和说话人分离)
# 注意:只有以下模型支持时间戳:
# - speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch
# - speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
self._model = AutoModel(
model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
punc_model="iic/punc_ct-transformer_cn-en-common-vocab471067-large",
spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
device=self.device,
ncpu=4,
disable_pbar=True,
disable_log=True,
)
elif self.model_name == "SenseVoice":
# SenseVoice 多语言模型配置
self._model = AutoModel(
model="iic/SenseVoiceSmall",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_kwargs={"max_single_segment_time": 30000},
device=self.device,
disable_pbar=True,
disable_log=True,
)
else:
raise ValueError(f"不支持的模型: {self.model_name}")
print(f"模型加载完成!")
def recognize(
self,
audio_path: Union[str, Path],
batch_size_s: int = 300,
return_raw: bool = False
) -> Union[List[Sentence], Dict]:
"""
识别音频文件
Args:
audio_path: 音频文件路径
batch_size_s: 批处理时长(秒)
return_raw: 是否返回原始结果
Returns:
List[Sentence]: 识别结果列表(默认)
Dict: 原始结果(如果 return_raw=True
"""
self._load_model()
audio_path = Path(audio_path)
if not audio_path.exists():
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
print(f"正在识别: {audio_path}")
# 执行识别
# 确保模型已正确加载
if self._model is None:
raise RuntimeError("模型加载失败,无法执行识别")
result = self._model.generate(
input=str(audio_path),
batch_size_s=batch_size_s,
return_raw_text=True,
return_spk_res=True,
)
if return_raw:
return result
# 解析结果
return self._parse_result(result)
def _parse_result(self, result: List[Dict]) -> List[Sentence]:
"""解析识别结果为 Sentence 列表"""
sentences = []
if not result:
return sentences
# FunASR 返回的是列表,取第一个元素
res = result[0] if isinstance(result, list) else result
# 提取句子列表
if "sentence_info" in res:
# 有说话人分离的情况
for sent_info in res["sentence_info"]:
sentence = Sentence(
speaker=sent_info.get("speaker", "SPEAKER_00"),
text=sent_info.get("text", "").strip(),
begin_time=sent_info.get("start", 0) / 1000.0, # ms -> s
end_time=sent_info.get("end", 0) / 1000.0
)
if sentence.text:
sentences.append(sentence)
elif "text" in res:
# 纯文本结果(没有时间戳和说话人)
sentences.append(Sentence(
speaker="SPEAKER_00",
text=res["text"].strip(),
begin_time=0.0,
end_time=0.0
))
return sentences
def recognize_batch(
self,
audio_paths: List[Union[str, Path]],
batch_size_s: int = 300
) -> List[List[Sentence]]:
"""
批量识别多个音频文件
Args:
audio_paths: 音频文件路径列表
batch_size_s: 批处理时长(秒)
Returns:
List[List[Sentence]]: 每个音频的识别结果
"""
results = []
for audio_path in audio_paths:
try:
result = self.recognize(audio_path, batch_size_s)
results.append(result)
except Exception as e:
print(f"识别失败 [{audio_path}]: {e}")
results.append([])
return results
def export_to_json(
self,
sentences: List[Sentence],
output_path: Union[str, Path]
):
"""
导出识别结果为 JSON 文件
Args:
sentences: 识别结果列表
output_path: 输出文件路径
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
data = {
"total_sentences": len(sentences),
"sentences": [s.to_dict() for s in sentences]
}
with open(output_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
print(f"结果已保存: {output_path}")
def export_to_srt(
self,
sentences: List[Sentence],
output_path: Union[str, Path]
):
"""
导出识别结果为 SRT 字幕文件
Args:
sentences: 识别结果列表
output_path: 输出文件路径
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
def format_time(seconds: float) -> str:
"""格式化为 SRT 时间格式"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
millis = int((seconds % 1) * 1000)
return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}"
with open(output_path, "w", encoding="utf-8") as f:
for i, sentence in enumerate(sentences, 1):
f.write(f"{i}\n")
f.write(f"{format_time(sentence.begin_time)} --> {format_time(sentence.end_time)}\n")
f.write(f"[{sentence.speaker}] {sentence.text}\n\n")
print(f"字幕已保存: {output_path}")
# 便捷函数
def recognize_audio(
audio_path: Union[str, Path],
model_name: str = "paraformer-zh",
device: str = "auto"
) -> List[Sentence]:
"""
快速识别音频文件
Args:
audio_path: 音频文件路径
model_name: 模型名称
device: 运行设备
Returns:
List[Sentence]: 识别结果
"""
service = ASRService(model_name=model_name, device=device)
result = service.recognize(audio_path)
# 如果返回的是字典return_raw=True的情况则解析为Sentence列表
if isinstance(result, dict):
return service._parse_result([result])
return result
if __name__ == "__main__":
# 示例用法
print("=" * 60)
print("FunASR 语音识别服务")
print("=" * 60)
print("\n支持的音频格式: wav, mp3, m4a, flac 等")
print("\n使用方法:")
print(' from asr_service import ASRService')
print(' service = ASRService()')
print(' results = service.recognize("your_audio.wav")')
print(' for sent in results:')
print(' print(sent)')