179 lines
4.9 KiB
Python
179 lines
4.9 KiB
Python
"""
|
||
FunASR 语音识别测试脚本
|
||
测试功能:句级时间戳、说话人分离
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import argparse
|
||
from pathlib import Path
|
||
|
||
|
||
def print_banner():
|
||
"""打印欢迎信息"""
|
||
print("=" * 70)
|
||
print(" FunASR 语音识别测试工具")
|
||
print("=" * 70)
|
||
print("功能特性:")
|
||
print(" • 句级时间戳(开始时间 - 结束时间)")
|
||
print(" • 说话人分离(自动区分不同说话人)")
|
||
print(" • 抗噪处理(VAD 语音活动检测)")
|
||
print(" • 支持中文、方言、多语言")
|
||
print("=" * 70)
|
||
print()
|
||
|
||
|
||
def test_single_audio(audio_path: str, model_name: str = "paraformer-zh"):
|
||
"""测试单个音频文件"""
|
||
from asr_service import ASRService
|
||
|
||
# 检查文件
|
||
if not os.path.exists(audio_path):
|
||
print(f"❌ 错误: 文件不存在 - {audio_path}")
|
||
return
|
||
|
||
# 初始化服务
|
||
print(f"🔄 正在初始化模型: {model_name}")
|
||
print(f"📝 音频文件: {audio_path}")
|
||
print("-" * 70)
|
||
|
||
service = ASRService(model_name=model_name)
|
||
|
||
# 执行识别
|
||
try:
|
||
sentences = service.recognize(audio_path)
|
||
except Exception as e:
|
||
print(f"❌ 识别失败: {e}")
|
||
return
|
||
|
||
# 显示结果
|
||
print("\n✅ 识别完成!")
|
||
print("=" * 70)
|
||
print(f"共识别出 {len(sentences)} 句话\n")
|
||
|
||
for i, sent in enumerate(sentences, 1):
|
||
print(f"[{i}] {sent}")
|
||
|
||
# 导出结果
|
||
base_name = Path(audio_path).stem
|
||
|
||
# 导出 JSON
|
||
json_path = f"output/{base_name}_result.json"
|
||
service.export_to_json(sentences, json_path) # type: ignore
|
||
|
||
# 导出 SRT 字幕
|
||
srt_path = f"output/{base_name}_result.srt"
|
||
service.export_to_srt(sentences, srt_path) # type: ignore
|
||
|
||
print("\n" + "=" * 70)
|
||
print("📁 输出文件:")
|
||
print(f" • JSON: {json_path}")
|
||
print(f" • SRT: {srt_path}")
|
||
print("=" * 70)
|
||
|
||
|
||
def test_batch(audio_dir: str, model_name: str = "paraformer-zh"):
|
||
"""批量测试目录中的音频文件"""
|
||
from asr_service import ASRService
|
||
|
||
# 支持的音频格式
|
||
audio_extensions = {".wav", ".mp3", ".m4a", ".flac", ".ogg", ".wma"}
|
||
|
||
# 扫描音频文件
|
||
audio_files = []
|
||
for ext in audio_extensions:
|
||
audio_files.extend(Path(audio_dir).glob(f"*{ext}"))
|
||
|
||
if not audio_files:
|
||
print(f"❌ 未找到音频文件(支持格式: {', '.join(audio_extensions)})")
|
||
return
|
||
|
||
print(f"🔄 找到 {len(audio_files)} 个音频文件")
|
||
print("-" * 70)
|
||
|
||
# 初始化服务
|
||
service = ASRService(model_name=model_name)
|
||
|
||
# 批量识别
|
||
for audio_path in audio_files:
|
||
print(f"\n处理: {audio_path.name}")
|
||
try:
|
||
sentences = service.recognize(audio_path)
|
||
print(f" ✓ 识别出 {len(sentences)} 句话")
|
||
|
||
# 导出
|
||
base_name = audio_path.stem
|
||
service.export_to_json(sentences, f"output/{base_name}_result.json") # type: ignore
|
||
except Exception as e:
|
||
print(f" ✗ 失败: {e}")
|
||
|
||
print("\n✅ 批量处理完成!")
|
||
|
||
|
||
def download_test_audio():
|
||
"""下载测试音频(示例)"""
|
||
print("📝 请准备测试音频文件")
|
||
print("支持的格式: wav, mp3, m4a, flac, ogg, wma")
|
||
print("\n示例音频来源:")
|
||
print(" • 自行录制会议/对话音频")
|
||
print(" • AISHELL 开源数据集: https://www.openslr.org/33/")
|
||
print(" • 魔搭社区示例: https://modelscope.cn/models/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="FunASR 语音识别测试工具",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
示例用法:
|
||
# 识别单个文件
|
||
python test_asr.py -f your_audio.wav
|
||
|
||
# 使用 SenseVoice 模型(多语言)
|
||
python test_asr.py -f your_audio.wav -m SenseVoice
|
||
|
||
# 批量识别目录
|
||
python test_asr.py -d ./audio_files/
|
||
"""
|
||
)
|
||
|
||
parser.add_argument(
|
||
"-f", "--file",
|
||
help="要识别的音频文件路径"
|
||
)
|
||
parser.add_argument(
|
||
"-d", "--directory",
|
||
help="要批量识别的音频目录"
|
||
)
|
||
parser.add_argument(
|
||
"-m", "--model",
|
||
default="paraformer-zh",
|
||
choices=["paraformer-zh", "SenseVoice"],
|
||
help="选择模型 (默认: paraformer-zh)"
|
||
)
|
||
parser.add_argument(
|
||
"--download-sample",
|
||
action="store_true",
|
||
help="显示测试音频下载信息"
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
print_banner()
|
||
|
||
if args.download_sample:
|
||
download_test_audio()
|
||
elif args.file:
|
||
test_single_audio(args.file, args.model)
|
||
elif args.directory:
|
||
test_batch(args.directory, args.model)
|
||
else:
|
||
parser.print_help()
|
||
print("\n" + "=" * 70)
|
||
print("提示: 使用 -f 指定音频文件,或 -d 指定音频目录")
|
||
print("=" * 70)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|