ai/SpeechRecognition/diarization_service.py

294 lines
9.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
3D-Speaker 说话人分离服务
支持:说话人分离、可调聚类参数、自动人数检测
"""
import os
import sys
import json
from pathlib import Path
from typing import List, Dict, Union, Optional
from dataclasses import dataclass
diarization_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "3D-Speaker")
if os.path.exists(diarization_path):
sys.path.insert(0, diarization_path)
MODEL_CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
os.environ["MODELSCOPE_CACHE"] = MODEL_CACHE_DIR
import warnings
warnings.filterwarnings('ignore')
@dataclass
class DiarizationSegment:
"""说话人分离结果片段"""
speaker: str
begin_time: float
end_time: float
def to_dict(self) -> Dict:
return {
"speaker": self.speaker,
"begin_time": round(self.begin_time, 2),
"end_time": round(self.end_time, 2),
"duration": round(self.end_time - self.begin_time, 2)
}
class DiarizationService:
"""
3D-Speaker 说话人分离服务
功能:
1. 说话人分离Speaker Diarization
2. 可调节聚类参数
3. 支持多人对话
4. 自动说话人人数检测
支持的说话人嵌入模型:
- campplus: CAM++ (默认,快速)
- eres2net: ERes2Net (更准确)
- eres2netv2: ERes2NetV2 (最新,效果最好)
"""
def __init__(
self,
embedding_model: str = "eres2net",
device: str = "auto",
include_overlap: bool = False,
hf_access_token: Optional[str] = None,
cache_dir: Optional[str] = None,
min_speakers: int = 1,
max_speakers: int = 10,
cluster_threshold: float = 0.8,
min_cluster_size: int = 4
):
"""
初始化说话人分离服务
Args:
embedding_model: 说话人嵌入模型
- "campplus": CAM++ 模型
- "eres2net": ERes2Net 模型
- "eres2netv2": ERes2NetV2 模型
device: 运行设备 ("cpu", "cuda", "auto")
include_overlap: 是否包含重叠语音检测(需要 hf_access_token
hf_access_token: HuggingFace 访问令牌(用于重叠语音检测)
cache_dir: 模型缓存目录
min_speakers: 最少说话人数量
max_speakers: 最多说话人数量
cluster_threshold: 聚类相似度阈值 (0.0-1.0)
- 值越高:越严格,可能分成更多说话人
- 值越低:越宽松,会合并更多说话人
min_cluster_size: 每个说话人最少片段数
"""
self.embedding_model = embedding_model
self.device = self._get_device(device)
self.include_overlap = include_overlap
self.hf_access_token = hf_access_token
self.cache_dir = cache_dir or MODEL_CACHE_DIR
self.min_speakers = min_speakers
self.max_speakers = max_speakers
self.cluster_threshold = cluster_threshold
self.min_cluster_size = min_cluster_size
self.model = None
def _get_device(self, device: str) -> str:
if device == "auto":
try:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
except ImportError:
device = "cpu"
return device
def _load_model(self):
"""加载 3D-Speaker 说话人分离模型"""
if self.model is not None:
return
print(f"正在加载 3D-Speaker 说话人分离模型...")
print(f"设备:{self.device}")
print(f"说话人嵌入模型:{self.embedding_model}")
print(f"聚类参数threshold={self.cluster_threshold}, min_cluster_size={self.min_cluster_size}")
sys.stdout.flush() # 确保输出立即显示
embedding_models = {
"campplus": "iic/speech_campplus_sv_zh_en_16k-common_advanced",
"eres2net": "iic/speech_eres2net_sv_zh-cn_16k-common",
"eres2netv2": "iic/speech_eres2netv2_sv_zh-cn_16k-common",
}
try:
from speakerlab.bin.infer_diarization import Diarization3Dspeaker # type: ignore
print(f" - 导入 Diarization3Dspeaker 完成")
sys.stdout.flush()
self.model = Diarization3Dspeaker(
device=self.device,
include_overlap=self.include_overlap,
hf_access_token=self.hf_access_token,
model_cache_dir=self.cache_dir
)
print(f" - 模型实例化完成")
sys.stdout.flush()
print(f"模型加载完成!")
sys.stdout.flush()
except Exception as e:
print(f"\n✗ 模型加载失败:{e}")
sys.stdout.flush()
import traceback
traceback.print_exc()
sys.stdout.flush()
raise
def diarize(
self,
audio_path: Union[str, Path],
speaker_num: Optional[int] = None,
) -> List[DiarizationSegment]:
"""
执行说话人分离
Args:
audio_path: 音频文件路径
speaker_num: 预设说话人数量(可选)
- 如果不指定,会自动检测
Returns:
List[DiarizationSegment]: 说话人分离结果
"""
self._load_model()
audio_path = Path(audio_path)
if not audio_path.exists():
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
print(f"正在执行说话人分离: {audio_path}")
if self.model is None:
raise RuntimeError("模型未正确加载,无法执行说话人分离")
result = self.model(
wav=str(audio_path),
speaker_num=speaker_num
)
segments = []
for seg in result:
begin_time, end_time, speaker_id = seg
segments.append(DiarizationSegment(
speaker=f"speaker_{speaker_id}",
begin_time=begin_time,
end_time=end_time
))
unique_speakers = len(set(s. speaker for s in segments))
print(f"分离完成,检测到 {unique_speakers} 个说话人")
return segments
def export_to_json(
self,
segments: List[DiarizationSegment],
output_path: str | Path
):
"""导出结果为 JSON 文件"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
data = {
"total_segments": len(segments),
"speaker_count": len(set(s.speaker for s in segments)),
"segments": [s.to_dict() for s in segments]
}
with open(output_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
print(f"结果已保存: {output_path}")
def export_to_rttm(
self,
segments: List[DiarizationSegment],
output_path: Union[str, Path],
wav_id: str = "default"
):
"""导出结果为 RTTM 文件"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
for seg in segments:
speaker_id = seg.speaker.replace("speaker_", "")
duration = seg.end_time - seg.begin_time
line = f"SPEAKER {wav_id} 0 {seg.begin_time:.3f} {duration:.3f} <NA> <NA> {speaker_id} <NA> <NA>\n"
f.write(line)
print(f"RTTM 结果已保存: {output_path}")
def create_diarization_service(
embedding_model: str = "eres2netv2",
device: str = "auto",
cluster_threshold: float = 0.5,
min_cluster_size: int = 10
) -> DiarizationService:
"""
创建说话人分离服务的工厂函数
Args:
embedding_model: 说话人嵌入模型 (campplus/eres2net/eres2netv2)
device: 运行设备
cluster_threshold: 聚类阈值 (0.0-1.0)
- 值越低 → 越容易合并说话人(适合少人对话)
- 值越高 → 越容易分开说话人(适合多人对话)
min_cluster_size: 每个说话人最少片段数
Returns:
DiarizationService 实例
"""
return DiarizationService(
embedding_model=embedding_model,
device=device,
cluster_threshold=cluster_threshold,
min_cluster_size=min_cluster_size
)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='3D-Speaker 说话人分离')
parser.add_argument('--wav', type=str, required=True, help='输入音频文件')
parser.add_argument('--out', type=str, default='./diarization_result.json', help='输出文件')
parser.add_argument('--model', type=str, default='eres2netv2',
choices=['campplus', 'eres2net', 'eres2netv2'], help='说话人嵌入模型')
parser.add_argument('--device', type=str, default='auto', help='设备 (cpu/cuda/auto)')
parser.add_argument('--speaker_num', type=int, default=None, help='预设说话人数量')
parser.add_argument('--threshold', type=float, default=0.5, help='聚类阈值 (0.0-1.0)')
parser.add_argument('--min_cluster_size', type=int, default=10, help='每个说话人最少片段数')
args = parser.parse_args()
diarization = DiarizationService(
embedding_model=args.model,
device=args.device,
cluster_threshold=args.threshold,
min_cluster_size=args.min_cluster_size
)
segments = diarization.diarize(args.wav, speaker_num=args.speaker_num)
diarization.export_to_json(segments, args.out)
print(f"\n分离结果:")
for seg in segments[:10]:
print(f" [{seg.begin_time:.2f}s - {seg.end_time:.2f}s] {seg.speaker}")