323 lines
9.7 KiB
Python
323 lines
9.7 KiB
Python
"""
|
||
Web API Server for ASR and Speaker Diarization
|
||
提供语音识别和说话人分离的 REST API 服务
|
||
"""
|
||
|
||
import os
|
||
import subprocess
|
||
import sys
|
||
import gc
|
||
import json
|
||
import shutil
|
||
import signal
|
||
import threading
|
||
from pathlib import Path
|
||
from flask import Flask, request, jsonify
|
||
from werkzeug.utils import secure_filename
|
||
import uuid
|
||
from datetime import datetime, timezone
|
||
|
||
from lib.convert import convert_to_h264
|
||
|
||
import os
|
||
|
||
# 指定 ffmpeg 所在目录,注意路径中的空格用双引号包裹或使用原始字符串
|
||
ffmpeg_dir = r"C:\Program Files (x86)\Lenovo\LegionZone\2.0.23.3251\SEGamingAI\services\editor"
|
||
os.environ["PATH"] = ffmpeg_dir + ";" + os.environ.get("PATH", "")
|
||
|
||
# 如果项目使用 pydub 进行音频处理,也设置一下(可选)
|
||
try:
|
||
import pydub
|
||
pydub.AudioSegment.converter = os.path.join(ffmpeg_dir, "ffmpeg.exe")
|
||
except ImportError:
|
||
pass
|
||
def make_response(status="success", data=None, errors=None, message=None, extra=None):
|
||
"""
|
||
统一 API 响应格式
|
||
|
||
Args:
|
||
status: 状态 ("success" 或 "error")
|
||
data: 返回的数据
|
||
errors: 错误列表
|
||
message: 消息
|
||
extra: 其他额外字段
|
||
|
||
Returns:
|
||
统一格式的 JSON 响应
|
||
"""
|
||
response = {
|
||
"status": status,
|
||
"data": data if data is not None else {},
|
||
"errors": errors if errors is not None else [],
|
||
"message": message or ("操作成功" if status == "success" else "操作失败"),
|
||
"timestamp": datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||
}
|
||
|
||
if extra:
|
||
response.update(extra)
|
||
|
||
return response
|
||
|
||
app = Flask(__name__)
|
||
app.config['MAX_CONTENT_LENGTH'] = 500 * 1024 * 1024
|
||
app.config['UPLOAD_FOLDER'] = 'uploads'
|
||
app.config['RESULT_FOLDER'] = 'results'
|
||
# 增加请求超时时间(秒),支持长时间运行的任务
|
||
app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 300 # 5
|
||
OUTPUT_DIR = Path('output')
|
||
|
||
# 手动添加 CORS 响应头
|
||
@app.after_request
|
||
def after_request(response):
|
||
# 获取请求的 Origin
|
||
origin = request.headers.get('Origin', '*')
|
||
|
||
# 设置允许的来源(不使用通配符)
|
||
response.headers.add('Access-Control-Allow-Origin', origin)
|
||
response.headers.add('Access-Control-Allow-Credentials', 'true')
|
||
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
|
||
response.headers.add('Access-Control-Allow-Methods', 'GET,PUT,POST,DELETE,OPTIONS')
|
||
return response
|
||
|
||
|
||
# 处理 OPTIONS 预检请求
|
||
@app.route('/config', methods=['OPTIONS'])
|
||
def config_options():
|
||
"""处理 /config 接口的 OPTIONS 预检请求"""
|
||
return '', 200
|
||
|
||
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
||
os.makedirs(app.config['RESULT_FOLDER'], exist_ok=True)
|
||
|
||
GLOBAL_ASR_SERVICE = None
|
||
GLOBAL_DIAR_SERVICE = None
|
||
ASR_MODEL_LOADED = False
|
||
DIAR_MODEL_LOADED = False
|
||
ASR_MODEL_LOCK = threading.Lock()
|
||
DIAR_MODEL_LOCK = threading.Lock()
|
||
|
||
# 全局变量用于控制任务执行
|
||
task_running = {}
|
||
task_timeout = 600 # 10 分钟超时(秒)
|
||
|
||
|
||
@app.route('/api/recognize', methods=['GET'])
|
||
def recognize():
|
||
"""文件推理 - 调用 main.py 的 main 函数"""
|
||
task_id = str(uuid.uuid4())
|
||
task_running[task_id] = True
|
||
|
||
try:
|
||
|
||
# 从请求参数获取路径(只接受文件名或 input 目录下的相对路径)
|
||
path = request.args.get('path', '')
|
||
|
||
if not path:
|
||
return jsonify(make_response(
|
||
status="error",
|
||
message="请提供文件路径",
|
||
errors=["缺少必要参数:path"]
|
||
)), 400
|
||
|
||
print(f"\n{'='*60}")
|
||
print(f"API 收到请求:path={path}, task_id={task_id}")
|
||
print(f"{'='*60}")
|
||
print(f"开始调用 main 函数...")
|
||
|
||
# 切换到 server.py 所在目录
|
||
server_dir = Path(__file__).parent.absolute()
|
||
os.chdir(server_dir)
|
||
|
||
print(f"当前工作目录:{os.getcwd()}")
|
||
|
||
# 设置超时处理
|
||
def timeout_handler(signum, frame):
|
||
task_running[task_id] = False
|
||
raise TimeoutError(f"任务执行超时 ({task_timeout}秒)")
|
||
|
||
# 注册信号处理器(仅在 Unix/Linux/Mac 有效,Windows 下会被忽略)
|
||
try:
|
||
signal.signal(signal.SIGALRM, timeout_handler)
|
||
signal.alarm(task_timeout)
|
||
use_alarm = True
|
||
except (AttributeError, ValueError):
|
||
# Windows 不支持 SIGALRM
|
||
use_alarm = False
|
||
print("注意:Windows 系统,使用超时检测线程")
|
||
|
||
try:
|
||
# 调用 main.py 的 main 函数
|
||
from main import main
|
||
main(path)
|
||
|
||
# 取消超时
|
||
if use_alarm:
|
||
signal.alarm(0)
|
||
|
||
print(f"main 函数执行完成")
|
||
|
||
return jsonify(make_response(
|
||
status="success",
|
||
message="文件推理完成",
|
||
data={"path": path, "task_id": task_id}
|
||
)), 200
|
||
|
||
except TimeoutError as e:
|
||
print(f"任务超时:{e}")
|
||
return jsonify(make_response(
|
||
status="error",
|
||
message=str(e),
|
||
errors=[str(e)]
|
||
)), 504 # Gateway Timeout
|
||
|
||
finally:
|
||
# 清理信号处理器
|
||
if use_alarm:
|
||
signal.alarm(0)
|
||
task_running[task_id] = False
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
task_running[task_id] = False
|
||
return jsonify(make_response(
|
||
status="error",
|
||
message=str(e),
|
||
errors=[str(e)]
|
||
)), 500
|
||
|
||
|
||
@app.route('/api/result', methods=['GET'])
|
||
def result():
|
||
"""获取文件推理结果"""
|
||
try:
|
||
# 从请求参数获取路径
|
||
path = request.args.get('path', '')
|
||
|
||
if not path:
|
||
return jsonify(make_response(
|
||
status="error",
|
||
message="请提供文件路径",
|
||
errors=["缺少必要参数:path"]
|
||
)), 400
|
||
|
||
# 读取结果
|
||
print(Path(path).stem)
|
||
result_file = OUTPUT_DIR / f"{Path(path).stem}_result.json"
|
||
if result_file.exists():
|
||
with open(result_file, 'r', encoding='utf-8') as f:
|
||
result_data = json.load(f)
|
||
return jsonify(make_response(
|
||
status="success",
|
||
message="获取成功",
|
||
data=result_data
|
||
)), 200
|
||
else:
|
||
return jsonify(make_response(
|
||
status="error",
|
||
message="处理完成但未找到结果文件",
|
||
errors=["结果文件不存在"]
|
||
)), 404
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
return jsonify(make_response(
|
||
status="error",
|
||
message=str(e),
|
||
errors=[str(e)]
|
||
)), 500
|
||
|
||
@app.route('/api/convert', methods=['GET'])
|
||
def convert():
|
||
"""视频文件转码"""
|
||
try:
|
||
# 从请求参数获取路径
|
||
path = request.args.get('path', '')
|
||
|
||
if not path:
|
||
return jsonify(make_response(
|
||
status="error",
|
||
message="请提供文件路径",
|
||
errors=["缺少必要参数:path"]
|
||
)), 400
|
||
|
||
# 转码视频文件
|
||
output_path = convert_to_h264(path)
|
||
|
||
return jsonify(make_response(
|
||
status="success",
|
||
message="视频文件转码完成",
|
||
data={"path": output_path}
|
||
)), 200
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
return jsonify(make_response(
|
||
status="error",
|
||
message=str(e),
|
||
errors=[str(e)]
|
||
)), 500
|
||
|
||
@app.route('/api/getVidUrl', methods=['GET'])
|
||
def getVidUrl():
|
||
"""获取视频文件URL"""
|
||
try:
|
||
# 从请求参数获取路径
|
||
path = request.args.get('path', '')
|
||
|
||
if not path:
|
||
return jsonify(make_response(
|
||
status="error",
|
||
message="请提供文件路径",
|
||
errors=["缺少必要参数:path"]
|
||
)), 400
|
||
|
||
|
||
# 检查视频文件是否存在
|
||
if not Path(f"vid_h264/{Path(path).stem}_h264.mp4").exists():
|
||
return jsonify(make_response(
|
||
status="error",
|
||
message="视频文件不存在",
|
||
errors=["视频文件不存在"]
|
||
)), 404
|
||
|
||
# 生成视频文件URL
|
||
url = f"http://localhost:8086/{Path(path).stem}_h264.mp4"
|
||
print(url)
|
||
|
||
return jsonify(make_response(
|
||
status="success",
|
||
message="获取成功",
|
||
data={"url": url}
|
||
)), 200
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
return jsonify(make_response(
|
||
status="error",
|
||
message=str(e),
|
||
errors=[str(e)]
|
||
)), 500
|
||
|
||
if __name__ == '__main__':
|
||
print("=" * 60)
|
||
print(" ASR & Speaker Diarization API Server")
|
||
print("=" * 60)
|
||
print("\nAPI 接口:")
|
||
print(" GET /api/recognize - 文件推理")
|
||
print(" GET /api/result - 获取文件推理结果")
|
||
print(" GET /api/convert - 转码视频文件")
|
||
print(" GET /api/getVidUrl - 获取视频文件URL")
|
||
print("\n" + "=" * 60)
|
||
print("启动服务: http://localhost:5176")
|
||
print("使用 Waitress WSGI 服务器(无超时限制)")
|
||
print("=" * 60)
|
||
|
||
# 启动 Caddy 服务(后台运行)
|
||
caddy_dir = os.path.join(os.path.dirname(__file__), "vid_h264")
|
||
caddy_exe = os.path.join(os.path.dirname(__file__), "lib", "caddy_windows_amd64.exe")
|
||
subprocess.Popen([caddy_exe, "file-server", "--listen", ":8086", "--browse"], cwd=caddy_dir, shell=True)
|
||
|
||
from waitress import serve
|
||
serve(app, host='0.0.0.0', port=5000, threads=4, connection_limit=100)
|