131 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			131 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
from flask import Flask, request, jsonify
 | 
						||
import torch
 | 
						||
from transformers import BertTokenizer, BertForSequenceClassification
 | 
						||
import joblib
 | 
						||
import re
 | 
						||
from functools import lru_cache
 | 
						||
from concurrent.futures import ThreadPoolExecutor
 | 
						||
from typing import List, Dict
 | 
						||
import threading
 | 
						||
 | 
						||
# 配置参数
 | 
						||
MAX_LENGTH = 512
 | 
						||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 | 
						||
SAVE_DIR = "cate1"  # 模型路径
 | 
						||
LABEL_ENCODER_PATH = "cate1/label_encoder.pkl"  # 标签映射器路径
 | 
						||
 | 
						||
app = Flask(__name__)
 | 
						||
app.config['JSON_AS_ASCII'] = False
 | 
						||
# 全局变量锁
 | 
						||
model_lock = threading.Lock()
 | 
						||
 | 
						||
 | 
						||
class Predictor:
 | 
						||
    """预测器类,用于管理模型和分词器的生命周期"""
 | 
						||
    _instance = None
 | 
						||
 | 
						||
    def __new__(cls):
 | 
						||
        if cls._instance is None:
 | 
						||
            with model_lock:
 | 
						||
                if cls._instance is None:  # 双重检查锁定
 | 
						||
                    cls._instance = super().__new__(cls)
 | 
						||
                    cls._instance._initialized = False
 | 
						||
        return cls._instance
 | 
						||
 | 
						||
    def __init__(self):
 | 
						||
        if self._initialized:
 | 
						||
            return
 | 
						||
 | 
						||
        # 加载模型和分词器
 | 
						||
        self.tokenizer = BertTokenizer.from_pretrained(SAVE_DIR)
 | 
						||
        self.model = BertForSequenceClassification.from_pretrained(SAVE_DIR).to(DEVICE)
 | 
						||
        self.model.eval()
 | 
						||
 | 
						||
        # 加载标签映射器
 | 
						||
        self.label_encoder = joblib.load(LABEL_ENCODER_PATH)
 | 
						||
        self.label_map = {i: label for i, label in enumerate(self.label_encoder.classes_)}
 | 
						||
 | 
						||
        # 创建线程池
 | 
						||
        self.executor = ThreadPoolExecutor(max_workers=4)
 | 
						||
        self._initialized = True
 | 
						||
 | 
						||
    @lru_cache(maxsize=1000)
 | 
						||
    def clean_text(self, text: str) -> str:
 | 
						||
        """文本清洗函数,带缓存"""
 | 
						||
        if not isinstance(text, str):
 | 
						||
            return ""
 | 
						||
        cleaned_text = re.sub(r'[^\u4e00-\u9fa5]', '', text)
 | 
						||
        return cleaned_text.strip()
 | 
						||
 | 
						||
    def predict_single(self, text: str) -> Dict:
 | 
						||
        """单个文本预测"""
 | 
						||
        with torch.no_grad():
 | 
						||
            cleaned_text = self.clean_text(text)
 | 
						||
            if not cleaned_text:
 | 
						||
                return {"category": "", "confidence": 0.0}
 | 
						||
 | 
						||
            inputs = self.tokenizer(
 | 
						||
                cleaned_text,
 | 
						||
                return_tensors="pt",
 | 
						||
                truncation=True,
 | 
						||
                padding=True,
 | 
						||
                max_length=MAX_LENGTH
 | 
						||
            ).to(DEVICE)
 | 
						||
 | 
						||
            outputs = self.model(**inputs)
 | 
						||
            probs = torch.softmax(outputs.logits, dim=1).cpu()
 | 
						||
            top_prob, top_idx = torch.topk(probs, k=1)
 | 
						||
 | 
						||
            #去掉置信度,速度提升10%-30%
 | 
						||
            # top_idx = torch.argmax(outputs.logits, dim=1)
 | 
						||
            return {
 | 
						||
                "category": self.label_map[top_idx.item()],
 | 
						||
                "confidence": top_prob.item()
 | 
						||
            }
 | 
						||
 | 
						||
    def batch_predict(self, texts: List[str]) -> List[Dict]:
 | 
						||
        """批量预测(并发处理)"""
 | 
						||
        if not texts:
 | 
						||
            return []
 | 
						||
 | 
						||
        # 使用线程池并发处理
 | 
						||
        futures = [self.executor.submit(self.predict_single, text) for text in texts]
 | 
						||
        return [future.result() for future in futures]
 | 
						||
 | 
						||
 | 
						||
# 初始化预测器
 | 
						||
predictor = Predictor()
 | 
						||
 | 
						||
 | 
						||
@app.route('/predict', methods=['POST'])
 | 
						||
def predict():
 | 
						||
    """预测接口"""
 | 
						||
    data = request.get_json()
 | 
						||
    if not data or 'product_names' not in data:
 | 
						||
        return jsonify({"error": "Invalid request, 'product_names' is required"}), 400
 | 
						||
 | 
						||
    product_names = data['product_names']
 | 
						||
    if not isinstance(product_names, list):
 | 
						||
        product_names = [product_names]
 | 
						||
 | 
						||
    # 批量预测
 | 
						||
    results = predictor.batch_predict(product_names)
 | 
						||
 | 
						||
    return jsonify({
 | 
						||
        "status": "success",
 | 
						||
        "predictions": results
 | 
						||
    })
 | 
						||
 | 
						||
 | 
						||
@app.route('/health', methods=['GET'])
 | 
						||
def health_check():
 | 
						||
    """健康检查接口"""
 | 
						||
    return jsonify({"status": "healthy"})
 | 
						||
 | 
						||
 | 
						||
if __name__ == '__main__':
 | 
						||
    app.run(host='0.0.0.0', port=5002, threaded=True)
 | 
						||
else:
 | 
						||
    application = app  # 兼容 WSGI 标准(如 Gunicorn)
 | 
						||
 |