完善服务端

This commit is contained in:
yueliuli 2026-05-06 15:57:34 +08:00
parent bdfce46d7c
commit 1507093701
4 changed files with 220 additions and 265 deletions

View File

@ -26,6 +26,12 @@ D:\ProgramData\miniconda3\envs\audio\python.exe
# FunASR 语音识别服务
基于阿里达摩院 [FunASR](https://github.com/alibaba-damo-academy/FunASR) 的本地语音识别解决方案。

62
main.py
View File

@ -12,6 +12,7 @@
import gc
import json
import os
import shutil
import subprocess
import time
@ -23,8 +24,7 @@ BASE_DIR = Path(__file__).parent.absolute()
TEMP_DIR = BASE_DIR / "temp"
OUTPUT_DIR = BASE_DIR / "output"
# 视频文件夹路径(全局变量)
VIDEO_FOLDER = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\宁波北仑区鼎邦杰西雅服饰有限公司"
# 支持的视频格式
SUPPORTED_VIDEO_FORMATS = ["*.mp4", "*.avi", "*.mkv", "*.mov", "*.flv", "*.wmv", "*.m4v"]
@ -271,7 +271,12 @@ def process_batch_asr(video_paths, diar_results, max_workers=1):
print("加载 ASR 模型...")
from asr_service import ASRService
asr_service = ASRService(model_name="paraformer-zh", device="auto")
asr_service = ASRService(
model_name="paraformer-zh",
device="auto",
min_segment_duration=0.5, # 增加最小片段时长
merge_gap=0.3 # 减少合并间隔
)
asr_service._load_model()
print("✓ ASR 模型已加载")
print()
@ -489,34 +494,54 @@ def process_batch_asr(video_paths, diar_results, max_workers=1):
return results
def main():
def main(path: str):
"""主函数"""
import torch
# 拼接根目录
path = os.path.join(os.path.dirname(__file__), "input", path)
print(f"开始处理路径:{path}")
print("\n" + "=" * 60)
print(" 并发批量语音识别处理系统")
print("=" * 60)
print()
# 0. 检查输入路径是否存在
if not Path(path).exists():
print(f"✗ 错误:输入路径不存在 - {path}")
return "输入路径不存在"
# 文件夹为空
if Path(path).is_dir() and not list(Path(path).iterdir()):
print("✗ 错误:文件夹为空")
return "文件夹为空"
# 1. 清空 temp 目录
clear_temp_dir()
# 2. 确保 output 目录存在
ensure_output_dir()
# 3. 准备视频列表(从 VIDEO_FOLDER 自动获取)
video_folder = Path(VIDEO_FOLDER)
if not video_folder.exists():
print(f"✗ 错误:视频文件夹不存在 - {video_folder}")
return
# 3. 准备视频列表(从 path 自动获取)
# 如果是文件夹,递归获取所有视频文件
# 如果是文件,直接添加到列表
video_paths = []
if Path(path).is_dir():
video_folder = Path(path)
if not video_folder.exists():
print(f"✗ 错误:视频文件夹不存在 - {video_folder}")
return
video_paths = get_video_list(video_folder)
video_paths = get_video_list(video_folder)
if not video_paths:
print("✗ 错误:未找到任何视频文件")
print(f"支持格式:{', '.join(SUPPORTED_VIDEO_FORMATS)}")
print(f"请检查文件夹:{video_folder}")
return
else:
video_paths.append(Path(path))
if not video_paths:
print("✗ 错误:未找到任何视频文件")
print(f"支持格式:{', '.join(SUPPORTED_VIDEO_FORMATS)}")
print(f"请检查文件夹:{video_folder}")
return
video_paths = list(set(video_paths)) # 去重
print(f"找到 {len(video_paths)} 个视频文件")
for vp in video_paths:
@ -569,4 +594,9 @@ def main():
if __name__ == "__main__":
main()
# 视频文件夹路径(全局变量)
# PATH = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\宁波北仑区鼎邦杰西雅服饰有限公司"
# PATH = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\temp"
# PATH = r"D:\Userfile\Projects\AnzezxianxHazardInspectAI\Code\audio2\input\temp\VID_20251104_085655_024.AVI"
PATH = r"VID_20251104_085655_024.AVI"
main(PATH)

View File

@ -59,5 +59,6 @@ jieba>=0.42.0
# onnxruntime-gpu>=1.23.0
# ---------- 可选Web API 服务 ----------
# Flask>=3.1.0
Flask>=3.1.0
waitress>=3.0.0
# SQLAlchemy>=2.0.0

414
server.py
View File

@ -6,16 +6,71 @@ Web API Server for ASR and Speaker Diarization
import os
import sys
import gc
from pathlib import Path
from flask import Flask, request, jsonify, send_file
from werkzeug.utils import secure_filename
import json
import shutil
import signal
import threading
from pathlib import Path
from flask import Flask, request, jsonify
from werkzeug.utils import secure_filename
import uuid
from datetime import datetime, timezone
def make_response(status="success", data=None, errors=None, message=None, extra=None):
"""
统一 API 响应格式
Args:
status: 状态 ("success" "error")
data: 返回的数据
errors: 错误列表
message: 消息
extra: 其他额外字段
Returns:
统一格式的 JSON 响应
"""
response = {
"status": status,
"data": data if data is not None else {},
"errors": errors if errors is not None else [],
"message": message or ("操作成功" if status == "success" else "操作失败"),
"timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
}
if extra:
response.update(extra)
return response
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024
app.config['UPLOAD_FOLDER'] = 'uploads'
app.config['RESULT_FOLDER'] = 'results'
# 增加请求超时时间(秒),支持长时间运行的任务
app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 300 # 5
OUTPUT_DIR = Path('output')
# 手动添加 CORS 响应头
@app.after_request
def after_request(response):
# 获取请求的 Origin
origin = request.headers.get('Origin', '*')
# 设置允许的来源(不使用通配符)
response.headers.add('Access-Control-Allow-Origin', origin)
response.headers.add('Access-Control-Allow-Credentials', 'true')
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
response.headers.add('Access-Control-Allow-Methods', 'GET,PUT,POST,DELETE,OPTIONS')
return response
# 处理 OPTIONS 预检请求
@app.route('/config', methods=['OPTIONS'])
def config_options():
"""处理 /config 接口的 OPTIONS 预检请求"""
return '', 200
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True)
@ -27,285 +82,148 @@ DIAR_MODEL_LOADED = False
ASR_MODEL_LOCK = threading.Lock()
DIAR_MODEL_LOCK = threading.Lock()
def get_asr_service():
global GLOBAL_ASR_SERVICE, ASR_MODEL_LOADED
if GLOBAL_ASR_SERVICE is None:
from asr_service import ASRService
GLOBAL_ASR_SERVICE = ASRService()
return GLOBAL_ASR_SERVICE
# 全局变量用于控制任务执行
task_running = {}
task_timeout = 600 # 10 分钟超时(秒)
def get_diar_service():
global GLOBAL_DIAR_SERVICE, DIAR_MODEL_LOADED
if GLOBAL_DIAR_SERVICE is None:
from diarization_service import DiarizationService
GLOBAL_DIAR_SERVICE = DiarizationService()
return GLOBAL_DIAR_SERVICE
@app.route('/health', methods=['GET'])
def health_check():
"""健康检查"""
return jsonify({
'status': 'ok',
'asr_loaded': ASR_MODEL_LOADED,
'diar_loaded': DIAR_MODEL_LOADED
})
@app.route('/api/asr/load', methods=['GET'])
def load_asr_model():
"""加载 ASR 模型"""
global ASR_MODEL_LOADED
with ASR_MODEL_LOCK:
if ASR_MODEL_LOADED:
return jsonify({'message': 'ASR 模型已加载', 'loaded': True})
try:
data = request.json or {}
model_name = data.get('model_name', 'paraformer-zh')
device = data.get('device', 'auto')
print(f"正在加载 ASR 模型: {model_name}, 设备: {device}")
service = get_asr_service()
service._load_model()
ASR_MODEL_LOADED = True
return jsonify({
'message': 'ASR 模型加载成功',
'loaded': True,
'model': model_name,
'device': device
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/api/diar/load', methods=['POST'])
def load_diar_model():
"""加载 3D-Speaker 模型"""
global DIAR_MODEL_LOADED
with DIAR_MODEL_LOCK:
if DIAR_MODEL_LOADED:
return jsonify({'message': '3D-Speaker 模型已加载', 'loaded': True})
try:
data = request.json or {}
embedding_model = data.get('embedding_model', 'eres2netv2')
device = data.get('device', 'auto')
cluster_threshold = data.get('cluster_threshold', 0.5)
min_cluster_size = data.get('min_cluster_size', 10)
print(f"正在加载 3D-Speaker 模型: {embedding_model}, 设备: {device}")
service = get_diar_service()
service._load_model()
DIAR_MODEL_LOADED = True
return jsonify({
'message': '3D-Speaker 模型加载成功',
'loaded': True,
'embedding_model': embedding_model,
'device': device
})
except Exception as e:
return jsonify({'error': str(e)}), 500
@app.route('/api/asr/unload', methods=['POST'])
def unload_asr_model():
"""卸载 ASR 模型"""
global GLOBAL_ASR_SERVICE, ASR_MODEL_LOADED
@app.route('/api/recognize', methods=['POST'])
def recognize():
"""文件推理 - 调用 main.py 的 main 函数"""
task_id = str(uuid.uuid4())
task_running[task_id] = True
try:
GLOBAL_ASR_SERVICE = None
ASR_MODEL_LOADED = False
gc.collect()
# 从请求参数获取路径(只接受文件名或 input 目录下的相对路径)
data = request.json or {}
path = data.get('path', '')
return jsonify({'message': 'ASR 模型已卸载'})
except Exception as e:
return jsonify({'error': str(e)}), 500
if not path:
return jsonify(make_response(
status="error",
message="请提供文件路径",
errors=["缺少必要参数path"]
)), 400
print(f"\n{'='*60}")
print(f"API 收到请求path={path}, task_id={task_id}")
print(f"{'='*60}")
print(f"开始调用 main 函数...")
@app.route('/api/diar/unload', methods=['POST'])
def unload_diar_model():
"""卸载 3D-Speaker 模型"""
global GLOBAL_DIAR_SERVICE, DIAR_MODEL_LOADED
# 切换到 server.py 所在目录
server_dir = Path(__file__).parent.absolute()
os.chdir(server_dir)
try:
GLOBAL_DIAR_SERVICE = None
DIAR_MODEL_LOADED = False
gc.collect()
print(f"当前工作目录:{os.getcwd()}")
return jsonify({'message': '3D-Speaker 模型已卸载'})
except Exception as e:
return jsonify({'error': str(e)}), 500
# 设置超时处理
def timeout_handler(signum, frame):
task_running[task_id] = False
raise TimeoutError(f"任务执行超时 ({task_timeout}秒)")
@app.route('/api/recognize/single', methods=['POST'])
def recognize_single():
"""单文件推理"""
try:
if 'file' not in request.files:
return jsonify({'error': '请上传音频文件'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': '文件名不能为空'}), 400
data = request.form or {}
use_3d_speaker = data.get('use_3d_speaker', 'false').lower() == 'true'
embedding_model = data.get('embedding_model', 'eres2netv2')
cluster_threshold = float(data.get('cluster_threshold', 0.5))
min_cluster_size = int(data.get('min_cluster_size', 10))
output_format = data.get('format', 'json')
filename = secure_filename(file.filename) if file.filename is not None else secure_filename("unnamed_file")
task_id = str(uuid.uuid4())[:8]
audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{task_id}_{filename}")
file.save(audio_path)
# 注册信号处理器(仅在 Unix/Linux/Mac 有效Windows 下会被忽略)
try:
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(task_timeout)
use_alarm = True
except (AttributeError, ValueError):
# Windows 不支持 SIGALRM
use_alarm = False
print("注意Windows 系统,使用超时检测线程")
try:
global ASR_MODEL_LOADED
service = get_asr_service()
if not ASR_MODEL_LOADED:
service._load_model()
ASR_MODEL_LOADED = True
# 调用 main.py 的 main 函数
from main import main
main(path)
sentences = service.recognize(
audio_path
)
# 取消超时
if use_alarm:
signal.alarm(0)
result = {
'file': filename,
'total_sentences': len(sentences),
'sentences': [s.to_dict() for s in sentences]
}
print(f"main 函数执行完成")
if output_format == 'json':
result_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_result.json")
with open(result_path, 'w', encoding='utf-8') as f:
import json
json.dump(result, f, ensure_ascii=False, indent=2)
return jsonify(result)
else:
srt_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_result.srt")
service.export_to_srt([s for s in sentences], srt_path)
return send_file(srt_path, as_attachment=True, download_name=f"{task_id}_result.srt")
return jsonify(make_response(
status="success",
message="文件推理完成",
data={"path": path, "task_id": task_id}
)), 200
except TimeoutError as e:
print(f"任务超时:{e}")
return jsonify(make_response(
status="error",
message=str(e),
errors=[str(e)]
)), 504 # Gateway Timeout
finally:
if os.path.exists(audio_path):
os.remove(audio_path)
# 清理信号处理器
if use_alarm:
signal.alarm(0)
task_running[task_id] = False
except Exception as e:
import traceback
traceback.print_exc()
return jsonify({'error': str(e)}), 500
task_running[task_id] = False
return jsonify(make_response(
status="error",
message=str(e),
errors=[str(e)]
)), 500
@app.route('/api/recognize/batch', methods=['POST'])
def recognize_batch():
"""批量推理"""
@app.route('/api/result', methods=['GET'])
def result():
"""获取文件推理结果"""
try:
if 'files' not in request.files:
return jsonify({'error': '请上传音频文件'}), 400
# 从请求参数获取路径
path = request.args.get('path', '')
files = request.files.getlist('files')
if not files or len(files) == 0:
return jsonify({'error': '文件列表为空'}), 400
data = request.form or {}
use_3d_speaker = data.get('use_3d_speaker', 'false').lower() == 'true'
embedding_model = data.get('embedding_model', 'eres2netv2')
cluster_threshold = float(data.get('cluster_threshold', 0.5))
min_cluster_size = int(data.get('min_cluster_size', 10))
task_id = str(uuid.uuid4())[:8]
audio_paths = []
for f in files:
if f.filename:
filename = secure_filename(f.filename)
audio_path = os.path.join(app.config['UPLOAD_FOLDER'], f"{task_id}_{filename}")
f.save(audio_path)
audio_paths.append(audio_path)
try:
global ASR_MODEL_LOADED
service = get_asr_service()
if not ASR_MODEL_LOADED:
service._load_model()
ASR_MODEL_LOADED = True
results = []
for audio_path in audio_paths:
try:
sentences = service.recognize(
audio_path,
)
results.append({
'file': os.path.basename(audio_path),
'total_sentences': len(sentences),
'sentences': [s.to_dict() for s in sentences]
})
except Exception as e:
results.append({
'file': os.path.basename(audio_path),
'error': str(e)
})
result_path = os.path.join(app.config['RESULT_FOLDER'], f"{task_id}_batch_result.json")
import json
with open(result_path, 'w', encoding='utf-8') as f:
json.dump({
'task_id': task_id,
'total_files': len(results),
'results': results
}, f, ensure_ascii=False, indent=2)
return jsonify({
'task_id': task_id,
'total_files': len(results),
'results': results,
'result_file': result_path
})
finally:
for audio_path in audio_paths:
if os.path.exists(audio_path):
os.remove(audio_path)
if not path:
return jsonify(make_response(
status="error",
message="请提供文件路径",
errors=["缺少必要参数path"]
)), 400
# 读取结果
print(Path(path).stem)
result_file = OUTPUT_DIR / f"{Path(path).stem}_result.json"
if result_file.exists():
with open(result_file, 'r', encoding='utf-8') as f:
result_data = json.load(f)
return jsonify(make_response(
status="success",
message="获取成功",
data=result_data
)), 200
else:
return jsonify(make_response(
status="error",
message="处理完成但未找到结果文件",
errors=["结果文件不存在"]
)), 404
except Exception as e:
import traceback
traceback.print_exc()
return jsonify({'error': str(e)}), 500
@app.route('/api/status', methods=['GET'])
def get_status():
"""获取模型状态"""
return jsonify({
'asr_loaded': ASR_MODEL_LOADED,
'diar_loaded': DIAR_MODEL_LOADED,
'asr_model': 'paraformer-zh',
'diar_model': '3D-Speaker'
})
return jsonify(make_response(
status="error",
message=str(e),
errors=[str(e)]
)), 500
if __name__ == '__main__':
print("=" * 60)
print(" ASR & Speaker Diarization API Server")
print("=" * 60)
print("\nAPI 接口:")
print(" GET /health - 健康检查")
print(" GET /api/status - 获取模型状态")
print(" POST /api/asr/load - 加载 ASR 模型")
print(" POST /api/diar/load - 加载 3D-Speaker 模型")
print(" POST /api/asr/unload - 卸载 ASR 模型")
print(" POST /api/diar/unload - 卸载 3D-Speaker 模型")
print(" POST /api/recognize/single - 单文件推理")
print(" POST /api/recognize/batch - 批量推理")
print(" POST /api/recognize - 文件推理")
print(" POST /api/result - 获取文件推理结果")
print("\n" + "=" * 60)
print("启动服务: http://localhost:5000")
print("启动服务http://localhost:5000")
print("使用 Waitress WSGI 服务器(无超时限制)")
print("=" * 60)
app.run(host='0.0.0.0', port=5000, debug=False)
from waitress import serve
serve(app, host='0.0.0.0', port=5000, threads=4, connection_limit=100)