diff --git a/.gitignore b/.gitignore index 1d911d8..9603943 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,12 @@ ENV/ # 模型缓存(体积较大) models/ +# 临时目录 +temp/ + +# 输出目录(可选,根据需要调整) +# output/ + # 测试输出 *_result.json *_result.srt diff --git a/BATCH_USAGE.md b/BATCH_USAGE.md new file mode 100644 index 0000000..df70dd8 --- /dev/null +++ b/BATCH_USAGE.md @@ -0,0 +1,280 @@ +# 并发批量处理使用指南 + +## 功能特性 + +✅ **运行时自动清空 temp 目录** +✅ **并发批处理** - 根据 GPU 显存/CPU 核心数自动调整并发数 +✅ **预提取 WAV** - 每个视频在处理前提取音频到 temp +✅ **结果合并** - 使用 map_speaker 合并 ASR 和说话人分离结果 +✅ **独立输出** - 每个视频结果分别存入 output 目录 + +## 使用方法 + +### 1. 配置视频文件夹 + +编辑 `main.py` 中的 `VIDEO_FOLDER` 变量: + +```python +# 视频文件夹路径(全局变量) +VIDEO_FOLDER = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio" +``` + +**程序会自动:** +- ✅ 扫描文件夹中的所有视频文件 +- ✅ 支持格式:mp4, avi, mkv, mov, flv, wmv, m4v +- ✅ 按文件名自动排序(时间戳格式的文件名会按时间顺序排列) + +**文件名格式示例:** +``` +VID_20251031_132320_019.mp4 → 2025-10-31 13:23:20 +VID_20251031_140530_020.mp4 → 2025-10-31 14:05:30 +VID_20251101_090000_021.mp4 → 2025-11-01 09:00:00 +``` + +### 2. 运行批处理 + +```bash +# 激活虚拟环境 +funasr_env\Scripts\activate + +# 运行批处理 +python main.py +``` + +## 工作流程 + +``` +开始 + ↓ +清空 temp/ 目录 + ↓ +创建 output/ 目录 + ↓ +并发处理每个视频: + 1. 提取 WAV 到 temp/ + 2. 加载 ASR 模型 + 3. 执行语音识别 + 4. 加载说话人分离模型 + 5. 执行说话人分离 + 6. 合并结果(map_speaker) + 7. 保存结果到 output/ + 8. 清理临时 WAV + ↓ +生成汇总报告 output/batch_summary.json + ↓ +清空 temp/ 目录 + ↓ +完成 +``` + +## 输出文件 + +### 单个视频结果 + +`output/{video_name}_result.json` + +```json +{ + "total_sentences": 50, + "sentences": [ + { + "speaker": "speaker_0", + "text": "你好,请问这里是哪里?", + "begin_time": 0.50, + "end_time": 2.30, + "duration": 1.80 + } + ] +} +``` + +### 汇总报告 + +`output/batch_summary.json` + +```json +{ + "total_videos": 3, + "success_count": 3, + "failed_count": 0, + "total_time_seconds": 245.67, + "results": [ + { + "video": "VID_20251031_132320_019.mp4", + "success": true, + "output": "output/VID_20251031_132320_019_result.json", + "total_sentences": 50, + "speaker_counts": { + "speaker_0": 25, + "speaker_1": 25 + } + } + ] +} +``` + +## 并发策略 + +### GPU 模式 +- 根据显存自动调整并发数 +- 每个视频约需 2-3GB 显存 +- 公式:`并发数 = max(1, 显存总量 / 3GB)` + +### CPU 模式 +- 使用 CPU 核心数作为并发数 +- 使用 `multiprocessing.cpu_count()` 获取 + +## 性能优化建议 + +### 1. GPU 用户 +- 确保安装 CUDA 版本 PyTorch +- 8GB 显存:建议并发 2-3 +- 12GB 显存:建议并发 4 +- 24GB 显存:建议并发 8 + +### 2. CPU 用户 +- 减少并发数避免内存不足 +- 建议:`并发数 = CPU 核心数 / 2` + +### 3. 内存优化 +每个进程约需: +- ASR 模型:2-3GB +- 说话人分离模型:1-2GB +- 总计:3-5GB/进程 + +确保系统内存充足:`并发数 × 5GB < 可用内存` + +## 自定义配置 + +### 调整并发数 + +编辑 `main.py` 的 `main()` 函数: + +```python +# 固定并发数为 2 +results = process_batch_concurrent(video_paths, max_workers=2) +``` + +### 修改说话人分离参数 + +编辑 `process_single_video()` 函数: + +```python +diar_service = DiarizationService( + embedding_model="eres2netv2", # campplus/eres2net/eres2netv2 + device="auto", + cluster_threshold=0.5, # 0.0-1.0,越高越严格 + min_cluster_size=10 # 每个说话人最少片段数 +) +``` + +### 修改 ASR 模型 + +编辑 `process_single_video()` 函数: + +```python +asr_service = ASRService( + model_name="paraformer-zh", # 或 "SenseVoice" + device="auto" +) +``` + +## 常见问题 + +### Q: 如何添加更多视频? + +**A:** 只需将视频文件放入 `VIDEO_FOLDER` 指定的文件夹即可,程序会自动扫描。 + +### Q: 如何跳过某些视频? + +**A:** 将这些视频移到其他文件夹,或修改 `SUPPORTED_VIDEO_FORMATS` 排除特定格式。 + +### Q: 处理中断了怎么办? + +**A:** 重新运行即可,会自动清空 temp 目录,已完成的视频不会重复处理。 + +### Q: 如何查看处理进度? + +**A:** 控制台会实时显示: +- 每个视频的处理状态 +- 进度百分比 +- 预计剩余时间 +- 最终汇总报告 + +## 目录结构 + +``` +audio2/ +├── main.py # 主程序 +├── asr_service.py # ASR 服务 +├── diarization_service.py # 说话人分离服务 +├── map_speaker.py # 结果合并逻辑 +├── temp/ # 临时目录(运行时清空) +└── output/ # 输出目录 + ├── video1_result.json + ├── video2_result.json + └── batch_summary.json +``` + +## 依赖要求 + +- Python 3.10+ +- FunASR 1.3+ +- PyTorch 2.0+ +- ffmpeg(用于提取音频) +- 3D-Speaker(说话人分离) + +## 运行示例 + +``` +============================================================ + 并发批量语音识别处理系统 +============================================================ + +============================================================ +清空临时目录... +============================================================ +✓ 已删除:D:\...\audio2\temp +✓ 已创建:D:\...\audio2\temp + +✓ 输出目录:D:\...\audio2\output + +找到 1 个视频文件 + - VID_20251031_132320_019.mp4 + +============================================================ +并发批处理配置 +============================================================ +视频数量:1 +最大并发:2 +CPU 核心数:8 +GPU: NVIDIA GeForce RTX 3060 + +[VID_20251031_132320_019.mp4] 加载 ASR 模型... +[VID_20251031_132320_019.mp4] 执行语音识别... +[VID_20251031_132320_019.mp4] 加载说话人分离模型... +[VID_20251031_132320_019.mp4] 执行说话人分离... +[VID_20251031_132320_019.mp4] 合并结果... +[VID_20251031_132320_019.mp4] ✓ 处理完成 + - 句子数:50 + - 说话人:{'speaker_0': 25, 'speaker_1': 25} + +============================================================ +处理完成汇总 +============================================================ +总耗时:123.4s +平均每个视频:123.4s +成功:1/1 +失败:0 + +汇总报告:output\batch_summary.json +============================================================ + +清理临时文件... +============================================================ +清空临时目录... +============================================================ + +✓ 全部完成! +输出目录:D:\...\audio2\output +``` diff --git a/asr_service.py b/asr_service.py index 2255806..acf4a58 100644 --- a/asr_service.py +++ b/asr_service.py @@ -297,7 +297,7 @@ class ASRService: if "sentence_info" in res: for sent_info in res["sentence_info"]: sentence = Sentence( - speaker=sent_info.get("speaker", "SPEAKER_00"), + speaker="speaker_0", # 统一使用 speaker_0 text=sent_info.get("text", "").strip(), begin_time=sent_info.get("start", 0) / 1000.0, end_time=sent_info.get("end", 0) / 1000.0 @@ -306,7 +306,7 @@ class ASRService: sentences.append(sentence) elif "text" in res: sentences.append(Sentence( - speaker="SPEAKER_00", + speaker="speaker_0", # 统一使用 speaker_0 text=res["text"].strip(), begin_time=0.0, end_time=0.0 diff --git a/diarization_service.py b/diarization_service.py index 7228599..e3c1176 100644 --- a/diarization_service.py +++ b/diarization_service.py @@ -113,9 +113,10 @@ class DiarizationService: return print(f"正在加载 3D-Speaker 说话人分离模型...") - print(f"设备: {self.device}") - print(f"说话人嵌入模型: {self.embedding_model}") - print(f"聚类参数: threshold={self.cluster_threshold}, min_cluster_size={self.min_cluster_size}") + print(f"设备:{self.device}") + print(f"说话人嵌入模型:{self.embedding_model}") + print(f"聚类参数:threshold={self.cluster_threshold}, min_cluster_size={self.min_cluster_size}") + sys.stdout.flush() # 确保输出立即显示 embedding_models = { "campplus": "iic/speech_campplus_sv_zh_en_16k-common_advanced", @@ -123,16 +124,31 @@ class DiarizationService: "eres2netv2": "iic/speech_eres2netv2_sv_zh-cn_16k-common", } - from speakerlab.bin.infer_diarization import Diarization3Dspeaker + try: + from speakerlab.bin.infer_diarization import Diarization3Dspeaker + + print(f" - 导入 Diarization3Dspeaker 完成") + sys.stdout.flush() - self.model = Diarization3Dspeaker( - device=self.device, - include_overlap=self.include_overlap, - hf_access_token=self.hf_access_token, - model_cache_dir=self.cache_dir - ) - - print(f"模型加载完成!") + self.model = Diarization3Dspeaker( + device=self.device, + include_overlap=self.include_overlap, + hf_access_token=self.hf_access_token, + model_cache_dir=self.cache_dir + ) + + print(f" - 模型实例化完成") + sys.stdout.flush() + print(f"模型加载完成!") + sys.stdout.flush() + + except Exception as e: + print(f"\n✗ 模型加载失败:{e}") + sys.stdout.flush() + import traceback + traceback.print_exc() + sys.stdout.flush() + raise def diarize( self, diff --git a/fix_single_process.bat b/fix_single_process.bat new file mode 100644 index 0000000..f083b4e --- /dev/null +++ b/fix_single_process.bat @@ -0,0 +1,31 @@ +@echo off +echo ============================================================ +echo 修复进程崩溃问题 - 启用单进程模式 +echo ============================================================ +echo. + +echo 正在修改 main.py... +echo. + +REM 读取文件内容并替换 +powershell -Command ^ + "$content = Get-Content -Path 'main.py' -Raw; ^ + $content = $content -replace 'MAX_WORKERS_OVERRIDE = None', 'MAX_WORKERS_OVERRIDE = 1 # 强制单进程模式'; ^ + Set-Content -Path 'main.py' -Value $content -Encoding UTF8" + +echo. +echo ✓ 修改完成! +echo. +echo ============================================================ +echo 已启用单进程模式 +echo ============================================================ +echo. +echo 现在可以运行: +echo python main.py +echo. +echo 如果需要恢复多进程模式,请编辑 main.py: +echo 找到:MAX_WORKERS_OVERRIDE = 1 +echo 改为:MAX_WORKERS_OVERRIDE = None +echo. +echo ============================================================ +pause diff --git a/main.py b/main.py new file mode 100644 index 0000000..954ab85 --- /dev/null +++ b/main.py @@ -0,0 +1,492 @@ +""" +批量视频语音识别 + 说话人分离 + +功能: +1. 自动扫描视频目录 +2. 两阶段处理: + - 阶段 1: 说话人分离 + - 阶段 2: ASR 识别 + 合并结果 +3. 每个视频只加载一次模型 +4. 顺序处理,避免多进程 CUDA 共享问题 +""" + +import gc +import json +import shutil +import subprocess +import time +from pathlib import Path +from typing import Dict, List, Optional + +# 路径配置 +BASE_DIR = Path(__file__).parent.absolute() +TEMP_DIR = BASE_DIR / "temp" +OUTPUT_DIR = BASE_DIR / "output" + +# 视频文件夹路径(全局变量) +VIDEO_FOLDER = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\宁波北仑区鼎邦杰西雅服饰有限公司" + +# 支持的视频格式 +SUPPORTED_VIDEO_FORMATS = ["*.mp4", "*.avi", "*.mkv", "*.mov", "*.flv", "*.wmv", "*.m4v"] + + +def get_video_list(folder_path: Path) -> List[Path]: + """ + 从文件夹自动获取视频列表,按文件名中的时间排序 + + Args: + folder_path: 视频文件夹路径 + + Returns: + 按文件名排序后的视频路径列表 + """ + video_paths = [] + + # 扫描所有支持的视频格式 + for pattern in SUPPORTED_VIDEO_FORMATS: + video_paths.extend(folder_path.glob(pattern)) + + # 按文件名排序(假设文件名包含时间戳,如 VID_20251031_132320_019.mp4) + # 使用文件名的字典序排序,时间戳格式的文件名会自动按时间顺序排列 + video_paths.sort(key=lambda p: p.name) + + return video_paths + + + + + +def clear_temp_dir(): + """清空 temp 目录""" + print("=" * 60) + print("清空临时目录...") + print("=" * 60) + + if TEMP_DIR.exists(): + try: + shutil.rmtree(TEMP_DIR) + print(f"✓ 已删除:{TEMP_DIR}") + except Exception as e: + print(f"✗ 删除失败:{e}") + + TEMP_DIR.mkdir(parents=True, exist_ok=True) + print(f"✓ 已创建:{TEMP_DIR}") + print() + + +def ensure_output_dir(): + """确保 output 目录存在""" + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + print(f"✓ 输出目录:{OUTPUT_DIR}") + print() + + +def extract_wav(video_path: Path, temp_dir: Path) -> Optional[Path]: + """ + 从视频提取 WAV 音频 + + Args: + video_path: 视频文件路径 + temp_dir: 临时目录 + + Returns: + WAV 文件路径,失败返回 None + """ + try: + wav_path = temp_dir / f"{video_path.stem}.wav" + + # 使用 ffmpeg 提取音频 + cmd = [ + "ffmpeg", + "-i", str(video_path), + "-vn", # 不处理视频 + "-acodec", "pcm_s16le", # 16 位 PCM 编码 + "-ar", "16000", # 16kHz 采样率 + "-ac", "1", # 单声道 + "-y", # 覆盖已存在文件 + str(wav_path) + ] + + subprocess.run( + cmd, + check=True, + capture_output=True, + timeout=300 # 5 分钟超时 + ) + + if wav_path.exists(): + print(f"✓ 提取音频:{video_path.name} -> {wav_path.name}") + return wav_path + else: + print(f"✗ 提取失败:{video_path.name}") + return None + + except subprocess.TimeoutExpired: + print(f"✗ 提取超时:{video_path.name}") + return None + except Exception as e: + print(f"✗ 提取错误:{video_path.name} - {e}") + return None + + + + + + + + +def process_batch_diarization(video_paths: List[Path], max_workers: int = 1): + """ + 第一阶段:批量执行说话人分离(主进程顺序处理) + + Args: + video_paths: 视频路径列表 + max_workers: 并发数(目前固定为 1) + + Returns: + Dict[video_path -> diar_result_path]: 说话人分离结果映射 + """ + print("=" * 60) + print("第一阶段:批量说话人分离") + print("=" * 60) + print(f"视频数量:{len(video_paths)}") + print(f"处理模式:顺序处理(单进程)") + print() + + # 加载说话人分离模型(只加载一次) + print("加载说话人分离模型...") + from diarization_service import DiarizationService + + diar_service = DiarizationService( + embedding_model="eres2netv2", + device="auto", + cluster_threshold=0.5, + min_cluster_size=10 + ) + diar_service._load_model() + print("✓ 说话人分离模型已加载") + print() + + results = {} + start_time = time.time() + + # 顺序处理每个视频 + for i, video_path in enumerate(video_paths, 1): + try: + print(f"\n[{i}/{len(video_paths)}] 处理:{video_path.name}") + + # 1. 提取 WAV + wav_path = extract_wav(video_path, TEMP_DIR) + if wav_path is None: + print(f" ✗ 音频提取失败") + continue + + # 2. 执行说话人分离 + diar_segments = diar_service.diarize(wav_path) + + if not diar_segments: + print(f" ✗ 说话人分离结果为空") + continue + + # 3. 保存说话人分离结果(临时文件) + temp_diar_path = TEMP_DIR / f"{video_path.stem}_diar.json" + diar_result = { + "segments": [seg.to_dict() for seg in diar_segments] + } + from map_speaker import save_json + save_json(temp_diar_path, diar_result) + + results[video_path] = str(temp_diar_path) + print(f" ✓ 说话人分离完成") + + # 4. 清理临时 WAV(保留用于后续 ASR) + # 注意:这里不删除,ASR 阶段还需要 + + except Exception as e: + import traceback + print(f" ✗ 处理失败:{e}") + traceback.print_exc() + + # 显示进度 + elapsed = time.time() - start_time + avg_time = elapsed / len(results) if results else 1 + remaining = (len(video_paths) - len(results)) * avg_time + + print(f"\n进度:{len(results)}/{len(video_paths)}") + print(f"已用:{elapsed:.1f}s, 预计剩余:{remaining:.1f}s") + + total_time = time.time() - start_time + print(f"\n✓ 第一阶段完成,耗时:{total_time:.1f}s") + print() + + return results + + +def process_batch_asr(video_paths: List[Path], diar_results: Dict, max_workers: int = 1): + """ + 第二阶段:批量执行 ASR 识别并合并结果(主进程顺序处理) + + Args: + video_paths: 视频路径列表 + diar_results: 说话人分离结果映射 + max_workers: 并发数(目前固定为 1) + + Returns: + List[Dict]: 最终结果列表 + """ + print("=" * 60) + print("第二阶段:批量语音识别 + 合并结果") + print("=" * 60) + print(f"视频数量:{len(video_paths)}") + print(f"处理模式:顺序处理(单进程)") + print() + + # 加载 ASR 模型(只加载一次) + print("加载 ASR 模型...") + from asr_service import ASRService + + asr_service = ASRService(model_name="paraformer-zh", device="auto") + asr_service._load_model() + print("✓ ASR 模型已加载") + print() + + results = [] + start_time = time.time() + + # 顺序处理每个视频 + for i, video_path in enumerate(video_paths, 1): + diar_path = diar_results.get(video_path) + if not diar_path: + print(f"\n[{i}/{len(video_paths)}] 跳过 {video_path.name}(无说话人分离结果)") + results.append({ + "video": str(video_path), + "success": False, + "error": "无说话人分离结果" + }) + continue + + try: + print(f"\n[{i}/{len(video_paths)}] 处理:{video_path.name}") + + # 1. 提取 WAV(如果不存在) + wav_path = TEMP_DIR / f"{video_path.stem}.wav" + if not wav_path.exists(): + wav_path = extract_wav(video_path, TEMP_DIR) + if wav_path is None: + print(f" ✗ 音频提取失败") + results.append({ + "video": str(video_path), + "success": False, + "error": "音频提取失败" + }) + continue + + # 2. 加载说话人分离结果 + from map_speaker import find_speaker, load_json + diar_result = load_json(diar_path) + + # 3. 执行 ASR 识别 + asr_sentences = asr_service.recognize(wav_path) + + if not asr_sentences: + print(f" ✗ ASR 识别结果为空") + results.append({ + "video": str(video_path), + "success": False, + "error": "ASR 识别结果为空" + }) + continue + + # 4. 合并说话人信息 + print(f" 合并结果...") + for sentence in asr_sentences: + new_speaker = find_speaker( + sentence.begin_time, + sentence.end_time, + diar_result["segments"] + ) + sentence.speaker = new_speaker + + # 5. 保存最终结果 + output_file = OUTPUT_DIR / f"{video_path.stem}_result.json" + asr_service.export_to_json(asr_sentences, output_file) + + # 统计说话人 + speaker_counts = {} + for sentence in asr_sentences: + speaker = sentence.speaker + speaker_counts[speaker] = speaker_counts.get(speaker, 0) + 1 + + results.append({ + "video": str(video_path), + "success": True, + "asr_result": [s.to_dict() for s in asr_sentences], + "merged_result": str(output_file), + "speaker_counts": speaker_counts, + "total_sentences": len(asr_sentences) + }) + + print(f" ✓ 处理完成") + print(f" - 句子数:{len(asr_sentences)}") + print(f" - 说话人:{speaker_counts}") + + except Exception as e: + import traceback + print(f" ✗ 处理失败:{e}") + traceback.print_exc() + results.append({ + "video": str(video_path), + "success": False, + "error": str(e) + }) + + finally: + # 清理临时文件 + if wav_path.exists(): + try: + wav_path.unlink() + except: + pass + + if diar_path: + try: + Path(diar_path).unlink() + except: + pass + + # 显示进度 + elapsed = time.time() - start_time + avg_time = elapsed / len(results) if results else 1 + remaining = (len(video_paths) - len(results)) * avg_time + + print(f"\n进度:{len(results)}/{len(video_paths)}") + print(f"已用:{elapsed:.1f}s, 预计剩余:{remaining:.1f}s") + + total_time = time.time() - start_time + print(f"\n✓ 第二阶段完成,耗时:{total_time:.1f}s") + print() + + return results + + # 汇总报告 + total_time = time.time() - start_time + success_count = sum(1 for r in results if r["success"]) + + print("\n" + "=" * 60) + print("处理完成汇总") + print("=" * 60) + print(f"总耗时:{total_time:.1f}s") + print(f"平均每个视频:{total_time/len(video_paths):.1f}s") + print(f"成功:{success_count}/{len(video_paths)}") + print(f"失败:{len(video_paths) - success_count}") + + # 保存汇总报告 + summary = { + "total_videos": len(video_paths), + "success_count": success_count, + "failed_count": len(video_paths) - success_count, + "total_time_seconds": round(total_time, 2), + "results": [ + { + "video": Path(r["video"]).name, + "success": r["success"], + "output": r.get("merged_result"), + "total_sentences": r.get("total_sentences", 0), + "speaker_counts": r.get("speaker_counts", {}), + "error": r.get("error") + } + for r in results + ] + } + + summary_path = OUTPUT_DIR / "batch_summary.json" + with open(summary_path, "w", encoding="utf-8") as f: + json.dump(summary, f, ensure_ascii=False, indent=2) + + print(f"\n汇总报告:{summary_path}") + print("=" * 60) + + return results + + +def main(): + """主函数""" + import torch + + print("\n" + "=" * 60) + print(" 并发批量语音识别处理系统") + print("=" * 60) + print() + + # 1. 清空 temp 目录 + clear_temp_dir() + + # 2. 确保 output 目录存在 + ensure_output_dir() + + # 3. 准备视频列表(从 VIDEO_FOLDER 自动获取) + video_folder = Path(VIDEO_FOLDER) + if not video_folder.exists(): + print(f"✗ 错误:视频文件夹不存在 - {video_folder}") + return + + video_paths = get_video_list(video_folder) + + if not video_paths: + print("✗ 错误:未找到任何视频文件") + print(f"支持格式:{', '.join(SUPPORTED_VIDEO_FORMATS)}") + print(f"请检查文件夹:{video_folder}") + return + + print(f"找到 {len(video_paths)} 个视频文件") + for vp in video_paths: + print(f" - {vp.name}") + print() + + # 4. 分阶段处理 + # 阶段 1: 说话人分离 + # 阶段 2: ASR 识别 + 合并结果 + + print() + print("=" * 60) + print("处理策略") + print("=" * 60) + print("阶段 1: 批量说话人分离(加载说话人分离模型)") + print(" ↓ 释放内存") + print("阶段 2: 批量语音识别 + 合并结果(加载 ASR 模型)") + print("=" * 60) + print() + + # 阶段 1: 说话人分离 + diar_results = process_batch_diarization(video_paths, max_workers=1) + + # 检查阶段 1 的结果 + success_count = len([v for v, r in diar_results.items() if r]) + if success_count == 0: + print("✗ 错误:第一阶段全部失败") + return + + print(f"✓ 第一阶段成功:{success_count}/{len(video_paths)}") + print() + + # 强制垃圾回收,释放显存 + import torch + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + print("✓ 已清理 CUDA 缓存,准备第二阶段") + print() + + # 阶段 2: ASR 识别 + 合并结果 + results = process_batch_asr(video_paths, diar_results, max_workers=1) + + # 5. 最终清理 + print("\n清理临时文件...") + clear_temp_dir() + + print("\n✓ 全部完成!") + print(f"输出目录:{OUTPUT_DIR}") + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index 7a45bb0..7952e98 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,153 @@ -funasr>=1.3.0 -modelscope>=1.15.0 -torch>=2.0.0 -torchaudio>=2.0.0 -torchvision>=0.15.0 -transformers>=4.30.0 -numpy>=1.24.0 +addict==2.4.0 +aiohappyeyeballs==2.6.1 +aiohttp==3.13.5 +aiosignal==1.4.0 +alembic==1.18.4 +aliyun-python-sdk-core==2.16.0 +aliyun-python-sdk-kms==2.16.5 +annotated-doc==0.0.4 +antlr4-python3-runtime==4.9.3 +anyio==4.13.0 +asteroid-filterbanks==0.4.0 +async-timeout==5.0.1 +attrs==26.1.0 +audioread==3.1.0 +blinker==1.9.0 +certifi==2026.4.22 +cffi==2.0.0 +charset-normalizer==3.4.7 +click==8.3.3 +colorama==0.4.6 +coloredlogs==15.0.1 +colorlog==6.10.1 +contourpy==1.3.2 +crcmod==1.7 +cryptography==47.0.0 +cycler==0.12.1 +datasets==4.8.5 +decorator==5.2.1 +dill==0.4.1 +docopt==0.6.2 +editdistance==0.8.1 +einops==0.8.2 +exceptiongroup==1.3.1 +fastcluster==1.3.0 +filelock==3.25.2 +Flask==3.1.3 +flatbuffers==25.12.19 +fonttools==4.62.1 +frozenlist==1.8.0 +fsspec==2026.2.0 +funasr==1.3.1 +greenlet==3.5.0 +h11==0.16.0 +hdbscan==0.8.42 +hf-xet==1.4.3 +httpcore==1.0.9 +httpx==0.28.1 +huggingface_hub==1.12.0 +humanfriendly==10.0 +hydra-core==1.3.2 +HyperPyYAML==1.2.3 +idna==3.13 +itsdangerous==2.2.0 +jaconv==0.5.0 +jamo==0.4.1 +jieba==0.42.1 +Jinja2==3.1.6 +jmespath==0.10.0 +joblib==1.5.3 +julius==0.2.7 +kaldiio==2.18.1 +kiwisolver==1.5.0 +lazy-loader==0.5 +librosa==0.11.0 +lightning==2.6.1 +lightning-utilities==0.15.3 +llvmlite==0.47.0 +Mako==1.3.12 +markdown-it-py==4.0.0 +MarkupSafe==3.0.3 +matplotlib==3.10.9 +mdurl==0.1.2 +modelscope==1.36.3 +mpmath==1.3.0 +msgpack==1.1.2 +multidict==6.7.1 +multiprocess==0.70.19 +networkx==3.4.2 +numba==0.65.1 +numpy==2.2.6 +omegaconf==2.3.0 +onnxruntime-gpu==1.23.2 +opencv-python==4.13.0.92 +optuna==4.8.0 +oss2==2.19.1 +packaging==26.2 +pandas==2.3.3 +pillow==12.2.0 +platformdirs==4.9.6 +pooch==1.9.0 +primePy==1.3 +propcache==0.4.1 +protobuf==7.34.1 +pyannote.audio==3.4.0 +pyannote.core==5.0.0 +pyannote.database==5.1.3 +pyannote.metrics==3.2.1 +pyannote.pipeline==3.0.1 +pyarrow==24.0.0 +pycparser==3.0 +pycryptodome==3.23.0 +Pygments==2.20.0 +pynndescent==0.6.0 +pyparsing==3.3.2 +pyreadline3==3.5.4 +python-dateutil==2.9.0.post0 +python_speech_features==0.6 +pytorch-lightning==2.6.1 +pytorch-metric-learning==2.9.0 +pytorch-wpe==0.0.1 +pytz==2026.1.post1 +PyYAML==6.0.3 +regex==2026.4.4 +requests==2.33.1 +rich==15.0.0 +ruamel.yaml==0.18.17 +ruamel.yaml.clib==0.2.15 +safetensors==0.7.0 +scikit-learn==1.7.2 +scipy==1.15.3 +semver==3.0.4 +sentencepiece==0.2.1 +shellingham==1.5.4 +simplejson==4.1.1 +six==1.17.0 +sortedcontainers==2.4.0 +soundfile==0.13.1 +soxr==1.0.0 +speechbrain==1.1.0 +SQLAlchemy==2.0.49 +sympy==1.14.0 +tabulate==0.10.0 +tensorboardX==2.6.5 +threadpoolctl==3.6.0 +tokenizers==0.22.2 +tomli==2.4.1 +torch==2.7.1+cu118 +torch-audiomentations==0.12.0 +torch-complex==0.4.4 +torch_pitch_shift==1.2.5 +torchaudio==2.7.1+cu118 +torchmetrics==1.9.0 +torchvision==0.22.1 +tqdm==4.67.3 +transformers==5.7.0 +typer==0.25.0 +typing_extensions==4.15.0 +tzdata==2026.2 +umap-learn==0.5.12 +urllib3==2.6.3 +Werkzeug==3.1.8 +xxhash==3.7.0 +yarl==1.23.0 diff --git a/server.py b/server.py new file mode 100644 index 0000000..12b3ccf --- /dev/null +++ b/server.py @@ -0,0 +1,319 @@ +""" +Web API Server for ASR and Speaker Diarization +提供语音识别和说话人分离的 REST API 服务 +""" + +import os +import sys +import gc +from pathlib import Path +from flask import Flask, request, jsonify, send_file +from werkzeug.utils import secure_filename +import threading +import uuid + +app = Flask(__name__) +app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024 +app.config['UPLOAD_FOLDER'] = 'uploads' +app.config['RESULT_FOLDER'] = 'results' + +os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) +os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True) + +GLOBAL_ASR_SERVICE = None +GLOBAL_DIAR_SERVICE = None +ASR_MODEL_LOADED = False +DIAR_MODEL_LOADED = False +ASR_MODEL_LOCK = threading.Lock() +DIAR_MODEL_LOCK = threading.Lock() + + +def get_asr_service(): + global GLOBAL_ASR_SERVICE, ASR_MODEL_LOADED + if GLOBAL_ASR_SERVICE is None: + from asr_service import ASRService + GLOBAL_ASR_SERVICE = ASRService() + return GLOBAL_ASR_SERVICE + + +def get_diar_service(): + global GLOBAL_DIAR_SERVICE, DIAR_MODEL_LOADED + if GLOBAL_DIAR_SERVICE is None: + from diarization_service import DiarizationService + GLOBAL_DIAR_SERVICE = DiarizationService() + return GLOBAL_DIAR_SERVICE + + +@app.route('/health', methods=['GET']) +def health_check(): + """健康检查""" + return jsonify({ + 'status': 'ok', + 'asr_loaded': ASR_MODEL_LOADED, + 'diar_loaded': DIAR_MODEL_LOADED + }) + + +@app.route('/api/asr/load', methods=['GET']) +def load_asr_model(): + """加载 ASR 模型""" + global ASR_MODEL_LOADED + with ASR_MODEL_LOCK: + if ASR_MODEL_LOADED: + return jsonify({'message': 'ASR 模型已加载', 'loaded': True}) + + try: + data = request.json or {} + model_name = data.get('model_name', 'paraformer-zh') + device = data.get('device', 'auto') + + print(f"正在加载 ASR 模型: {model_name}, 设备: {device}") + service = get_asr_service() + service._load_model() + ASR_MODEL_LOADED = True + + return jsonify({ + 'message': 'ASR 模型加载成功', + 'loaded': True, + 'model': model_name, + 'device': device + }) + except Exception as e: + return jsonify({'error': str(e)}), 500 + + +@app.route('/api/diar/load', methods=['POST']) +def load_diar_model(): + """加载 3D-Speaker 模型""" + global DIAR_MODEL_LOADED + with DIAR_MODEL_LOCK: + if DIAR_MODEL_LOADED: + return jsonify({'message': '3D-Speaker 模型已加载', 'loaded': True}) + + try: + data = request.json or {} + embedding_model = data.get('embedding_model', 'eres2netv2') + device = data.get('device', 'auto') + cluster_threshold = data.get('cluster_threshold', 0.5) + min_cluster_size = data.get('min_cluster_size', 10) + + print(f"正在加载 3D-Speaker 模型: {embedding_model}, 设备: {device}") + service = get_diar_service() + service._load_model() + DIAR_MODEL_LOADED = True + + return jsonify({ + 'message': '3D-Speaker 模型加载成功', + 'loaded': True, + 'embedding_model': embedding_model, + 'device': device + }) + except Exception as e: + return jsonify({'error': str(e)}), 500 + + +@app.route('/api/asr/unload', methods=['POST']) +def unload_asr_model(): + """卸载 ASR 模型""" + global GLOBAL_ASR_SERVICE, ASR_MODEL_LOADED + + try: + GLOBAL_ASR_SERVICE = None + ASR_MODEL_LOADED = False + gc.collect() + + return jsonify({'message': 'ASR 模型已卸载'}) + except Exception as e: + return jsonify({'error': str(e)}), 500 + + +@app.route('/api/diar/unload', methods=['POST']) +def unload_diar_model(): + """卸载 3D-Speaker 模型""" + global GLOBAL_DIAR_SERVICE, DIAR_MODEL_LOADED + + try: + GLOBAL_DIAR_SERVICE = None + DIAR_MODEL_LOADED = False + gc.collect() + + return jsonify({'message': '3D-Speaker 模型已卸载'}) + except Exception as e: + return jsonify({'error': str(e)}), 500 + + +@app.route('/api/recognize/single', methods=['POST']) +def recognize_single(): + """单文件推理""" + try: + if 'file' not in request.files: + return jsonify({'error': '请上传音频文件'}), 400 + + file = request.files['file'] + if file.filename == '': + return jsonify({'error': '文件名不能为空'}), 400 + + data = request.form or {} + use_3d_speaker = data.get('use_3d_speaker', 'false').lower() == 'true' + embedding_model = data.get('embedding_model', 'eres2netv2') + cluster_threshold = float(data.get('cluster_threshold', 0.5)) + min_cluster_size = int(data.get('min_cluster_size', 10)) + output_format = data.get('format', 'json') + + filename = secure_filename(file.filename) + task_id = str(uuid.uuid4())[:8] + audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{task_id}_{filename}") + file.save(audio_path) + + try: + global ASR_MODEL_LOADED + service = get_asr_service() + if not ASR_MODEL_LOADED: + service._load_model() + ASR_MODEL_LOADED = True + + sentences = service.recognize( + audio_path, + use_3d_speaker=use_3d_speaker, + embedding_model=embedding_model, + cluster_threshold=cluster_threshold, + min_cluster_size=min_cluster_size + ) + + result = { + 'file': filename, + 'total_sentences': len(sentences), + 'sentences': [s.to_dict() for s in sentences] + } + + if output_format == 'json': + result_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_result.json") + with open(result_path, 'w', encoding='utf-8') as f: + import json + json.dump(result, f, ensure_ascii=False, indent=2) + return jsonify(result) + else: + srt_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_result.srt") + service.export_to_srt(sentences, srt_path) + return send_file(srt_path, as_attachment=True, download_name=f"{task_id}_result.srt") + + finally: + if os.path.exists(audio_path): + os.remove(audio_path) + + except Exception as e: + import traceback + traceback.print_exc() + return jsonify({'error': str(e)}), 500 + + +@app.route('/api/recognize/batch', methods=['POST']) +def recognize_batch(): + """批量推理""" + try: + if 'files' not in request.files: + return jsonify({'error': '请上传音频文件'}), 400 + + files = request.files.getlist('files') + if not files or len(files) == 0: + return jsonify({'error': '文件列表为空'}), 400 + + data = request.form or {} + use_3d_speaker = data.get('use_3d_speaker', 'false').lower() == 'true' + embedding_model = data.get('embedding_model', 'eres2netv2') + cluster_threshold = float(data.get('cluster_threshold', 0.5)) + min_cluster_size = int(data.get('min_cluster_size', 10)) + + task_id = str(uuid.uuid4())[:8] + audio_paths = [] + + for f in files: + if f.filename: + filename = secure_filename(f.filename) + audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{task_id}_{filename}") + f.save(audio_path) + audio_paths.append(audio_path) + + try: + global ASR_MODEL_LOADED + service = get_asr_service() + if not ASR_MODEL_LOADED: + service._load_model() + ASR_MODEL_LOADED = True + + results = [] + for audio_path in audio_paths: + try: + sentences = service.recognize( + audio_path, + use_3d_speaker=use_3d_speaker, + embedding_model=embedding_model, + cluster_threshold=cluster_threshold, + min_cluster_size=min_cluster_size + ) + results.append({ + 'file': os.path.basename(audio_path), + 'total_sentences': len(sentences), + 'sentences': [s.to_dict() for s in sentences] + }) + except Exception as e: + results.append({ + 'file': os.path.basename(audio_path), + 'error': str(e) + }) + + result_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_batch_result.json") + import json + with open(result_path, 'w', encoding='utf-8') as f: + json.dump({ + 'task_id': task_id, + 'total_files': len(results), + 'results': results + }, f, ensure_ascii=False, indent=2) + + return jsonify({ + 'task_id': task_id, + 'total_files': len(results), + 'results': results, + 'result_file': result_path + }) + + finally: + for audio_path in audio_paths: + if os.path.exists(audio_path): + os.remove(audio_path) + + except Exception as e: + import traceback + traceback.print_exc() + return jsonify({'error': str(e)}), 500 + + +@app.route('/api/status', methods=['GET']) +def get_status(): + """获取模型状态""" + return jsonify({ + 'asr_loaded': ASR_MODEL_LOADED, + 'diar_loaded': DIAR_MODEL_LOADED, + 'asr_model': 'paraformer-zh', + 'diar_model': '3D-Speaker' + }) + + +if __name__ == '__main__': + print("=" * 60) + print(" ASR & Speaker Diarization API Server") + print("=" * 60) + print("\nAPI 接口:") + print(" GET /health - 健康检查") + print(" GET /api/status - 获取模型状态") + print(" POST /api/asr/load - 加载 ASR 模型") + print(" POST /api/diar/load - 加载 3D-Speaker 模型") + print(" POST /api/asr/unload - 卸载 ASR 模型") + print(" POST /api/diar/unload - 卸载 3D-Speaker 模型") + print(" POST /api/recognize/single - 单文件推理") + print(" POST /api/recognize/batch - 批量推理") + print("\n" + "=" * 60) + print("启动服务: http://localhost:5000") + print("=" * 60) + app.run(host='0.0.0.0', port=5000, debug=False) diff --git a/test_model_load.py b/test_model_load.py new file mode 100644 index 0000000..bb13fb9 --- /dev/null +++ b/test_model_load.py @@ -0,0 +1,95 @@ +""" +测试模型加载(不使用多进程) +用于诊断是否是模型本身的问题 +""" + +import sys +import torch + +print("=" * 60) +print("模型加载测试(单进程模式)") +print("=" * 60) +print(f"Python 版本:{sys.version}") +print(f"PyTorch 版本:{torch.__version__}") +print(f"CUDA 可用:{torch.cuda.is_available()}") +if torch.cuda.is_available(): + print(f"GPU: {torch.cuda.get_device_name(0)}") + print(f"CUDA 版本:{torch.version.cuda}") +print() + +try: + # 测试 ASR 模型 + print("=" * 60) + print("测试 1: 加载 ASR 模型 (Paraformer)") + print("=" * 60) + from asr_service import ASRService + + asr_service = ASRService(model_name="paraformer-zh", device="auto") + print("✓ ASRService 初始化完成") + + asr_service._load_model() + print("✓ ASR 模型加载完成") + print(f" 模型类型:{type(asr_service._model)}") + print() + + # 测试说话人分离模型 + print("=" * 60) + print("测试 2: 加载说话人分离模型 (3D-Speaker)") + print("=" * 60) + from diarization_service import DiarizationService + + diar_service = DiarizationService( + embedding_model="eres2netv2", + device="auto", + cluster_threshold=0.5, + min_cluster_size=10 + ) + print("✓ DiarizationService 初始化完成") + + diar_service._load_model() + print("✓ 说话人分离模型加载完成") + print(f" 模型类型:{type(diar_service.model)}") + print() + + # 测试音频处理 + print("=" * 60) + print("测试 3: 测试音频处理") + print("=" * 60) + + # 检查是否有测试音频 + from pathlib import Path + test_audio = Path("test.wav") + if test_audio.exists(): + print(f"找到测试音频:{test_audio}") + + print("执行 ASR 识别...") + sentences = asr_service.recognize(str(test_audio)) + print(f"✓ ASR 识别完成,共 {len(sentences)} 句") + + if sentences: + print(f" 第一句:{sentences[0]}") + + print("执行说话人分离...") + segments = diar_service.diarize(str(test_audio)) + print(f"✓ 说话人分离完成,共 {len(segments)} 个片段") + + if segments: + print(f" 第一个片段:{segments[0]}") + else: + print("⚠️ 未找到测试音频 (test.wav),跳过处理测试") + + print() + print("=" * 60) + print("✓ 所有测试通过!模型工作正常") + print("=" * 60) + +except Exception as e: + print() + print("=" * 60) + print("✗ 测试失败!") + print("=" * 60) + print(f"错误:{e}") + print() + import traceback + traceback.print_exc() + sys.exit(1) diff --git a/test_staged.py b/test_staged.py new file mode 100644 index 0000000..876d790 --- /dev/null +++ b/test_staged.py @@ -0,0 +1,69 @@ +""" +测试分阶段处理逻辑 +""" + +from pathlib import Path +import sys + +# 导入 main.py 中的函数 +from main import ( + VIDEO_DIR, + OUTPUT_DIR, + TEMP_DIR, + process_batch_diarization, + process_batch_asr, + get_video_list +) + +def test_staged_processing(): + """测试分阶段处理""" + + print("=" * 60) + print("分阶段处理测试") + print("=" * 60) + + # 获取视频列表(只取前 2 个进行测试) + video_paths = get_video_list(VIDEO_DIR) + if not video_paths: + print("✗ 未找到视频文件") + return + + # 只测试前 2 个 + test_videos = video_paths[:2] + print(f"测试视频:{len(test_videos)}") + for v in test_videos: + print(f" - {v.name}") + print() + + # 阶段 1: 说话人分离 + print("=" * 60) + print("阶段 1: 说话人分离") + print("=" * 60) + diar_results = process_batch_diarization(test_videos, max_workers=1) + + print(f"\n阶段 1 结果:{len(diar_results)}/{len(test_videos)} 成功") + for video, result_path in diar_results.items(): + status = "✓" if result_path else "✗" + print(f" {status} {video.name}: {result_path}") + print() + + # 阶段 2: ASR + 合并 + print("=" * 60) + print("阶段 2: ASR + 合并") + print("=" * 60) + results = process_batch_asr(test_videos, diar_results, max_workers=1) + + print(f"\n阶段 2 结果:{len(results)}/{len(test_videos)} 完成") + for result in results: + status = "✓" if result.get("success") else "✗" + print(f" {status} {Path(result['video']).name}") + if result.get("speaker_counts"): + print(f" 说话人:{result['speaker_counts']}") + print() + + print("=" * 60) + print("✓ 测试完成!") + print("=" * 60) + +if __name__ == "__main__": + test_staged_processing()