96 lines
2.7 KiB
Python
96 lines
2.7 KiB
Python
"""
|
|
测试模型加载(不使用多进程)
|
|
用于诊断是否是模型本身的问题
|
|
"""
|
|
|
|
import sys
|
|
import torch
|
|
|
|
print("=" * 60)
|
|
print("模型加载测试(单进程模式)")
|
|
print("=" * 60)
|
|
print(f"Python 版本:{sys.version}")
|
|
print(f"PyTorch 版本:{torch.__version__}")
|
|
print(f"CUDA 可用:{torch.cuda.is_available()}")
|
|
if torch.cuda.is_available():
|
|
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
|
print(f"CUDA 版本:{torch.version.cuda}")
|
|
print()
|
|
|
|
try:
|
|
# 测试 ASR 模型
|
|
print("=" * 60)
|
|
print("测试 1: 加载 ASR 模型 (Paraformer)")
|
|
print("=" * 60)
|
|
from asr_service import ASRService
|
|
|
|
asr_service = ASRService(model_name="paraformer-zh", device="auto")
|
|
print("✓ ASRService 初始化完成")
|
|
|
|
asr_service._load_model()
|
|
print("✓ ASR 模型加载完成")
|
|
print(f" 模型类型:{type(asr_service._model)}")
|
|
print()
|
|
|
|
# 测试说话人分离模型
|
|
print("=" * 60)
|
|
print("测试 2: 加载说话人分离模型 (3D-Speaker)")
|
|
print("=" * 60)
|
|
from diarization_service import DiarizationService
|
|
|
|
diar_service = DiarizationService(
|
|
embedding_model="eres2netv2",
|
|
device="auto",
|
|
cluster_threshold=0.5,
|
|
min_cluster_size=10
|
|
)
|
|
print("✓ DiarizationService 初始化完成")
|
|
|
|
diar_service._load_model()
|
|
print("✓ 说话人分离模型加载完成")
|
|
print(f" 模型类型:{type(diar_service.model)}")
|
|
print()
|
|
|
|
# 测试音频处理
|
|
print("=" * 60)
|
|
print("测试 3: 测试音频处理")
|
|
print("=" * 60)
|
|
|
|
# 检查是否有测试音频
|
|
from pathlib import Path
|
|
test_audio = Path("test.wav")
|
|
if test_audio.exists():
|
|
print(f"找到测试音频:{test_audio}")
|
|
|
|
print("执行 ASR 识别...")
|
|
sentences = asr_service.recognize(str(test_audio))
|
|
print(f"✓ ASR 识别完成,共 {len(sentences)} 句")
|
|
|
|
if sentences:
|
|
print(f" 第一句:{sentences[0]}")
|
|
|
|
print("执行说话人分离...")
|
|
segments = diar_service.diarize(str(test_audio))
|
|
print(f"✓ 说话人分离完成,共 {len(segments)} 个片段")
|
|
|
|
if segments:
|
|
print(f" 第一个片段:{segments[0]}")
|
|
else:
|
|
print("⚠️ 未找到测试音频 (test.wav),跳过处理测试")
|
|
|
|
print()
|
|
print("=" * 60)
|
|
print("✓ 所有测试通过!模型工作正常")
|
|
print("=" * 60)
|
|
|
|
except Exception as e:
|
|
print()
|
|
print("=" * 60)
|
|
print("✗ 测试失败!")
|
|
print("=" * 60)
|
|
print(f"错误:{e}")
|
|
print()
|
|
import traceback
|
|
traceback.print_exc()
|
|
sys.exit(1)
|