""" Web API Server for ASR and Speaker Diarization 提供语音识别和说话人分离的 REST API 服务 """ import os import sys import gc from pathlib import Path from flask import Flask, request, jsonify, send_file from werkzeug.utils import secure_filename import threading import uuid app = Flask(__name__) app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024 app.config['UPLOAD_FOLDER'] = 'uploads' app.config['RESULT_FOLDER'] = 'results' os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True) GLOBAL_ASR_SERVICE = None GLOBAL_DIAR_SERVICE = None ASR_MODEL_LOADED = False DIAR_MODEL_LOADED = False ASR_MODEL_LOCK = threading.Lock() DIAR_MODEL_LOCK = threading.Lock() def get_asr_service(): global GLOBAL_ASR_SERVICE, ASR_MODEL_LOADED if GLOBAL_ASR_SERVICE is None: from asr_service import ASRService GLOBAL_ASR_SERVICE = ASRService() return GLOBAL_ASR_SERVICE def get_diar_service(): global GLOBAL_DIAR_SERVICE, DIAR_MODEL_LOADED if GLOBAL_DIAR_SERVICE is None: from diarization_service import DiarizationService GLOBAL_DIAR_SERVICE = DiarizationService() return GLOBAL_DIAR_SERVICE @app.route('/health', methods=['GET']) def health_check(): """健康检查""" return jsonify({ 'status': 'ok', 'asr_loaded': ASR_MODEL_LOADED, 'diar_loaded': DIAR_MODEL_LOADED }) @app.route('/api/asr/load', methods=['GET']) def load_asr_model(): """加载 ASR 模型""" global ASR_MODEL_LOADED with ASR_MODEL_LOCK: if ASR_MODEL_LOADED: return jsonify({'message': 'ASR 模型已加载', 'loaded': True}) try: data = request.json or {} model_name = data.get('model_name', 'paraformer-zh') device = data.get('device', 'auto') print(f"正在加载 ASR 模型: {model_name}, 设备: {device}") service = get_asr_service() service._load_model() ASR_MODEL_LOADED = True return jsonify({ 'message': 'ASR 模型加载成功', 'loaded': True, 'model': model_name, 'device': device }) except Exception as e: return jsonify({'error': str(e)}), 500 @app.route('/api/diar/load', methods=['POST']) def load_diar_model(): """加载 3D-Speaker 模型""" global DIAR_MODEL_LOADED with DIAR_MODEL_LOCK: if DIAR_MODEL_LOADED: return jsonify({'message': '3D-Speaker 模型已加载', 'loaded': True}) try: data = request.json or {} embedding_model = data.get('embedding_model', 'eres2netv2') device = data.get('device', 'auto') cluster_threshold = data.get('cluster_threshold', 0.5) min_cluster_size = data.get('min_cluster_size', 10) print(f"正在加载 3D-Speaker 模型: {embedding_model}, 设备: {device}") service = get_diar_service() service._load_model() DIAR_MODEL_LOADED = True return jsonify({ 'message': '3D-Speaker 模型加载成功', 'loaded': True, 'embedding_model': embedding_model, 'device': device }) except Exception as e: return jsonify({'error': str(e)}), 500 @app.route('/api/asr/unload', methods=['POST']) def unload_asr_model(): """卸载 ASR 模型""" global GLOBAL_ASR_SERVICE, ASR_MODEL_LOADED try: GLOBAL_ASR_SERVICE = None ASR_MODEL_LOADED = False gc.collect() return jsonify({'message': 'ASR 模型已卸载'}) except Exception as e: return jsonify({'error': str(e)}), 500 @app.route('/api/diar/unload', methods=['POST']) def unload_diar_model(): """卸载 3D-Speaker 模型""" global GLOBAL_DIAR_SERVICE, DIAR_MODEL_LOADED try: GLOBAL_DIAR_SERVICE = None DIAR_MODEL_LOADED = False gc.collect() return jsonify({'message': '3D-Speaker 模型已卸载'}) except Exception as e: return jsonify({'error': str(e)}), 500 @app.route('/api/recognize/single', methods=['POST']) def recognize_single(): """单文件推理""" try: if 'file' not in request.files: return jsonify({'error': '请上传音频文件'}), 400 file = request.files['file'] if file.filename == '': return jsonify({'error': '文件名不能为空'}), 400 data = request.form or {} use_3d_speaker = data.get('use_3d_speaker', 'false').lower() == 'true' embedding_model = data.get('embedding_model', 'eres2netv2') cluster_threshold = float(data.get('cluster_threshold', 0.5)) min_cluster_size = int(data.get('min_cluster_size', 10)) output_format = data.get('format', 'json') filename = secure_filename(file.filename) task_id = str(uuid.uuid4())[:8] audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{task_id}_{filename}") file.save(audio_path) try: global ASR_MODEL_LOADED service = get_asr_service() if not ASR_MODEL_LOADED: service._load_model() ASR_MODEL_LOADED = True sentences = service.recognize( audio_path, use_3d_speaker=use_3d_speaker, embedding_model=embedding_model, cluster_threshold=cluster_threshold, min_cluster_size=min_cluster_size ) result = { 'file': filename, 'total_sentences': len(sentences), 'sentences': [s.to_dict() for s in sentences] } if output_format == 'json': result_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_result.json") with open(result_path, 'w', encoding='utf-8') as f: import json json.dump(result, f, ensure_ascii=False, indent=2) return jsonify(result) else: srt_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_result.srt") service.export_to_srt(sentences, srt_path) return send_file(srt_path, as_attachment=True, download_name=f"{task_id}_result.srt") finally: if os.path.exists(audio_path): os.remove(audio_path) except Exception as e: import traceback traceback.print_exc() return jsonify({'error': str(e)}), 500 @app.route('/api/recognize/batch', methods=['POST']) def recognize_batch(): """批量推理""" try: if 'files' not in request.files: return jsonify({'error': '请上传音频文件'}), 400 files = request.files.getlist('files') if not files or len(files) == 0: return jsonify({'error': '文件列表为空'}), 400 data = request.form or {} use_3d_speaker = data.get('use_3d_speaker', 'false').lower() == 'true' embedding_model = data.get('embedding_model', 'eres2netv2') cluster_threshold = float(data.get('cluster_threshold', 0.5)) min_cluster_size = int(data.get('min_cluster_size', 10)) task_id = str(uuid.uuid4())[:8] audio_paths = [] for f in files: if f.filename: filename = secure_filename(f.filename) audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{task_id}_{filename}") f.save(audio_path) audio_paths.append(audio_path) try: global ASR_MODEL_LOADED service = get_asr_service() if not ASR_MODEL_LOADED: service._load_model() ASR_MODEL_LOADED = True results = [] for audio_path in audio_paths: try: sentences = service.recognize( audio_path, use_3d_speaker=use_3d_speaker, embedding_model=embedding_model, cluster_threshold=cluster_threshold, min_cluster_size=min_cluster_size ) results.append({ 'file': os.path.basename(audio_path), 'total_sentences': len(sentences), 'sentences': [s.to_dict() for s in sentences] }) except Exception as e: results.append({ 'file': os.path.basename(audio_path), 'error': str(e) }) result_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_batch_result.json") import json with open(result_path, 'w', encoding='utf-8') as f: json.dump({ 'task_id': task_id, 'total_files': len(results), 'results': results }, f, ensure_ascii=False, indent=2) return jsonify({ 'task_id': task_id, 'total_files': len(results), 'results': results, 'result_file': result_path }) finally: for audio_path in audio_paths: if os.path.exists(audio_path): os.remove(audio_path) except Exception as e: import traceback traceback.print_exc() return jsonify({'error': str(e)}), 500 @app.route('/api/status', methods=['GET']) def get_status(): """获取模型状态""" return jsonify({ 'asr_loaded': ASR_MODEL_LOADED, 'diar_loaded': DIAR_MODEL_LOADED, 'asr_model': 'paraformer-zh', 'diar_model': '3D-Speaker' }) if __name__ == '__main__': print("=" * 60) print(" ASR & Speaker Diarization API Server") print("=" * 60) print("\nAPI 接口:") print(" GET /health - 健康检查") print(" GET /api/status - 获取模型状态") print(" POST /api/asr/load - 加载 ASR 模型") print(" POST /api/diar/load - 加载 3D-Speaker 模型") print(" POST /api/asr/unload - 卸载 ASR 模型") print(" POST /api/diar/unload - 卸载 3D-Speaker 模型") print(" POST /api/recognize/single - 单文件推理") print(" POST /api/recognize/batch - 批量推理") print("\n" + "=" * 60) print("启动服务: http://localhost:5000") print("=" * 60) app.run(host='0.0.0.0', port=5000, debug=False)