82 lines
2.5 KiB
Python
82 lines
2.5 KiB
Python
from fastapi import FastAPI
|
||
from pydantic import BaseModel
|
||
from modelscope import AutoModelForCausalLM, AutoTokenizer
|
||
import re
|
||
from typing import List
|
||
|
||
app = FastAPI()
|
||
|
||
# 初始化模型和tokenizer(全局变量,只加载一次)
|
||
model_name = "Qwen/Qwen3-0.6B"
|
||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
model_name,
|
||
torch_dtype="auto",
|
||
device_map="auto",
|
||
)
|
||
|
||
|
||
class BrandExtractionRequest(BaseModel):
|
||
goods: str
|
||
brand_list: str # 改为字符串格式,如 "apple,samsung,huawei"
|
||
|
||
|
||
@app.post("/extract_brand/")
|
||
async def extract_brand(request: BrandExtractionRequest):
|
||
goods = request.goods
|
||
brand_set = set([brand.strip() for brand in request.brand_list.split(",")]) # 解析逗号分隔的字符串并去空格
|
||
|
||
# 构建prompt
|
||
prompt = (
|
||
f"商品名称:{goods}\n"
|
||
"-只需要返回一个品牌的名字,去掉多余的描述\n"
|
||
f"-请在以下品牌中选择:{brand_set}"
|
||
)
|
||
|
||
messages = [
|
||
{"role": "system",
|
||
"content": '从商品名称中提取品牌名称,需严格匹配预定义的品牌列表。若未找到匹配品牌,返回 "失败" '},
|
||
{"role": "user", "content": prompt}
|
||
]
|
||
|
||
# 生成文本
|
||
text = tokenizer.apply_chat_template(
|
||
messages,
|
||
tokenize=False,
|
||
add_generation_prompt=True,
|
||
enable_thinking=False
|
||
)
|
||
|
||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||
generated_ids = model.generate(
|
||
**model_inputs,
|
||
max_new_tokens=32768
|
||
)
|
||
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
|
||
|
||
# 解析输出
|
||
try:
|
||
index = len(output_ids) - output_ids[::-1].index(151668) # 151668是tokenizer中的特殊token
|
||
except ValueError:
|
||
index = 0
|
||
|
||
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
|
||
|
||
# 后处理:确保返回的品牌在brand_set中,否则返回"失败"
|
||
extracted_brand = content.strip()
|
||
if extracted_brand not in brand_set:
|
||
# 尝试在goods中直接查找品牌名(简单匹配)
|
||
for brand in brand_set:
|
||
if brand in goods:
|
||
extracted_brand = brand
|
||
break
|
||
else:
|
||
extracted_brand = "失败"
|
||
|
||
return {"extracted_brand": extracted_brand}
|
||
|
||
|
||
# 测试用的GET端点(可选)
|
||
@app.get("/")
|
||
async def root():
|
||
return {"message": "Brand Extraction API is running"} |