from fastapi import FastAPI from pydantic import BaseModel from modelscope import AutoModelForCausalLM, AutoTokenizer import re from typing import List app = FastAPI() # 初始化模型和tokenizer(全局变量,只加载一次) model_name = "Qwen/Qwen3-0.6B" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="auto", device_map="auto", ) class BrandExtractionRequest(BaseModel): goods: str brand_list: str # 改为字符串格式,如 "apple,samsung,huawei" @app.post("/extract_brand/") async def extract_brand(request: BrandExtractionRequest): goods = request.goods brand_set = set([brand.strip() for brand in request.brand_list.split(",")]) # 解析逗号分隔的字符串并去空格 # 构建prompt prompt = ( f"商品名称:{goods}\n" "-只需要返回一个品牌的名字,去掉多余的描述\n" f"-请在以下品牌中选择:{brand_set}" ) messages = [ {"role": "system", "content": '从商品名称中提取品牌名称,需严格匹配预定义的品牌列表。若未找到匹配品牌,返回 "失败" '}, {"role": "user", "content": prompt} ] # 生成文本 text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False ) model_inputs = tokenizer([text], return_tensors="pt").to(model.device) generated_ids = model.generate( **model_inputs, max_new_tokens=32768 ) output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() # 解析输出 try: index = len(output_ids) - output_ids[::-1].index(151668) # 151668是tokenizer中的特殊token except ValueError: index = 0 content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n") # 后处理:确保返回的品牌在brand_set中,否则返回"失败" extracted_brand = content.strip() if extracted_brand not in brand_set: # 尝试在goods中直接查找品牌名(简单匹配) for brand in brand_set: if brand in goods: extracted_brand = brand break else: extracted_brand = "失败" return {"extracted_brand": extracted_brand} # 测试用的GET端点(可选) @app.get("/") async def root(): return {"message": "Brand Extraction API is running"}