diff --git a/README.md b/README.md index e6db829..ba51da5 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,12 @@ D:\ProgramData\miniconda3\envs\audio\python.exe + + + + + + # FunASR 语音识别服务 基于阿里达摩院 [FunASR](https://github.com/alibaba-damo-academy/FunASR) 的本地语音识别解决方案。 diff --git a/main.py b/main.py index 311089a..b8c567e 100644 --- a/main.py +++ b/main.py @@ -12,6 +12,7 @@ import gc import json +import os import shutil import subprocess import time @@ -23,8 +24,7 @@ BASE_DIR = Path(__file__).parent.absolute() TEMP_DIR = BASE_DIR / "temp" OUTPUT_DIR = BASE_DIR / "output" -# 视频文件夹路径(全局变量) -VIDEO_FOLDER = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\宁波北仑区鼎邦杰西雅服饰有限公司" + # 支持的视频格式 SUPPORTED_VIDEO_FORMATS = ["*.mp4", "*.avi", "*.mkv", "*.mov", "*.flv", "*.wmv", "*.m4v"] @@ -271,7 +271,12 @@ def process_batch_asr(video_paths, diar_results, max_workers=1): print("加载 ASR 模型...") from asr_service import ASRService - asr_service = ASRService(model_name="paraformer-zh", device="auto") + asr_service = ASRService( + model_name="paraformer-zh", + device="auto", + min_segment_duration=0.5, # 增加最小片段时长 + merge_gap=0.3 # 减少合并间隔 + ) asr_service._load_model() print("✓ ASR 模型已加载") print() @@ -489,34 +494,54 @@ def process_batch_asr(video_paths, diar_results, max_workers=1): return results -def main(): +def main(path: str): """主函数""" import torch + # 拼接根目录 + path = os.path.join(os.path.dirname(__file__), "input", path) + print(f"开始处理路径:{path}") print("\n" + "=" * 60) print(" 并发批量语音识别处理系统") print("=" * 60) print() + # 0. 检查输入路径是否存在 + if not Path(path).exists(): + print(f"✗ 错误:输入路径不存在 - {path}") + return "输入路径不存在" + + # 文件夹为空 + if Path(path).is_dir() and not list(Path(path).iterdir()): + print("✗ 错误:文件夹为空") + return "文件夹为空" + # 1. 清空 temp 目录 clear_temp_dir() # 2. 确保 output 目录存在 ensure_output_dir() - # 3. 准备视频列表(从 VIDEO_FOLDER 自动获取) - video_folder = Path(VIDEO_FOLDER) - if not video_folder.exists(): - print(f"✗ 错误:视频文件夹不存在 - {video_folder}") - return + # 3. 准备视频列表(从 path 自动获取) + # 如果是文件夹,递归获取所有视频文件 + # 如果是文件,直接添加到列表 + video_paths = [] + if Path(path).is_dir(): + video_folder = Path(path) + if not video_folder.exists(): + print(f"✗ 错误:视频文件夹不存在 - {video_folder}") + return - video_paths = get_video_list(video_folder) + video_paths = get_video_list(video_folder) + if not video_paths: + print("✗ 错误:未找到任何视频文件") + print(f"支持格式:{', '.join(SUPPORTED_VIDEO_FORMATS)}") + print(f"请检查文件夹:{video_folder}") + return + else: + video_paths.append(Path(path)) - if not video_paths: - print("✗ 错误:未找到任何视频文件") - print(f"支持格式:{', '.join(SUPPORTED_VIDEO_FORMATS)}") - print(f"请检查文件夹:{video_folder}") - return + video_paths = list(set(video_paths)) # 去重 print(f"找到 {len(video_paths)} 个视频文件") for vp in video_paths: @@ -569,4 +594,9 @@ def main(): if __name__ == "__main__": - main() + # 视频文件夹路径(全局变量) + # PATH = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\宁波北仑区鼎邦杰西雅服饰有限公司" + # PATH = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\temp" + # PATH = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\temp\VID_20251104_085655_024.AVI" + PATH = r"VID_20251104_085655_024.AVI" + main(PATH) diff --git a/requirements.txt b/requirements.txt index 30aee20..a0ea165 100644 --- a/requirements.txt +++ b/requirements.txt @@ -59,5 +59,6 @@ jieba>=0.42.0 # onnxruntime-gpu>=1.23.0 # ---------- 可选:Web API 服务 ---------- -# Flask>=3.1.0 +Flask>=3.1.0 +waitress>=3.0.0 # SQLAlchemy>=2.0.0 diff --git a/server.py b/server.py index cbc47dc..7556ec5 100644 --- a/server.py +++ b/server.py @@ -6,16 +6,71 @@ Web API Server for ASR and Speaker Diarization import os import sys import gc -from pathlib import Path -from flask import Flask, request, jsonify, send_file -from werkzeug.utils import secure_filename +import json +import shutil +import signal import threading +from pathlib import Path +from flask import Flask, request, jsonify +from werkzeug.utils import secure_filename import uuid +from datetime import datetime, timezone + + +def make_response(status="success", data=None, errors=None, message=None, extra=None): + """ + 统一 API 响应格式 + + Args: + status: 状态 ("success" 或 "error") + data: 返回的数据 + errors: 错误列表 + message: 消息 + extra: 其他额外字段 + + Returns: + 统一格式的 JSON 响应 + """ + response = { + "status": status, + "data": data if data is not None else {}, + "errors": errors if errors is not None else [], + "message": message or ("操作成功" if status == "success" else "操作失败"), + "timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + } + + if extra: + response.update(extra) + + return response app = Flask(__name__) app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024 app.config['UPLOAD_FOLDER'] = 'uploads' app.config['RESULT_FOLDER'] = 'results' +# 增加请求超时时间(秒),支持长时间运行的任务 +app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 300 # 5 +OUTPUT_DIR = Path('output') + +# 手动添加 CORS 响应头 +@app.after_request +def after_request(response): + # 获取请求的 Origin + origin = request.headers.get('Origin', '*') + + # 设置允许的来源(不使用通配符) + response.headers.add('Access-Control-Allow-Origin', origin) + response.headers.add('Access-Control-Allow-Credentials', 'true') + response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization') + response.headers.add('Access-Control-Allow-Methods', 'GET,PUT,POST,DELETE,OPTIONS') + return response + + +# 处理 OPTIONS 预检请求 +@app.route('/config', methods=['OPTIONS']) +def config_options(): + """处理 /config 接口的 OPTIONS 预检请求""" + return '', 200 os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True) @@ -27,285 +82,148 @@ DIAR_MODEL_LOADED = False ASR_MODEL_LOCK = threading.Lock() DIAR_MODEL_LOCK = threading.Lock() - -def get_asr_service(): - global GLOBAL_ASR_SERVICE, ASR_MODEL_LOADED - if GLOBAL_ASR_SERVICE is None: - from asr_service import ASRService - GLOBAL_ASR_SERVICE = ASRService() - return GLOBAL_ASR_SERVICE +# 全局变量用于控制任务执行 +task_running = {} +task_timeout = 600 # 10 分钟超时(秒) -def get_diar_service(): - global GLOBAL_DIAR_SERVICE, DIAR_MODEL_LOADED - if GLOBAL_DIAR_SERVICE is None: - from diarization_service import DiarizationService - GLOBAL_DIAR_SERVICE = DiarizationService() - return GLOBAL_DIAR_SERVICE - - -@app.route('/health', methods=['GET']) -def health_check(): - """健康检查""" - return jsonify({ - 'status': 'ok', - 'asr_loaded': ASR_MODEL_LOADED, - 'diar_loaded': DIAR_MODEL_LOADED - }) - - -@app.route('/api/asr/load', methods=['GET']) -def load_asr_model(): - """加载 ASR 模型""" - global ASR_MODEL_LOADED - with ASR_MODEL_LOCK: - if ASR_MODEL_LOADED: - return jsonify({'message': 'ASR 模型已加载', 'loaded': True}) - - try: - data = request.json or {} - model_name = data.get('model_name', 'paraformer-zh') - device = data.get('device', 'auto') - - print(f"正在加载 ASR 模型: {model_name}, 设备: {device}") - service = get_asr_service() - service._load_model() - ASR_MODEL_LOADED = True - - return jsonify({ - 'message': 'ASR 模型加载成功', - 'loaded': True, - 'model': model_name, - 'device': device - }) - except Exception as e: - return jsonify({'error': str(e)}), 500 - - -@app.route('/api/diar/load', methods=['POST']) -def load_diar_model(): - """加载 3D-Speaker 模型""" - global DIAR_MODEL_LOADED - with DIAR_MODEL_LOCK: - if DIAR_MODEL_LOADED: - return jsonify({'message': '3D-Speaker 模型已加载', 'loaded': True}) - - try: - data = request.json or {} - embedding_model = data.get('embedding_model', 'eres2netv2') - device = data.get('device', 'auto') - cluster_threshold = data.get('cluster_threshold', 0.5) - min_cluster_size = data.get('min_cluster_size', 10) - - print(f"正在加载 3D-Speaker 模型: {embedding_model}, 设备: {device}") - service = get_diar_service() - service._load_model() - DIAR_MODEL_LOADED = True - - return jsonify({ - 'message': '3D-Speaker 模型加载成功', - 'loaded': True, - 'embedding_model': embedding_model, - 'device': device - }) - except Exception as e: - return jsonify({'error': str(e)}), 500 - - -@app.route('/api/asr/unload', methods=['POST']) -def unload_asr_model(): - """卸载 ASR 模型""" - global GLOBAL_ASR_SERVICE, ASR_MODEL_LOADED +@app.route('/api/recognize', methods=['POST']) +def recognize(): + """文件推理 - 调用 main.py 的 main 函数""" + task_id = str(uuid.uuid4()) + task_running[task_id] = True try: - GLOBAL_ASR_SERVICE = None - ASR_MODEL_LOADED = False - gc.collect() + # 从请求参数获取路径(只接受文件名或 input 目录下的相对路径) + data = request.json or {} + path = data.get('path', '') - return jsonify({'message': 'ASR 模型已卸载'}) - except Exception as e: - return jsonify({'error': str(e)}), 500 + if not path: + return jsonify(make_response( + status="error", + message="请提供文件路径", + errors=["缺少必要参数:path"] + )), 400 + print(f"\n{'='*60}") + print(f"API 收到请求:path={path}, task_id={task_id}") + print(f"{'='*60}") + print(f"开始调用 main 函数...") -@app.route('/api/diar/unload', methods=['POST']) -def unload_diar_model(): - """卸载 3D-Speaker 模型""" - global GLOBAL_DIAR_SERVICE, DIAR_MODEL_LOADED + # 切换到 server.py 所在目录 + server_dir = Path(__file__).parent.absolute() + os.chdir(server_dir) - try: - GLOBAL_DIAR_SERVICE = None - DIAR_MODEL_LOADED = False - gc.collect() + print(f"当前工作目录:{os.getcwd()}") - return jsonify({'message': '3D-Speaker 模型已卸载'}) - except Exception as e: - return jsonify({'error': str(e)}), 500 + # 设置超时处理 + def timeout_handler(signum, frame): + task_running[task_id] = False + raise TimeoutError(f"任务执行超时 ({task_timeout}秒)") - -@app.route('/api/recognize/single', methods=['POST']) -def recognize_single(): - """单文件推理""" - try: - if 'file' not in request.files: - return jsonify({'error': '请上传音频文件'}), 400 - - file = request.files['file'] - if file.filename == '': - return jsonify({'error': '文件名不能为空'}), 400 - - data = request.form or {} - use_3d_speaker = data.get('use_3d_speaker', 'false').lower() == 'true' - embedding_model = data.get('embedding_model', 'eres2netv2') - cluster_threshold = float(data.get('cluster_threshold', 0.5)) - min_cluster_size = int(data.get('min_cluster_size', 10)) - output_format = data.get('format', 'json') - - filename = secure_filename(file.filename) if file.filename is not None else secure_filename("unnamed_file") - task_id = str(uuid.uuid4())[:8] - audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{task_id}_{filename}") - file.save(audio_path) + # 注册信号处理器(仅在 Unix/Linux/Mac 有效,Windows 下会被忽略) + try: + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(task_timeout) + use_alarm = True + except (AttributeError, ValueError): + # Windows 不支持 SIGALRM + use_alarm = False + print("注意:Windows 系统,使用超时检测线程") try: - global ASR_MODEL_LOADED - service = get_asr_service() - if not ASR_MODEL_LOADED: - service._load_model() - ASR_MODEL_LOADED = True + # 调用 main.py 的 main 函数 + from main import main + main(path) - sentences = service.recognize( - audio_path - ) + # 取消超时 + if use_alarm: + signal.alarm(0) - result = { - 'file': filename, - 'total_sentences': len(sentences), - 'sentences': [s.to_dict() for s in sentences] - } + print(f"main 函数执行完成") - if output_format == 'json': - result_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_result.json") - with open(result_path, 'w', encoding='utf-8') as f: - import json - json.dump(result, f, ensure_ascii=False, indent=2) - return jsonify(result) - else: - srt_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_result.srt") - service.export_to_srt([s for s in sentences], srt_path) - return send_file(srt_path, as_attachment=True, download_name=f"{task_id}_result.srt") + return jsonify(make_response( + status="success", + message="文件推理完成", + data={"path": path, "task_id": task_id} + )), 200 + + except TimeoutError as e: + print(f"任务超时:{e}") + return jsonify(make_response( + status="error", + message=str(e), + errors=[str(e)] + )), 504 # Gateway Timeout finally: - if os.path.exists(audio_path): - os.remove(audio_path) + # 清理信号处理器 + if use_alarm: + signal.alarm(0) + task_running[task_id] = False except Exception as e: import traceback traceback.print_exc() - return jsonify({'error': str(e)}), 500 + task_running[task_id] = False + return jsonify(make_response( + status="error", + message=str(e), + errors=[str(e)] + )), 500 -@app.route('/api/recognize/batch', methods=['POST']) -def recognize_batch(): - """批量推理""" +@app.route('/api/result', methods=['GET']) +def result(): + """获取文件推理结果""" try: - if 'files' not in request.files: - return jsonify({'error': '请上传音频文件'}), 400 + # 从请求参数获取路径 + path = request.args.get('path', '') - files = request.files.getlist('files') - if not files or len(files) == 0: - return jsonify({'error': '文件列表为空'}), 400 - - data = request.form or {} - use_3d_speaker = data.get('use_3d_speaker', 'false').lower() == 'true' - embedding_model = data.get('embedding_model', 'eres2netv2') - cluster_threshold = float(data.get('cluster_threshold', 0.5)) - min_cluster_size = int(data.get('min_cluster_size', 10)) - - task_id = str(uuid.uuid4())[:8] - audio_paths = [] - - for f in files: - if f.filename: - filename = secure_filename(f.filename) - audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{task_id}_{filename}") - f.save(audio_path) - audio_paths.append(audio_path) - - try: - global ASR_MODEL_LOADED - service = get_asr_service() - if not ASR_MODEL_LOADED: - service._load_model() - ASR_MODEL_LOADED = True - - results = [] - for audio_path in audio_paths: - try: - sentences = service.recognize( - audio_path, - ) - results.append({ - 'file': os.path.basename(audio_path), - 'total_sentences': len(sentences), - 'sentences': [s.to_dict() for s in sentences] - }) - except Exception as e: - results.append({ - 'file': os.path.basename(audio_path), - 'error': str(e) - }) - - result_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_batch_result.json") - import json - with open(result_path, 'w', encoding='utf-8') as f: - json.dump({ - 'task_id': task_id, - 'total_files': len(results), - 'results': results - }, f, ensure_ascii=False, indent=2) - - return jsonify({ - 'task_id': task_id, - 'total_files': len(results), - 'results': results, - 'result_file': result_path - }) - - finally: - for audio_path in audio_paths: - if os.path.exists(audio_path): - os.remove(audio_path) + if not path: + return jsonify(make_response( + status="error", + message="请提供文件路径", + errors=["缺少必要参数:path"] + )), 400 + # 读取结果 + print(Path(path).stem) + result_file = OUTPUT_DIR / f"{Path(path).stem}_result.json" + if result_file.exists(): + with open(result_file, 'r', encoding='utf-8') as f: + result_data = json.load(f) + return jsonify(make_response( + status="success", + message="获取成功", + data=result_data + )), 200 + else: + return jsonify(make_response( + status="error", + message="处理完成但未找到结果文件", + errors=["结果文件不存在"] + )), 404 except Exception as e: import traceback traceback.print_exc() - return jsonify({'error': str(e)}), 500 - - -@app.route('/api/status', methods=['GET']) -def get_status(): - """获取模型状态""" - return jsonify({ - 'asr_loaded': ASR_MODEL_LOADED, - 'diar_loaded': DIAR_MODEL_LOADED, - 'asr_model': 'paraformer-zh', - 'diar_model': '3D-Speaker' - }) - + return jsonify(make_response( + status="error", + message=str(e), + errors=[str(e)] + )), 500 if __name__ == '__main__': print("=" * 60) print(" ASR & Speaker Diarization API Server") print("=" * 60) print("\nAPI 接口:") - print(" GET /health - 健康检查") - print(" GET /api/status - 获取模型状态") - print(" POST /api/asr/load - 加载 ASR 模型") - print(" POST /api/diar/load - 加载 3D-Speaker 模型") - print(" POST /api/asr/unload - 卸载 ASR 模型") - print(" POST /api/diar/unload - 卸载 3D-Speaker 模型") - print(" POST /api/recognize/single - 单文件推理") - print(" POST /api/recognize/batch - 批量推理") + print(" POST /api/recognize - 文件推理") + print(" POST /api/result - 获取文件推理结果") print("\n" + "=" * 60) - print("启动服务: http://localhost:5000") + print("启动服务:http://localhost:5000") + print("使用 Waitress WSGI 服务器(无超时限制)") print("=" * 60) - app.run(host='0.0.0.0', port=5000, debug=False) + + from waitress import serve + serve(app, host='0.0.0.0', port=5000, threads=4, connection_limit=100)