This commit is contained in:
renzhiyuan 2025-08-13 11:45:01 +08:00
commit 5fc9026222
11 changed files with 175 additions and 0 deletions

8
.idea/.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

8
.idea/brandSelect.iml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="brand" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@ -0,0 +1,14 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="1">
<item index="0" class="java.lang.String" itemvalue="betterproto" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

View File

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

7
.idea/misc.xml Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="311" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="brand" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/brandSelect.iml" filepath="$PROJECT_DIR$/.idea/brandSelect.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

15
Dockerfile Normal file
View File

@ -0,0 +1,15 @@
FROM python:3.11-slim
WORKDIR /app
# 复制 requirements.txt 并优先安装依赖(利用 Docker 层缓存)
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
# 复制整个项目
COPY . .
EXPOSE 5001
# 确保模块名和 Flask 实例名正确(默认是 app:app
CMD ["uvicorn", "app:app", "--reload", "--host", "0.0.0.0", "--port", "5001"]

Binary file not shown.

82
app.py Normal file
View File

@ -0,0 +1,82 @@
from fastapi import FastAPI
from pydantic import BaseModel
from modelscope import AutoModelForCausalLM, AutoTokenizer
import re
from typing import List
app = FastAPI()
# 初始化模型和tokenizer全局变量只加载一次
model_name = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
)
class BrandExtractionRequest(BaseModel):
goods: str
brand_list: str # 改为字符串格式,如 "apple,samsung,huawei"
@app.post("/extract_brand/")
async def extract_brand(request: BrandExtractionRequest):
goods = request.goods
brand_set = set([brand.strip() for brand in request.brand_list.split(",")]) # 解析逗号分隔的字符串并去空格
# 构建prompt
prompt = (
f"商品名称:{goods}\n"
"-只需要返回一个品牌的名字,去掉多余的描述\n"
f"-请在以下品牌中选择:{brand_set}"
)
messages = [
{"role": "system",
"content": '从商品名称中提取品牌名称,需严格匹配预定义的品牌列表。若未找到匹配品牌,返回 "失败" '},
{"role": "user", "content": prompt}
]
# 生成文本
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=32768
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
# 解析输出
try:
index = len(output_ids) - output_ids[::-1].index(151668) # 151668是tokenizer中的特殊token
except ValueError:
index = 0
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
# 后处理确保返回的品牌在brand_set中否则返回"失败"
extracted_brand = content.strip()
if extracted_brand not in brand_set:
# 尝试在goods中直接查找品牌名简单匹配
for brand in brand_set:
if brand in goods:
extracted_brand = brand
break
else:
extracted_brand = "失败"
return {"extracted_brand": extracted_brand}
# 测试用的GET端点可选
@app.get("/")
async def root():
return {"message": "Brand Extraction API is running"}

21
requirements.txt Normal file
View File

@ -0,0 +1,21 @@
modelscope
filelock
transformers
torch<=2.3
fastapi
numpy<=1.26.4
starlette
torch<=2.3
torchaudio
uvicorn
addict
datasets==2.21.0
pillow
simplejson
sortedcontainers
loguru
accelerate
pip install accelerate -i https://pypi.tuna.tsinghua.edu.cn/simple/