ai/SpeechRecognition/main.py

603 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
批量视频语音识别 + 说话人分离
功能:
1. 自动扫描视频目录
2. 两阶段处理:
- 阶段 1: 说话人分离
- 阶段 2: ASR 识别 + 合并结果
3. 每个视频只加载一次模型
4. 顺序处理,避免多进程 CUDA 共享问题
"""
import gc
import json
import os
import shutil
import subprocess
import time
from pathlib import Path
from typing import Dict, List, Optional
# 路径配置
BASE_DIR = Path(__file__).parent.absolute()
TEMP_DIR = BASE_DIR / "temp"
OUTPUT_DIR = BASE_DIR / "output"
# 支持的视频格式
SUPPORTED_VIDEO_FORMATS = ["*.mp4", "*.avi", "*.mkv", "*.mov", "*.flv", "*.wmv", "*.m4v"]
def get_video_list(folder_path: Path) -> List[Path]:
"""
从文件夹自动获取视频列表,按文件名中的时间排序
Args:
folder_path: 视频文件夹路径
Returns:
按文件名排序后的视频路径列表
"""
video_paths = []
# 扫描所有支持的视频格式
for pattern in SUPPORTED_VIDEO_FORMATS:
video_paths.extend(folder_path.glob(pattern))
# 按文件名排序(假设文件名包含时间戳,如 VID_20251031_132320_019.mp4
# 使用文件名的字典序排序,时间戳格式的文件名会自动按时间顺序排列
video_paths.sort(key=lambda p: p.name)
return video_paths
def clear_temp_dir():
"""清空 temp 目录"""
print("=" * 60)
print("清空临时目录...")
print("=" * 60)
if TEMP_DIR.exists():
try:
shutil.rmtree(TEMP_DIR)
print(f"✓ 已删除:{TEMP_DIR}")
except Exception as e:
print(f"✗ 删除失败:{e}")
TEMP_DIR.mkdir(parents=True, exist_ok=True)
print(f"✓ 已创建:{TEMP_DIR}")
print()
def ensure_output_dir():
"""确保 output 目录存在"""
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print(f"✓ 输出目录:{OUTPUT_DIR}")
print()
def extract_wav(video_path: Path, temp_dir: Path) -> Optional[Path]:
"""
从视频提取 WAV 音频
Args:
video_path: 视频文件路径
temp_dir: 临时目录
Returns:
WAV 文件路径,失败返回 None
"""
try:
wav_path = temp_dir / f"{video_path.stem}.wav"
# 使用 ffmpeg 提取音频
cmd = [
"ffmpeg",
"-i", str(video_path),
"-vn", # 不处理视频
"-acodec", "pcm_s16le", # 16 位 PCM 编码
"-ar", "16000", # 16kHz 采样率
"-ac", "1", # 单声道
"-y", # 覆盖已存在文件
str(wav_path)
]
subprocess.run(
cmd,
check=True,
capture_output=True,
timeout=300 # 5 分钟超时
)
if wav_path.exists():
print(f"✓ 提取音频:{video_path.name} -> {wav_path.name}")
return wav_path
else:
print(f"✗ 提取失败:{video_path.name}")
return None
except subprocess.TimeoutExpired:
print(f"✗ 提取超时:{video_path.name}")
return None
except Exception as e:
print(f"✗ 提取错误:{video_path.name} - {e}")
return None
def process_batch_diarization(video_paths, max_workers=1):
"""
第一阶段:批量执行说话人分离(主进程顺序处理)
Args:
video_paths: 视频路径列表
max_workers: 并发数(目前固定为 1
Returns:
字典video_path -> diar_result_path
"""
print("=" * 60)
print("第一阶段:批量说话人分离")
print("=" * 60)
print(f"视频数量:{len(video_paths)}")
print(f"处理模式:顺序处理(单进程)")
print()
# 加载说话人分离模型(只加载一次)
print("加载说话人分离模型...")
from diarization_service import DiarizationService
diar_service = DiarizationService(
embedding_model="eres2netv2",
device="auto",
cluster_threshold=0.5,
min_cluster_size=10
)
diar_service._load_model()
print("✓ 说话人分离模型已加载")
print()
results = {}
start_time = time.time()
# 顺序处理每个视频
for i, video_path in enumerate(video_paths, 1):
video_start_time = time.time()
try:
print(f"\n[{i}/{len(video_paths)}] 处理:{video_path.name}")
# 1. 提取 WAV
wav_path = extract_wav(video_path, TEMP_DIR)
if wav_path is None:
print(f" ✗ 音频提取失败")
results[video_path] = {
"success": False,
"diar_result": None,
"error": "音频提取失败",
"process_time": time.time() - video_start_time
}
continue
# 2. 执行说话人分离
diar_segments = diar_service.diarize(wav_path)
if not diar_segments:
print(f" ✗ 说话人分离结果为空")
results[video_path] = {
"success": False,
"diar_result": None,
"error": "说话人分离结果为空",
"process_time": time.time() - video_start_time
}
continue
# 3. 保存说话人分离结果(临时文件)
temp_diar_path = TEMP_DIR / f"{video_path.stem}_diar.json"
diar_result = {
"segments": [seg.to_dict() for seg in diar_segments]
}
from map_speaker import save_json
save_json(temp_diar_path, diar_result)
video_process_time = time.time() - video_start_time
results[video_path] = {
"success": True,
"diar_result": str(temp_diar_path),
"error": None,
"process_time": video_process_time
}
print(f" ✓ 说话人分离完成 (耗时:{video_process_time:.1f}s)")
# 4. 清理临时 WAV保留用于后续 ASR
# 注意这里不删除ASR 阶段还需要
except Exception as e:
import traceback
video_process_time = time.time() - video_start_time
print(f" ✗ 处理失败:{e}")
traceback.print_exc()
results[video_path] = {
"success": False,
"diar_result": None,
"error": str(e),
"process_time": video_process_time
}
# 显示进度
elapsed = time.time() - start_time
avg_time = elapsed / len(results) if results else 1
remaining = (len(video_paths) - len(results)) * avg_time
print(f"\n进度:{len(results)}/{len(video_paths)}")
print(f"已用:{elapsed:.1f}s, 预计剩余:{remaining:.1f}s")
total_time = time.time() - start_time
print(f"\n✓ 第一阶段完成,耗时:{total_time:.1f}s")
print()
return results
def process_batch_asr(video_paths, diar_results, max_workers=1):
"""
第二阶段:批量执行 ASR 识别并合并结果(主进程顺序处理)
Args:
video_paths: 视频路径列表
diar_results: 说话人分离结果映射
max_workers: 并发数(目前固定为 1
Returns:
列表:最终结果列表
"""
print("=" * 60)
print("第二阶段:批量语音识别 + 合并结果")
print("=" * 60)
print(f"视频数量:{len(video_paths)}")
print(f"处理模式:顺序处理(单进程)")
print()
# 加载 ASR 模型(只加载一次)
print("加载 ASR 模型...")
from asr_service import ASRService
asr_service = ASRService(
model_name="paraformer-zh",
device="auto",
min_segment_duration=0.5, # 增加最小片段时长
merge_gap=0.3 # 减少合并间隔
)
asr_service._load_model()
print("✓ ASR 模型已加载")
print()
results = []
start_time = time.time()
# 顺序处理每个视频
for i, video_path in enumerate(video_paths, 1):
video_start_time = time.time()
diar_info = diar_results.get(video_path)
# 检查第一阶段的结果(现在是字典结构)
if not diar_info:
print(f"\n[{i}/{len(video_paths)}] 跳过 {video_path.name}(无说话人分离结果)")
results.append({
"video": str(video_path),
"success": False,
"error": "无说话人分离结果",
"process_time": 0.0
})
continue
# 如果第一阶段失败,跳过该视频
if not diar_info.get("success"):
print(f"\n[{i}/{len(video_paths)}] 跳过 {video_path.name}(第一阶段失败:{diar_info.get('error')})")
results.append({
"video": str(video_path),
"success": False,
"error": f"说话人分离失败:{diar_info.get('error')}",
"process_time": diar_info.get("process_time", 0.0)
})
continue
diar_path = diar_info.get("diar_result")
if not diar_path:
print(f"\n[{i}/{len(video_paths)}] 跳过 {video_path.name}(无说话人分离结果文件)")
results.append({
"video": str(video_path),
"success": False,
"error": "说话人分离结果文件不存在",
"process_time": 0.0
})
continue
wav_path = None
try:
print(f"\n[{i}/{len(video_paths)}] 处理:{video_path.name}")
# 1. 提取 WAV如果不存在
wav_path = TEMP_DIR / f"{video_path.stem}.wav"
if not wav_path.exists():
wav_path = extract_wav(video_path, TEMP_DIR)
if wav_path is None:
print(f" ✗ 音频提取失败")
results.append({
"video": str(video_path),
"success": False,
"error": "音频提取失败",
"process_time": time.time() - video_start_time
})
continue
# 2. 加载说话人分离结果
from map_speaker import load_json
diar_result = load_json(diar_path)
# 3. 执行 ASR 识别(不使用 ASR 自带的说话人)
asr_sentences = asr_service.recognize(wav_path)
if not asr_sentences:
print(f" ✗ ASR 识别结果为空")
results.append({
"video": str(video_path),
"success": False,
"error": "ASR 识别结果为空",
"process_time": time.time() - video_start_time
})
continue
# 4. 合并说话人信息(只使用 3D-Speaker 结果)
print(f" 合并结果...")
for sentence in asr_sentences:
# 查找该时间段对应的说话人
matched_speaker = None
best_overlap = 0.0
for seg in diar_result["segments"]:
seg_begin = seg['begin_time']
seg_end = seg['end_time']
# 计算重叠时间
overlap_begin = max(sentence.begin_time, seg_begin)
overlap_end = min(sentence.end_time, seg_end)
if overlap_begin < overlap_end:
overlap_duration = overlap_end - overlap_begin
if overlap_duration > best_overlap:
best_overlap = overlap_duration
matched_speaker = seg['speaker']
# 如果有匹配,使用匹配的说话人;否则使用 speaker_0
if matched_speaker:
sentence.speaker = matched_speaker
else:
sentence.speaker = "speaker_0"
# 5. 保存最终结果
output_file = OUTPUT_DIR / f"{video_path.stem}_result.json"
# 确保 asr_sentences 是列表类型
if isinstance(asr_sentences, dict):
# 如果是字典,尝试获取 sentences 键或转换为空列表
asr_sentences = asr_sentences.get("sentences", [])
asr_service.export_to_json(asr_sentences, output_file)
# 统计说话人
speaker_counts = {}
for sentence in asr_sentences:
speaker = sentence.speaker
speaker_counts[speaker] = speaker_counts.get(speaker, 0) + 1
video_process_time = time.time() - video_start_time
results.append({
"video": str(video_path),
"success": True,
"asr_result": [s.to_dict() for s in asr_sentences],
"merged_result": str(output_file),
"speaker_counts": speaker_counts,
"total_sentences": len(asr_sentences),
"process_time": video_process_time
})
print(f" ✓ 处理完成 (耗时:{video_process_time:.1f}s)")
print(f" - 句子数:{len(asr_sentences)}")
print(f" - 说话人:{speaker_counts}")
except Exception as e:
import traceback
video_process_time = time.time() - video_start_time
print(f" ✗ 处理失败:{e}")
traceback.print_exc()
results.append({
"video": str(video_path),
"success": False,
"error": str(e),
"process_time": video_process_time
})
finally:
# 清理临时文件
if wav_path and wav_path.exists():
try:
wav_path.unlink()
except:
pass
if diar_path:
try:
Path(diar_path).unlink()
except:
pass
# 显示进度
elapsed = time.time() - start_time
avg_time = elapsed / len(results) if results else 1
remaining = (len(video_paths) - len(results)) * avg_time
print(f"\n进度:{len(results)}/{len(video_paths)}")
print(f"已用:{elapsed:.1f}s, 预计剩余:{remaining:.1f}s")
total_time = time.time() - start_time
print(f"\n✓ 第二阶段完成,耗时:{total_time:.1f}s")
print()
# 汇总报告
success_count = sum(1 for r in results if r["success"])
print("\n" + "=" * 60)
print("处理完成汇总")
print("=" * 60)
print(f"总耗时:{total_time:.1f}s")
print(f"平均每个视频:{total_time/len(video_paths):.1f}s")
print(f"成功:{success_count}/{len(video_paths)}")
print(f"失败:{len(video_paths) - success_count}")
# 保存汇总报告
summary = {
"total_videos": len(video_paths),
"success_count": success_count,
"failed_count": len(video_paths) - success_count,
"total_time_seconds": round(total_time, 2),
"results": [
{
"video": Path(r["video"]).name,
"success": r["success"],
"output": r.get("merged_result"),
"total_sentences": r.get("total_sentences", 0),
"speaker_counts": r.get("speaker_counts", {}),
"process_time_seconds": round(r.get("process_time", 0.0), 2),
"error": r.get("error")
}
for r in results
]
}
summary_path = OUTPUT_DIR / "batch_summary.json"
with open(summary_path, "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
print(f"\n汇总报告:{summary_path}")
print("=" * 60)
return results
def main(path: str):
"""主函数"""
import torch
# 拼接根目录
path = os.path.join(os.path.dirname(__file__), "input", path)
print(f"开始处理路径:{path}")
print("\n" + "=" * 60)
print(" 并发批量语音识别处理系统")
print("=" * 60)
print()
# 0. 检查输入路径是否存在
if not Path(path).exists():
print(f"✗ 错误:输入路径不存在 - {path}")
return "输入路径不存在"
# 文件夹为空
if Path(path).is_dir() and not list(Path(path).iterdir()):
print("✗ 错误:文件夹为空")
return "文件夹为空"
# 1. 清空 temp 目录
clear_temp_dir()
# 2. 确保 output 目录存在
ensure_output_dir()
# 3. 准备视频列表(从 path 自动获取)
# 如果是文件夹,递归获取所有视频文件
# 如果是文件,直接添加到列表
video_paths = []
if Path(path).is_dir():
video_folder = Path(path)
if not video_folder.exists():
print(f"✗ 错误:视频文件夹不存在 - {video_folder}")
return
video_paths = get_video_list(video_folder)
if not video_paths:
print("✗ 错误:未找到任何视频文件")
print(f"支持格式:{', '.join(SUPPORTED_VIDEO_FORMATS)}")
print(f"请检查文件夹:{video_folder}")
return
else:
video_paths.append(Path(path))
video_paths = list(set(video_paths)) # 去重
print(f"找到 {len(video_paths)} 个视频文件")
for vp in video_paths:
print(f" - {vp.name}")
print()
# 4. 分阶段处理
# 阶段 1: 说话人分离
# 阶段 2: ASR 识别 + 合并结果
print()
print("=" * 60)
print("处理策略")
print("=" * 60)
print("阶段 1: 批量说话人分离(加载说话人分离模型)")
print(" ↓ 释放内存")
print("阶段 2: 批量语音识别 + 合并结果(加载 ASR 模型)")
print("=" * 60)
print()
# 阶段 1: 说话人分离
diar_results = process_batch_diarization(video_paths, max_workers=1)
# 检查阶段 1 的结果
success_count = len([v for v, r in diar_results.items() if r])
if success_count == 0:
print("✗ 错误:第一阶段全部失败")
return
print(f"✓ 第一阶段成功:{success_count}/{len(video_paths)}")
print()
# 强制垃圾回收,释放显存
import torch
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("✓ 已清理 CUDA 缓存,准备第二阶段")
print()
# 阶段 2: ASR 识别 + 合并结果
results = process_batch_asr(video_paths, diar_results, max_workers=1)
# 5. 最终清理
print("\n清理临时文件...")
clear_temp_dir()
print("\n✓ 全部完成!")
print(f"输出目录:{OUTPUT_DIR}")
if __name__ == "__main__":
# 视频文件夹路径(全局变量)
# PATH = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\宁波北仑区鼎邦杰西雅服饰有限公司"
# PATH = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\temp"
# PATH = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\temp\VID_20251104_085655_024.AVI"
PATH = r"VID_20251104_085655_024.AVI"
main(PATH)