ai/SpeechRecognition/server.py

323 lines
9.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Web API Server for ASR and Speaker Diarization
提供语音识别和说话人分离的 REST API 服务
"""
import os
import subprocess
import sys
import gc
import json
import shutil
import signal
import threading
from pathlib import Path
from flask import Flask, request, jsonify
from werkzeug.utils import secure_filename
import uuid
from datetime import datetime, timezone
from lib.convert import convert_to_h264
import os
# 指定 ffmpeg 所在目录,注意路径中的空格用双引号包裹或使用原始字符串
ffmpeg_dir = r"C:\Program Files (x86)\Lenovo\LegionZone\2.0.23.3251\SEGamingAI\services\editor"
os.environ["PATH"] = ffmpeg_dir + ";" + os.environ.get("PATH", "")
# 如果项目使用 pydub 进行音频处理,也设置一下(可选)
try:
import pydub
pydub.AudioSegment.converter = os.path.join(ffmpeg_dir, "ffmpeg.exe")
except ImportError:
pass
def make_response(status="success", data=None, errors=None, message=None, extra=None):
"""
统一 API 响应格式
Args:
status: 状态 ("success""error")
data: 返回的数据
errors: 错误列表
message: 消息
extra: 其他额外字段
Returns:
统一格式的 JSON 响应
"""
response = {
"status": status,
"data": data if data is not None else {},
"errors": errors if errors is not None else [],
"message": message or ("操作成功" if status == "success" else "操作失败"),
"timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
}
if extra:
response.update(extra)
return response
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024
app.config['UPLOAD_FOLDER'] = 'uploads'
app.config['RESULT_FOLDER'] = 'results'
# 增加请求超时时间(秒),支持长时间运行的任务
app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 300 # 5
OUTPUT_DIR = Path('output')
# 手动添加 CORS 响应头
@app.after_request
def after_request(response):
# 获取请求的 Origin
origin = request.headers.get('Origin', '*')
# 设置允许的来源(不使用通配符)
response.headers.add('Access-Control-Allow-Origin', origin)
response.headers.add('Access-Control-Allow-Credentials', 'true')
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
response.headers.add('Access-Control-Allow-Methods', 'GET,PUT,POST,DELETE,OPTIONS')
return response
# 处理 OPTIONS 预检请求
@app.route('/config', methods=['OPTIONS'])
def config_options():
"""处理 /config 接口的 OPTIONS 预检请求"""
return '', 200
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True)
GLOBAL_ASR_SERVICE = None
GLOBAL_DIAR_SERVICE = None
ASR_MODEL_LOADED = False
DIAR_MODEL_LOADED = False
ASR_MODEL_LOCK = threading.Lock()
DIAR_MODEL_LOCK = threading.Lock()
# 全局变量用于控制任务执行
task_running = {}
task_timeout = 600 # 10 分钟超时(秒)
@app.route('/api/recognize', methods=['GET'])
def recognize():
"""文件推理 - 调用 main.py 的 main 函数"""
task_id = str(uuid.uuid4())
task_running[task_id] = True
try:
# 从请求参数获取路径(只接受文件名或 input 目录下的相对路径)
path = request.args.get('path', '')
if not path:
return jsonify(make_response(
status="error",
message="请提供文件路径",
errors=["缺少必要参数path"]
)), 400
print(f"\n{'='*60}")
print(f"API 收到请求path={path}, task_id={task_id}")
print(f"{'='*60}")
print(f"开始调用 main 函数...")
# 切换到 server.py 所在目录
server_dir = Path(__file__).parent.absolute()
os.chdir(server_dir)
print(f"当前工作目录:{os.getcwd()}")
# 设置超时处理
def timeout_handler(signum, frame):
task_running[task_id] = False
raise TimeoutError(f"任务执行超时 ({task_timeout}秒)")
# 注册信号处理器(仅在 Unix/Linux/Mac 有效Windows 下会被忽略)
try:
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(task_timeout)
use_alarm = True
except (AttributeError, ValueError):
# Windows 不支持 SIGALRM
use_alarm = False
print("注意Windows 系统,使用超时检测线程")
try:
# 调用 main.py 的 main 函数
from main import main
main(path)
# 取消超时
if use_alarm:
signal.alarm(0)
print(f"main 函数执行完成")
return jsonify(make_response(
status="success",
message="文件推理完成",
data={"path": path, "task_id": task_id}
)), 200
except TimeoutError as e:
print(f"任务超时:{e}")
return jsonify(make_response(
status="error",
message=str(e),
errors=[str(e)]
)), 504 # Gateway Timeout
finally:
# 清理信号处理器
if use_alarm:
signal.alarm(0)
task_running[task_id] = False
except Exception as e:
import traceback
traceback.print_exc()
task_running[task_id] = False
return jsonify(make_response(
status="error",
message=str(e),
errors=[str(e)]
)), 500
@app.route('/api/result', methods=['GET'])
def result():
"""获取文件推理结果"""
try:
# 从请求参数获取路径
path = request.args.get('path', '')
if not path:
return jsonify(make_response(
status="error",
message="请提供文件路径",
errors=["缺少必要参数path"]
)), 400
# 读取结果
print(Path(path).stem)
result_file = OUTPUT_DIR / f"{Path(path).stem}_result.json"
if result_file.exists():
with open(result_file, 'r', encoding='utf-8') as f:
result_data = json.load(f)
return jsonify(make_response(
status="success",
message="获取成功",
data=result_data
)), 200
else:
return jsonify(make_response(
status="error",
message="处理完成但未找到结果文件",
errors=["结果文件不存在"]
)), 404
except Exception as e:
import traceback
traceback.print_exc()
return jsonify(make_response(
status="error",
message=str(e),
errors=[str(e)]
)), 500
@app.route('/api/convert', methods=['GET'])
def convert():
"""视频文件转码"""
try:
# 从请求参数获取路径
path = request.args.get('path', '')
if not path:
return jsonify(make_response(
status="error",
message="请提供文件路径",
errors=["缺少必要参数path"]
)), 400
# 转码视频文件
output_path = convert_to_h264(path)
return jsonify(make_response(
status="success",
message="视频文件转码完成",
data={"path": output_path}
)), 200
except Exception as e:
import traceback
traceback.print_exc()
return jsonify(make_response(
status="error",
message=str(e),
errors=[str(e)]
)), 500
@app.route('/api/getVidUrl', methods=['GET'])
def getVidUrl():
"""获取视频文件URL"""
try:
# 从请求参数获取路径
path = request.args.get('path', '')
if not path:
return jsonify(make_response(
status="error",
message="请提供文件路径",
errors=["缺少必要参数path"]
)), 400
# 检查视频文件是否存在
if not Path(f"vid_h264/{Path(path).stem}_h264.mp4").exists():
return jsonify(make_response(
status="error",
message="视频文件不存在",
errors=["视频文件不存在"]
)), 404
# 生成视频文件URL
url = f"http://localhost:8086/{Path(path).stem}_h264.mp4"
print(url)
return jsonify(make_response(
status="success",
message="获取成功",
data={"url": url}
)), 200
except Exception as e:
import traceback
traceback.print_exc()
return jsonify(make_response(
status="error",
message=str(e),
errors=[str(e)]
)), 500
if __name__ == '__main__':
print("=" * 60)
print(" ASR & Speaker Diarization API Server")
print("=" * 60)
print("\nAPI 接口:")
print(" GET /api/recognize - 文件推理")
print(" GET /api/result - 获取文件推理结果")
print(" GET /api/convert - 转码视频文件")
print(" GET /api/getVidUrl - 获取视频文件URL")
print("\n" + "=" * 60)
print("启动服务: http://localhost:5176")
print("使用 Waitress WSGI 服务器(无超时限制)")
print("=" * 60)
# 启动 Caddy 服务(后台运行)
caddy_dir = os.path.join(os.path.dirname(__file__), "vid_h264")
caddy_exe = os.path.join(os.path.dirname(__file__), "lib", "caddy_windows_amd64.exe")
subprocess.Popen([caddy_exe, "file-server", "--listen", ":8086", "--browse"], cwd=caddy_dir, shell=True)
from waitress import serve
serve(app, host='0.0.0.0', port=5000, threads=4, connection_limit=100)