131 lines
3.9 KiB
Python
131 lines
3.9 KiB
Python
from flask import Flask, request, jsonify
|
||
import torch
|
||
from transformers import BertTokenizer, BertForSequenceClassification
|
||
import joblib
|
||
import re
|
||
from functools import lru_cache
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from typing import List, Dict
|
||
import threading
|
||
|
||
# 配置参数
|
||
MAX_LENGTH = 512
|
||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||
SAVE_DIR = "./cate1" # 模型路径
|
||
LABEL_ENCODER_PATH = "cate1/label_encoder.pkl" # 标签映射器路径
|
||
|
||
app = Flask(__name__)
|
||
app.config['JSON_AS_ASCII'] = False
|
||
# 全局变量锁
|
||
model_lock = threading.Lock()
|
||
|
||
|
||
class Predictor:
|
||
"""预测器类,用于管理模型和分词器的生命周期"""
|
||
_instance = None
|
||
|
||
def __new__(cls):
|
||
if cls._instance is None:
|
||
with model_lock:
|
||
if cls._instance is None: # 双重检查锁定
|
||
cls._instance = super().__new__(cls)
|
||
cls._instance._initialized = False
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if self._initialized:
|
||
return
|
||
|
||
# 加载模型和分词器
|
||
self.tokenizer = BertTokenizer.from_pretrained(SAVE_DIR)
|
||
self.model = BertForSequenceClassification.from_pretrained(SAVE_DIR).to(DEVICE)
|
||
self.model.eval()
|
||
|
||
# 加载标签映射器
|
||
self.label_encoder = joblib.load(LABEL_ENCODER_PATH)
|
||
self.label_map = {i: label for i, label in enumerate(self.label_encoder.classes_)}
|
||
|
||
# 创建线程池
|
||
self.executor = ThreadPoolExecutor(max_workers=4)
|
||
self._initialized = True
|
||
|
||
@lru_cache(maxsize=1000)
|
||
def clean_text(self, text: str) -> str:
|
||
"""文本清洗函数,带缓存"""
|
||
if not isinstance(text, str):
|
||
return ""
|
||
cleaned_text = re.sub(r'[^\u4e00-\u9fa5]', '', text)
|
||
return cleaned_text.strip()
|
||
|
||
def predict_single(self, text: str) -> Dict:
|
||
"""单个文本预测"""
|
||
with torch.no_grad():
|
||
cleaned_text = self.clean_text(text)
|
||
if not cleaned_text:
|
||
return {"category": "", "confidence": 0.0}
|
||
|
||
inputs = self.tokenizer(
|
||
cleaned_text,
|
||
return_tensors="pt",
|
||
truncation=True,
|
||
padding=True,
|
||
max_length=MAX_LENGTH
|
||
).to(DEVICE)
|
||
|
||
outputs = self.model(**inputs)
|
||
probs = torch.softmax(outputs.logits, dim=1).cpu()
|
||
top_prob, top_idx = torch.topk(probs, k=1)
|
||
|
||
#去掉置信度,速度提升10%-30%
|
||
# top_idx = torch.argmax(outputs.logits, dim=1)
|
||
return {
|
||
"category": self.label_map[top_idx.item()],
|
||
"confidence": top_prob.item()
|
||
}
|
||
|
||
def batch_predict(self, texts: List[str]) -> List[Dict]:
|
||
"""批量预测(并发处理)"""
|
||
if not texts:
|
||
return []
|
||
|
||
# 使用线程池并发处理
|
||
futures = [self.executor.submit(self.predict_single, text) for text in texts]
|
||
return [future.result() for future in futures]
|
||
|
||
|
||
# 初始化预测器
|
||
predictor = Predictor()
|
||
|
||
|
||
@app.route('/predict', methods=['POST'])
|
||
def predict():
|
||
"""预测接口"""
|
||
data = request.get_json()
|
||
if not data or 'product_names' not in data:
|
||
return jsonify({"error": "Invalid request, 'product_names' is required"}), 400
|
||
|
||
product_names = data['product_names']
|
||
if not isinstance(product_names, list):
|
||
product_names = [product_names]
|
||
|
||
# 批量预测
|
||
results = predictor.batch_predict(product_names)
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"predictions": results
|
||
})
|
||
|
||
|
||
@app.route('/health', methods=['GET'])
|
||
def health_check():
|
||
"""健康检查接口"""
|
||
return jsonify({"status": "healthy"})
|
||
|
||
|
||
if __name__ == '__main__':
|
||
app.run(host='0.0.0.0', port=5000, threaded=True)
|
||
else:
|
||
application = app # 兼容 WSGI 标准(如 Gunicorn)
|
||
|