SpeechRecognition/test_model_load.py

96 lines
2.7 KiB
Python

"""
测试模型加载(不使用多进程)
用于诊断是否是模型本身的问题
"""
import sys
import torch
print("=" * 60)
print("模型加载测试(单进程模式)")
print("=" * 60)
print(f"Python 版本:{sys.version}")
print(f"PyTorch 版本:{torch.__version__}")
print(f"CUDA 可用:{torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"CUDA 版本:{torch.version.cuda}")
print()
try:
# 测试 ASR 模型
print("=" * 60)
print("测试 1: 加载 ASR 模型 (Paraformer)")
print("=" * 60)
from asr_service import ASRService
asr_service = ASRService(model_name="paraformer-zh", device="auto")
print("✓ ASRService 初始化完成")
asr_service._load_model()
print("✓ ASR 模型加载完成")
print(f" 模型类型:{type(asr_service._model)}")
print()
# 测试说话人分离模型
print("=" * 60)
print("测试 2: 加载说话人分离模型 (3D-Speaker)")
print("=" * 60)
from diarization_service import DiarizationService
diar_service = DiarizationService(
embedding_model="eres2netv2",
device="auto",
cluster_threshold=0.5,
min_cluster_size=10
)
print("✓ DiarizationService 初始化完成")
diar_service._load_model()
print("✓ 说话人分离模型加载完成")
print(f" 模型类型:{type(diar_service.model)}")
print()
# 测试音频处理
print("=" * 60)
print("测试 3: 测试音频处理")
print("=" * 60)
# 检查是否有测试音频
from pathlib import Path
test_audio = Path("test.wav")
if test_audio.exists():
print(f"找到测试音频:{test_audio}")
print("执行 ASR 识别...")
sentences = asr_service.recognize(str(test_audio))
print(f"✓ ASR 识别完成,共 {len(sentences)}")
if sentences:
print(f" 第一句:{sentences[0]}")
print("执行说话人分离...")
segments = diar_service.diarize(str(test_audio))
print(f"✓ 说话人分离完成,共 {len(segments)} 个片段")
if segments:
print(f" 第一个片段:{segments[0]}")
else:
print("⚠️ 未找到测试音频 (test.wav),跳过处理测试")
print()
print("=" * 60)
print("✓ 所有测试通过!模型工作正常")
print("=" * 60)
except Exception as e:
print()
print("=" * 60)
print("✗ 测试失败!")
print("=" * 60)
print(f"错误:{e}")
print()
import traceback
traceback.print_exc()
sys.exit(1)