qwen_brand/app.py

82 lines
2.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI
from pydantic import BaseModel
from modelscope import AutoModelForCausalLM, AutoTokenizer
import re
from typing import List
app = FastAPI()
# 初始化模型和tokenizer全局变量只加载一次
model_name = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
)
class BrandExtractionRequest(BaseModel):
goods: str
brand_list: str # 改为字符串格式,如 "apple,samsung,huawei"
@app.post("/extract_brand/")
async def extract_brand(request: BrandExtractionRequest):
goods = request.goods
brand_set = set([brand.strip() for brand in request.brand_list.split(",")]) # 解析逗号分隔的字符串并去空格
# 构建prompt
prompt = (
f"商品名称:{goods}\n"
"-只需要返回一个品牌的名字,去掉多余的描述\n"
f"-请在以下品牌中选择:{brand_set}"
)
messages = [
{"role": "system",
"content": '从商品名称中提取品牌名称,需严格匹配预定义的品牌列表。若未找到匹配品牌,返回 "失败" '},
{"role": "user", "content": prompt}
]
# 生成文本
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=32768
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
# 解析输出
try:
index = len(output_ids) - output_ids[::-1].index(151668) # 151668是tokenizer中的特殊token
except ValueError:
index = 0
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
# 后处理确保返回的品牌在brand_set中否则返回"失败"
extracted_brand = content.strip()
if extracted_brand not in brand_set:
# 尝试在goods中直接查找品牌名简单匹配
for brand in brand_set:
if brand in goods:
extracted_brand = brand
break
else:
extracted_brand = "失败"
return {"extracted_brand": extracted_brand}
# 测试用的GET端点可选
@app.get("/")
async def root():
return {"message": "Brand Extraction API is running"}