603 lines
19 KiB
Python
603 lines
19 KiB
Python
"""
|
||
批量视频语音识别 + 说话人分离
|
||
|
||
功能:
|
||
1. 自动扫描视频目录
|
||
2. 两阶段处理:
|
||
- 阶段 1: 说话人分离
|
||
- 阶段 2: ASR 识别 + 合并结果
|
||
3. 每个视频只加载一次模型
|
||
4. 顺序处理,避免多进程 CUDA 共享问题
|
||
"""
|
||
|
||
import gc
|
||
import json
|
||
import os
|
||
import shutil
|
||
import subprocess
|
||
import time
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional
|
||
|
||
# 路径配置
|
||
BASE_DIR = Path(__file__).parent.absolute()
|
||
TEMP_DIR = BASE_DIR / "temp"
|
||
OUTPUT_DIR = BASE_DIR / "output"
|
||
|
||
|
||
|
||
# 支持的视频格式
|
||
SUPPORTED_VIDEO_FORMATS = ["*.mp4", "*.avi", "*.mkv", "*.mov", "*.flv", "*.wmv", "*.m4v"]
|
||
|
||
|
||
def get_video_list(folder_path: Path) -> List[Path]:
|
||
"""
|
||
从文件夹自动获取视频列表,按文件名中的时间排序
|
||
|
||
Args:
|
||
folder_path: 视频文件夹路径
|
||
|
||
Returns:
|
||
按文件名排序后的视频路径列表
|
||
"""
|
||
video_paths = []
|
||
|
||
# 扫描所有支持的视频格式
|
||
for pattern in SUPPORTED_VIDEO_FORMATS:
|
||
video_paths.extend(folder_path.glob(pattern))
|
||
|
||
# 按文件名排序(假设文件名包含时间戳,如 VID_20251031_132320_019.mp4)
|
||
# 使用文件名的字典序排序,时间戳格式的文件名会自动按时间顺序排列
|
||
video_paths.sort(key=lambda p: p.name)
|
||
|
||
return video_paths
|
||
|
||
|
||
|
||
|
||
|
||
def clear_temp_dir():
|
||
"""清空 temp 目录"""
|
||
print("=" * 60)
|
||
print("清空临时目录...")
|
||
print("=" * 60)
|
||
|
||
if TEMP_DIR.exists():
|
||
try:
|
||
shutil.rmtree(TEMP_DIR)
|
||
print(f"✓ 已删除:{TEMP_DIR}")
|
||
except Exception as e:
|
||
print(f"✗ 删除失败:{e}")
|
||
|
||
TEMP_DIR.mkdir(parents=True, exist_ok=True)
|
||
print(f"✓ 已创建:{TEMP_DIR}")
|
||
print()
|
||
|
||
|
||
def ensure_output_dir():
|
||
"""确保 output 目录存在"""
|
||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||
print(f"✓ 输出目录:{OUTPUT_DIR}")
|
||
print()
|
||
|
||
|
||
def extract_wav(video_path: Path, temp_dir: Path) -> Optional[Path]:
|
||
"""
|
||
从视频提取 WAV 音频
|
||
|
||
Args:
|
||
video_path: 视频文件路径
|
||
temp_dir: 临时目录
|
||
|
||
Returns:
|
||
WAV 文件路径,失败返回 None
|
||
"""
|
||
try:
|
||
wav_path = temp_dir / f"{video_path.stem}.wav"
|
||
|
||
# 使用 ffmpeg 提取音频
|
||
cmd = [
|
||
"ffmpeg",
|
||
"-i", str(video_path),
|
||
"-vn", # 不处理视频
|
||
"-acodec", "pcm_s16le", # 16 位 PCM 编码
|
||
"-ar", "16000", # 16kHz 采样率
|
||
"-ac", "1", # 单声道
|
||
"-y", # 覆盖已存在文件
|
||
str(wav_path)
|
||
]
|
||
|
||
subprocess.run(
|
||
cmd,
|
||
check=True,
|
||
capture_output=True,
|
||
timeout=300 # 5 分钟超时
|
||
)
|
||
|
||
if wav_path.exists():
|
||
print(f"✓ 提取音频:{video_path.name} -> {wav_path.name}")
|
||
return wav_path
|
||
else:
|
||
print(f"✗ 提取失败:{video_path.name}")
|
||
return None
|
||
|
||
except subprocess.TimeoutExpired:
|
||
print(f"✗ 提取超时:{video_path.name}")
|
||
return None
|
||
except Exception as e:
|
||
print(f"✗ 提取错误:{video_path.name} - {e}")
|
||
return None
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
def process_batch_diarization(video_paths, max_workers=1):
|
||
"""
|
||
第一阶段:批量执行说话人分离(主进程顺序处理)
|
||
|
||
Args:
|
||
video_paths: 视频路径列表
|
||
max_workers: 并发数(目前固定为 1)
|
||
|
||
Returns:
|
||
字典:video_path -> diar_result_path
|
||
"""
|
||
print("=" * 60)
|
||
print("第一阶段:批量说话人分离")
|
||
print("=" * 60)
|
||
print(f"视频数量:{len(video_paths)}")
|
||
print(f"处理模式:顺序处理(单进程)")
|
||
print()
|
||
|
||
# 加载说话人分离模型(只加载一次)
|
||
print("加载说话人分离模型...")
|
||
from diarization_service import DiarizationService
|
||
|
||
diar_service = DiarizationService(
|
||
embedding_model="eres2netv2",
|
||
device="auto",
|
||
cluster_threshold=0.5,
|
||
min_cluster_size=10
|
||
)
|
||
diar_service._load_model()
|
||
print("✓ 说话人分离模型已加载")
|
||
print()
|
||
|
||
results = {}
|
||
start_time = time.time()
|
||
|
||
# 顺序处理每个视频
|
||
for i, video_path in enumerate(video_paths, 1):
|
||
video_start_time = time.time()
|
||
try:
|
||
print(f"\n[{i}/{len(video_paths)}] 处理:{video_path.name}")
|
||
|
||
# 1. 提取 WAV
|
||
wav_path = extract_wav(video_path, TEMP_DIR)
|
||
if wav_path is None:
|
||
print(f" ✗ 音频提取失败")
|
||
results[video_path] = {
|
||
"success": False,
|
||
"diar_result": None,
|
||
"error": "音频提取失败",
|
||
"process_time": time.time() - video_start_time
|
||
}
|
||
continue
|
||
|
||
# 2. 执行说话人分离
|
||
diar_segments = diar_service.diarize(wav_path)
|
||
|
||
if not diar_segments:
|
||
print(f" ✗ 说话人分离结果为空")
|
||
results[video_path] = {
|
||
"success": False,
|
||
"diar_result": None,
|
||
"error": "说话人分离结果为空",
|
||
"process_time": time.time() - video_start_time
|
||
}
|
||
continue
|
||
|
||
# 3. 保存说话人分离结果(临时文件)
|
||
temp_diar_path = TEMP_DIR / f"{video_path.stem}_diar.json"
|
||
diar_result = {
|
||
"segments": [seg.to_dict() for seg in diar_segments]
|
||
}
|
||
from map_speaker import save_json
|
||
save_json(temp_diar_path, diar_result)
|
||
|
||
video_process_time = time.time() - video_start_time
|
||
results[video_path] = {
|
||
"success": True,
|
||
"diar_result": str(temp_diar_path),
|
||
"error": None,
|
||
"process_time": video_process_time
|
||
}
|
||
print(f" ✓ 说话人分离完成 (耗时:{video_process_time:.1f}s)")
|
||
|
||
# 4. 清理临时 WAV(保留用于后续 ASR)
|
||
# 注意:这里不删除,ASR 阶段还需要
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
video_process_time = time.time() - video_start_time
|
||
print(f" ✗ 处理失败:{e}")
|
||
traceback.print_exc()
|
||
results[video_path] = {
|
||
"success": False,
|
||
"diar_result": None,
|
||
"error": str(e),
|
||
"process_time": video_process_time
|
||
}
|
||
|
||
# 显示进度
|
||
elapsed = time.time() - start_time
|
||
avg_time = elapsed / len(results) if results else 1
|
||
remaining = (len(video_paths) - len(results)) * avg_time
|
||
|
||
print(f"\n进度:{len(results)}/{len(video_paths)}")
|
||
print(f"已用:{elapsed:.1f}s, 预计剩余:{remaining:.1f}s")
|
||
|
||
total_time = time.time() - start_time
|
||
print(f"\n✓ 第一阶段完成,耗时:{total_time:.1f}s")
|
||
print()
|
||
|
||
return results
|
||
|
||
|
||
def process_batch_asr(video_paths, diar_results, max_workers=1):
|
||
"""
|
||
第二阶段:批量执行 ASR 识别并合并结果(主进程顺序处理)
|
||
|
||
Args:
|
||
video_paths: 视频路径列表
|
||
diar_results: 说话人分离结果映射
|
||
max_workers: 并发数(目前固定为 1)
|
||
|
||
Returns:
|
||
列表:最终结果列表
|
||
"""
|
||
print("=" * 60)
|
||
print("第二阶段:批量语音识别 + 合并结果")
|
||
print("=" * 60)
|
||
print(f"视频数量:{len(video_paths)}")
|
||
print(f"处理模式:顺序处理(单进程)")
|
||
print()
|
||
|
||
# 加载 ASR 模型(只加载一次)
|
||
print("加载 ASR 模型...")
|
||
from asr_service import ASRService
|
||
|
||
asr_service = ASRService(
|
||
model_name="paraformer-zh",
|
||
device="auto",
|
||
min_segment_duration=0.5, # 增加最小片段时长
|
||
merge_gap=0.3 # 减少合并间隔
|
||
)
|
||
asr_service._load_model()
|
||
print("✓ ASR 模型已加载")
|
||
print()
|
||
|
||
results = []
|
||
start_time = time.time()
|
||
|
||
# 顺序处理每个视频
|
||
for i, video_path in enumerate(video_paths, 1):
|
||
video_start_time = time.time()
|
||
diar_info = diar_results.get(video_path)
|
||
|
||
# 检查第一阶段的结果(现在是字典结构)
|
||
if not diar_info:
|
||
print(f"\n[{i}/{len(video_paths)}] 跳过 {video_path.name}(无说话人分离结果)")
|
||
results.append({
|
||
"video": str(video_path),
|
||
"success": False,
|
||
"error": "无说话人分离结果",
|
||
"process_time": 0.0
|
||
})
|
||
continue
|
||
|
||
# 如果第一阶段失败,跳过该视频
|
||
if not diar_info.get("success"):
|
||
print(f"\n[{i}/{len(video_paths)}] 跳过 {video_path.name}(第一阶段失败:{diar_info.get('error')})")
|
||
results.append({
|
||
"video": str(video_path),
|
||
"success": False,
|
||
"error": f"说话人分离失败:{diar_info.get('error')}",
|
||
"process_time": diar_info.get("process_time", 0.0)
|
||
})
|
||
continue
|
||
|
||
diar_path = diar_info.get("diar_result")
|
||
if not diar_path:
|
||
print(f"\n[{i}/{len(video_paths)}] 跳过 {video_path.name}(无说话人分离结果文件)")
|
||
results.append({
|
||
"video": str(video_path),
|
||
"success": False,
|
||
"error": "说话人分离结果文件不存在",
|
||
"process_time": 0.0
|
||
})
|
||
continue
|
||
|
||
wav_path = None
|
||
|
||
try:
|
||
print(f"\n[{i}/{len(video_paths)}] 处理:{video_path.name}")
|
||
|
||
# 1. 提取 WAV(如果不存在)
|
||
wav_path = TEMP_DIR / f"{video_path.stem}.wav"
|
||
if not wav_path.exists():
|
||
wav_path = extract_wav(video_path, TEMP_DIR)
|
||
if wav_path is None:
|
||
print(f" ✗ 音频提取失败")
|
||
results.append({
|
||
"video": str(video_path),
|
||
"success": False,
|
||
"error": "音频提取失败",
|
||
"process_time": time.time() - video_start_time
|
||
})
|
||
continue
|
||
|
||
# 2. 加载说话人分离结果
|
||
from map_speaker import load_json
|
||
diar_result = load_json(diar_path)
|
||
|
||
# 3. 执行 ASR 识别(不使用 ASR 自带的说话人)
|
||
asr_sentences = asr_service.recognize(wav_path)
|
||
|
||
if not asr_sentences:
|
||
print(f" ✗ ASR 识别结果为空")
|
||
results.append({
|
||
"video": str(video_path),
|
||
"success": False,
|
||
"error": "ASR 识别结果为空",
|
||
"process_time": time.time() - video_start_time
|
||
})
|
||
continue
|
||
|
||
# 4. 合并说话人信息(只使用 3D-Speaker 结果)
|
||
print(f" 合并结果...")
|
||
for sentence in asr_sentences:
|
||
# 查找该时间段对应的说话人
|
||
matched_speaker = None
|
||
best_overlap = 0.0
|
||
|
||
for seg in diar_result["segments"]:
|
||
seg_begin = seg['begin_time']
|
||
seg_end = seg['end_time']
|
||
|
||
# 计算重叠时间
|
||
overlap_begin = max(sentence.begin_time, seg_begin)
|
||
overlap_end = min(sentence.end_time, seg_end)
|
||
|
||
if overlap_begin < overlap_end:
|
||
overlap_duration = overlap_end - overlap_begin
|
||
if overlap_duration > best_overlap:
|
||
best_overlap = overlap_duration
|
||
matched_speaker = seg['speaker']
|
||
|
||
# 如果有匹配,使用匹配的说话人;否则使用 speaker_0
|
||
if matched_speaker:
|
||
sentence.speaker = matched_speaker
|
||
else:
|
||
sentence.speaker = "speaker_0"
|
||
|
||
# 5. 保存最终结果
|
||
output_file = OUTPUT_DIR / f"{video_path.stem}_result.json"
|
||
# 确保 asr_sentences 是列表类型
|
||
if isinstance(asr_sentences, dict):
|
||
# 如果是字典,尝试获取 sentences 键或转换为空列表
|
||
asr_sentences = asr_sentences.get("sentences", [])
|
||
|
||
asr_service.export_to_json(asr_sentences, output_file)
|
||
|
||
# 统计说话人
|
||
speaker_counts = {}
|
||
for sentence in asr_sentences:
|
||
speaker = sentence.speaker
|
||
speaker_counts[speaker] = speaker_counts.get(speaker, 0) + 1
|
||
|
||
video_process_time = time.time() - video_start_time
|
||
results.append({
|
||
"video": str(video_path),
|
||
"success": True,
|
||
"asr_result": [s.to_dict() for s in asr_sentences],
|
||
"merged_result": str(output_file),
|
||
"speaker_counts": speaker_counts,
|
||
"total_sentences": len(asr_sentences),
|
||
"process_time": video_process_time
|
||
})
|
||
|
||
print(f" ✓ 处理完成 (耗时:{video_process_time:.1f}s)")
|
||
print(f" - 句子数:{len(asr_sentences)}")
|
||
print(f" - 说话人:{speaker_counts}")
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
video_process_time = time.time() - video_start_time
|
||
print(f" ✗ 处理失败:{e}")
|
||
traceback.print_exc()
|
||
results.append({
|
||
"video": str(video_path),
|
||
"success": False,
|
||
"error": str(e),
|
||
"process_time": video_process_time
|
||
})
|
||
|
||
finally:
|
||
# 清理临时文件
|
||
if wav_path and wav_path.exists():
|
||
try:
|
||
wav_path.unlink()
|
||
except:
|
||
pass
|
||
|
||
if diar_path:
|
||
try:
|
||
Path(diar_path).unlink()
|
||
except:
|
||
pass
|
||
|
||
# 显示进度
|
||
elapsed = time.time() - start_time
|
||
avg_time = elapsed / len(results) if results else 1
|
||
remaining = (len(video_paths) - len(results)) * avg_time
|
||
|
||
print(f"\n进度:{len(results)}/{len(video_paths)}")
|
||
print(f"已用:{elapsed:.1f}s, 预计剩余:{remaining:.1f}s")
|
||
|
||
total_time = time.time() - start_time
|
||
print(f"\n✓ 第二阶段完成,耗时:{total_time:.1f}s")
|
||
print()
|
||
|
||
# 汇总报告
|
||
success_count = sum(1 for r in results if r["success"])
|
||
|
||
print("\n" + "=" * 60)
|
||
print("处理完成汇总")
|
||
print("=" * 60)
|
||
print(f"总耗时:{total_time:.1f}s")
|
||
print(f"平均每个视频:{total_time/len(video_paths):.1f}s")
|
||
print(f"成功:{success_count}/{len(video_paths)}")
|
||
print(f"失败:{len(video_paths) - success_count}")
|
||
|
||
# 保存汇总报告
|
||
summary = {
|
||
"total_videos": len(video_paths),
|
||
"success_count": success_count,
|
||
"failed_count": len(video_paths) - success_count,
|
||
"total_time_seconds": round(total_time, 2),
|
||
"results": [
|
||
{
|
||
"video": Path(r["video"]).name,
|
||
"success": r["success"],
|
||
"output": r.get("merged_result"),
|
||
"total_sentences": r.get("total_sentences", 0),
|
||
"speaker_counts": r.get("speaker_counts", {}),
|
||
"process_time_seconds": round(r.get("process_time", 0.0), 2),
|
||
"error": r.get("error")
|
||
}
|
||
for r in results
|
||
]
|
||
}
|
||
|
||
summary_path = OUTPUT_DIR / "batch_summary.json"
|
||
with open(summary_path, "w", encoding="utf-8") as f:
|
||
json.dump(summary, f, ensure_ascii=False, indent=2)
|
||
|
||
print(f"\n汇总报告:{summary_path}")
|
||
print("=" * 60)
|
||
|
||
return results
|
||
|
||
|
||
def main(path: str):
|
||
"""主函数"""
|
||
import torch
|
||
# 拼接根目录
|
||
path = os.path.join(os.path.dirname(__file__), "input", path)
|
||
print(f"开始处理路径:{path}")
|
||
|
||
print("\n" + "=" * 60)
|
||
print(" 并发批量语音识别处理系统")
|
||
print("=" * 60)
|
||
print()
|
||
|
||
# 0. 检查输入路径是否存在
|
||
if not Path(path).exists():
|
||
print(f"✗ 错误:输入路径不存在 - {path}")
|
||
return "输入路径不存在"
|
||
|
||
# 文件夹为空
|
||
if Path(path).is_dir() and not list(Path(path).iterdir()):
|
||
print("✗ 错误:文件夹为空")
|
||
return "文件夹为空"
|
||
|
||
# 1. 清空 temp 目录
|
||
clear_temp_dir()
|
||
|
||
# 2. 确保 output 目录存在
|
||
ensure_output_dir()
|
||
|
||
# 3. 准备视频列表(从 path 自动获取)
|
||
# 如果是文件夹,递归获取所有视频文件
|
||
# 如果是文件,直接添加到列表
|
||
video_paths = []
|
||
if Path(path).is_dir():
|
||
video_folder = Path(path)
|
||
if not video_folder.exists():
|
||
print(f"✗ 错误:视频文件夹不存在 - {video_folder}")
|
||
return
|
||
|
||
video_paths = get_video_list(video_folder)
|
||
if not video_paths:
|
||
print("✗ 错误:未找到任何视频文件")
|
||
print(f"支持格式:{', '.join(SUPPORTED_VIDEO_FORMATS)}")
|
||
print(f"请检查文件夹:{video_folder}")
|
||
return
|
||
else:
|
||
video_paths.append(Path(path))
|
||
|
||
video_paths = list(set(video_paths)) # 去重
|
||
|
||
print(f"找到 {len(video_paths)} 个视频文件")
|
||
for vp in video_paths:
|
||
print(f" - {vp.name}")
|
||
print()
|
||
|
||
# 4. 分阶段处理
|
||
# 阶段 1: 说话人分离
|
||
# 阶段 2: ASR 识别 + 合并结果
|
||
|
||
print()
|
||
print("=" * 60)
|
||
print("处理策略")
|
||
print("=" * 60)
|
||
print("阶段 1: 批量说话人分离(加载说话人分离模型)")
|
||
print(" ↓ 释放内存")
|
||
print("阶段 2: 批量语音识别 + 合并结果(加载 ASR 模型)")
|
||
print("=" * 60)
|
||
print()
|
||
|
||
# 阶段 1: 说话人分离
|
||
diar_results = process_batch_diarization(video_paths, max_workers=1)
|
||
|
||
# 检查阶段 1 的结果
|
||
success_count = len([v for v, r in diar_results.items() if r])
|
||
if success_count == 0:
|
||
print("✗ 错误:第一阶段全部失败")
|
||
return
|
||
|
||
print(f"✓ 第一阶段成功:{success_count}/{len(video_paths)}")
|
||
print()
|
||
|
||
# 强制垃圾回收,释放显存
|
||
import torch
|
||
gc.collect()
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
print("✓ 已清理 CUDA 缓存,准备第二阶段")
|
||
print()
|
||
|
||
# 阶段 2: ASR 识别 + 合并结果
|
||
results = process_batch_asr(video_paths, diar_results, max_workers=1)
|
||
|
||
# 5. 最终清理
|
||
print("\n清理临时文件...")
|
||
clear_temp_dir()
|
||
|
||
print("\n✓ 全部完成!")
|
||
print(f"输出目录:{OUTPUT_DIR}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 视频文件夹路径(全局变量)
|
||
# PATH = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\宁波北仑区鼎邦杰西雅服饰有限公司"
|
||
# PATH = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\temp"
|
||
# PATH = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\temp\VID_20251104_085655_024.AVI"
|
||
PATH = r"VID_20251104_085655_024.AVI"
|
||
main(PATH)
|