from flask import Flask, request, jsonify import torch from transformers import BertTokenizer, BertForSequenceClassification import joblib import re from functools import lru_cache from concurrent.futures import ThreadPoolExecutor from typing import List, Dict import threading # 配置参数 MAX_LENGTH = 512 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" SAVE_DIR = "./cate1" # 模型路径 LABEL_ENCODER_PATH = "cate1/label_encoder.pkl" # 标签映射器路径 app = Flask(__name__) app.config['JSON_AS_ASCII'] = False # 全局变量锁 model_lock = threading.Lock() class Predictor: """预测器类,用于管理模型和分词器的生命周期""" _instance = None def __new__(cls): if cls._instance is None: with model_lock: if cls._instance is None: # 双重检查锁定 cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self): if self._initialized: return # 加载模型和分词器 self.tokenizer = BertTokenizer.from_pretrained(SAVE_DIR) self.model = BertForSequenceClassification.from_pretrained(SAVE_DIR).to(DEVICE) self.model.eval() # 加载标签映射器 self.label_encoder = joblib.load(LABEL_ENCODER_PATH) self.label_map = {i: label for i, label in enumerate(self.label_encoder.classes_)} # 创建线程池 self.executor = ThreadPoolExecutor(max_workers=4) self._initialized = True @lru_cache(maxsize=1000) def clean_text(self, text: str) -> str: """文本清洗函数,带缓存""" if not isinstance(text, str): return "" cleaned_text = re.sub(r'[^\u4e00-\u9fa5]', '', text) return cleaned_text.strip() def predict_single(self, text: str) -> Dict: """单个文本预测""" with torch.no_grad(): cleaned_text = self.clean_text(text) if not cleaned_text: return {"category": "", "confidence": 0.0} inputs = self.tokenizer( cleaned_text, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LENGTH ).to(DEVICE) outputs = self.model(**inputs) probs = torch.softmax(outputs.logits, dim=1).cpu() top_prob, top_idx = torch.topk(probs, k=1) #去掉置信度,速度提升10%-30% # top_idx = torch.argmax(outputs.logits, dim=1) return { "category": self.label_map[top_idx.item()], "confidence": top_prob.item() } def batch_predict(self, texts: List[str]) -> List[Dict]: """批量预测(并发处理)""" if not texts: return [] # 使用线程池并发处理 futures = [self.executor.submit(self.predict_single, text) for text in texts] return [future.result() for future in futures] # 初始化预测器 predictor = Predictor() @app.route('/predict', methods=['POST']) def predict(): """预测接口""" data = request.get_json() if not data or 'product_names' not in data: return jsonify({"error": "Invalid request, 'product_names' is required"}), 400 product_names = data['product_names'] if not isinstance(product_names, list): product_names = [product_names] # 批量预测 results = predictor.batch_predict(product_names) return jsonify({ "status": "success", "predictions": results }) @app.route('/health', methods=['GET']) def health_check(): """健康检查接口""" return jsonify({"status": "healthy"}) if __name__ == '__main__': app.run(host='0.0.0.0', port=5000, threaded=True) else: application = app # 兼容 WSGI 标准(如 Gunicorn)