348 lines
10 KiB
Python
348 lines
10 KiB
Python
"""
|
||
FunASR 语音识别服务
|
||
支持:句级时间戳、说话人分离、抗噪
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
|
||
# 解决 Windows 路径长度限制问题
|
||
# 设置模型缓存目录为短路径
|
||
MODEL_CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
|
||
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
|
||
os.environ["MODELSCOPE_CACHE"] = MODEL_CACHE_DIR
|
||
os.environ["FUNASR_MODELS_DIR"] = MODEL_CACHE_DIR
|
||
|
||
# Windows 长路径支持(Windows 10 1607+)
|
||
if sys.platform == "win32":
|
||
os.environ["PYTHONLEGACYWINDOWSFSENCODING"] = "1"
|
||
|
||
import json
|
||
from pathlib import Path
|
||
from typing import List, Dict, Union, Optional
|
||
from dataclasses import dataclass
|
||
import warnings
|
||
|
||
warnings.filterwarnings('ignore')
|
||
|
||
|
||
@dataclass
|
||
class Sentence:
|
||
"""识别结果句子"""
|
||
speaker: str
|
||
text: str
|
||
begin_time: float
|
||
end_time: float
|
||
|
||
def to_dict(self) -> Dict:
|
||
return {
|
||
"speaker": self.speaker,
|
||
"text": self.text,
|
||
"begin_time": round(self.begin_time, 2),
|
||
"end_time": round(self.end_time, 2),
|
||
"duration": round(self.end_time - self.begin_time, 2)
|
||
}
|
||
|
||
def __str__(self) -> str:
|
||
return f"[{self.speaker}] {self.text} ({self.begin_time:.2f}s - {self.end_time:.2f}s)"
|
||
|
||
|
||
class ASRService:
|
||
"""
|
||
语音识别服务
|
||
|
||
功能:
|
||
1. 语音识别(ASR)
|
||
2. 句级时间戳
|
||
3. 说话人分离(Speaker Diarization)
|
||
4. 语音活动检测(VAD)- 抗噪
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
model_name: str = "paraformer-zh", # paraformer-zh 或 SenseVoice
|
||
device: str = "auto",
|
||
cache_dir: Optional[str] = None
|
||
):
|
||
"""
|
||
初始化 ASR 服务
|
||
|
||
Args:
|
||
model_name: 模型名称
|
||
- "paraformer-zh": 达摩院 Paraformer 模型(推荐中文)
|
||
- "SenseVoice": SenseVoice 多语言模型
|
||
device: 运行设备 ("cpu", "cuda", "auto")
|
||
cache_dir: 模型缓存目录
|
||
"""
|
||
self.model_name = model_name
|
||
self.device = device
|
||
self.cache_dir = cache_dir or MODEL_CACHE_DIR
|
||
|
||
# 确保缓存目录存在
|
||
os.makedirs(self.cache_dir, exist_ok=True)
|
||
|
||
# 处理设备参数
|
||
self.device = self._get_device(device)
|
||
|
||
# 延迟加载模型
|
||
self._model = None
|
||
|
||
def _get_device(self, device: str) -> str:
|
||
"""
|
||
处理设备参数
|
||
|
||
Args:
|
||
device: 用户指定的设备 ("cpu", "cuda", "auto")
|
||
|
||
Returns:
|
||
str: 实际的设备 ("cpu" 或 "cuda")
|
||
"""
|
||
import torch
|
||
|
||
if device == "auto":
|
||
# 自动检测 CUDA 是否可用
|
||
if torch.cuda.is_available():
|
||
device = "cuda"
|
||
print(f"检测到 GPU: {torch.cuda.get_device_name(0)}")
|
||
else:
|
||
device = "cpu"
|
||
print("未检测到 GPU,使用 CPU 运行")
|
||
elif device not in ["cpu", "cuda"]:
|
||
raise ValueError(f"不支持的设备: {device},请选择 'cpu', 'cuda' 或 'auto'")
|
||
|
||
return device
|
||
|
||
def _load_model(self):
|
||
"""懒加载模型"""
|
||
if self._model is not None:
|
||
return
|
||
|
||
try:
|
||
from funasr import AutoModel
|
||
except ImportError:
|
||
raise ImportError("请安装 FunASR: pip install funasr")
|
||
|
||
print(f"正在加载模型: {self.model_name}")
|
||
print(f"设备: {self.device}")
|
||
print(f"模型缓存目录: {self.cache_dir}")
|
||
|
||
# 模型配置
|
||
if self.model_name == "paraformer-zh":
|
||
# Paraformer 中文模型配置(支持时间戳和说话人分离)
|
||
# 注意:只有以下模型支持时间戳:
|
||
# - speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch
|
||
# - speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
|
||
self._model = AutoModel(
|
||
model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||
punc_model="iic/punc_ct-transformer_cn-en-common-vocab471067-large",
|
||
spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
|
||
device=self.device,
|
||
ncpu=4,
|
||
disable_pbar=True,
|
||
disable_log=True,
|
||
)
|
||
elif self.model_name == "SenseVoice":
|
||
# SenseVoice 多语言模型配置
|
||
self._model = AutoModel(
|
||
model="iic/SenseVoiceSmall",
|
||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||
vad_kwargs={"max_single_segment_time": 30000},
|
||
device=self.device,
|
||
disable_pbar=True,
|
||
disable_log=True,
|
||
)
|
||
else:
|
||
raise ValueError(f"不支持的模型: {self.model_name}")
|
||
|
||
print(f"模型加载完成!")
|
||
|
||
def recognize(
|
||
self,
|
||
audio_path: Union[str, Path],
|
||
batch_size_s: int = 300,
|
||
return_raw: bool = False
|
||
) -> Union[List[Sentence], Dict]:
|
||
"""
|
||
识别音频文件
|
||
|
||
Args:
|
||
audio_path: 音频文件路径
|
||
batch_size_s: 批处理时长(秒)
|
||
return_raw: 是否返回原始结果
|
||
|
||
Returns:
|
||
List[Sentence]: 识别结果列表(默认)
|
||
Dict: 原始结果(如果 return_raw=True)
|
||
"""
|
||
self._load_model()
|
||
|
||
audio_path = Path(audio_path)
|
||
if not audio_path.exists():
|
||
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
|
||
|
||
print(f"正在识别: {audio_path}")
|
||
|
||
# 执行识别
|
||
result = self._model.generate(
|
||
input=str(audio_path),
|
||
batch_size_s=batch_size_s,
|
||
return_raw_text=True,
|
||
return_spk_res=True,
|
||
)
|
||
|
||
if return_raw:
|
||
return result
|
||
|
||
# 解析结果
|
||
return self._parse_result(result)
|
||
|
||
def _parse_result(self, result: List[Dict]) -> List[Sentence]:
|
||
"""解析识别结果为 Sentence 列表"""
|
||
sentences = []
|
||
|
||
if not result:
|
||
return sentences
|
||
|
||
# FunASR 返回的是列表,取第一个元素
|
||
res = result[0] if isinstance(result, list) else result
|
||
|
||
# 提取句子列表
|
||
if "sentence_info" in res:
|
||
# 有说话人分离的情况
|
||
for sent_info in res["sentence_info"]:
|
||
sentence = Sentence(
|
||
speaker=sent_info.get("speaker", "SPEAKER_00"),
|
||
text=sent_info.get("text", "").strip(),
|
||
begin_time=sent_info.get("start", 0) / 1000.0, # ms -> s
|
||
end_time=sent_info.get("end", 0) / 1000.0
|
||
)
|
||
if sentence.text:
|
||
sentences.append(sentence)
|
||
elif "text" in res:
|
||
# 纯文本结果(没有时间戳和说话人)
|
||
sentences.append(Sentence(
|
||
speaker="SPEAKER_00",
|
||
text=res["text"].strip(),
|
||
begin_time=0.0,
|
||
end_time=0.0
|
||
))
|
||
|
||
return sentences
|
||
|
||
def recognize_batch(
|
||
self,
|
||
audio_paths: List[Union[str, Path]],
|
||
batch_size_s: int = 300
|
||
) -> List[List[Sentence]]:
|
||
"""
|
||
批量识别多个音频文件
|
||
|
||
Args:
|
||
audio_paths: 音频文件路径列表
|
||
batch_size_s: 批处理时长(秒)
|
||
|
||
Returns:
|
||
List[List[Sentence]]: 每个音频的识别结果
|
||
"""
|
||
results = []
|
||
for audio_path in audio_paths:
|
||
try:
|
||
result = self.recognize(audio_path, batch_size_s)
|
||
results.append(result)
|
||
except Exception as e:
|
||
print(f"识别失败 [{audio_path}]: {e}")
|
||
results.append([])
|
||
return results
|
||
|
||
def export_to_json(
|
||
self,
|
||
sentences: List[Sentence],
|
||
output_path: Union[str, Path]
|
||
):
|
||
"""
|
||
导出识别结果为 JSON 文件
|
||
|
||
Args:
|
||
sentences: 识别结果列表
|
||
output_path: 输出文件路径
|
||
"""
|
||
output_path = Path(output_path)
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
data = {
|
||
"total_sentences": len(sentences),
|
||
"sentences": [s.to_dict() for s in sentences]
|
||
}
|
||
|
||
with open(output_path, "w", encoding="utf-8") as f:
|
||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||
|
||
print(f"结果已保存: {output_path}")
|
||
|
||
def export_to_srt(
|
||
self,
|
||
sentences: List[Sentence],
|
||
output_path: Union[str, Path]
|
||
):
|
||
"""
|
||
导出识别结果为 SRT 字幕文件
|
||
|
||
Args:
|
||
sentences: 识别结果列表
|
||
output_path: 输出文件路径
|
||
"""
|
||
output_path = Path(output_path)
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
def format_time(seconds: float) -> str:
|
||
"""格式化为 SRT 时间格式"""
|
||
hours = int(seconds // 3600)
|
||
minutes = int((seconds % 3600) // 60)
|
||
secs = int(seconds % 60)
|
||
millis = int((seconds % 1) * 1000)
|
||
return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}"
|
||
|
||
with open(output_path, "w", encoding="utf-8") as f:
|
||
for i, sentence in enumerate(sentences, 1):
|
||
f.write(f"{i}\n")
|
||
f.write(f"{format_time(sentence.begin_time)} --> {format_time(sentence.end_time)}\n")
|
||
f.write(f"[{sentence.speaker}] {sentence.text}\n\n")
|
||
|
||
print(f"字幕已保存: {output_path}")
|
||
|
||
|
||
# 便捷函数
|
||
def recognize_audio(
|
||
audio_path: Union[str, Path],
|
||
model_name: str = "paraformer-zh",
|
||
device: str = "auto"
|
||
) -> List[Sentence]:
|
||
"""
|
||
快速识别音频文件
|
||
|
||
Args:
|
||
audio_path: 音频文件路径
|
||
model_name: 模型名称
|
||
device: 运行设备
|
||
|
||
Returns:
|
||
List[Sentence]: 识别结果
|
||
"""
|
||
service = ASRService(model_name=model_name, device=device)
|
||
return service.recognize(audio_path)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 示例用法
|
||
print("=" * 60)
|
||
print("FunASR 语音识别服务")
|
||
print("=" * 60)
|
||
print("\n支持的音频格式: wav, mp3, m4a, flac 等")
|
||
print("\n使用方法:")
|
||
print(' from asr_service import ASRService')
|
||
print(' service = ASRService()')
|
||
print(' results = service.recognize("your_audio.wav")')
|
||
print(' for sent in results:')
|
||
print(' print(sent)')
|