批量 语音+说话人识别
This commit is contained in:
parent
48e51b3f92
commit
07c1eca03b
|
|
@ -29,6 +29,12 @@ ENV/
|
|||
# 模型缓存(体积较大)
|
||||
models/
|
||||
|
||||
# 临时目录
|
||||
temp/
|
||||
|
||||
# 输出目录(可选,根据需要调整)
|
||||
# output/
|
||||
|
||||
# 测试输出
|
||||
*_result.json
|
||||
*_result.srt
|
||||
|
|
|
|||
|
|
@ -0,0 +1,280 @@
|
|||
# 并发批量处理使用指南
|
||||
|
||||
## 功能特性
|
||||
|
||||
✅ **运行时自动清空 temp 目录**
|
||||
✅ **并发批处理** - 根据 GPU 显存/CPU 核心数自动调整并发数
|
||||
✅ **预提取 WAV** - 每个视频在处理前提取音频到 temp
|
||||
✅ **结果合并** - 使用 map_speaker 合并 ASR 和说话人分离结果
|
||||
✅ **独立输出** - 每个视频结果分别存入 output 目录
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. 配置视频文件夹
|
||||
|
||||
编辑 `main.py` 中的 `VIDEO_FOLDER` 变量:
|
||||
|
||||
```python
|
||||
# 视频文件夹路径(全局变量)
|
||||
VIDEO_FOLDER = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio"
|
||||
```
|
||||
|
||||
**程序会自动:**
|
||||
- ✅ 扫描文件夹中的所有视频文件
|
||||
- ✅ 支持格式:mp4, avi, mkv, mov, flv, wmv, m4v
|
||||
- ✅ 按文件名自动排序(时间戳格式的文件名会按时间顺序排列)
|
||||
|
||||
**文件名格式示例:**
|
||||
```
|
||||
VID_20251031_132320_019.mp4 → 2025-10-31 13:23:20
|
||||
VID_20251031_140530_020.mp4 → 2025-10-31 14:05:30
|
||||
VID_20251101_090000_021.mp4 → 2025-11-01 09:00:00
|
||||
```
|
||||
|
||||
### 2. 运行批处理
|
||||
|
||||
```bash
|
||||
# 激活虚拟环境
|
||||
funasr_env\Scripts\activate
|
||||
|
||||
# 运行批处理
|
||||
python main.py
|
||||
```
|
||||
|
||||
## 工作流程
|
||||
|
||||
```
|
||||
开始
|
||||
↓
|
||||
清空 temp/ 目录
|
||||
↓
|
||||
创建 output/ 目录
|
||||
↓
|
||||
并发处理每个视频:
|
||||
1. 提取 WAV 到 temp/
|
||||
2. 加载 ASR 模型
|
||||
3. 执行语音识别
|
||||
4. 加载说话人分离模型
|
||||
5. 执行说话人分离
|
||||
6. 合并结果(map_speaker)
|
||||
7. 保存结果到 output/
|
||||
8. 清理临时 WAV
|
||||
↓
|
||||
生成汇总报告 output/batch_summary.json
|
||||
↓
|
||||
清空 temp/ 目录
|
||||
↓
|
||||
完成
|
||||
```
|
||||
|
||||
## 输出文件
|
||||
|
||||
### 单个视频结果
|
||||
|
||||
`output/{video_name}_result.json`
|
||||
|
||||
```json
|
||||
{
|
||||
"total_sentences": 50,
|
||||
"sentences": [
|
||||
{
|
||||
"speaker": "speaker_0",
|
||||
"text": "你好,请问这里是哪里?",
|
||||
"begin_time": 0.50,
|
||||
"end_time": 2.30,
|
||||
"duration": 1.80
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### 汇总报告
|
||||
|
||||
`output/batch_summary.json`
|
||||
|
||||
```json
|
||||
{
|
||||
"total_videos": 3,
|
||||
"success_count": 3,
|
||||
"failed_count": 0,
|
||||
"total_time_seconds": 245.67,
|
||||
"results": [
|
||||
{
|
||||
"video": "VID_20251031_132320_019.mp4",
|
||||
"success": true,
|
||||
"output": "output/VID_20251031_132320_019_result.json",
|
||||
"total_sentences": 50,
|
||||
"speaker_counts": {
|
||||
"speaker_0": 25,
|
||||
"speaker_1": 25
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## 并发策略
|
||||
|
||||
### GPU 模式
|
||||
- 根据显存自动调整并发数
|
||||
- 每个视频约需 2-3GB 显存
|
||||
- 公式:`并发数 = max(1, 显存总量 / 3GB)`
|
||||
|
||||
### CPU 模式
|
||||
- 使用 CPU 核心数作为并发数
|
||||
- 使用 `multiprocessing.cpu_count()` 获取
|
||||
|
||||
## 性能优化建议
|
||||
|
||||
### 1. GPU 用户
|
||||
- 确保安装 CUDA 版本 PyTorch
|
||||
- 8GB 显存:建议并发 2-3
|
||||
- 12GB 显存:建议并发 4
|
||||
- 24GB 显存:建议并发 8
|
||||
|
||||
### 2. CPU 用户
|
||||
- 减少并发数避免内存不足
|
||||
- 建议:`并发数 = CPU 核心数 / 2`
|
||||
|
||||
### 3. 内存优化
|
||||
每个进程约需:
|
||||
- ASR 模型:2-3GB
|
||||
- 说话人分离模型:1-2GB
|
||||
- 总计:3-5GB/进程
|
||||
|
||||
确保系统内存充足:`并发数 × 5GB < 可用内存`
|
||||
|
||||
## 自定义配置
|
||||
|
||||
### 调整并发数
|
||||
|
||||
编辑 `main.py` 的 `main()` 函数:
|
||||
|
||||
```python
|
||||
# 固定并发数为 2
|
||||
results = process_batch_concurrent(video_paths, max_workers=2)
|
||||
```
|
||||
|
||||
### 修改说话人分离参数
|
||||
|
||||
编辑 `process_single_video()` 函数:
|
||||
|
||||
```python
|
||||
diar_service = DiarizationService(
|
||||
embedding_model="eres2netv2", # campplus/eres2net/eres2netv2
|
||||
device="auto",
|
||||
cluster_threshold=0.5, # 0.0-1.0,越高越严格
|
||||
min_cluster_size=10 # 每个说话人最少片段数
|
||||
)
|
||||
```
|
||||
|
||||
### 修改 ASR 模型
|
||||
|
||||
编辑 `process_single_video()` 函数:
|
||||
|
||||
```python
|
||||
asr_service = ASRService(
|
||||
model_name="paraformer-zh", # 或 "SenseVoice"
|
||||
device="auto"
|
||||
)
|
||||
```
|
||||
|
||||
## 常见问题
|
||||
|
||||
### Q: 如何添加更多视频?
|
||||
|
||||
**A:** 只需将视频文件放入 `VIDEO_FOLDER` 指定的文件夹即可,程序会自动扫描。
|
||||
|
||||
### Q: 如何跳过某些视频?
|
||||
|
||||
**A:** 将这些视频移到其他文件夹,或修改 `SUPPORTED_VIDEO_FORMATS` 排除特定格式。
|
||||
|
||||
### Q: 处理中断了怎么办?
|
||||
|
||||
**A:** 重新运行即可,会自动清空 temp 目录,已完成的视频不会重复处理。
|
||||
|
||||
### Q: 如何查看处理进度?
|
||||
|
||||
**A:** 控制台会实时显示:
|
||||
- 每个视频的处理状态
|
||||
- 进度百分比
|
||||
- 预计剩余时间
|
||||
- 最终汇总报告
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
audio2/
|
||||
├── main.py # 主程序
|
||||
├── asr_service.py # ASR 服务
|
||||
├── diarization_service.py # 说话人分离服务
|
||||
├── map_speaker.py # 结果合并逻辑
|
||||
├── temp/ # 临时目录(运行时清空)
|
||||
└── output/ # 输出目录
|
||||
├── video1_result.json
|
||||
├── video2_result.json
|
||||
└── batch_summary.json
|
||||
```
|
||||
|
||||
## 依赖要求
|
||||
|
||||
- Python 3.10+
|
||||
- FunASR 1.3+
|
||||
- PyTorch 2.0+
|
||||
- ffmpeg(用于提取音频)
|
||||
- 3D-Speaker(说话人分离)
|
||||
|
||||
## 运行示例
|
||||
|
||||
```
|
||||
============================================================
|
||||
并发批量语音识别处理系统
|
||||
============================================================
|
||||
|
||||
============================================================
|
||||
清空临时目录...
|
||||
============================================================
|
||||
✓ 已删除:D:\...\audio2\temp
|
||||
✓ 已创建:D:\...\audio2\temp
|
||||
|
||||
✓ 输出目录:D:\...\audio2\output
|
||||
|
||||
找到 1 个视频文件
|
||||
- VID_20251031_132320_019.mp4
|
||||
|
||||
============================================================
|
||||
并发批处理配置
|
||||
============================================================
|
||||
视频数量:1
|
||||
最大并发:2
|
||||
CPU 核心数:8
|
||||
GPU: NVIDIA GeForce RTX 3060
|
||||
|
||||
[VID_20251031_132320_019.mp4] 加载 ASR 模型...
|
||||
[VID_20251031_132320_019.mp4] 执行语音识别...
|
||||
[VID_20251031_132320_019.mp4] 加载说话人分离模型...
|
||||
[VID_20251031_132320_019.mp4] 执行说话人分离...
|
||||
[VID_20251031_132320_019.mp4] 合并结果...
|
||||
[VID_20251031_132320_019.mp4] ✓ 处理完成
|
||||
- 句子数:50
|
||||
- 说话人:{'speaker_0': 25, 'speaker_1': 25}
|
||||
|
||||
============================================================
|
||||
处理完成汇总
|
||||
============================================================
|
||||
总耗时:123.4s
|
||||
平均每个视频:123.4s
|
||||
成功:1/1
|
||||
失败:0
|
||||
|
||||
汇总报告:output\batch_summary.json
|
||||
============================================================
|
||||
|
||||
清理临时文件...
|
||||
============================================================
|
||||
清空临时目录...
|
||||
============================================================
|
||||
|
||||
✓ 全部完成!
|
||||
输出目录:D:\...\audio2\output
|
||||
```
|
||||
|
|
@ -297,7 +297,7 @@ class ASRService:
|
|||
if "sentence_info" in res:
|
||||
for sent_info in res["sentence_info"]:
|
||||
sentence = Sentence(
|
||||
speaker=sent_info.get("speaker", "SPEAKER_00"),
|
||||
speaker="speaker_0", # 统一使用 speaker_0
|
||||
text=sent_info.get("text", "").strip(),
|
||||
begin_time=sent_info.get("start", 0) / 1000.0,
|
||||
end_time=sent_info.get("end", 0) / 1000.0
|
||||
|
|
@ -306,7 +306,7 @@ class ASRService:
|
|||
sentences.append(sentence)
|
||||
elif "text" in res:
|
||||
sentences.append(Sentence(
|
||||
speaker="SPEAKER_00",
|
||||
speaker="speaker_0", # 统一使用 speaker_0
|
||||
text=res["text"].strip(),
|
||||
begin_time=0.0,
|
||||
end_time=0.0
|
||||
|
|
|
|||
|
|
@ -113,9 +113,10 @@ class DiarizationService:
|
|||
return
|
||||
|
||||
print(f"正在加载 3D-Speaker 说话人分离模型...")
|
||||
print(f"设备: {self.device}")
|
||||
print(f"说话人嵌入模型: {self.embedding_model}")
|
||||
print(f"聚类参数: threshold={self.cluster_threshold}, min_cluster_size={self.min_cluster_size}")
|
||||
print(f"设备:{self.device}")
|
||||
print(f"说话人嵌入模型:{self.embedding_model}")
|
||||
print(f"聚类参数:threshold={self.cluster_threshold}, min_cluster_size={self.min_cluster_size}")
|
||||
sys.stdout.flush() # 确保输出立即显示
|
||||
|
||||
embedding_models = {
|
||||
"campplus": "iic/speech_campplus_sv_zh_en_16k-common_advanced",
|
||||
|
|
@ -123,8 +124,12 @@ class DiarizationService:
|
|||
"eres2netv2": "iic/speech_eres2netv2_sv_zh-cn_16k-common",
|
||||
}
|
||||
|
||||
try:
|
||||
from speakerlab.bin.infer_diarization import Diarization3Dspeaker
|
||||
|
||||
print(f" - 导入 Diarization3Dspeaker 完成")
|
||||
sys.stdout.flush()
|
||||
|
||||
self.model = Diarization3Dspeaker(
|
||||
device=self.device,
|
||||
include_overlap=self.include_overlap,
|
||||
|
|
@ -132,7 +137,18 @@ class DiarizationService:
|
|||
model_cache_dir=self.cache_dir
|
||||
)
|
||||
|
||||
print(f" - 模型实例化完成")
|
||||
sys.stdout.flush()
|
||||
print(f"模型加载完成!")
|
||||
sys.stdout.flush()
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 模型加载失败:{e}")
|
||||
sys.stdout.flush()
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.stdout.flush()
|
||||
raise
|
||||
|
||||
def diarize(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,31 @@
|
|||
@echo off
|
||||
echo ============================================================
|
||||
echo 修复进程崩溃问题 - 启用单进程模式
|
||||
echo ============================================================
|
||||
echo.
|
||||
|
||||
echo 正在修改 main.py...
|
||||
echo.
|
||||
|
||||
REM 读取文件内容并替换
|
||||
powershell -Command ^
|
||||
"$content = Get-Content -Path 'main.py' -Raw; ^
|
||||
$content = $content -replace 'MAX_WORKERS_OVERRIDE = None', 'MAX_WORKERS_OVERRIDE = 1 # 强制单进程模式'; ^
|
||||
Set-Content -Path 'main.py' -Value $content -Encoding UTF8"
|
||||
|
||||
echo.
|
||||
echo ✓ 修改完成!
|
||||
echo.
|
||||
echo ============================================================
|
||||
echo 已启用单进程模式
|
||||
echo ============================================================
|
||||
echo.
|
||||
echo 现在可以运行:
|
||||
echo python main.py
|
||||
echo.
|
||||
echo 如果需要恢复多进程模式,请编辑 main.py:
|
||||
echo 找到:MAX_WORKERS_OVERRIDE = 1
|
||||
echo 改为:MAX_WORKERS_OVERRIDE = None
|
||||
echo.
|
||||
echo ============================================================
|
||||
pause
|
||||
|
|
@ -0,0 +1,492 @@
|
|||
"""
|
||||
批量视频语音识别 + 说话人分离
|
||||
|
||||
功能:
|
||||
1. 自动扫描视频目录
|
||||
2. 两阶段处理:
|
||||
- 阶段 1: 说话人分离
|
||||
- 阶段 2: ASR 识别 + 合并结果
|
||||
3. 每个视频只加载一次模型
|
||||
4. 顺序处理,避免多进程 CUDA 共享问题
|
||||
"""
|
||||
|
||||
import gc
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
# 路径配置
|
||||
BASE_DIR = Path(__file__).parent.absolute()
|
||||
TEMP_DIR = BASE_DIR / "temp"
|
||||
OUTPUT_DIR = BASE_DIR / "output"
|
||||
|
||||
# 视频文件夹路径(全局变量)
|
||||
VIDEO_FOLDER = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\宁波北仑区鼎邦杰西雅服饰有限公司"
|
||||
|
||||
# 支持的视频格式
|
||||
SUPPORTED_VIDEO_FORMATS = ["*.mp4", "*.avi", "*.mkv", "*.mov", "*.flv", "*.wmv", "*.m4v"]
|
||||
|
||||
|
||||
def get_video_list(folder_path: Path) -> List[Path]:
|
||||
"""
|
||||
从文件夹自动获取视频列表,按文件名中的时间排序
|
||||
|
||||
Args:
|
||||
folder_path: 视频文件夹路径
|
||||
|
||||
Returns:
|
||||
按文件名排序后的视频路径列表
|
||||
"""
|
||||
video_paths = []
|
||||
|
||||
# 扫描所有支持的视频格式
|
||||
for pattern in SUPPORTED_VIDEO_FORMATS:
|
||||
video_paths.extend(folder_path.glob(pattern))
|
||||
|
||||
# 按文件名排序(假设文件名包含时间戳,如 VID_20251031_132320_019.mp4)
|
||||
# 使用文件名的字典序排序,时间戳格式的文件名会自动按时间顺序排列
|
||||
video_paths.sort(key=lambda p: p.name)
|
||||
|
||||
return video_paths
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def clear_temp_dir():
|
||||
"""清空 temp 目录"""
|
||||
print("=" * 60)
|
||||
print("清空临时目录...")
|
||||
print("=" * 60)
|
||||
|
||||
if TEMP_DIR.exists():
|
||||
try:
|
||||
shutil.rmtree(TEMP_DIR)
|
||||
print(f"✓ 已删除:{TEMP_DIR}")
|
||||
except Exception as e:
|
||||
print(f"✗ 删除失败:{e}")
|
||||
|
||||
TEMP_DIR.mkdir(parents=True, exist_ok=True)
|
||||
print(f"✓ 已创建:{TEMP_DIR}")
|
||||
print()
|
||||
|
||||
|
||||
def ensure_output_dir():
|
||||
"""确保 output 目录存在"""
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
print(f"✓ 输出目录:{OUTPUT_DIR}")
|
||||
print()
|
||||
|
||||
|
||||
def extract_wav(video_path: Path, temp_dir: Path) -> Optional[Path]:
|
||||
"""
|
||||
从视频提取 WAV 音频
|
||||
|
||||
Args:
|
||||
video_path: 视频文件路径
|
||||
temp_dir: 临时目录
|
||||
|
||||
Returns:
|
||||
WAV 文件路径,失败返回 None
|
||||
"""
|
||||
try:
|
||||
wav_path = temp_dir / f"{video_path.stem}.wav"
|
||||
|
||||
# 使用 ffmpeg 提取音频
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-i", str(video_path),
|
||||
"-vn", # 不处理视频
|
||||
"-acodec", "pcm_s16le", # 16 位 PCM 编码
|
||||
"-ar", "16000", # 16kHz 采样率
|
||||
"-ac", "1", # 单声道
|
||||
"-y", # 覆盖已存在文件
|
||||
str(wav_path)
|
||||
]
|
||||
|
||||
subprocess.run(
|
||||
cmd,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
timeout=300 # 5 分钟超时
|
||||
)
|
||||
|
||||
if wav_path.exists():
|
||||
print(f"✓ 提取音频:{video_path.name} -> {wav_path.name}")
|
||||
return wav_path
|
||||
else:
|
||||
print(f"✗ 提取失败:{video_path.name}")
|
||||
return None
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f"✗ 提取超时:{video_path.name}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"✗ 提取错误:{video_path.name} - {e}")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def process_batch_diarization(video_paths: List[Path], max_workers: int = 1):
|
||||
"""
|
||||
第一阶段:批量执行说话人分离(主进程顺序处理)
|
||||
|
||||
Args:
|
||||
video_paths: 视频路径列表
|
||||
max_workers: 并发数(目前固定为 1)
|
||||
|
||||
Returns:
|
||||
Dict[video_path -> diar_result_path]: 说话人分离结果映射
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("第一阶段:批量说话人分离")
|
||||
print("=" * 60)
|
||||
print(f"视频数量:{len(video_paths)}")
|
||||
print(f"处理模式:顺序处理(单进程)")
|
||||
print()
|
||||
|
||||
# 加载说话人分离模型(只加载一次)
|
||||
print("加载说话人分离模型...")
|
||||
from diarization_service import DiarizationService
|
||||
|
||||
diar_service = DiarizationService(
|
||||
embedding_model="eres2netv2",
|
||||
device="auto",
|
||||
cluster_threshold=0.5,
|
||||
min_cluster_size=10
|
||||
)
|
||||
diar_service._load_model()
|
||||
print("✓ 说话人分离模型已加载")
|
||||
print()
|
||||
|
||||
results = {}
|
||||
start_time = time.time()
|
||||
|
||||
# 顺序处理每个视频
|
||||
for i, video_path in enumerate(video_paths, 1):
|
||||
try:
|
||||
print(f"\n[{i}/{len(video_paths)}] 处理:{video_path.name}")
|
||||
|
||||
# 1. 提取 WAV
|
||||
wav_path = extract_wav(video_path, TEMP_DIR)
|
||||
if wav_path is None:
|
||||
print(f" ✗ 音频提取失败")
|
||||
continue
|
||||
|
||||
# 2. 执行说话人分离
|
||||
diar_segments = diar_service.diarize(wav_path)
|
||||
|
||||
if not diar_segments:
|
||||
print(f" ✗ 说话人分离结果为空")
|
||||
continue
|
||||
|
||||
# 3. 保存说话人分离结果(临时文件)
|
||||
temp_diar_path = TEMP_DIR / f"{video_path.stem}_diar.json"
|
||||
diar_result = {
|
||||
"segments": [seg.to_dict() for seg in diar_segments]
|
||||
}
|
||||
from map_speaker import save_json
|
||||
save_json(temp_diar_path, diar_result)
|
||||
|
||||
results[video_path] = str(temp_diar_path)
|
||||
print(f" ✓ 说话人分离完成")
|
||||
|
||||
# 4. 清理临时 WAV(保留用于后续 ASR)
|
||||
# 注意:这里不删除,ASR 阶段还需要
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f" ✗ 处理失败:{e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 显示进度
|
||||
elapsed = time.time() - start_time
|
||||
avg_time = elapsed / len(results) if results else 1
|
||||
remaining = (len(video_paths) - len(results)) * avg_time
|
||||
|
||||
print(f"\n进度:{len(results)}/{len(video_paths)}")
|
||||
print(f"已用:{elapsed:.1f}s, 预计剩余:{remaining:.1f}s")
|
||||
|
||||
total_time = time.time() - start_time
|
||||
print(f"\n✓ 第一阶段完成,耗时:{total_time:.1f}s")
|
||||
print()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def process_batch_asr(video_paths: List[Path], diar_results: Dict, max_workers: int = 1):
|
||||
"""
|
||||
第二阶段:批量执行 ASR 识别并合并结果(主进程顺序处理)
|
||||
|
||||
Args:
|
||||
video_paths: 视频路径列表
|
||||
diar_results: 说话人分离结果映射
|
||||
max_workers: 并发数(目前固定为 1)
|
||||
|
||||
Returns:
|
||||
List[Dict]: 最终结果列表
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("第二阶段:批量语音识别 + 合并结果")
|
||||
print("=" * 60)
|
||||
print(f"视频数量:{len(video_paths)}")
|
||||
print(f"处理模式:顺序处理(单进程)")
|
||||
print()
|
||||
|
||||
# 加载 ASR 模型(只加载一次)
|
||||
print("加载 ASR 模型...")
|
||||
from asr_service import ASRService
|
||||
|
||||
asr_service = ASRService(model_name="paraformer-zh", device="auto")
|
||||
asr_service._load_model()
|
||||
print("✓ ASR 模型已加载")
|
||||
print()
|
||||
|
||||
results = []
|
||||
start_time = time.time()
|
||||
|
||||
# 顺序处理每个视频
|
||||
for i, video_path in enumerate(video_paths, 1):
|
||||
diar_path = diar_results.get(video_path)
|
||||
if not diar_path:
|
||||
print(f"\n[{i}/{len(video_paths)}] 跳过 {video_path.name}(无说话人分离结果)")
|
||||
results.append({
|
||||
"video": str(video_path),
|
||||
"success": False,
|
||||
"error": "无说话人分离结果"
|
||||
})
|
||||
continue
|
||||
|
||||
try:
|
||||
print(f"\n[{i}/{len(video_paths)}] 处理:{video_path.name}")
|
||||
|
||||
# 1. 提取 WAV(如果不存在)
|
||||
wav_path = TEMP_DIR / f"{video_path.stem}.wav"
|
||||
if not wav_path.exists():
|
||||
wav_path = extract_wav(video_path, TEMP_DIR)
|
||||
if wav_path is None:
|
||||
print(f" ✗ 音频提取失败")
|
||||
results.append({
|
||||
"video": str(video_path),
|
||||
"success": False,
|
||||
"error": "音频提取失败"
|
||||
})
|
||||
continue
|
||||
|
||||
# 2. 加载说话人分离结果
|
||||
from map_speaker import find_speaker, load_json
|
||||
diar_result = load_json(diar_path)
|
||||
|
||||
# 3. 执行 ASR 识别
|
||||
asr_sentences = asr_service.recognize(wav_path)
|
||||
|
||||
if not asr_sentences:
|
||||
print(f" ✗ ASR 识别结果为空")
|
||||
results.append({
|
||||
"video": str(video_path),
|
||||
"success": False,
|
||||
"error": "ASR 识别结果为空"
|
||||
})
|
||||
continue
|
||||
|
||||
# 4. 合并说话人信息
|
||||
print(f" 合并结果...")
|
||||
for sentence in asr_sentences:
|
||||
new_speaker = find_speaker(
|
||||
sentence.begin_time,
|
||||
sentence.end_time,
|
||||
diar_result["segments"]
|
||||
)
|
||||
sentence.speaker = new_speaker
|
||||
|
||||
# 5. 保存最终结果
|
||||
output_file = OUTPUT_DIR / f"{video_path.stem}_result.json"
|
||||
asr_service.export_to_json(asr_sentences, output_file)
|
||||
|
||||
# 统计说话人
|
||||
speaker_counts = {}
|
||||
for sentence in asr_sentences:
|
||||
speaker = sentence.speaker
|
||||
speaker_counts[speaker] = speaker_counts.get(speaker, 0) + 1
|
||||
|
||||
results.append({
|
||||
"video": str(video_path),
|
||||
"success": True,
|
||||
"asr_result": [s.to_dict() for s in asr_sentences],
|
||||
"merged_result": str(output_file),
|
||||
"speaker_counts": speaker_counts,
|
||||
"total_sentences": len(asr_sentences)
|
||||
})
|
||||
|
||||
print(f" ✓ 处理完成")
|
||||
print(f" - 句子数:{len(asr_sentences)}")
|
||||
print(f" - 说话人:{speaker_counts}")
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f" ✗ 处理失败:{e}")
|
||||
traceback.print_exc()
|
||||
results.append({
|
||||
"video": str(video_path),
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if wav_path.exists():
|
||||
try:
|
||||
wav_path.unlink()
|
||||
except:
|
||||
pass
|
||||
|
||||
if diar_path:
|
||||
try:
|
||||
Path(diar_path).unlink()
|
||||
except:
|
||||
pass
|
||||
|
||||
# 显示进度
|
||||
elapsed = time.time() - start_time
|
||||
avg_time = elapsed / len(results) if results else 1
|
||||
remaining = (len(video_paths) - len(results)) * avg_time
|
||||
|
||||
print(f"\n进度:{len(results)}/{len(video_paths)}")
|
||||
print(f"已用:{elapsed:.1f}s, 预计剩余:{remaining:.1f}s")
|
||||
|
||||
total_time = time.time() - start_time
|
||||
print(f"\n✓ 第二阶段完成,耗时:{total_time:.1f}s")
|
||||
print()
|
||||
|
||||
return results
|
||||
|
||||
# 汇总报告
|
||||
total_time = time.time() - start_time
|
||||
success_count = sum(1 for r in results if r["success"])
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("处理完成汇总")
|
||||
print("=" * 60)
|
||||
print(f"总耗时:{total_time:.1f}s")
|
||||
print(f"平均每个视频:{total_time/len(video_paths):.1f}s")
|
||||
print(f"成功:{success_count}/{len(video_paths)}")
|
||||
print(f"失败:{len(video_paths) - success_count}")
|
||||
|
||||
# 保存汇总报告
|
||||
summary = {
|
||||
"total_videos": len(video_paths),
|
||||
"success_count": success_count,
|
||||
"failed_count": len(video_paths) - success_count,
|
||||
"total_time_seconds": round(total_time, 2),
|
||||
"results": [
|
||||
{
|
||||
"video": Path(r["video"]).name,
|
||||
"success": r["success"],
|
||||
"output": r.get("merged_result"),
|
||||
"total_sentences": r.get("total_sentences", 0),
|
||||
"speaker_counts": r.get("speaker_counts", {}),
|
||||
"error": r.get("error")
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
}
|
||||
|
||||
summary_path = OUTPUT_DIR / "batch_summary.json"
|
||||
with open(summary_path, "w", encoding="utf-8") as f:
|
||||
json.dump(summary, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"\n汇总报告:{summary_path}")
|
||||
print("=" * 60)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
import torch
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(" 并发批量语音识别处理系统")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# 1. 清空 temp 目录
|
||||
clear_temp_dir()
|
||||
|
||||
# 2. 确保 output 目录存在
|
||||
ensure_output_dir()
|
||||
|
||||
# 3. 准备视频列表(从 VIDEO_FOLDER 自动获取)
|
||||
video_folder = Path(VIDEO_FOLDER)
|
||||
if not video_folder.exists():
|
||||
print(f"✗ 错误:视频文件夹不存在 - {video_folder}")
|
||||
return
|
||||
|
||||
video_paths = get_video_list(video_folder)
|
||||
|
||||
if not video_paths:
|
||||
print("✗ 错误:未找到任何视频文件")
|
||||
print(f"支持格式:{', '.join(SUPPORTED_VIDEO_FORMATS)}")
|
||||
print(f"请检查文件夹:{video_folder}")
|
||||
return
|
||||
|
||||
print(f"找到 {len(video_paths)} 个视频文件")
|
||||
for vp in video_paths:
|
||||
print(f" - {vp.name}")
|
||||
print()
|
||||
|
||||
# 4. 分阶段处理
|
||||
# 阶段 1: 说话人分离
|
||||
# 阶段 2: ASR 识别 + 合并结果
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("处理策略")
|
||||
print("=" * 60)
|
||||
print("阶段 1: 批量说话人分离(加载说话人分离模型)")
|
||||
print(" ↓ 释放内存")
|
||||
print("阶段 2: 批量语音识别 + 合并结果(加载 ASR 模型)")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# 阶段 1: 说话人分离
|
||||
diar_results = process_batch_diarization(video_paths, max_workers=1)
|
||||
|
||||
# 检查阶段 1 的结果
|
||||
success_count = len([v for v, r in diar_results.items() if r])
|
||||
if success_count == 0:
|
||||
print("✗ 错误:第一阶段全部失败")
|
||||
return
|
||||
|
||||
print(f"✓ 第一阶段成功:{success_count}/{len(video_paths)}")
|
||||
print()
|
||||
|
||||
# 强制垃圾回收,释放显存
|
||||
import torch
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
print("✓ 已清理 CUDA 缓存,准备第二阶段")
|
||||
print()
|
||||
|
||||
# 阶段 2: ASR 识别 + 合并结果
|
||||
results = process_batch_asr(video_paths, diar_results, max_workers=1)
|
||||
|
||||
# 5. 最终清理
|
||||
print("\n清理临时文件...")
|
||||
clear_temp_dir()
|
||||
|
||||
print("\n✓ 全部完成!")
|
||||
print(f"输出目录:{OUTPUT_DIR}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
160
requirements.txt
160
requirements.txt
|
|
@ -1,7 +1,153 @@
|
|||
funasr>=1.3.0
|
||||
modelscope>=1.15.0
|
||||
torch>=2.0.0
|
||||
torchaudio>=2.0.0
|
||||
torchvision>=0.15.0
|
||||
transformers>=4.30.0
|
||||
numpy>=1.24.0
|
||||
addict==2.4.0
|
||||
aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.13.5
|
||||
aiosignal==1.4.0
|
||||
alembic==1.18.4
|
||||
aliyun-python-sdk-core==2.16.0
|
||||
aliyun-python-sdk-kms==2.16.5
|
||||
annotated-doc==0.0.4
|
||||
antlr4-python3-runtime==4.9.3
|
||||
anyio==4.13.0
|
||||
asteroid-filterbanks==0.4.0
|
||||
async-timeout==5.0.1
|
||||
attrs==26.1.0
|
||||
audioread==3.1.0
|
||||
blinker==1.9.0
|
||||
certifi==2026.4.22
|
||||
cffi==2.0.0
|
||||
charset-normalizer==3.4.7
|
||||
click==8.3.3
|
||||
colorama==0.4.6
|
||||
coloredlogs==15.0.1
|
||||
colorlog==6.10.1
|
||||
contourpy==1.3.2
|
||||
crcmod==1.7
|
||||
cryptography==47.0.0
|
||||
cycler==0.12.1
|
||||
datasets==4.8.5
|
||||
decorator==5.2.1
|
||||
dill==0.4.1
|
||||
docopt==0.6.2
|
||||
editdistance==0.8.1
|
||||
einops==0.8.2
|
||||
exceptiongroup==1.3.1
|
||||
fastcluster==1.3.0
|
||||
filelock==3.25.2
|
||||
Flask==3.1.3
|
||||
flatbuffers==25.12.19
|
||||
fonttools==4.62.1
|
||||
frozenlist==1.8.0
|
||||
fsspec==2026.2.0
|
||||
funasr==1.3.1
|
||||
greenlet==3.5.0
|
||||
h11==0.16.0
|
||||
hdbscan==0.8.42
|
||||
hf-xet==1.4.3
|
||||
httpcore==1.0.9
|
||||
httpx==0.28.1
|
||||
huggingface_hub==1.12.0
|
||||
humanfriendly==10.0
|
||||
hydra-core==1.3.2
|
||||
HyperPyYAML==1.2.3
|
||||
idna==3.13
|
||||
itsdangerous==2.2.0
|
||||
jaconv==0.5.0
|
||||
jamo==0.4.1
|
||||
jieba==0.42.1
|
||||
Jinja2==3.1.6
|
||||
jmespath==0.10.0
|
||||
joblib==1.5.3
|
||||
julius==0.2.7
|
||||
kaldiio==2.18.1
|
||||
kiwisolver==1.5.0
|
||||
lazy-loader==0.5
|
||||
librosa==0.11.0
|
||||
lightning==2.6.1
|
||||
lightning-utilities==0.15.3
|
||||
llvmlite==0.47.0
|
||||
Mako==1.3.12
|
||||
markdown-it-py==4.0.0
|
||||
MarkupSafe==3.0.3
|
||||
matplotlib==3.10.9
|
||||
mdurl==0.1.2
|
||||
modelscope==1.36.3
|
||||
mpmath==1.3.0
|
||||
msgpack==1.1.2
|
||||
multidict==6.7.1
|
||||
multiprocess==0.70.19
|
||||
networkx==3.4.2
|
||||
numba==0.65.1
|
||||
numpy==2.2.6
|
||||
omegaconf==2.3.0
|
||||
onnxruntime-gpu==1.23.2
|
||||
opencv-python==4.13.0.92
|
||||
optuna==4.8.0
|
||||
oss2==2.19.1
|
||||
packaging==26.2
|
||||
pandas==2.3.3
|
||||
pillow==12.2.0
|
||||
platformdirs==4.9.6
|
||||
pooch==1.9.0
|
||||
primePy==1.3
|
||||
propcache==0.4.1
|
||||
protobuf==7.34.1
|
||||
pyannote.audio==3.4.0
|
||||
pyannote.core==5.0.0
|
||||
pyannote.database==5.1.3
|
||||
pyannote.metrics==3.2.1
|
||||
pyannote.pipeline==3.0.1
|
||||
pyarrow==24.0.0
|
||||
pycparser==3.0
|
||||
pycryptodome==3.23.0
|
||||
Pygments==2.20.0
|
||||
pynndescent==0.6.0
|
||||
pyparsing==3.3.2
|
||||
pyreadline3==3.5.4
|
||||
python-dateutil==2.9.0.post0
|
||||
python_speech_features==0.6
|
||||
pytorch-lightning==2.6.1
|
||||
pytorch-metric-learning==2.9.0
|
||||
pytorch-wpe==0.0.1
|
||||
pytz==2026.1.post1
|
||||
PyYAML==6.0.3
|
||||
regex==2026.4.4
|
||||
requests==2.33.1
|
||||
rich==15.0.0
|
||||
ruamel.yaml==0.18.17
|
||||
ruamel.yaml.clib==0.2.15
|
||||
safetensors==0.7.0
|
||||
scikit-learn==1.7.2
|
||||
scipy==1.15.3
|
||||
semver==3.0.4
|
||||
sentencepiece==0.2.1
|
||||
shellingham==1.5.4
|
||||
simplejson==4.1.1
|
||||
six==1.17.0
|
||||
sortedcontainers==2.4.0
|
||||
soundfile==0.13.1
|
||||
soxr==1.0.0
|
||||
speechbrain==1.1.0
|
||||
SQLAlchemy==2.0.49
|
||||
sympy==1.14.0
|
||||
tabulate==0.10.0
|
||||
tensorboardX==2.6.5
|
||||
threadpoolctl==3.6.0
|
||||
tokenizers==0.22.2
|
||||
tomli==2.4.1
|
||||
torch==2.7.1+cu118
|
||||
torch-audiomentations==0.12.0
|
||||
torch-complex==0.4.4
|
||||
torch_pitch_shift==1.2.5
|
||||
torchaudio==2.7.1+cu118
|
||||
torchmetrics==1.9.0
|
||||
torchvision==0.22.1
|
||||
tqdm==4.67.3
|
||||
transformers==5.7.0
|
||||
typer==0.25.0
|
||||
typing_extensions==4.15.0
|
||||
tzdata==2026.2
|
||||
umap-learn==0.5.12
|
||||
urllib3==2.6.3
|
||||
Werkzeug==3.1.8
|
||||
xxhash==3.7.0
|
||||
yarl==1.23.0
|
||||
|
|
|
|||
|
|
@ -0,0 +1,319 @@
|
|||
"""
|
||||
Web API Server for ASR and Speaker Diarization
|
||||
提供语音识别和说话人分离的 REST API 服务
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import gc
|
||||
from pathlib import Path
|
||||
from flask import Flask, request, jsonify, send_file
|
||||
from werkzeug.utils import secure_filename
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024
|
||||
app.config['UPLOAD_FOLDER'] = 'uploads'
|
||||
app.config['RESULT_FOLDER'] = 'results'
|
||||
|
||||
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
||||
os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True)
|
||||
|
||||
GLOBAL_ASR_SERVICE = None
|
||||
GLOBAL_DIAR_SERVICE = None
|
||||
ASR_MODEL_LOADED = False
|
||||
DIAR_MODEL_LOADED = False
|
||||
ASR_MODEL_LOCK = threading.Lock()
|
||||
DIAR_MODEL_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def get_asr_service():
|
||||
global GLOBAL_ASR_SERVICE, ASR_MODEL_LOADED
|
||||
if GLOBAL_ASR_SERVICE is None:
|
||||
from asr_service import ASRService
|
||||
GLOBAL_ASR_SERVICE = ASRService()
|
||||
return GLOBAL_ASR_SERVICE
|
||||
|
||||
|
||||
def get_diar_service():
|
||||
global GLOBAL_DIAR_SERVICE, DIAR_MODEL_LOADED
|
||||
if GLOBAL_DIAR_SERVICE is None:
|
||||
from diarization_service import DiarizationService
|
||||
GLOBAL_DIAR_SERVICE = DiarizationService()
|
||||
return GLOBAL_DIAR_SERVICE
|
||||
|
||||
|
||||
@app.route('/health', methods=['GET'])
|
||||
def health_check():
|
||||
"""健康检查"""
|
||||
return jsonify({
|
||||
'status': 'ok',
|
||||
'asr_loaded': ASR_MODEL_LOADED,
|
||||
'diar_loaded': DIAR_MODEL_LOADED
|
||||
})
|
||||
|
||||
|
||||
@app.route('/api/asr/load', methods=['GET'])
|
||||
def load_asr_model():
|
||||
"""加载 ASR 模型"""
|
||||
global ASR_MODEL_LOADED
|
||||
with ASR_MODEL_LOCK:
|
||||
if ASR_MODEL_LOADED:
|
||||
return jsonify({'message': 'ASR 模型已加载', 'loaded': True})
|
||||
|
||||
try:
|
||||
data = request.json or {}
|
||||
model_name = data.get('model_name', 'paraformer-zh')
|
||||
device = data.get('device', 'auto')
|
||||
|
||||
print(f"正在加载 ASR 模型: {model_name}, 设备: {device}")
|
||||
service = get_asr_service()
|
||||
service._load_model()
|
||||
ASR_MODEL_LOADED = True
|
||||
|
||||
return jsonify({
|
||||
'message': 'ASR 模型加载成功',
|
||||
'loaded': True,
|
||||
'model': model_name,
|
||||
'device': device
|
||||
})
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
|
||||
@app.route('/api/diar/load', methods=['POST'])
|
||||
def load_diar_model():
|
||||
"""加载 3D-Speaker 模型"""
|
||||
global DIAR_MODEL_LOADED
|
||||
with DIAR_MODEL_LOCK:
|
||||
if DIAR_MODEL_LOADED:
|
||||
return jsonify({'message': '3D-Speaker 模型已加载', 'loaded': True})
|
||||
|
||||
try:
|
||||
data = request.json or {}
|
||||
embedding_model = data.get('embedding_model', 'eres2netv2')
|
||||
device = data.get('device', 'auto')
|
||||
cluster_threshold = data.get('cluster_threshold', 0.5)
|
||||
min_cluster_size = data.get('min_cluster_size', 10)
|
||||
|
||||
print(f"正在加载 3D-Speaker 模型: {embedding_model}, 设备: {device}")
|
||||
service = get_diar_service()
|
||||
service._load_model()
|
||||
DIAR_MODEL_LOADED = True
|
||||
|
||||
return jsonify({
|
||||
'message': '3D-Speaker 模型加载成功',
|
||||
'loaded': True,
|
||||
'embedding_model': embedding_model,
|
||||
'device': device
|
||||
})
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
|
||||
@app.route('/api/asr/unload', methods=['POST'])
|
||||
def unload_asr_model():
|
||||
"""卸载 ASR 模型"""
|
||||
global GLOBAL_ASR_SERVICE, ASR_MODEL_LOADED
|
||||
|
||||
try:
|
||||
GLOBAL_ASR_SERVICE = None
|
||||
ASR_MODEL_LOADED = False
|
||||
gc.collect()
|
||||
|
||||
return jsonify({'message': 'ASR 模型已卸载'})
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
|
||||
@app.route('/api/diar/unload', methods=['POST'])
|
||||
def unload_diar_model():
|
||||
"""卸载 3D-Speaker 模型"""
|
||||
global GLOBAL_DIAR_SERVICE, DIAR_MODEL_LOADED
|
||||
|
||||
try:
|
||||
GLOBAL_DIAR_SERVICE = None
|
||||
DIAR_MODEL_LOADED = False
|
||||
gc.collect()
|
||||
|
||||
return jsonify({'message': '3D-Speaker 模型已卸载'})
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
|
||||
@app.route('/api/recognize/single', methods=['POST'])
|
||||
def recognize_single():
|
||||
"""单文件推理"""
|
||||
try:
|
||||
if 'file' not in request.files:
|
||||
return jsonify({'error': '请上传音频文件'}), 400
|
||||
|
||||
file = request.files['file']
|
||||
if file.filename == '':
|
||||
return jsonify({'error': '文件名不能为空'}), 400
|
||||
|
||||
data = request.form or {}
|
||||
use_3d_speaker = data.get('use_3d_speaker', 'false').lower() == 'true'
|
||||
embedding_model = data.get('embedding_model', 'eres2netv2')
|
||||
cluster_threshold = float(data.get('cluster_threshold', 0.5))
|
||||
min_cluster_size = int(data.get('min_cluster_size', 10))
|
||||
output_format = data.get('format', 'json')
|
||||
|
||||
filename = secure_filename(file.filename)
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{task_id}_{filename}")
|
||||
file.save(audio_path)
|
||||
|
||||
try:
|
||||
global ASR_MODEL_LOADED
|
||||
service = get_asr_service()
|
||||
if not ASR_MODEL_LOADED:
|
||||
service._load_model()
|
||||
ASR_MODEL_LOADED = True
|
||||
|
||||
sentences = service.recognize(
|
||||
audio_path,
|
||||
use_3d_speaker=use_3d_speaker,
|
||||
embedding_model=embedding_model,
|
||||
cluster_threshold=cluster_threshold,
|
||||
min_cluster_size=min_cluster_size
|
||||
)
|
||||
|
||||
result = {
|
||||
'file': filename,
|
||||
'total_sentences': len(sentences),
|
||||
'sentences': [s.to_dict() for s in sentences]
|
||||
}
|
||||
|
||||
if output_format == 'json':
|
||||
result_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_result.json")
|
||||
with open(result_path, 'w', encoding='utf-8') as f:
|
||||
import json
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
return jsonify(result)
|
||||
else:
|
||||
srt_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_result.srt")
|
||||
service.export_to_srt(sentences, srt_path)
|
||||
return send_file(srt_path, as_attachment=True, download_name=f"{task_id}_result.srt")
|
||||
|
||||
finally:
|
||||
if os.path.exists(audio_path):
|
||||
os.remove(audio_path)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
|
||||
@app.route('/api/recognize/batch', methods=['POST'])
|
||||
def recognize_batch():
|
||||
"""批量推理"""
|
||||
try:
|
||||
if 'files' not in request.files:
|
||||
return jsonify({'error': '请上传音频文件'}), 400
|
||||
|
||||
files = request.files.getlist('files')
|
||||
if not files or len(files) == 0:
|
||||
return jsonify({'error': '文件列表为空'}), 400
|
||||
|
||||
data = request.form or {}
|
||||
use_3d_speaker = data.get('use_3d_speaker', 'false').lower() == 'true'
|
||||
embedding_model = data.get('embedding_model', 'eres2netv2')
|
||||
cluster_threshold = float(data.get('cluster_threshold', 0.5))
|
||||
min_cluster_size = int(data.get('min_cluster_size', 10))
|
||||
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
audio_paths = []
|
||||
|
||||
for f in files:
|
||||
if f.filename:
|
||||
filename = secure_filename(f.filename)
|
||||
audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{task_id}_{filename}")
|
||||
f.save(audio_path)
|
||||
audio_paths.append(audio_path)
|
||||
|
||||
try:
|
||||
global ASR_MODEL_LOADED
|
||||
service = get_asr_service()
|
||||
if not ASR_MODEL_LOADED:
|
||||
service._load_model()
|
||||
ASR_MODEL_LOADED = True
|
||||
|
||||
results = []
|
||||
for audio_path in audio_paths:
|
||||
try:
|
||||
sentences = service.recognize(
|
||||
audio_path,
|
||||
use_3d_speaker=use_3d_speaker,
|
||||
embedding_model=embedding_model,
|
||||
cluster_threshold=cluster_threshold,
|
||||
min_cluster_size=min_cluster_size
|
||||
)
|
||||
results.append({
|
||||
'file': os.path.basename(audio_path),
|
||||
'total_sentences': len(sentences),
|
||||
'sentences': [s.to_dict() for s in sentences]
|
||||
})
|
||||
except Exception as e:
|
||||
results.append({
|
||||
'file': os.path.basename(audio_path),
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
result_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_batch_result.json")
|
||||
import json
|
||||
with open(result_path, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
'task_id': task_id,
|
||||
'total_files': len(results),
|
||||
'results': results
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
|
||||
return jsonify({
|
||||
'task_id': task_id,
|
||||
'total_files': len(results),
|
||||
'results': results,
|
||||
'result_file': result_path
|
||||
})
|
||||
|
||||
finally:
|
||||
for audio_path in audio_paths:
|
||||
if os.path.exists(audio_path):
|
||||
os.remove(audio_path)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
|
||||
@app.route('/api/status', methods=['GET'])
|
||||
def get_status():
|
||||
"""获取模型状态"""
|
||||
return jsonify({
|
||||
'asr_loaded': ASR_MODEL_LOADED,
|
||||
'diar_loaded': DIAR_MODEL_LOADED,
|
||||
'asr_model': 'paraformer-zh',
|
||||
'diar_model': '3D-Speaker'
|
||||
})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("=" * 60)
|
||||
print(" ASR & Speaker Diarization API Server")
|
||||
print("=" * 60)
|
||||
print("\nAPI 接口:")
|
||||
print(" GET /health - 健康检查")
|
||||
print(" GET /api/status - 获取模型状态")
|
||||
print(" POST /api/asr/load - 加载 ASR 模型")
|
||||
print(" POST /api/diar/load - 加载 3D-Speaker 模型")
|
||||
print(" POST /api/asr/unload - 卸载 ASR 模型")
|
||||
print(" POST /api/diar/unload - 卸载 3D-Speaker 模型")
|
||||
print(" POST /api/recognize/single - 单文件推理")
|
||||
print(" POST /api/recognize/batch - 批量推理")
|
||||
print("\n" + "=" * 60)
|
||||
print("启动服务: http://localhost:5000")
|
||||
print("=" * 60)
|
||||
app.run(host='0.0.0.0', port=5000, debug=False)
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
"""
|
||||
测试模型加载(不使用多进程)
|
||||
用于诊断是否是模型本身的问题
|
||||
"""
|
||||
|
||||
import sys
|
||||
import torch
|
||||
|
||||
print("=" * 60)
|
||||
print("模型加载测试(单进程模式)")
|
||||
print("=" * 60)
|
||||
print(f"Python 版本:{sys.version}")
|
||||
print(f"PyTorch 版本:{torch.__version__}")
|
||||
print(f"CUDA 可用:{torch.cuda.is_available()}")
|
||||
if torch.cuda.is_available():
|
||||
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
||||
print(f"CUDA 版本:{torch.version.cuda}")
|
||||
print()
|
||||
|
||||
try:
|
||||
# 测试 ASR 模型
|
||||
print("=" * 60)
|
||||
print("测试 1: 加载 ASR 模型 (Paraformer)")
|
||||
print("=" * 60)
|
||||
from asr_service import ASRService
|
||||
|
||||
asr_service = ASRService(model_name="paraformer-zh", device="auto")
|
||||
print("✓ ASRService 初始化完成")
|
||||
|
||||
asr_service._load_model()
|
||||
print("✓ ASR 模型加载完成")
|
||||
print(f" 模型类型:{type(asr_service._model)}")
|
||||
print()
|
||||
|
||||
# 测试说话人分离模型
|
||||
print("=" * 60)
|
||||
print("测试 2: 加载说话人分离模型 (3D-Speaker)")
|
||||
print("=" * 60)
|
||||
from diarization_service import DiarizationService
|
||||
|
||||
diar_service = DiarizationService(
|
||||
embedding_model="eres2netv2",
|
||||
device="auto",
|
||||
cluster_threshold=0.5,
|
||||
min_cluster_size=10
|
||||
)
|
||||
print("✓ DiarizationService 初始化完成")
|
||||
|
||||
diar_service._load_model()
|
||||
print("✓ 说话人分离模型加载完成")
|
||||
print(f" 模型类型:{type(diar_service.model)}")
|
||||
print()
|
||||
|
||||
# 测试音频处理
|
||||
print("=" * 60)
|
||||
print("测试 3: 测试音频处理")
|
||||
print("=" * 60)
|
||||
|
||||
# 检查是否有测试音频
|
||||
from pathlib import Path
|
||||
test_audio = Path("test.wav")
|
||||
if test_audio.exists():
|
||||
print(f"找到测试音频:{test_audio}")
|
||||
|
||||
print("执行 ASR 识别...")
|
||||
sentences = asr_service.recognize(str(test_audio))
|
||||
print(f"✓ ASR 识别完成,共 {len(sentences)} 句")
|
||||
|
||||
if sentences:
|
||||
print(f" 第一句:{sentences[0]}")
|
||||
|
||||
print("执行说话人分离...")
|
||||
segments = diar_service.diarize(str(test_audio))
|
||||
print(f"✓ 说话人分离完成,共 {len(segments)} 个片段")
|
||||
|
||||
if segments:
|
||||
print(f" 第一个片段:{segments[0]}")
|
||||
else:
|
||||
print("⚠️ 未找到测试音频 (test.wav),跳过处理测试")
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("✓ 所有测试通过!模型工作正常")
|
||||
print("=" * 60)
|
||||
|
||||
except Exception as e:
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("✗ 测试失败!")
|
||||
print("=" * 60)
|
||||
print(f"错误:{e}")
|
||||
print()
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
"""
|
||||
测试分阶段处理逻辑
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
# 导入 main.py 中的函数
|
||||
from main import (
|
||||
VIDEO_DIR,
|
||||
OUTPUT_DIR,
|
||||
TEMP_DIR,
|
||||
process_batch_diarization,
|
||||
process_batch_asr,
|
||||
get_video_list
|
||||
)
|
||||
|
||||
def test_staged_processing():
|
||||
"""测试分阶段处理"""
|
||||
|
||||
print("=" * 60)
|
||||
print("分阶段处理测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 获取视频列表(只取前 2 个进行测试)
|
||||
video_paths = get_video_list(VIDEO_DIR)
|
||||
if not video_paths:
|
||||
print("✗ 未找到视频文件")
|
||||
return
|
||||
|
||||
# 只测试前 2 个
|
||||
test_videos = video_paths[:2]
|
||||
print(f"测试视频:{len(test_videos)}")
|
||||
for v in test_videos:
|
||||
print(f" - {v.name}")
|
||||
print()
|
||||
|
||||
# 阶段 1: 说话人分离
|
||||
print("=" * 60)
|
||||
print("阶段 1: 说话人分离")
|
||||
print("=" * 60)
|
||||
diar_results = process_batch_diarization(test_videos, max_workers=1)
|
||||
|
||||
print(f"\n阶段 1 结果:{len(diar_results)}/{len(test_videos)} 成功")
|
||||
for video, result_path in diar_results.items():
|
||||
status = "✓" if result_path else "✗"
|
||||
print(f" {status} {video.name}: {result_path}")
|
||||
print()
|
||||
|
||||
# 阶段 2: ASR + 合并
|
||||
print("=" * 60)
|
||||
print("阶段 2: ASR + 合并")
|
||||
print("=" * 60)
|
||||
results = process_batch_asr(test_videos, diar_results, max_workers=1)
|
||||
|
||||
print(f"\n阶段 2 结果:{len(results)}/{len(test_videos)} 完成")
|
||||
for result in results:
|
||||
status = "✓" if result.get("success") else "✗"
|
||||
print(f" {status} {Path(result['video']).name}")
|
||||
if result.get("speaker_counts"):
|
||||
print(f" 说话人:{result['speaker_counts']}")
|
||||
print()
|
||||
|
||||
print("=" * 60)
|
||||
print("✓ 测试完成!")
|
||||
print("=" * 60)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_staged_processing()
|
||||
Loading…
Reference in New Issue