294 lines
9.8 KiB
Python
294 lines
9.8 KiB
Python
"""
|
||
3D-Speaker 说话人分离服务
|
||
支持:说话人分离、可调聚类参数、自动人数检测
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
from pathlib import Path
|
||
from typing import List, Dict, Union, Optional
|
||
from dataclasses import dataclass
|
||
|
||
diarization_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "3D-Speaker")
|
||
if os.path.exists(diarization_path):
|
||
sys.path.insert(0, diarization_path)
|
||
|
||
MODEL_CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
|
||
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
|
||
os.environ["MODELSCOPE_CACHE"] = MODEL_CACHE_DIR
|
||
|
||
import warnings
|
||
warnings.filterwarnings('ignore')
|
||
|
||
|
||
@dataclass
|
||
class DiarizationSegment:
|
||
"""说话人分离结果片段"""
|
||
speaker: str
|
||
begin_time: float
|
||
end_time: float
|
||
|
||
def to_dict(self) -> Dict:
|
||
return {
|
||
"speaker": self.speaker,
|
||
"begin_time": round(self.begin_time, 2),
|
||
"end_time": round(self.end_time, 2),
|
||
"duration": round(self.end_time - self.begin_time, 2)
|
||
}
|
||
|
||
|
||
class DiarizationService:
|
||
"""
|
||
3D-Speaker 说话人分离服务
|
||
|
||
功能:
|
||
1. 说话人分离(Speaker Diarization)
|
||
2. 可调节聚类参数
|
||
3. 支持多人对话
|
||
4. 自动说话人人数检测
|
||
|
||
支持的说话人嵌入模型:
|
||
- campplus: CAM++ (默认,快速)
|
||
- eres2net: ERes2Net (更准确)
|
||
- eres2netv2: ERes2NetV2 (最新,效果最好)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
embedding_model: str = "eres2net",
|
||
device: str = "auto",
|
||
include_overlap: bool = False,
|
||
hf_access_token: Optional[str] = None,
|
||
cache_dir: Optional[str] = None,
|
||
min_speakers: int = 1,
|
||
max_speakers: int = 10,
|
||
cluster_threshold: float = 0.8,
|
||
min_cluster_size: int = 4
|
||
):
|
||
"""
|
||
初始化说话人分离服务
|
||
|
||
Args:
|
||
embedding_model: 说话人嵌入模型
|
||
- "campplus": CAM++ 模型
|
||
- "eres2net": ERes2Net 模型
|
||
- "eres2netv2": ERes2NetV2 模型
|
||
device: 运行设备 ("cpu", "cuda", "auto")
|
||
include_overlap: 是否包含重叠语音检测(需要 hf_access_token)
|
||
hf_access_token: HuggingFace 访问令牌(用于重叠语音检测)
|
||
cache_dir: 模型缓存目录
|
||
min_speakers: 最少说话人数量
|
||
max_speakers: 最多说话人数量
|
||
cluster_threshold: 聚类相似度阈值 (0.0-1.0)
|
||
- 值越高:越严格,可能分成更多说话人
|
||
- 值越低:越宽松,会合并更多说话人
|
||
min_cluster_size: 每个说话人最少片段数
|
||
"""
|
||
self.embedding_model = embedding_model
|
||
self.device = self._get_device(device)
|
||
self.include_overlap = include_overlap
|
||
self.hf_access_token = hf_access_token
|
||
self.cache_dir = cache_dir or MODEL_CACHE_DIR
|
||
|
||
self.min_speakers = min_speakers
|
||
self.max_speakers = max_speakers
|
||
self.cluster_threshold = cluster_threshold
|
||
self.min_cluster_size = min_cluster_size
|
||
|
||
self.model = None
|
||
|
||
def _get_device(self, device: str) -> str:
|
||
if device == "auto":
|
||
try:
|
||
import torch
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
except ImportError:
|
||
device = "cpu"
|
||
return device
|
||
|
||
def _load_model(self):
|
||
"""加载 3D-Speaker 说话人分离模型"""
|
||
if self.model is not None:
|
||
return
|
||
|
||
print(f"正在加载 3D-Speaker 说话人分离模型...")
|
||
print(f"设备:{self.device}")
|
||
print(f"说话人嵌入模型:{self.embedding_model}")
|
||
print(f"聚类参数:threshold={self.cluster_threshold}, min_cluster_size={self.min_cluster_size}")
|
||
sys.stdout.flush() # 确保输出立即显示
|
||
|
||
embedding_models = {
|
||
"campplus": "iic/speech_campplus_sv_zh_en_16k-common_advanced",
|
||
"eres2net": "iic/speech_eres2net_sv_zh-cn_16k-common",
|
||
"eres2netv2": "iic/speech_eres2netv2_sv_zh-cn_16k-common",
|
||
}
|
||
|
||
try:
|
||
from speakerlab.bin.infer_diarization import Diarization3Dspeaker # type: ignore
|
||
|
||
print(f" - 导入 Diarization3Dspeaker 完成")
|
||
sys.stdout.flush()
|
||
|
||
self.model = Diarization3Dspeaker(
|
||
device=self.device,
|
||
include_overlap=self.include_overlap,
|
||
hf_access_token=self.hf_access_token,
|
||
model_cache_dir=self.cache_dir
|
||
)
|
||
|
||
print(f" - 模型实例化完成")
|
||
sys.stdout.flush()
|
||
print(f"模型加载完成!")
|
||
sys.stdout.flush()
|
||
|
||
except Exception as e:
|
||
print(f"\n✗ 模型加载失败:{e}")
|
||
sys.stdout.flush()
|
||
import traceback
|
||
traceback.print_exc()
|
||
sys.stdout.flush()
|
||
raise
|
||
|
||
def diarize(
|
||
self,
|
||
audio_path: Union[str, Path],
|
||
speaker_num: Optional[int] = None,
|
||
) -> List[DiarizationSegment]:
|
||
"""
|
||
执行说话人分离
|
||
|
||
Args:
|
||
audio_path: 音频文件路径
|
||
speaker_num: 预设说话人数量(可选)
|
||
- 如果不指定,会自动检测
|
||
|
||
Returns:
|
||
List[DiarizationSegment]: 说话人分离结果
|
||
"""
|
||
self._load_model()
|
||
|
||
audio_path = Path(audio_path)
|
||
if not audio_path.exists():
|
||
raise FileNotFoundError(f"音频文件不存在: {audio_path}")
|
||
|
||
print(f"正在执行说话人分离: {audio_path}")
|
||
|
||
if self.model is None:
|
||
raise RuntimeError("模型未正确加载,无法执行说话人分离")
|
||
|
||
result = self.model(
|
||
wav=str(audio_path),
|
||
speaker_num=speaker_num
|
||
)
|
||
|
||
segments = []
|
||
for seg in result:
|
||
begin_time, end_time, speaker_id = seg
|
||
segments.append(DiarizationSegment(
|
||
speaker=f"speaker_{speaker_id}",
|
||
begin_time=begin_time,
|
||
end_time=end_time
|
||
))
|
||
|
||
unique_speakers = len(set(s. speaker for s in segments))
|
||
print(f"分离完成,检测到 {unique_speakers} 个说话人")
|
||
return segments
|
||
|
||
def export_to_json(
|
||
self,
|
||
segments: List[DiarizationSegment],
|
||
output_path: str | Path
|
||
):
|
||
"""导出结果为 JSON 文件"""
|
||
output_path = Path(output_path)
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
data = {
|
||
"total_segments": len(segments),
|
||
"speaker_count": len(set(s.speaker for s in segments)),
|
||
"segments": [s.to_dict() for s in segments]
|
||
}
|
||
|
||
with open(output_path, "w", encoding="utf-8") as f:
|
||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||
|
||
print(f"结果已保存: {output_path}")
|
||
|
||
def export_to_rttm(
|
||
self,
|
||
segments: List[DiarizationSegment],
|
||
output_path: Union[str, Path],
|
||
wav_id: str = "default"
|
||
):
|
||
"""导出结果为 RTTM 文件"""
|
||
output_path = Path(output_path)
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
with open(output_path, "w", encoding="utf-8") as f:
|
||
for seg in segments:
|
||
speaker_id = seg.speaker.replace("speaker_", "")
|
||
duration = seg.end_time - seg.begin_time
|
||
line = f"SPEAKER {wav_id} 0 {seg.begin_time:.3f} {duration:.3f} <NA> <NA> {speaker_id} <NA> <NA>\n"
|
||
f.write(line)
|
||
|
||
print(f"RTTM 结果已保存: {output_path}")
|
||
|
||
|
||
def create_diarization_service(
|
||
embedding_model: str = "eres2netv2",
|
||
device: str = "auto",
|
||
cluster_threshold: float = 0.5,
|
||
min_cluster_size: int = 10
|
||
) -> DiarizationService:
|
||
"""
|
||
创建说话人分离服务的工厂函数
|
||
|
||
Args:
|
||
embedding_model: 说话人嵌入模型 (campplus/eres2net/eres2netv2)
|
||
device: 运行设备
|
||
cluster_threshold: 聚类阈值 (0.0-1.0)
|
||
- 值越低 → 越容易合并说话人(适合少人对话)
|
||
- 值越高 → 越容易分开说话人(适合多人对话)
|
||
min_cluster_size: 每个说话人最少片段数
|
||
|
||
Returns:
|
||
DiarizationService 实例
|
||
"""
|
||
return DiarizationService(
|
||
embedding_model=embedding_model,
|
||
device=device,
|
||
cluster_threshold=cluster_threshold,
|
||
min_cluster_size=min_cluster_size
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='3D-Speaker 说话人分离')
|
||
parser.add_argument('--wav', type=str, required=True, help='输入音频文件')
|
||
parser.add_argument('--out', type=str, default='./diarization_result.json', help='输出文件')
|
||
parser.add_argument('--model', type=str, default='eres2netv2',
|
||
choices=['campplus', 'eres2net', 'eres2netv2'], help='说话人嵌入模型')
|
||
parser.add_argument('--device', type=str, default='auto', help='设备 (cpu/cuda/auto)')
|
||
parser.add_argument('--speaker_num', type=int, default=None, help='预设说话人数量')
|
||
parser.add_argument('--threshold', type=float, default=0.5, help='聚类阈值 (0.0-1.0)')
|
||
parser.add_argument('--min_cluster_size', type=int, default=10, help='每个说话人最少片段数')
|
||
|
||
args = parser.parse_args()
|
||
|
||
diarization = DiarizationService(
|
||
embedding_model=args.model,
|
||
device=args.device,
|
||
cluster_threshold=args.threshold,
|
||
min_cluster_size=args.min_cluster_size
|
||
)
|
||
|
||
segments = diarization.diarize(args.wav, speaker_num=args.speaker_num)
|
||
diarization.export_to_json(segments, args.out)
|
||
|
||
print(f"\n分离结果:")
|
||
for seg in segments[:10]:
|
||
print(f" [{seg.begin_time:.2f}s - {seg.end_time:.2f}s] {seg.speaker}")
|