""" 3D-Speaker 说话人分离服务 支持:说话人分离、可调聚类参数、自动人数检测 """ import os import sys import json from pathlib import Path from typing import List, Dict, Union, Optional from dataclasses import dataclass diarization_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "3D-Speaker") if os.path.exists(diarization_path): sys.path.insert(0, diarization_path) MODEL_CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models") os.makedirs(MODEL_CACHE_DIR, exist_ok=True) os.environ["MODELSCOPE_CACHE"] = MODEL_CACHE_DIR import warnings warnings.filterwarnings('ignore') @dataclass class DiarizationSegment: """说话人分离结果片段""" speaker: str begin_time: float end_time: float def to_dict(self) -> Dict: return { "speaker": self.speaker, "begin_time": round(self.begin_time, 2), "end_time": round(self.end_time, 2), "duration": round(self.end_time - self.begin_time, 2) } class DiarizationService: """ 3D-Speaker 说话人分离服务 功能: 1. 说话人分离(Speaker Diarization) 2. 可调节聚类参数 3. 支持多人对话 4. 自动说话人人数检测 支持的说话人嵌入模型: - campplus: CAM++ (默认,快速) - eres2net: ERes2Net (更准确) - eres2netv2: ERes2NetV2 (最新,效果最好) """ def __init__( self, embedding_model: str = "eres2net", device: str = "auto", include_overlap: bool = False, hf_access_token: Optional[str] = None, cache_dir: Optional[str] = None, min_speakers: int = 1, max_speakers: int = 10, cluster_threshold: float = 0.8, min_cluster_size: int = 4 ): """ 初始化说话人分离服务 Args: embedding_model: 说话人嵌入模型 - "campplus": CAM++ 模型 - "eres2net": ERes2Net 模型 - "eres2netv2": ERes2NetV2 模型 device: 运行设备 ("cpu", "cuda", "auto") include_overlap: 是否包含重叠语音检测(需要 hf_access_token) hf_access_token: HuggingFace 访问令牌(用于重叠语音检测) cache_dir: 模型缓存目录 min_speakers: 最少说话人数量 max_speakers: 最多说话人数量 cluster_threshold: 聚类相似度阈值 (0.0-1.0) - 值越高:越严格,可能分成更多说话人 - 值越低:越宽松,会合并更多说话人 min_cluster_size: 每个说话人最少片段数 """ self.embedding_model = embedding_model self.device = self._get_device(device) self.include_overlap = include_overlap self.hf_access_token = hf_access_token self.cache_dir = cache_dir or MODEL_CACHE_DIR self.min_speakers = min_speakers self.max_speakers = max_speakers self.cluster_threshold = cluster_threshold self.min_cluster_size = min_cluster_size self.model = None def _get_device(self, device: str) -> str: if device == "auto": try: import torch device = "cuda" if torch.cuda.is_available() else "cpu" except ImportError: device = "cpu" return device def _load_model(self): """加载 3D-Speaker 说话人分离模型""" if self.model is not None: return print(f"正在加载 3D-Speaker 说话人分离模型...") print(f"设备:{self.device}") print(f"说话人嵌入模型:{self.embedding_model}") print(f"聚类参数:threshold={self.cluster_threshold}, min_cluster_size={self.min_cluster_size}") sys.stdout.flush() # 确保输出立即显示 embedding_models = { "campplus": "iic/speech_campplus_sv_zh_en_16k-common_advanced", "eres2net": "iic/speech_eres2net_sv_zh-cn_16k-common", "eres2netv2": "iic/speech_eres2netv2_sv_zh-cn_16k-common", } try: from speakerlab.bin.infer_diarization import Diarization3Dspeaker print(f" - 导入 Diarization3Dspeaker 完成") sys.stdout.flush() self.model = Diarization3Dspeaker( device=self.device, include_overlap=self.include_overlap, hf_access_token=self.hf_access_token, model_cache_dir=self.cache_dir ) print(f" - 模型实例化完成") sys.stdout.flush() print(f"模型加载完成!") sys.stdout.flush() except Exception as e: print(f"\n✗ 模型加载失败:{e}") sys.stdout.flush() import traceback traceback.print_exc() sys.stdout.flush() raise def diarize( self, audio_path: Union[str, Path], speaker_num: Optional[int] = None, ) -> List[DiarizationSegment]: """ 执行说话人分离 Args: audio_path: 音频文件路径 speaker_num: 预设说话人数量(可选) - 如果不指定,会自动检测 Returns: List[DiarizationSegment]: 说话人分离结果 """ self._load_model() audio_path = Path(audio_path) if not audio_path.exists(): raise FileNotFoundError(f"音频文件不存在: {audio_path}") print(f"正在执行说话人分离: {audio_path}") result = self.model( wav=str(audio_path), speaker_num=speaker_num ) segments = [] for seg in result: begin_time, end_time, speaker_id = seg segments.append(DiarizationSegment( speaker=f"speaker_{speaker_id}", begin_time=begin_time, end_time=end_time )) unique_speakers = len(set(s. speaker for s in segments)) print(f"分离完成,检测到 {unique_speakers} 个说话人") return segments def export_to_json( self, segments: List[DiarizationSegment], output_path: str | Path ): """导出结果为 JSON 文件""" output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) data = { "total_segments": len(segments), "speaker_count": len(set(s.speaker for s in segments)), "segments": [s.to_dict() for s in segments] } with open(output_path, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) print(f"结果已保存: {output_path}") def export_to_rttm( self, segments: List[DiarizationSegment], output_path: Union[str, Path], wav_id: str = "default" ): """导出结果为 RTTM 文件""" output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: for seg in segments: speaker_id = seg.speaker.replace("speaker_", "") duration = seg.end_time - seg.begin_time line = f"SPEAKER {wav_id} 0 {seg.begin_time:.3f} {duration:.3f} {speaker_id} \n" f.write(line) print(f"RTTM 结果已保存: {output_path}") def create_diarization_service( embedding_model: str = "eres2netv2", device: str = "auto", cluster_threshold: float = 0.5, min_cluster_size: int = 10 ) -> DiarizationService: """ 创建说话人分离服务的工厂函数 Args: embedding_model: 说话人嵌入模型 (campplus/eres2net/eres2netv2) device: 运行设备 cluster_threshold: 聚类阈值 (0.0-1.0) - 值越低 → 越容易合并说话人(适合少人对话) - 值越高 → 越容易分开说话人(适合多人对话) min_cluster_size: 每个说话人最少片段数 Returns: DiarizationService 实例 """ return DiarizationService( embedding_model=embedding_model, device=device, cluster_threshold=cluster_threshold, min_cluster_size=min_cluster_size ) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='3D-Speaker 说话人分离') parser.add_argument('--wav', type=str, required=True, help='输入音频文件') parser.add_argument('--out', type=str, default='./diarization_result.json', help='输出文件') parser.add_argument('--model', type=str, default='eres2netv2', choices=['campplus', 'eres2net', 'eres2netv2'], help='说话人嵌入模型') parser.add_argument('--device', type=str, default='auto', help='设备 (cpu/cuda/auto)') parser.add_argument('--speaker_num', type=int, default=None, help='预设说话人数量') parser.add_argument('--threshold', type=float, default=0.5, help='聚类阈值 (0.0-1.0)') parser.add_argument('--min_cluster_size', type=int, default=10, help='每个说话人最少片段数') args = parser.parse_args() diarization = DiarizationService( embedding_model=args.model, device=args.device, cluster_threshold=args.threshold, min_cluster_size=args.min_cluster_size ) segments = diarization.diarize(args.wav, speaker_num=args.speaker_num) diarization.export_to_json(segments, args.out) print(f"\n分离结果:") for seg in segments[:10]: print(f" [{seg.begin_time:.2f}s - {seg.end_time:.2f}s] {seg.speaker}")