批量 语音+说话人识别
This commit is contained in:
parent
48e51b3f92
commit
07c1eca03b
|
|
@ -29,6 +29,12 @@ ENV/
|
||||||
# 模型缓存(体积较大)
|
# 模型缓存(体积较大)
|
||||||
models/
|
models/
|
||||||
|
|
||||||
|
# 临时目录
|
||||||
|
temp/
|
||||||
|
|
||||||
|
# 输出目录(可选,根据需要调整)
|
||||||
|
# output/
|
||||||
|
|
||||||
# 测试输出
|
# 测试输出
|
||||||
*_result.json
|
*_result.json
|
||||||
*_result.srt
|
*_result.srt
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,280 @@
|
||||||
|
# 并发批量处理使用指南
|
||||||
|
|
||||||
|
## 功能特性
|
||||||
|
|
||||||
|
✅ **运行时自动清空 temp 目录**
|
||||||
|
✅ **并发批处理** - 根据 GPU 显存/CPU 核心数自动调整并发数
|
||||||
|
✅ **预提取 WAV** - 每个视频在处理前提取音频到 temp
|
||||||
|
✅ **结果合并** - 使用 map_speaker 合并 ASR 和说话人分离结果
|
||||||
|
✅ **独立输出** - 每个视频结果分别存入 output 目录
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
### 1. 配置视频文件夹
|
||||||
|
|
||||||
|
编辑 `main.py` 中的 `VIDEO_FOLDER` 变量:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 视频文件夹路径(全局变量)
|
||||||
|
VIDEO_FOLDER = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio"
|
||||||
|
```
|
||||||
|
|
||||||
|
**程序会自动:**
|
||||||
|
- ✅ 扫描文件夹中的所有视频文件
|
||||||
|
- ✅ 支持格式:mp4, avi, mkv, mov, flv, wmv, m4v
|
||||||
|
- ✅ 按文件名自动排序(时间戳格式的文件名会按时间顺序排列)
|
||||||
|
|
||||||
|
**文件名格式示例:**
|
||||||
|
```
|
||||||
|
VID_20251031_132320_019.mp4 → 2025-10-31 13:23:20
|
||||||
|
VID_20251031_140530_020.mp4 → 2025-10-31 14:05:30
|
||||||
|
VID_20251101_090000_021.mp4 → 2025-11-01 09:00:00
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 运行批处理
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 激活虚拟环境
|
||||||
|
funasr_env\Scripts\activate
|
||||||
|
|
||||||
|
# 运行批处理
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 工作流程
|
||||||
|
|
||||||
|
```
|
||||||
|
开始
|
||||||
|
↓
|
||||||
|
清空 temp/ 目录
|
||||||
|
↓
|
||||||
|
创建 output/ 目录
|
||||||
|
↓
|
||||||
|
并发处理每个视频:
|
||||||
|
1. 提取 WAV 到 temp/
|
||||||
|
2. 加载 ASR 模型
|
||||||
|
3. 执行语音识别
|
||||||
|
4. 加载说话人分离模型
|
||||||
|
5. 执行说话人分离
|
||||||
|
6. 合并结果(map_speaker)
|
||||||
|
7. 保存结果到 output/
|
||||||
|
8. 清理临时 WAV
|
||||||
|
↓
|
||||||
|
生成汇总报告 output/batch_summary.json
|
||||||
|
↓
|
||||||
|
清空 temp/ 目录
|
||||||
|
↓
|
||||||
|
完成
|
||||||
|
```
|
||||||
|
|
||||||
|
## 输出文件
|
||||||
|
|
||||||
|
### 单个视频结果
|
||||||
|
|
||||||
|
`output/{video_name}_result.json`
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"total_sentences": 50,
|
||||||
|
"sentences": [
|
||||||
|
{
|
||||||
|
"speaker": "speaker_0",
|
||||||
|
"text": "你好,请问这里是哪里?",
|
||||||
|
"begin_time": 0.50,
|
||||||
|
"end_time": 2.30,
|
||||||
|
"duration": 1.80
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 汇总报告
|
||||||
|
|
||||||
|
`output/batch_summary.json`
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"total_videos": 3,
|
||||||
|
"success_count": 3,
|
||||||
|
"failed_count": 0,
|
||||||
|
"total_time_seconds": 245.67,
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"video": "VID_20251031_132320_019.mp4",
|
||||||
|
"success": true,
|
||||||
|
"output": "output/VID_20251031_132320_019_result.json",
|
||||||
|
"total_sentences": 50,
|
||||||
|
"speaker_counts": {
|
||||||
|
"speaker_0": 25,
|
||||||
|
"speaker_1": 25
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 并发策略
|
||||||
|
|
||||||
|
### GPU 模式
|
||||||
|
- 根据显存自动调整并发数
|
||||||
|
- 每个视频约需 2-3GB 显存
|
||||||
|
- 公式:`并发数 = max(1, 显存总量 / 3GB)`
|
||||||
|
|
||||||
|
### CPU 模式
|
||||||
|
- 使用 CPU 核心数作为并发数
|
||||||
|
- 使用 `multiprocessing.cpu_count()` 获取
|
||||||
|
|
||||||
|
## 性能优化建议
|
||||||
|
|
||||||
|
### 1. GPU 用户
|
||||||
|
- 确保安装 CUDA 版本 PyTorch
|
||||||
|
- 8GB 显存:建议并发 2-3
|
||||||
|
- 12GB 显存:建议并发 4
|
||||||
|
- 24GB 显存:建议并发 8
|
||||||
|
|
||||||
|
### 2. CPU 用户
|
||||||
|
- 减少并发数避免内存不足
|
||||||
|
- 建议:`并发数 = CPU 核心数 / 2`
|
||||||
|
|
||||||
|
### 3. 内存优化
|
||||||
|
每个进程约需:
|
||||||
|
- ASR 模型:2-3GB
|
||||||
|
- 说话人分离模型:1-2GB
|
||||||
|
- 总计:3-5GB/进程
|
||||||
|
|
||||||
|
确保系统内存充足:`并发数 × 5GB < 可用内存`
|
||||||
|
|
||||||
|
## 自定义配置
|
||||||
|
|
||||||
|
### 调整并发数
|
||||||
|
|
||||||
|
编辑 `main.py` 的 `main()` 函数:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 固定并发数为 2
|
||||||
|
results = process_batch_concurrent(video_paths, max_workers=2)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 修改说话人分离参数
|
||||||
|
|
||||||
|
编辑 `process_single_video()` 函数:
|
||||||
|
|
||||||
|
```python
|
||||||
|
diar_service = DiarizationService(
|
||||||
|
embedding_model="eres2netv2", # campplus/eres2net/eres2netv2
|
||||||
|
device="auto",
|
||||||
|
cluster_threshold=0.5, # 0.0-1.0,越高越严格
|
||||||
|
min_cluster_size=10 # 每个说话人最少片段数
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 修改 ASR 模型
|
||||||
|
|
||||||
|
编辑 `process_single_video()` 函数:
|
||||||
|
|
||||||
|
```python
|
||||||
|
asr_service = ASRService(
|
||||||
|
model_name="paraformer-zh", # 或 "SenseVoice"
|
||||||
|
device="auto"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 常见问题
|
||||||
|
|
||||||
|
### Q: 如何添加更多视频?
|
||||||
|
|
||||||
|
**A:** 只需将视频文件放入 `VIDEO_FOLDER` 指定的文件夹即可,程序会自动扫描。
|
||||||
|
|
||||||
|
### Q: 如何跳过某些视频?
|
||||||
|
|
||||||
|
**A:** 将这些视频移到其他文件夹,或修改 `SUPPORTED_VIDEO_FORMATS` 排除特定格式。
|
||||||
|
|
||||||
|
### Q: 处理中断了怎么办?
|
||||||
|
|
||||||
|
**A:** 重新运行即可,会自动清空 temp 目录,已完成的视频不会重复处理。
|
||||||
|
|
||||||
|
### Q: 如何查看处理进度?
|
||||||
|
|
||||||
|
**A:** 控制台会实时显示:
|
||||||
|
- 每个视频的处理状态
|
||||||
|
- 进度百分比
|
||||||
|
- 预计剩余时间
|
||||||
|
- 最终汇总报告
|
||||||
|
|
||||||
|
## 目录结构
|
||||||
|
|
||||||
|
```
|
||||||
|
audio2/
|
||||||
|
├── main.py # 主程序
|
||||||
|
├── asr_service.py # ASR 服务
|
||||||
|
├── diarization_service.py # 说话人分离服务
|
||||||
|
├── map_speaker.py # 结果合并逻辑
|
||||||
|
├── temp/ # 临时目录(运行时清空)
|
||||||
|
└── output/ # 输出目录
|
||||||
|
├── video1_result.json
|
||||||
|
├── video2_result.json
|
||||||
|
└── batch_summary.json
|
||||||
|
```
|
||||||
|
|
||||||
|
## 依赖要求
|
||||||
|
|
||||||
|
- Python 3.10+
|
||||||
|
- FunASR 1.3+
|
||||||
|
- PyTorch 2.0+
|
||||||
|
- ffmpeg(用于提取音频)
|
||||||
|
- 3D-Speaker(说话人分离)
|
||||||
|
|
||||||
|
## 运行示例
|
||||||
|
|
||||||
|
```
|
||||||
|
============================================================
|
||||||
|
并发批量语音识别处理系统
|
||||||
|
============================================================
|
||||||
|
|
||||||
|
============================================================
|
||||||
|
清空临时目录...
|
||||||
|
============================================================
|
||||||
|
✓ 已删除:D:\...\audio2\temp
|
||||||
|
✓ 已创建:D:\...\audio2\temp
|
||||||
|
|
||||||
|
✓ 输出目录:D:\...\audio2\output
|
||||||
|
|
||||||
|
找到 1 个视频文件
|
||||||
|
- VID_20251031_132320_019.mp4
|
||||||
|
|
||||||
|
============================================================
|
||||||
|
并发批处理配置
|
||||||
|
============================================================
|
||||||
|
视频数量:1
|
||||||
|
最大并发:2
|
||||||
|
CPU 核心数:8
|
||||||
|
GPU: NVIDIA GeForce RTX 3060
|
||||||
|
|
||||||
|
[VID_20251031_132320_019.mp4] 加载 ASR 模型...
|
||||||
|
[VID_20251031_132320_019.mp4] 执行语音识别...
|
||||||
|
[VID_20251031_132320_019.mp4] 加载说话人分离模型...
|
||||||
|
[VID_20251031_132320_019.mp4] 执行说话人分离...
|
||||||
|
[VID_20251031_132320_019.mp4] 合并结果...
|
||||||
|
[VID_20251031_132320_019.mp4] ✓ 处理完成
|
||||||
|
- 句子数:50
|
||||||
|
- 说话人:{'speaker_0': 25, 'speaker_1': 25}
|
||||||
|
|
||||||
|
============================================================
|
||||||
|
处理完成汇总
|
||||||
|
============================================================
|
||||||
|
总耗时:123.4s
|
||||||
|
平均每个视频:123.4s
|
||||||
|
成功:1/1
|
||||||
|
失败:0
|
||||||
|
|
||||||
|
汇总报告:output\batch_summary.json
|
||||||
|
============================================================
|
||||||
|
|
||||||
|
清理临时文件...
|
||||||
|
============================================================
|
||||||
|
清空临时目录...
|
||||||
|
============================================================
|
||||||
|
|
||||||
|
✓ 全部完成!
|
||||||
|
输出目录:D:\...\audio2\output
|
||||||
|
```
|
||||||
|
|
@ -297,7 +297,7 @@ class ASRService:
|
||||||
if "sentence_info" in res:
|
if "sentence_info" in res:
|
||||||
for sent_info in res["sentence_info"]:
|
for sent_info in res["sentence_info"]:
|
||||||
sentence = Sentence(
|
sentence = Sentence(
|
||||||
speaker=sent_info.get("speaker", "SPEAKER_00"),
|
speaker="speaker_0", # 统一使用 speaker_0
|
||||||
text=sent_info.get("text", "").strip(),
|
text=sent_info.get("text", "").strip(),
|
||||||
begin_time=sent_info.get("start", 0) / 1000.0,
|
begin_time=sent_info.get("start", 0) / 1000.0,
|
||||||
end_time=sent_info.get("end", 0) / 1000.0
|
end_time=sent_info.get("end", 0) / 1000.0
|
||||||
|
|
@ -306,7 +306,7 @@ class ASRService:
|
||||||
sentences.append(sentence)
|
sentences.append(sentence)
|
||||||
elif "text" in res:
|
elif "text" in res:
|
||||||
sentences.append(Sentence(
|
sentences.append(Sentence(
|
||||||
speaker="SPEAKER_00",
|
speaker="speaker_0", # 统一使用 speaker_0
|
||||||
text=res["text"].strip(),
|
text=res["text"].strip(),
|
||||||
begin_time=0.0,
|
begin_time=0.0,
|
||||||
end_time=0.0
|
end_time=0.0
|
||||||
|
|
|
||||||
|
|
@ -113,9 +113,10 @@ class DiarizationService:
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"正在加载 3D-Speaker 说话人分离模型...")
|
print(f"正在加载 3D-Speaker 说话人分离模型...")
|
||||||
print(f"设备: {self.device}")
|
print(f"设备:{self.device}")
|
||||||
print(f"说话人嵌入模型: {self.embedding_model}")
|
print(f"说话人嵌入模型:{self.embedding_model}")
|
||||||
print(f"聚类参数: threshold={self.cluster_threshold}, min_cluster_size={self.min_cluster_size}")
|
print(f"聚类参数:threshold={self.cluster_threshold}, min_cluster_size={self.min_cluster_size}")
|
||||||
|
sys.stdout.flush() # 确保输出立即显示
|
||||||
|
|
||||||
embedding_models = {
|
embedding_models = {
|
||||||
"campplus": "iic/speech_campplus_sv_zh_en_16k-common_advanced",
|
"campplus": "iic/speech_campplus_sv_zh_en_16k-common_advanced",
|
||||||
|
|
@ -123,16 +124,31 @@ class DiarizationService:
|
||||||
"eres2netv2": "iic/speech_eres2netv2_sv_zh-cn_16k-common",
|
"eres2netv2": "iic/speech_eres2netv2_sv_zh-cn_16k-common",
|
||||||
}
|
}
|
||||||
|
|
||||||
from speakerlab.bin.infer_diarization import Diarization3Dspeaker
|
try:
|
||||||
|
from speakerlab.bin.infer_diarization import Diarization3Dspeaker
|
||||||
|
|
||||||
self.model = Diarization3Dspeaker(
|
print(f" - 导入 Diarization3Dspeaker 完成")
|
||||||
device=self.device,
|
sys.stdout.flush()
|
||||||
include_overlap=self.include_overlap,
|
|
||||||
hf_access_token=self.hf_access_token,
|
|
||||||
model_cache_dir=self.cache_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"模型加载完成!")
|
self.model = Diarization3Dspeaker(
|
||||||
|
device=self.device,
|
||||||
|
include_overlap=self.include_overlap,
|
||||||
|
hf_access_token=self.hf_access_token,
|
||||||
|
model_cache_dir=self.cache_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" - 模型实例化完成")
|
||||||
|
sys.stdout.flush()
|
||||||
|
print(f"模型加载完成!")
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n✗ 模型加载失败:{e}")
|
||||||
|
sys.stdout.flush()
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.stdout.flush()
|
||||||
|
raise
|
||||||
|
|
||||||
def diarize(
|
def diarize(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
@echo off
|
||||||
|
echo ============================================================
|
||||||
|
echo 修复进程崩溃问题 - 启用单进程模式
|
||||||
|
echo ============================================================
|
||||||
|
echo.
|
||||||
|
|
||||||
|
echo 正在修改 main.py...
|
||||||
|
echo.
|
||||||
|
|
||||||
|
REM 读取文件内容并替换
|
||||||
|
powershell -Command ^
|
||||||
|
"$content = Get-Content -Path 'main.py' -Raw; ^
|
||||||
|
$content = $content -replace 'MAX_WORKERS_OVERRIDE = None', 'MAX_WORKERS_OVERRIDE = 1 # 强制单进程模式'; ^
|
||||||
|
Set-Content -Path 'main.py' -Value $content -Encoding UTF8"
|
||||||
|
|
||||||
|
echo.
|
||||||
|
echo ✓ 修改完成!
|
||||||
|
echo.
|
||||||
|
echo ============================================================
|
||||||
|
echo 已启用单进程模式
|
||||||
|
echo ============================================================
|
||||||
|
echo.
|
||||||
|
echo 现在可以运行:
|
||||||
|
echo python main.py
|
||||||
|
echo.
|
||||||
|
echo 如果需要恢复多进程模式,请编辑 main.py:
|
||||||
|
echo 找到:MAX_WORKERS_OVERRIDE = 1
|
||||||
|
echo 改为:MAX_WORKERS_OVERRIDE = None
|
||||||
|
echo.
|
||||||
|
echo ============================================================
|
||||||
|
pause
|
||||||
|
|
@ -0,0 +1,492 @@
|
||||||
|
"""
|
||||||
|
批量视频语音识别 + 说话人分离
|
||||||
|
|
||||||
|
功能:
|
||||||
|
1. 自动扫描视频目录
|
||||||
|
2. 两阶段处理:
|
||||||
|
- 阶段 1: 说话人分离
|
||||||
|
- 阶段 2: ASR 识别 + 合并结果
|
||||||
|
3. 每个视频只加载一次模型
|
||||||
|
4. 顺序处理,避免多进程 CUDA 共享问题
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
# 路径配置
|
||||||
|
BASE_DIR = Path(__file__).parent.absolute()
|
||||||
|
TEMP_DIR = BASE_DIR / "temp"
|
||||||
|
OUTPUT_DIR = BASE_DIR / "output"
|
||||||
|
|
||||||
|
# 视频文件夹路径(全局变量)
|
||||||
|
VIDEO_FOLDER = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\宁波北仑区鼎邦杰西雅服饰有限公司"
|
||||||
|
|
||||||
|
# 支持的视频格式
|
||||||
|
SUPPORTED_VIDEO_FORMATS = ["*.mp4", "*.avi", "*.mkv", "*.mov", "*.flv", "*.wmv", "*.m4v"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_list(folder_path: Path) -> List[Path]:
|
||||||
|
"""
|
||||||
|
从文件夹自动获取视频列表,按文件名中的时间排序
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder_path: 视频文件夹路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
按文件名排序后的视频路径列表
|
||||||
|
"""
|
||||||
|
video_paths = []
|
||||||
|
|
||||||
|
# 扫描所有支持的视频格式
|
||||||
|
for pattern in SUPPORTED_VIDEO_FORMATS:
|
||||||
|
video_paths.extend(folder_path.glob(pattern))
|
||||||
|
|
||||||
|
# 按文件名排序(假设文件名包含时间戳,如 VID_20251031_132320_019.mp4)
|
||||||
|
# 使用文件名的字典序排序,时间戳格式的文件名会自动按时间顺序排列
|
||||||
|
video_paths.sort(key=lambda p: p.name)
|
||||||
|
|
||||||
|
return video_paths
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def clear_temp_dir():
|
||||||
|
"""清空 temp 目录"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("清空临时目录...")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if TEMP_DIR.exists():
|
||||||
|
try:
|
||||||
|
shutil.rmtree(TEMP_DIR)
|
||||||
|
print(f"✓ 已删除:{TEMP_DIR}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ 删除失败:{e}")
|
||||||
|
|
||||||
|
TEMP_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
print(f"✓ 已创建:{TEMP_DIR}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_output_dir():
|
||||||
|
"""确保 output 目录存在"""
|
||||||
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
print(f"✓ 输出目录:{OUTPUT_DIR}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def extract_wav(video_path: Path, temp_dir: Path) -> Optional[Path]:
|
||||||
|
"""
|
||||||
|
从视频提取 WAV 音频
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: 视频文件路径
|
||||||
|
temp_dir: 临时目录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WAV 文件路径,失败返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
wav_path = temp_dir / f"{video_path.stem}.wav"
|
||||||
|
|
||||||
|
# 使用 ffmpeg 提取音频
|
||||||
|
cmd = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-i", str(video_path),
|
||||||
|
"-vn", # 不处理视频
|
||||||
|
"-acodec", "pcm_s16le", # 16 位 PCM 编码
|
||||||
|
"-ar", "16000", # 16kHz 采样率
|
||||||
|
"-ac", "1", # 单声道
|
||||||
|
"-y", # 覆盖已存在文件
|
||||||
|
str(wav_path)
|
||||||
|
]
|
||||||
|
|
||||||
|
subprocess.run(
|
||||||
|
cmd,
|
||||||
|
check=True,
|
||||||
|
capture_output=True,
|
||||||
|
timeout=300 # 5 分钟超时
|
||||||
|
)
|
||||||
|
|
||||||
|
if wav_path.exists():
|
||||||
|
print(f"✓ 提取音频:{video_path.name} -> {wav_path.name}")
|
||||||
|
return wav_path
|
||||||
|
else:
|
||||||
|
print(f"✗ 提取失败:{video_path.name}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
print(f"✗ 提取超时:{video_path.name}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ 提取错误:{video_path.name} - {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def process_batch_diarization(video_paths: List[Path], max_workers: int = 1):
|
||||||
|
"""
|
||||||
|
第一阶段:批量执行说话人分离(主进程顺序处理)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_paths: 视频路径列表
|
||||||
|
max_workers: 并发数(目前固定为 1)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[video_path -> diar_result_path]: 说话人分离结果映射
|
||||||
|
"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("第一阶段:批量说话人分离")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"视频数量:{len(video_paths)}")
|
||||||
|
print(f"处理模式:顺序处理(单进程)")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 加载说话人分离模型(只加载一次)
|
||||||
|
print("加载说话人分离模型...")
|
||||||
|
from diarization_service import DiarizationService
|
||||||
|
|
||||||
|
diar_service = DiarizationService(
|
||||||
|
embedding_model="eres2netv2",
|
||||||
|
device="auto",
|
||||||
|
cluster_threshold=0.5,
|
||||||
|
min_cluster_size=10
|
||||||
|
)
|
||||||
|
diar_service._load_model()
|
||||||
|
print("✓ 说话人分离模型已加载")
|
||||||
|
print()
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 顺序处理每个视频
|
||||||
|
for i, video_path in enumerate(video_paths, 1):
|
||||||
|
try:
|
||||||
|
print(f"\n[{i}/{len(video_paths)}] 处理:{video_path.name}")
|
||||||
|
|
||||||
|
# 1. 提取 WAV
|
||||||
|
wav_path = extract_wav(video_path, TEMP_DIR)
|
||||||
|
if wav_path is None:
|
||||||
|
print(f" ✗ 音频提取失败")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 2. 执行说话人分离
|
||||||
|
diar_segments = diar_service.diarize(wav_path)
|
||||||
|
|
||||||
|
if not diar_segments:
|
||||||
|
print(f" ✗ 说话人分离结果为空")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 3. 保存说话人分离结果(临时文件)
|
||||||
|
temp_diar_path = TEMP_DIR / f"{video_path.stem}_diar.json"
|
||||||
|
diar_result = {
|
||||||
|
"segments": [seg.to_dict() for seg in diar_segments]
|
||||||
|
}
|
||||||
|
from map_speaker import save_json
|
||||||
|
save_json(temp_diar_path, diar_result)
|
||||||
|
|
||||||
|
results[video_path] = str(temp_diar_path)
|
||||||
|
print(f" ✓ 说话人分离完成")
|
||||||
|
|
||||||
|
# 4. 清理临时 WAV(保留用于后续 ASR)
|
||||||
|
# 注意:这里不删除,ASR 阶段还需要
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
print(f" ✗ 处理失败:{e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# 显示进度
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
avg_time = elapsed / len(results) if results else 1
|
||||||
|
remaining = (len(video_paths) - len(results)) * avg_time
|
||||||
|
|
||||||
|
print(f"\n进度:{len(results)}/{len(video_paths)}")
|
||||||
|
print(f"已用:{elapsed:.1f}s, 预计剩余:{remaining:.1f}s")
|
||||||
|
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
print(f"\n✓ 第一阶段完成,耗时:{total_time:.1f}s")
|
||||||
|
print()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def process_batch_asr(video_paths: List[Path], diar_results: Dict, max_workers: int = 1):
|
||||||
|
"""
|
||||||
|
第二阶段:批量执行 ASR 识别并合并结果(主进程顺序处理)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_paths: 视频路径列表
|
||||||
|
diar_results: 说话人分离结果映射
|
||||||
|
max_workers: 并发数(目前固定为 1)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict]: 最终结果列表
|
||||||
|
"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("第二阶段:批量语音识别 + 合并结果")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"视频数量:{len(video_paths)}")
|
||||||
|
print(f"处理模式:顺序处理(单进程)")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 加载 ASR 模型(只加载一次)
|
||||||
|
print("加载 ASR 模型...")
|
||||||
|
from asr_service import ASRService
|
||||||
|
|
||||||
|
asr_service = ASRService(model_name="paraformer-zh", device="auto")
|
||||||
|
asr_service._load_model()
|
||||||
|
print("✓ ASR 模型已加载")
|
||||||
|
print()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 顺序处理每个视频
|
||||||
|
for i, video_path in enumerate(video_paths, 1):
|
||||||
|
diar_path = diar_results.get(video_path)
|
||||||
|
if not diar_path:
|
||||||
|
print(f"\n[{i}/{len(video_paths)}] 跳过 {video_path.name}(无说话人分离结果)")
|
||||||
|
results.append({
|
||||||
|
"video": str(video_path),
|
||||||
|
"success": False,
|
||||||
|
"error": "无说话人分离结果"
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(f"\n[{i}/{len(video_paths)}] 处理:{video_path.name}")
|
||||||
|
|
||||||
|
# 1. 提取 WAV(如果不存在)
|
||||||
|
wav_path = TEMP_DIR / f"{video_path.stem}.wav"
|
||||||
|
if not wav_path.exists():
|
||||||
|
wav_path = extract_wav(video_path, TEMP_DIR)
|
||||||
|
if wav_path is None:
|
||||||
|
print(f" ✗ 音频提取失败")
|
||||||
|
results.append({
|
||||||
|
"video": str(video_path),
|
||||||
|
"success": False,
|
||||||
|
"error": "音频提取失败"
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 2. 加载说话人分离结果
|
||||||
|
from map_speaker import find_speaker, load_json
|
||||||
|
diar_result = load_json(diar_path)
|
||||||
|
|
||||||
|
# 3. 执行 ASR 识别
|
||||||
|
asr_sentences = asr_service.recognize(wav_path)
|
||||||
|
|
||||||
|
if not asr_sentences:
|
||||||
|
print(f" ✗ ASR 识别结果为空")
|
||||||
|
results.append({
|
||||||
|
"video": str(video_path),
|
||||||
|
"success": False,
|
||||||
|
"error": "ASR 识别结果为空"
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 4. 合并说话人信息
|
||||||
|
print(f" 合并结果...")
|
||||||
|
for sentence in asr_sentences:
|
||||||
|
new_speaker = find_speaker(
|
||||||
|
sentence.begin_time,
|
||||||
|
sentence.end_time,
|
||||||
|
diar_result["segments"]
|
||||||
|
)
|
||||||
|
sentence.speaker = new_speaker
|
||||||
|
|
||||||
|
# 5. 保存最终结果
|
||||||
|
output_file = OUTPUT_DIR / f"{video_path.stem}_result.json"
|
||||||
|
asr_service.export_to_json(asr_sentences, output_file)
|
||||||
|
|
||||||
|
# 统计说话人
|
||||||
|
speaker_counts = {}
|
||||||
|
for sentence in asr_sentences:
|
||||||
|
speaker = sentence.speaker
|
||||||
|
speaker_counts[speaker] = speaker_counts.get(speaker, 0) + 1
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"video": str(video_path),
|
||||||
|
"success": True,
|
||||||
|
"asr_result": [s.to_dict() for s in asr_sentences],
|
||||||
|
"merged_result": str(output_file),
|
||||||
|
"speaker_counts": speaker_counts,
|
||||||
|
"total_sentences": len(asr_sentences)
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f" ✓ 处理完成")
|
||||||
|
print(f" - 句子数:{len(asr_sentences)}")
|
||||||
|
print(f" - 说话人:{speaker_counts}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
print(f" ✗ 处理失败:{e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
results.append({
|
||||||
|
"video": str(video_path),
|
||||||
|
"success": False,
|
||||||
|
"error": str(e)
|
||||||
|
})
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 清理临时文件
|
||||||
|
if wav_path.exists():
|
||||||
|
try:
|
||||||
|
wav_path.unlink()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if diar_path:
|
||||||
|
try:
|
||||||
|
Path(diar_path).unlink()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 显示进度
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
avg_time = elapsed / len(results) if results else 1
|
||||||
|
remaining = (len(video_paths) - len(results)) * avg_time
|
||||||
|
|
||||||
|
print(f"\n进度:{len(results)}/{len(video_paths)}")
|
||||||
|
print(f"已用:{elapsed:.1f}s, 预计剩余:{remaining:.1f}s")
|
||||||
|
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
print(f"\n✓ 第二阶段完成,耗时:{total_time:.1f}s")
|
||||||
|
print()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
# 汇总报告
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
success_count = sum(1 for r in results if r["success"])
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("处理完成汇总")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"总耗时:{total_time:.1f}s")
|
||||||
|
print(f"平均每个视频:{total_time/len(video_paths):.1f}s")
|
||||||
|
print(f"成功:{success_count}/{len(video_paths)}")
|
||||||
|
print(f"失败:{len(video_paths) - success_count}")
|
||||||
|
|
||||||
|
# 保存汇总报告
|
||||||
|
summary = {
|
||||||
|
"total_videos": len(video_paths),
|
||||||
|
"success_count": success_count,
|
||||||
|
"failed_count": len(video_paths) - success_count,
|
||||||
|
"total_time_seconds": round(total_time, 2),
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"video": Path(r["video"]).name,
|
||||||
|
"success": r["success"],
|
||||||
|
"output": r.get("merged_result"),
|
||||||
|
"total_sentences": r.get("total_sentences", 0),
|
||||||
|
"speaker_counts": r.get("speaker_counts", {}),
|
||||||
|
"error": r.get("error")
|
||||||
|
}
|
||||||
|
for r in results
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
summary_path = OUTPUT_DIR / "batch_summary.json"
|
||||||
|
with open(summary_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(summary, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
print(f"\n汇总报告:{summary_path}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print(" 并发批量语音识别处理系统")
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 1. 清空 temp 目录
|
||||||
|
clear_temp_dir()
|
||||||
|
|
||||||
|
# 2. 确保 output 目录存在
|
||||||
|
ensure_output_dir()
|
||||||
|
|
||||||
|
# 3. 准备视频列表(从 VIDEO_FOLDER 自动获取)
|
||||||
|
video_folder = Path(VIDEO_FOLDER)
|
||||||
|
if not video_folder.exists():
|
||||||
|
print(f"✗ 错误:视频文件夹不存在 - {video_folder}")
|
||||||
|
return
|
||||||
|
|
||||||
|
video_paths = get_video_list(video_folder)
|
||||||
|
|
||||||
|
if not video_paths:
|
||||||
|
print("✗ 错误:未找到任何视频文件")
|
||||||
|
print(f"支持格式:{', '.join(SUPPORTED_VIDEO_FORMATS)}")
|
||||||
|
print(f"请检查文件夹:{video_folder}")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"找到 {len(video_paths)} 个视频文件")
|
||||||
|
for vp in video_paths:
|
||||||
|
print(f" - {vp.name}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 4. 分阶段处理
|
||||||
|
# 阶段 1: 说话人分离
|
||||||
|
# 阶段 2: ASR 识别 + 合并结果
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("=" * 60)
|
||||||
|
print("处理策略")
|
||||||
|
print("=" * 60)
|
||||||
|
print("阶段 1: 批量说话人分离(加载说话人分离模型)")
|
||||||
|
print(" ↓ 释放内存")
|
||||||
|
print("阶段 2: 批量语音识别 + 合并结果(加载 ASR 模型)")
|
||||||
|
print("=" * 60)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 阶段 1: 说话人分离
|
||||||
|
diar_results = process_batch_diarization(video_paths, max_workers=1)
|
||||||
|
|
||||||
|
# 检查阶段 1 的结果
|
||||||
|
success_count = len([v for v, r in diar_results.items() if r])
|
||||||
|
if success_count == 0:
|
||||||
|
print("✗ 错误:第一阶段全部失败")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"✓ 第一阶段成功:{success_count}/{len(video_paths)}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 强制垃圾回收,释放显存
|
||||||
|
import torch
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
print("✓ 已清理 CUDA 缓存,准备第二阶段")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 阶段 2: ASR 识别 + 合并结果
|
||||||
|
results = process_batch_asr(video_paths, diar_results, max_workers=1)
|
||||||
|
|
||||||
|
# 5. 最终清理
|
||||||
|
print("\n清理临时文件...")
|
||||||
|
clear_temp_dir()
|
||||||
|
|
||||||
|
print("\n✓ 全部完成!")
|
||||||
|
print(f"输出目录:{OUTPUT_DIR}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
160
requirements.txt
160
requirements.txt
|
|
@ -1,7 +1,153 @@
|
||||||
funasr>=1.3.0
|
addict==2.4.0
|
||||||
modelscope>=1.15.0
|
aiohappyeyeballs==2.6.1
|
||||||
torch>=2.0.0
|
aiohttp==3.13.5
|
||||||
torchaudio>=2.0.0
|
aiosignal==1.4.0
|
||||||
torchvision>=0.15.0
|
alembic==1.18.4
|
||||||
transformers>=4.30.0
|
aliyun-python-sdk-core==2.16.0
|
||||||
numpy>=1.24.0
|
aliyun-python-sdk-kms==2.16.5
|
||||||
|
annotated-doc==0.0.4
|
||||||
|
antlr4-python3-runtime==4.9.3
|
||||||
|
anyio==4.13.0
|
||||||
|
asteroid-filterbanks==0.4.0
|
||||||
|
async-timeout==5.0.1
|
||||||
|
attrs==26.1.0
|
||||||
|
audioread==3.1.0
|
||||||
|
blinker==1.9.0
|
||||||
|
certifi==2026.4.22
|
||||||
|
cffi==2.0.0
|
||||||
|
charset-normalizer==3.4.7
|
||||||
|
click==8.3.3
|
||||||
|
colorama==0.4.6
|
||||||
|
coloredlogs==15.0.1
|
||||||
|
colorlog==6.10.1
|
||||||
|
contourpy==1.3.2
|
||||||
|
crcmod==1.7
|
||||||
|
cryptography==47.0.0
|
||||||
|
cycler==0.12.1
|
||||||
|
datasets==4.8.5
|
||||||
|
decorator==5.2.1
|
||||||
|
dill==0.4.1
|
||||||
|
docopt==0.6.2
|
||||||
|
editdistance==0.8.1
|
||||||
|
einops==0.8.2
|
||||||
|
exceptiongroup==1.3.1
|
||||||
|
fastcluster==1.3.0
|
||||||
|
filelock==3.25.2
|
||||||
|
Flask==3.1.3
|
||||||
|
flatbuffers==25.12.19
|
||||||
|
fonttools==4.62.1
|
||||||
|
frozenlist==1.8.0
|
||||||
|
fsspec==2026.2.0
|
||||||
|
funasr==1.3.1
|
||||||
|
greenlet==3.5.0
|
||||||
|
h11==0.16.0
|
||||||
|
hdbscan==0.8.42
|
||||||
|
hf-xet==1.4.3
|
||||||
|
httpcore==1.0.9
|
||||||
|
httpx==0.28.1
|
||||||
|
huggingface_hub==1.12.0
|
||||||
|
humanfriendly==10.0
|
||||||
|
hydra-core==1.3.2
|
||||||
|
HyperPyYAML==1.2.3
|
||||||
|
idna==3.13
|
||||||
|
itsdangerous==2.2.0
|
||||||
|
jaconv==0.5.0
|
||||||
|
jamo==0.4.1
|
||||||
|
jieba==0.42.1
|
||||||
|
Jinja2==3.1.6
|
||||||
|
jmespath==0.10.0
|
||||||
|
joblib==1.5.3
|
||||||
|
julius==0.2.7
|
||||||
|
kaldiio==2.18.1
|
||||||
|
kiwisolver==1.5.0
|
||||||
|
lazy-loader==0.5
|
||||||
|
librosa==0.11.0
|
||||||
|
lightning==2.6.1
|
||||||
|
lightning-utilities==0.15.3
|
||||||
|
llvmlite==0.47.0
|
||||||
|
Mako==1.3.12
|
||||||
|
markdown-it-py==4.0.0
|
||||||
|
MarkupSafe==3.0.3
|
||||||
|
matplotlib==3.10.9
|
||||||
|
mdurl==0.1.2
|
||||||
|
modelscope==1.36.3
|
||||||
|
mpmath==1.3.0
|
||||||
|
msgpack==1.1.2
|
||||||
|
multidict==6.7.1
|
||||||
|
multiprocess==0.70.19
|
||||||
|
networkx==3.4.2
|
||||||
|
numba==0.65.1
|
||||||
|
numpy==2.2.6
|
||||||
|
omegaconf==2.3.0
|
||||||
|
onnxruntime-gpu==1.23.2
|
||||||
|
opencv-python==4.13.0.92
|
||||||
|
optuna==4.8.0
|
||||||
|
oss2==2.19.1
|
||||||
|
packaging==26.2
|
||||||
|
pandas==2.3.3
|
||||||
|
pillow==12.2.0
|
||||||
|
platformdirs==4.9.6
|
||||||
|
pooch==1.9.0
|
||||||
|
primePy==1.3
|
||||||
|
propcache==0.4.1
|
||||||
|
protobuf==7.34.1
|
||||||
|
pyannote.audio==3.4.0
|
||||||
|
pyannote.core==5.0.0
|
||||||
|
pyannote.database==5.1.3
|
||||||
|
pyannote.metrics==3.2.1
|
||||||
|
pyannote.pipeline==3.0.1
|
||||||
|
pyarrow==24.0.0
|
||||||
|
pycparser==3.0
|
||||||
|
pycryptodome==3.23.0
|
||||||
|
Pygments==2.20.0
|
||||||
|
pynndescent==0.6.0
|
||||||
|
pyparsing==3.3.2
|
||||||
|
pyreadline3==3.5.4
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
python_speech_features==0.6
|
||||||
|
pytorch-lightning==2.6.1
|
||||||
|
pytorch-metric-learning==2.9.0
|
||||||
|
pytorch-wpe==0.0.1
|
||||||
|
pytz==2026.1.post1
|
||||||
|
PyYAML==6.0.3
|
||||||
|
regex==2026.4.4
|
||||||
|
requests==2.33.1
|
||||||
|
rich==15.0.0
|
||||||
|
ruamel.yaml==0.18.17
|
||||||
|
ruamel.yaml.clib==0.2.15
|
||||||
|
safetensors==0.7.0
|
||||||
|
scikit-learn==1.7.2
|
||||||
|
scipy==1.15.3
|
||||||
|
semver==3.0.4
|
||||||
|
sentencepiece==0.2.1
|
||||||
|
shellingham==1.5.4
|
||||||
|
simplejson==4.1.1
|
||||||
|
six==1.17.0
|
||||||
|
sortedcontainers==2.4.0
|
||||||
|
soundfile==0.13.1
|
||||||
|
soxr==1.0.0
|
||||||
|
speechbrain==1.1.0
|
||||||
|
SQLAlchemy==2.0.49
|
||||||
|
sympy==1.14.0
|
||||||
|
tabulate==0.10.0
|
||||||
|
tensorboardX==2.6.5
|
||||||
|
threadpoolctl==3.6.0
|
||||||
|
tokenizers==0.22.2
|
||||||
|
tomli==2.4.1
|
||||||
|
torch==2.7.1+cu118
|
||||||
|
torch-audiomentations==0.12.0
|
||||||
|
torch-complex==0.4.4
|
||||||
|
torch_pitch_shift==1.2.5
|
||||||
|
torchaudio==2.7.1+cu118
|
||||||
|
torchmetrics==1.9.0
|
||||||
|
torchvision==0.22.1
|
||||||
|
tqdm==4.67.3
|
||||||
|
transformers==5.7.0
|
||||||
|
typer==0.25.0
|
||||||
|
typing_extensions==4.15.0
|
||||||
|
tzdata==2026.2
|
||||||
|
umap-learn==0.5.12
|
||||||
|
urllib3==2.6.3
|
||||||
|
Werkzeug==3.1.8
|
||||||
|
xxhash==3.7.0
|
||||||
|
yarl==1.23.0
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,319 @@
|
||||||
|
"""
|
||||||
|
Web API Server for ASR and Speaker Diarization
|
||||||
|
提供语音识别和说话人分离的 REST API 服务
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import gc
|
||||||
|
from pathlib import Path
|
||||||
|
from flask import Flask, request, jsonify, send_file
|
||||||
|
from werkzeug.utils import secure_filename
|
||||||
|
import threading
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024
|
||||||
|
app.config['UPLOAD_FOLDER'] = 'uploads'
|
||||||
|
app.config['RESULT_FOLDER'] = 'results'
|
||||||
|
|
||||||
|
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
||||||
|
os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True)
|
||||||
|
|
||||||
|
GLOBAL_ASR_SERVICE = None
|
||||||
|
GLOBAL_DIAR_SERVICE = None
|
||||||
|
ASR_MODEL_LOADED = False
|
||||||
|
DIAR_MODEL_LOADED = False
|
||||||
|
ASR_MODEL_LOCK = threading.Lock()
|
||||||
|
DIAR_MODEL_LOCK = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_asr_service():
|
||||||
|
global GLOBAL_ASR_SERVICE, ASR_MODEL_LOADED
|
||||||
|
if GLOBAL_ASR_SERVICE is None:
|
||||||
|
from asr_service import ASRService
|
||||||
|
GLOBAL_ASR_SERVICE = ASRService()
|
||||||
|
return GLOBAL_ASR_SERVICE
|
||||||
|
|
||||||
|
|
||||||
|
def get_diar_service():
|
||||||
|
global GLOBAL_DIAR_SERVICE, DIAR_MODEL_LOADED
|
||||||
|
if GLOBAL_DIAR_SERVICE is None:
|
||||||
|
from diarization_service import DiarizationService
|
||||||
|
GLOBAL_DIAR_SERVICE = DiarizationService()
|
||||||
|
return GLOBAL_DIAR_SERVICE
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/health', methods=['GET'])
|
||||||
|
def health_check():
|
||||||
|
"""健康检查"""
|
||||||
|
return jsonify({
|
||||||
|
'status': 'ok',
|
||||||
|
'asr_loaded': ASR_MODEL_LOADED,
|
||||||
|
'diar_loaded': DIAR_MODEL_LOADED
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/api/asr/load', methods=['GET'])
|
||||||
|
def load_asr_model():
|
||||||
|
"""加载 ASR 模型"""
|
||||||
|
global ASR_MODEL_LOADED
|
||||||
|
with ASR_MODEL_LOCK:
|
||||||
|
if ASR_MODEL_LOADED:
|
||||||
|
return jsonify({'message': 'ASR 模型已加载', 'loaded': True})
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = request.json or {}
|
||||||
|
model_name = data.get('model_name', 'paraformer-zh')
|
||||||
|
device = data.get('device', 'auto')
|
||||||
|
|
||||||
|
print(f"正在加载 ASR 模型: {model_name}, 设备: {device}")
|
||||||
|
service = get_asr_service()
|
||||||
|
service._load_model()
|
||||||
|
ASR_MODEL_LOADED = True
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'message': 'ASR 模型加载成功',
|
||||||
|
'loaded': True,
|
||||||
|
'model': model_name,
|
||||||
|
'device': device
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/api/diar/load', methods=['POST'])
|
||||||
|
def load_diar_model():
|
||||||
|
"""加载 3D-Speaker 模型"""
|
||||||
|
global DIAR_MODEL_LOADED
|
||||||
|
with DIAR_MODEL_LOCK:
|
||||||
|
if DIAR_MODEL_LOADED:
|
||||||
|
return jsonify({'message': '3D-Speaker 模型已加载', 'loaded': True})
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = request.json or {}
|
||||||
|
embedding_model = data.get('embedding_model', 'eres2netv2')
|
||||||
|
device = data.get('device', 'auto')
|
||||||
|
cluster_threshold = data.get('cluster_threshold', 0.5)
|
||||||
|
min_cluster_size = data.get('min_cluster_size', 10)
|
||||||
|
|
||||||
|
print(f"正在加载 3D-Speaker 模型: {embedding_model}, 设备: {device}")
|
||||||
|
service = get_diar_service()
|
||||||
|
service._load_model()
|
||||||
|
DIAR_MODEL_LOADED = True
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'message': '3D-Speaker 模型加载成功',
|
||||||
|
'loaded': True,
|
||||||
|
'embedding_model': embedding_model,
|
||||||
|
'device': device
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/api/asr/unload', methods=['POST'])
|
||||||
|
def unload_asr_model():
|
||||||
|
"""卸载 ASR 模型"""
|
||||||
|
global GLOBAL_ASR_SERVICE, ASR_MODEL_LOADED
|
||||||
|
|
||||||
|
try:
|
||||||
|
GLOBAL_ASR_SERVICE = None
|
||||||
|
ASR_MODEL_LOADED = False
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return jsonify({'message': 'ASR 模型已卸载'})
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/api/diar/unload', methods=['POST'])
|
||||||
|
def unload_diar_model():
|
||||||
|
"""卸载 3D-Speaker 模型"""
|
||||||
|
global GLOBAL_DIAR_SERVICE, DIAR_MODEL_LOADED
|
||||||
|
|
||||||
|
try:
|
||||||
|
GLOBAL_DIAR_SERVICE = None
|
||||||
|
DIAR_MODEL_LOADED = False
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return jsonify({'message': '3D-Speaker 模型已卸载'})
|
||||||
|
except Exception as e:
|
||||||
|
return jsonify({'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/api/recognize/single', methods=['POST'])
|
||||||
|
def recognize_single():
|
||||||
|
"""单文件推理"""
|
||||||
|
try:
|
||||||
|
if 'file' not in request.files:
|
||||||
|
return jsonify({'error': '请上传音频文件'}), 400
|
||||||
|
|
||||||
|
file = request.files['file']
|
||||||
|
if file.filename == '':
|
||||||
|
return jsonify({'error': '文件名不能为空'}), 400
|
||||||
|
|
||||||
|
data = request.form or {}
|
||||||
|
use_3d_speaker = data.get('use_3d_speaker', 'false').lower() == 'true'
|
||||||
|
embedding_model = data.get('embedding_model', 'eres2netv2')
|
||||||
|
cluster_threshold = float(data.get('cluster_threshold', 0.5))
|
||||||
|
min_cluster_size = int(data.get('min_cluster_size', 10))
|
||||||
|
output_format = data.get('format', 'json')
|
||||||
|
|
||||||
|
filename = secure_filename(file.filename)
|
||||||
|
task_id = str(uuid.uuid4())[:8]
|
||||||
|
audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{task_id}_{filename}")
|
||||||
|
file.save(audio_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
global ASR_MODEL_LOADED
|
||||||
|
service = get_asr_service()
|
||||||
|
if not ASR_MODEL_LOADED:
|
||||||
|
service._load_model()
|
||||||
|
ASR_MODEL_LOADED = True
|
||||||
|
|
||||||
|
sentences = service.recognize(
|
||||||
|
audio_path,
|
||||||
|
use_3d_speaker=use_3d_speaker,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
cluster_threshold=cluster_threshold,
|
||||||
|
min_cluster_size=min_cluster_size
|
||||||
|
)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
'file': filename,
|
||||||
|
'total_sentences': len(sentences),
|
||||||
|
'sentences': [s.to_dict() for s in sentences]
|
||||||
|
}
|
||||||
|
|
||||||
|
if output_format == 'json':
|
||||||
|
result_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_result.json")
|
||||||
|
with open(result_path, 'w', encoding='utf-8') as f:
|
||||||
|
import json
|
||||||
|
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||||
|
return jsonify(result)
|
||||||
|
else:
|
||||||
|
srt_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_result.srt")
|
||||||
|
service.export_to_srt(sentences, srt_path)
|
||||||
|
return send_file(srt_path, as_attachment=True, download_name=f"{task_id}_result.srt")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if os.path.exists(audio_path):
|
||||||
|
os.remove(audio_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return jsonify({'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/api/recognize/batch', methods=['POST'])
|
||||||
|
def recognize_batch():
|
||||||
|
"""批量推理"""
|
||||||
|
try:
|
||||||
|
if 'files' not in request.files:
|
||||||
|
return jsonify({'error': '请上传音频文件'}), 400
|
||||||
|
|
||||||
|
files = request.files.getlist('files')
|
||||||
|
if not files or len(files) == 0:
|
||||||
|
return jsonify({'error': '文件列表为空'}), 400
|
||||||
|
|
||||||
|
data = request.form or {}
|
||||||
|
use_3d_speaker = data.get('use_3d_speaker', 'false').lower() == 'true'
|
||||||
|
embedding_model = data.get('embedding_model', 'eres2netv2')
|
||||||
|
cluster_threshold = float(data.get('cluster_threshold', 0.5))
|
||||||
|
min_cluster_size = int(data.get('min_cluster_size', 10))
|
||||||
|
|
||||||
|
task_id = str(uuid.uuid4())[:8]
|
||||||
|
audio_paths = []
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
if f.filename:
|
||||||
|
filename = secure_filename(f.filename)
|
||||||
|
audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{task_id}_{filename}")
|
||||||
|
f.save(audio_path)
|
||||||
|
audio_paths.append(audio_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
global ASR_MODEL_LOADED
|
||||||
|
service = get_asr_service()
|
||||||
|
if not ASR_MODEL_LOADED:
|
||||||
|
service._load_model()
|
||||||
|
ASR_MODEL_LOADED = True
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for audio_path in audio_paths:
|
||||||
|
try:
|
||||||
|
sentences = service.recognize(
|
||||||
|
audio_path,
|
||||||
|
use_3d_speaker=use_3d_speaker,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
cluster_threshold=cluster_threshold,
|
||||||
|
min_cluster_size=min_cluster_size
|
||||||
|
)
|
||||||
|
results.append({
|
||||||
|
'file': os.path.basename(audio_path),
|
||||||
|
'total_sentences': len(sentences),
|
||||||
|
'sentences': [s.to_dict() for s in sentences]
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
results.append({
|
||||||
|
'file': os.path.basename(audio_path),
|
||||||
|
'error': str(e)
|
||||||
|
})
|
||||||
|
|
||||||
|
result_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_batch_result.json")
|
||||||
|
import json
|
||||||
|
with open(result_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump({
|
||||||
|
'task_id': task_id,
|
||||||
|
'total_files': len(results),
|
||||||
|
'results': results
|
||||||
|
}, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
return jsonify({
|
||||||
|
'task_id': task_id,
|
||||||
|
'total_files': len(results),
|
||||||
|
'results': results,
|
||||||
|
'result_file': result_path
|
||||||
|
})
|
||||||
|
|
||||||
|
finally:
|
||||||
|
for audio_path in audio_paths:
|
||||||
|
if os.path.exists(audio_path):
|
||||||
|
os.remove(audio_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return jsonify({'error': str(e)}), 500
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/api/status', methods=['GET'])
|
||||||
|
def get_status():
|
||||||
|
"""获取模型状态"""
|
||||||
|
return jsonify({
|
||||||
|
'asr_loaded': ASR_MODEL_LOADED,
|
||||||
|
'diar_loaded': DIAR_MODEL_LOADED,
|
||||||
|
'asr_model': 'paraformer-zh',
|
||||||
|
'diar_model': '3D-Speaker'
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print("=" * 60)
|
||||||
|
print(" ASR & Speaker Diarization API Server")
|
||||||
|
print("=" * 60)
|
||||||
|
print("\nAPI 接口:")
|
||||||
|
print(" GET /health - 健康检查")
|
||||||
|
print(" GET /api/status - 获取模型状态")
|
||||||
|
print(" POST /api/asr/load - 加载 ASR 模型")
|
||||||
|
print(" POST /api/diar/load - 加载 3D-Speaker 模型")
|
||||||
|
print(" POST /api/asr/unload - 卸载 ASR 模型")
|
||||||
|
print(" POST /api/diar/unload - 卸载 3D-Speaker 模型")
|
||||||
|
print(" POST /api/recognize/single - 单文件推理")
|
||||||
|
print(" POST /api/recognize/batch - 批量推理")
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("启动服务: http://localhost:5000")
|
||||||
|
print("=" * 60)
|
||||||
|
app.run(host='0.0.0.0', port=5000, debug=False)
|
||||||
|
|
@ -0,0 +1,95 @@
|
||||||
|
"""
|
||||||
|
测试模型加载(不使用多进程)
|
||||||
|
用于诊断是否是模型本身的问题
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("模型加载测试(单进程模式)")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Python 版本:{sys.version}")
|
||||||
|
print(f"PyTorch 版本:{torch.__version__}")
|
||||||
|
print(f"CUDA 可用:{torch.cuda.is_available()}")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
||||||
|
print(f"CUDA 版本:{torch.version.cuda}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 测试 ASR 模型
|
||||||
|
print("=" * 60)
|
||||||
|
print("测试 1: 加载 ASR 模型 (Paraformer)")
|
||||||
|
print("=" * 60)
|
||||||
|
from asr_service import ASRService
|
||||||
|
|
||||||
|
asr_service = ASRService(model_name="paraformer-zh", device="auto")
|
||||||
|
print("✓ ASRService 初始化完成")
|
||||||
|
|
||||||
|
asr_service._load_model()
|
||||||
|
print("✓ ASR 模型加载完成")
|
||||||
|
print(f" 模型类型:{type(asr_service._model)}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 测试说话人分离模型
|
||||||
|
print("=" * 60)
|
||||||
|
print("测试 2: 加载说话人分离模型 (3D-Speaker)")
|
||||||
|
print("=" * 60)
|
||||||
|
from diarization_service import DiarizationService
|
||||||
|
|
||||||
|
diar_service = DiarizationService(
|
||||||
|
embedding_model="eres2netv2",
|
||||||
|
device="auto",
|
||||||
|
cluster_threshold=0.5,
|
||||||
|
min_cluster_size=10
|
||||||
|
)
|
||||||
|
print("✓ DiarizationService 初始化完成")
|
||||||
|
|
||||||
|
diar_service._load_model()
|
||||||
|
print("✓ 说话人分离模型加载完成")
|
||||||
|
print(f" 模型类型:{type(diar_service.model)}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 测试音频处理
|
||||||
|
print("=" * 60)
|
||||||
|
print("测试 3: 测试音频处理")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 检查是否有测试音频
|
||||||
|
from pathlib import Path
|
||||||
|
test_audio = Path("test.wav")
|
||||||
|
if test_audio.exists():
|
||||||
|
print(f"找到测试音频:{test_audio}")
|
||||||
|
|
||||||
|
print("执行 ASR 识别...")
|
||||||
|
sentences = asr_service.recognize(str(test_audio))
|
||||||
|
print(f"✓ ASR 识别完成,共 {len(sentences)} 句")
|
||||||
|
|
||||||
|
if sentences:
|
||||||
|
print(f" 第一句:{sentences[0]}")
|
||||||
|
|
||||||
|
print("执行说话人分离...")
|
||||||
|
segments = diar_service.diarize(str(test_audio))
|
||||||
|
print(f"✓ 说话人分离完成,共 {len(segments)} 个片段")
|
||||||
|
|
||||||
|
if segments:
|
||||||
|
print(f" 第一个片段:{segments[0]}")
|
||||||
|
else:
|
||||||
|
print("⚠️ 未找到测试音频 (test.wav),跳过处理测试")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("=" * 60)
|
||||||
|
print("✓ 所有测试通过!模型工作正常")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print()
|
||||||
|
print("=" * 60)
|
||||||
|
print("✗ 测试失败!")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"错误:{e}")
|
||||||
|
print()
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
"""
|
||||||
|
测试分阶段处理逻辑
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# 导入 main.py 中的函数
|
||||||
|
from main import (
|
||||||
|
VIDEO_DIR,
|
||||||
|
OUTPUT_DIR,
|
||||||
|
TEMP_DIR,
|
||||||
|
process_batch_diarization,
|
||||||
|
process_batch_asr,
|
||||||
|
get_video_list
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_staged_processing():
|
||||||
|
"""测试分阶段处理"""
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("分阶段处理测试")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 获取视频列表(只取前 2 个进行测试)
|
||||||
|
video_paths = get_video_list(VIDEO_DIR)
|
||||||
|
if not video_paths:
|
||||||
|
print("✗ 未找到视频文件")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 只测试前 2 个
|
||||||
|
test_videos = video_paths[:2]
|
||||||
|
print(f"测试视频:{len(test_videos)}")
|
||||||
|
for v in test_videos:
|
||||||
|
print(f" - {v.name}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 阶段 1: 说话人分离
|
||||||
|
print("=" * 60)
|
||||||
|
print("阶段 1: 说话人分离")
|
||||||
|
print("=" * 60)
|
||||||
|
diar_results = process_batch_diarization(test_videos, max_workers=1)
|
||||||
|
|
||||||
|
print(f"\n阶段 1 结果:{len(diar_results)}/{len(test_videos)} 成功")
|
||||||
|
for video, result_path in diar_results.items():
|
||||||
|
status = "✓" if result_path else "✗"
|
||||||
|
print(f" {status} {video.name}: {result_path}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 阶段 2: ASR + 合并
|
||||||
|
print("=" * 60)
|
||||||
|
print("阶段 2: ASR + 合并")
|
||||||
|
print("=" * 60)
|
||||||
|
results = process_batch_asr(test_videos, diar_results, max_workers=1)
|
||||||
|
|
||||||
|
print(f"\n阶段 2 结果:{len(results)}/{len(test_videos)} 完成")
|
||||||
|
for result in results:
|
||||||
|
status = "✓" if result.get("success") else "✗"
|
||||||
|
print(f" {status} {Path(result['video']).name}")
|
||||||
|
if result.get("speaker_counts"):
|
||||||
|
print(f" 说话人:{result['speaker_counts']}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("✓ 测试完成!")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_staged_processing()
|
||||||
Loading…
Reference in New Issue