303 lines
9.5 KiB
Python
303 lines
9.5 KiB
Python
"""
|
||
FunASR 语音识别服务
|
||
支持:句级时间戳、说话人分离(FunASR CAM++)、抗噪
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
|
||
MODEL_CACHE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
|
||
os.environ["MODELSCOPE_CACHE"] = MODEL_CACHE_DIR
|
||
os.environ["FUNASR_MODELS_DIR"] = MODEL_CACHE_DIR
|
||
|
||
if sys.platform == "win32":
|
||
os.environ["PYTHONLEGACYWINDOWSFSENCODING"] = "1"
|
||
|
||
import json
|
||
from pathlib import Path
|
||
from typing import List, Dict, Union, Optional
|
||
from dataclasses import dataclass
|
||
import warnings
|
||
|
||
warnings.filterwarnings('ignore')
|
||
|
||
|
||
@dataclass
|
||
class Sentence:
|
||
"""识别结果句子"""
|
||
speaker: str
|
||
text: str
|
||
begin_time: float
|
||
end_time: float
|
||
|
||
def to_dict(self) -> Dict:
|
||
return {
|
||
"speaker": self.speaker,
|
||
"text": self.text,
|
||
"begin_time": round(self.begin_time, 2),
|
||
"end_time": round(self.end_time, 2),
|
||
"duration": round(self.end_time - self.begin_time, 2)
|
||
}
|
||
|
||
def __str__(self) -> str:
|
||
return f"[{self.speaker}] {self.text} ({self.begin_time:.2f}s - {self.end_time:.2f}s)"
|
||
|
||
|
||
class ASRService:
|
||
"""
|
||
语音识别服务
|
||
|
||
功能:
|
||
1. 语音识别(ASR)
|
||
2. 句级时间戳
|
||
3. 说话人分离(FunASR 内置 CAM++)
|
||
4. 语音活动检测(VAD)- 抗噪
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
model_name: str = "paraformer-zh",
|
||
device: str = "auto",
|
||
cache_dir: Optional[str] = None,
|
||
merge_segments: bool = True,
|
||
min_segment_duration: float = 0.3,
|
||
merge_gap: float = 0.5
|
||
):
|
||
"""
|
||
初始化 ASR 服务
|
||
|
||
Args:
|
||
model_name: 模型名称
|
||
- "paraformer-zh": 达摩院 Paraformer 模型(推荐中文)
|
||
- "SenseVoice": SenseVoice 多语言模型
|
||
device: 运行设备 ("cpu", "cuda", "auto")
|
||
cache_dir: 模型缓存目录
|
||
merge_segments: 是否合并相邻的同一说话人片段
|
||
min_segment_duration: 最小片段时长阈值(过滤噪音)
|
||
merge_gap: 合并片段的时间间隔阈值
|
||
"""
|
||
self.model_name = model_name
|
||
self.device = device
|
||
self.cache_dir = cache_dir or MODEL_CACHE_DIR
|
||
self.merge_segments = merge_segments
|
||
self.min_segment_duration = min_segment_duration
|
||
self.merge_gap = merge_gap
|
||
|
||
os.makedirs(self.cache_dir, exist_ok=True)
|
||
self.device = self._get_device(device)
|
||
|
||
self._model = None
|
||
|
||
def _get_device(self, device: str) -> str:
|
||
import torch
|
||
|
||
if device == "auto":
|
||
if torch.cuda.is_available():
|
||
device = "cuda"
|
||
print(f"检测到 GPU: {torch.cuda.get_device_name(0)}")
|
||
else:
|
||
device = "cpu"
|
||
print("未检测到 GPU,使用 CPU 运行")
|
||
elif device not in ["cpu", "cuda"]:
|
||
raise ValueError(f"不支持的设备: {device},请选择 'cpu', 'cuda' 或 'auto'")
|
||
|
||
return device
|
||
|
||
def _load_model(self):
|
||
"""懒加载模型"""
|
||
if self._model is not None:
|
||
return
|
||
|
||
try:
|
||
from funasr import AutoModel
|
||
except ImportError:
|
||
raise ImportError("请安装 FunASR: pip install funasr")
|
||
|
||
print(f"正在加载模型: {self.model_name}")
|
||
print(f"设备: {self.device}")
|
||
print(f"模型缓存目录: {self.cache_dir}")
|
||
|
||
if self.model_name == "paraformer-zh":
|
||
self._model = AutoModel(
|
||
model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||
punc_model="iic/punc_ct-transformer_cn-en-common-vocab471067-large",
|
||
spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
|
||
device=self.device,
|
||
ncpu=4,
|
||
disable_pbar=True,
|
||
disable_log=True,
|
||
)
|
||
elif self.model_name == "SenseVoice":
|
||
self._model = AutoModel(
|
||
model="iic/SenseVoiceSmall",
|
||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||
vad_kwargs={"max_single_segment_time": 30000},
|
||
device=self.device,
|
||
disable_pbar=True,
|
||
disable_log=True,
|
||
)
|
||
else:
|
||
raise ValueError(f"不支持的模型: {self.model_name}")
|
||
|
||
print(f"模型加载完成!")
|
||
|
||
def recognize(
|
||
self,
|
||
audio_path: Union[str, Path],
|
||
batch_size_s: int = 300,
|
||
return_raw: bool = False,
|
||
) -> Union[List[Sentence], Dict]:
|
||
"""
|
||
识别音频文件
|
||
|
||
Args:
|
||
audio_path: 音频文件路径
|
||
batch_size_s: 批处理时长(秒)
|
||
return_raw: 是否返回原始结果
|
||
|
||
Returns:
|
||
List[Sentence]: 识别结果列表(默认)
|
||
Dict: 原始结果(如果 return_raw=True)
|
||
"""
|
||
self._load_model()
|
||
|
||
audio_path = Path(audio_path)
|
||
if not audio_path.exists():
|
||
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
|
||
|
||
print(f"正在识别: {audio_path}")
|
||
|
||
if self._model is None:
|
||
raise RuntimeError("模型加载失败,无法执行识别")
|
||
|
||
result = self._model.generate(
|
||
input=str(audio_path),
|
||
batch_size_s=batch_size_s,
|
||
return_raw_text=True,
|
||
return_spk_res=True,
|
||
)
|
||
|
||
if return_raw:
|
||
return result
|
||
|
||
sentences = self._parse_result(result)
|
||
return sentences
|
||
|
||
def _parse_result(self, result: List[Dict]) -> List[Sentence]:
|
||
"""解析识别结果为 Sentence 列表"""
|
||
sentences = []
|
||
|
||
if not result:
|
||
return sentences
|
||
|
||
res = result[0] if isinstance(result, list) else result
|
||
|
||
if "sentence_info" in res:
|
||
for sent_info in res["sentence_info"]:
|
||
sentence = Sentence(
|
||
speaker="speaker_0", # 统一使用 speaker_0
|
||
text=sent_info.get("text", "").strip(),
|
||
begin_time=sent_info.get("start", 0) / 1000.0,
|
||
end_time=sent_info.get("end", 0) / 1000.0
|
||
)
|
||
if sentence.text:
|
||
sentences.append(sentence)
|
||
elif "text" in res:
|
||
sentences.append(Sentence(
|
||
speaker="speaker_0", # 统一使用 speaker_0
|
||
text=res["text"].strip(),
|
||
begin_time=0.0,
|
||
end_time=0.0
|
||
))
|
||
|
||
return sentences
|
||
|
||
def recognize_batch(
|
||
self,
|
||
audio_paths: List[Union[str, Path]],
|
||
batch_size_s: int = 300,
|
||
use_3d_speaker: bool = False
|
||
) -> List[List[Sentence]]:
|
||
"""批量识别多个音频文件"""
|
||
results = []
|
||
for audio_path in audio_paths:
|
||
try:
|
||
result = self.recognize(audio_path, batch_size_s)
|
||
results.append(result)
|
||
except Exception as e:
|
||
print(f"识别失败 [{audio_path}]: {e}")
|
||
results.append([])
|
||
return results
|
||
|
||
def export_to_json(
|
||
self,
|
||
sentences: List[Sentence],
|
||
output_path: Union[str, Path]
|
||
):
|
||
"""导出识别结果为 JSON 文件"""
|
||
output_path = Path(output_path)
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
data = {
|
||
"total_sentences": len(sentences),
|
||
"sentences": [s.to_dict() for s in sentences]
|
||
}
|
||
|
||
with open(output_path, "w", encoding="utf-8") as f:
|
||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||
|
||
print(f"结果已保存: {output_path}")
|
||
|
||
def export_to_srt(
|
||
self,
|
||
sentences: List[Sentence],
|
||
output_path: Union[str, Path]
|
||
):
|
||
"""导出识别结果为 SRT 字幕文件"""
|
||
output_path = Path(output_path)
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
def format_time(seconds: float) -> str:
|
||
hours = int(seconds // 3600)
|
||
minutes = int((seconds % 3600) // 60)
|
||
secs = int(seconds % 60)
|
||
millis = int((seconds % 1) * 1000)
|
||
return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}"
|
||
|
||
with open(output_path, "w", encoding="utf-8") as f:
|
||
for i, sentence in enumerate(sentences, 1):
|
||
f.write(f"{i}\n")
|
||
f.write(f"{format_time(sentence.begin_time)} --> {format_time(sentence.end_time)}\n")
|
||
f.write(f"[{sentence.speaker}] {sentence.text}\n\n")
|
||
|
||
print(f"字幕已保存: {output_path}")
|
||
|
||
|
||
def recognize_audio(
|
||
audio_path: Union[str, Path],
|
||
model_name: str = "paraformer-zh",
|
||
device: str = "auto",
|
||
use_3d_speaker: bool = False
|
||
) -> List[Sentence]:
|
||
"""快速识别音频文件"""
|
||
service = ASRService(model_name=model_name, device=device)
|
||
result = service.recognize(audio_path, return_raw=False)
|
||
assert isinstance(result, list)
|
||
return result
|
||
|
||
|
||
if __name__ == "__main__":
|
||
print("=" * 60)
|
||
print("FunASR 语音识别服务")
|
||
print("=" * 60)
|
||
print("\n使用方法:")
|
||
print(' from asr_service import ASRService')
|
||
print(' service = ASRService()')
|
||
print(' results = service.recognize("your_audio.wav")')
|
||
print(' for sent in results:')
|
||
print(' print(sent)')
|
||
print("\n使用 3D-Speaker 替换说话人:")
|
||
print(' results = service.recognize("your_audio.wav", use_3d_speaker=True)')
|