SpeechRecognition/main.py

517 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
批量视频语音识别 + 说话人分离
功能:
1. 自动扫描视频目录
2. 两阶段处理:
- 阶段 1: 说话人分离
- 阶段 2: ASR 识别 + 合并结果
3. 每个视频只加载一次模型
4. 顺序处理,避免多进程 CUDA 共享问题
"""
import gc
import json
import shutil
import subprocess
import time
from pathlib import Path
from typing import Dict, List, Optional
# 路径配置
BASE_DIR = Path(__file__).parent.absolute()
TEMP_DIR = BASE_DIR / "temp"
OUTPUT_DIR = BASE_DIR / "output"
# 视频文件夹路径(全局变量)
VIDEO_FOLDER = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\宁波北仑区鼎邦杰西雅服饰有限公司"
# 支持的视频格式
SUPPORTED_VIDEO_FORMATS = ["*.mp4", "*.avi", "*.mkv", "*.mov", "*.flv", "*.wmv", "*.m4v"]
def get_video_list(folder_path: Path) -> List[Path]:
"""
从文件夹自动获取视频列表,按文件名中的时间排序
Args:
folder_path: 视频文件夹路径
Returns:
按文件名排序后的视频路径列表
"""
video_paths = []
# 扫描所有支持的视频格式
for pattern in SUPPORTED_VIDEO_FORMATS:
video_paths.extend(folder_path.glob(pattern))
# 按文件名排序(假设文件名包含时间戳,如 VID_20251031_132320_019.mp4
# 使用文件名的字典序排序,时间戳格式的文件名会自动按时间顺序排列
video_paths.sort(key=lambda p: p.name)
return video_paths
def clear_temp_dir():
"""清空 temp 目录"""
print("=" * 60)
print("清空临时目录...")
print("=" * 60)
if TEMP_DIR.exists():
try:
shutil.rmtree(TEMP_DIR)
print(f"✓ 已删除:{TEMP_DIR}")
except Exception as e:
print(f"✗ 删除失败:{e}")
TEMP_DIR.mkdir(parents=True, exist_ok=True)
print(f"✓ 已创建:{TEMP_DIR}")
print()
def ensure_output_dir():
"""确保 output 目录存在"""
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print(f"✓ 输出目录:{OUTPUT_DIR}")
print()
def extract_wav(video_path: Path, temp_dir: Path) -> Optional[Path]:
"""
从视频提取 WAV 音频
Args:
video_path: 视频文件路径
temp_dir: 临时目录
Returns:
WAV 文件路径,失败返回 None
"""
try:
wav_path = temp_dir / f"{video_path.stem}.wav"
# 使用 ffmpeg 提取音频
cmd = [
"ffmpeg",
"-i", str(video_path),
"-vn", # 不处理视频
"-acodec", "pcm_s16le", # 16 位 PCM 编码
"-ar", "16000", # 16kHz 采样率
"-ac", "1", # 单声道
"-y", # 覆盖已存在文件
str(wav_path)
]
subprocess.run(
cmd,
check=True,
capture_output=True,
timeout=300 # 5 分钟超时
)
if wav_path.exists():
print(f"✓ 提取音频:{video_path.name} -> {wav_path.name}")
return wav_path
else:
print(f"✗ 提取失败:{video_path.name}")
return None
except subprocess.TimeoutExpired:
print(f"✗ 提取超时:{video_path.name}")
return None
except Exception as e:
print(f"✗ 提取错误:{video_path.name} - {e}")
return None
def process_batch_diarization(video_paths: List[Path], max_workers: int = 1):
"""
第一阶段:批量执行说话人分离(主进程顺序处理)
Args:
video_paths: 视频路径列表
max_workers: 并发数(目前固定为 1
Returns:
Dict[video_path -> diar_result_path]: 说话人分离结果映射
"""
print("=" * 60)
print("第一阶段:批量说话人分离")
print("=" * 60)
print(f"视频数量:{len(video_paths)}")
print(f"处理模式:顺序处理(单进程)")
print()
# 加载说话人分离模型(只加载一次)
print("加载说话人分离模型...")
from diarization_service import DiarizationService
diar_service = DiarizationService(
embedding_model="eres2netv2",
device="auto",
cluster_threshold=0.5,
min_cluster_size=10
)
diar_service._load_model()
print("✓ 说话人分离模型已加载")
print()
results = {}
start_time = time.time()
# 顺序处理每个视频
for i, video_path in enumerate(video_paths, 1):
try:
print(f"\n[{i}/{len(video_paths)}] 处理:{video_path.name}")
# 1. 提取 WAV
wav_path = extract_wav(video_path, TEMP_DIR)
if wav_path is None:
print(f" ✗ 音频提取失败")
continue
# 2. 执行说话人分离
diar_segments = diar_service.diarize(wav_path)
if not diar_segments:
print(f" ✗ 说话人分离结果为空")
continue
# 3. 保存说话人分离结果(临时文件)
temp_diar_path = TEMP_DIR / f"{video_path.stem}_diar.json"
diar_result = {
"segments": [seg.to_dict() for seg in diar_segments]
}
from map_speaker import save_json
save_json(temp_diar_path, diar_result)
results[video_path] = str(temp_diar_path)
print(f" ✓ 说话人分离完成")
# 4. 清理临时 WAV保留用于后续 ASR
# 注意这里不删除ASR 阶段还需要
except Exception as e:
import traceback
print(f" ✗ 处理失败:{e}")
traceback.print_exc()
# 显示进度
elapsed = time.time() - start_time
avg_time = elapsed / len(results) if results else 1
remaining = (len(video_paths) - len(results)) * avg_time
print(f"\n进度:{len(results)}/{len(video_paths)}")
print(f"已用:{elapsed:.1f}s, 预计剩余:{remaining:.1f}s")
total_time = time.time() - start_time
print(f"\n✓ 第一阶段完成,耗时:{total_time:.1f}s")
print()
return results
def process_batch_asr(video_paths: List[Path], diar_results: Dict, max_workers: int = 1):
"""
第二阶段:批量执行 ASR 识别并合并结果(主进程顺序处理)
Args:
video_paths: 视频路径列表
diar_results: 说话人分离结果映射
max_workers: 并发数(目前固定为 1
Returns:
List[Dict]: 最终结果列表
"""
print("=" * 60)
print("第二阶段:批量语音识别 + 合并结果")
print("=" * 60)
print(f"视频数量:{len(video_paths)}")
print(f"处理模式:顺序处理(单进程)")
print()
# 加载 ASR 模型(只加载一次)
print("加载 ASR 模型...")
from asr_service import ASRService
asr_service = ASRService(model_name="paraformer-zh", device="auto")
asr_service._load_model()
print("✓ ASR 模型已加载")
print()
results = []
start_time = time.time()
# 顺序处理每个视频
for i, video_path in enumerate(video_paths, 1):
diar_path = diar_results.get(video_path)
if not diar_path:
print(f"\n[{i}/{len(video_paths)}] 跳过 {video_path.name}(无说话人分离结果)")
results.append({
"video": str(video_path),
"success": False,
"error": "无说话人分离结果"
})
continue
wav_path = None
try:
print(f"\n[{i}/{len(video_paths)}] 处理:{video_path.name}")
# 1. 提取 WAV如果不存在
wav_path = TEMP_DIR / f"{video_path.stem}.wav"
if not wav_path.exists():
wav_path = extract_wav(video_path, TEMP_DIR)
if wav_path is None:
print(f" ✗ 音频提取失败")
results.append({
"video": str(video_path),
"success": False,
"error": "音频提取失败"
})
continue
# 2. 加载说话人分离结果
from map_speaker import load_json
diar_result = load_json(diar_path)
# 3. 执行 ASR 识别(不使用 ASR 自带的说话人)
asr_sentences = asr_service.recognize(wav_path)
if not asr_sentences:
print(f" ✗ ASR 识别结果为空")
results.append({
"video": str(video_path),
"success": False,
"error": "ASR 识别结果为空"
})
continue
# 4. 合并说话人信息(只使用 3D-Speaker 结果)
print(f" 合并结果...")
for sentence in asr_sentences:
# 查找该时间段对应的说话人
matched_speaker = None
best_overlap = 0.0
for seg in diar_result["segments"]:
seg_begin = seg['begin_time']
seg_end = seg['end_time']
# 计算重叠时间
overlap_begin = max(sentence.begin_time, seg_begin)
overlap_end = min(sentence.end_time, seg_end)
if overlap_begin < overlap_end:
overlap_duration = overlap_end - overlap_begin
if overlap_duration > best_overlap:
best_overlap = overlap_duration
matched_speaker = seg['speaker']
# 如果有匹配,使用匹配的说话人;否则使用 speaker_0
if matched_speaker:
sentence.speaker = matched_speaker
else:
sentence.speaker = "speaker_0"
# 5. 保存最终结果
output_file = OUTPUT_DIR / f"{video_path.stem}_result.json"
# 确保 asr_sentences 是 List[Sentence] 类型
if isinstance(asr_sentences, dict):
# 如果是字典,尝试获取 sentences 键或转换为空列表
asr_sentences = asr_sentences.get("sentences", [])
asr_service.export_to_json(asr_sentences, output_file)
# 统计说话人
speaker_counts = {}
for sentence in asr_sentences:
speaker = sentence.speaker
speaker_counts[speaker] = speaker_counts.get(speaker, 0) + 1
results.append({
"video": str(video_path),
"success": True,
"asr_result": [s.to_dict() for s in asr_sentences],
"merged_result": str(output_file),
"speaker_counts": speaker_counts,
"total_sentences": len(asr_sentences)
})
print(f" ✓ 处理完成")
print(f" - 句子数:{len(asr_sentences)}")
print(f" - 说话人:{speaker_counts}")
except Exception as e:
import traceback
print(f" ✗ 处理失败:{e}")
traceback.print_exc()
results.append({
"video": str(video_path),
"success": False,
"error": str(e)
})
finally:
# 清理临时文件
if wav_path and wav_path.exists():
try:
wav_path.unlink()
except:
pass
if diar_path:
try:
Path(diar_path).unlink()
except:
pass
# 显示进度
elapsed = time.time() - start_time
avg_time = elapsed / len(results) if results else 1
remaining = (len(video_paths) - len(results)) * avg_time
print(f"\n进度:{len(results)}/{len(video_paths)}")
print(f"已用:{elapsed:.1f}s, 预计剩余:{remaining:.1f}s")
total_time = time.time() - start_time
print(f"\n✓ 第二阶段完成,耗时:{total_time:.1f}s")
print()
return results
# 汇总报告
total_time = time.time() - start_time
success_count = sum(1 for r in results if r["success"])
print("\n" + "=" * 60)
print("处理完成汇总")
print("=" * 60)
print(f"总耗时:{total_time:.1f}s")
print(f"平均每个视频:{total_time/len(video_paths):.1f}s")
print(f"成功:{success_count}/{len(video_paths)}")
print(f"失败:{len(video_paths) - success_count}")
# 保存汇总报告
summary = {
"total_videos": len(video_paths),
"success_count": success_count,
"failed_count": len(video_paths) - success_count,
"total_time_seconds": round(total_time, 2),
"results": [
{
"video": Path(r["video"]).name,
"success": r["success"],
"output": r.get("merged_result"),
"total_sentences": r.get("total_sentences", 0),
"speaker_counts": r.get("speaker_counts", {}),
"error": r.get("error")
}
for r in results
]
}
summary_path = OUTPUT_DIR / "batch_summary.json"
with open(summary_path, "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
print(f"\n汇总报告:{summary_path}")
print("=" * 60)
return results
def main():
"""主函数"""
import torch
print("\n" + "=" * 60)
print(" 并发批量语音识别处理系统")
print("=" * 60)
print()
# 1. 清空 temp 目录
clear_temp_dir()
# 2. 确保 output 目录存在
ensure_output_dir()
# 3. 准备视频列表(从 VIDEO_FOLDER 自动获取)
video_folder = Path(VIDEO_FOLDER)
if not video_folder.exists():
print(f"✗ 错误:视频文件夹不存在 - {video_folder}")
return
video_paths = get_video_list(video_folder)
if not video_paths:
print("✗ 错误:未找到任何视频文件")
print(f"支持格式:{', '.join(SUPPORTED_VIDEO_FORMATS)}")
print(f"请检查文件夹:{video_folder}")
return
print(f"找到 {len(video_paths)} 个视频文件")
for vp in video_paths:
print(f" - {vp.name}")
print()
# 4. 分阶段处理
# 阶段 1: 说话人分离
# 阶段 2: ASR 识别 + 合并结果
print()
print("=" * 60)
print("处理策略")
print("=" * 60)
print("阶段 1: 批量说话人分离(加载说话人分离模型)")
print(" ↓ 释放内存")
print("阶段 2: 批量语音识别 + 合并结果(加载 ASR 模型)")
print("=" * 60)
print()
# 阶段 1: 说话人分离
diar_results = process_batch_diarization(video_paths, max_workers=1)
# 检查阶段 1 的结果
success_count = len([v for v, r in diar_results.items() if r])
if success_count == 0:
print("✗ 错误:第一阶段全部失败")
return
print(f"✓ 第一阶段成功:{success_count}/{len(video_paths)}")
print()
# 强制垃圾回收,释放显存
import torch
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("✓ 已清理 CUDA 缓存,准备第二阶段")
print()
# 阶段 2: ASR 识别 + 合并结果
results = process_batch_asr(video_paths, diar_results, max_workers=1)
# 5. 最终清理
print("\n清理临时文件...")
clear_temp_dir()
print("\n✓ 全部完成!")
print(f"输出目录:{OUTPUT_DIR}")
if __name__ == "__main__":
main()