This commit is contained in:
commit
5fc9026222
|
@ -0,0 +1,8 @@
|
||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# Editor-based HTTP Client requests
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
|
@ -0,0 +1,8 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="jdk" jdkName="brand" jdkType="Python SDK" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
|
@ -0,0 +1,14 @@
|
||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<profile version="1.0">
|
||||||
|
<option name="myName" value="Project Default" />
|
||||||
|
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||||
|
<option name="ignoredPackages">
|
||||||
|
<value>
|
||||||
|
<list size="1">
|
||||||
|
<item index="0" class="java.lang.String" itemvalue="betterproto" />
|
||||||
|
</list>
|
||||||
|
</value>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
</profile>
|
||||||
|
</component>
|
|
@ -0,0 +1,6 @@
|
||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
|
@ -0,0 +1,7 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="311" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="brand" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
|
@ -0,0 +1,8 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/brandSelect.iml" filepath="$PROJECT_DIR$/.idea/brandSelect.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
|
@ -0,0 +1,6 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
|
@ -0,0 +1,15 @@
|
||||||
|
FROM python:3.11-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# 复制 requirements.txt 并优先安装依赖(利用 Docker 层缓存)
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
|
||||||
|
|
||||||
|
# 复制整个项目
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
EXPOSE 5001
|
||||||
|
|
||||||
|
# 确保模块名和 Flask 实例名正确(默认是 app:app)
|
||||||
|
CMD ["uvicorn", "app:app", "--reload", "--host", "0.0.0.0", "--port", "5001"]
|
Binary file not shown.
|
@ -0,0 +1,82 @@
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from modelscope import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
import re
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# 初始化模型和tokenizer(全局变量,只加载一次)
|
||||||
|
model_name = "Qwen/Qwen3-0.6B"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype="auto",
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BrandExtractionRequest(BaseModel):
|
||||||
|
goods: str
|
||||||
|
brand_list: str # 改为字符串格式,如 "apple,samsung,huawei"
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/extract_brand/")
|
||||||
|
async def extract_brand(request: BrandExtractionRequest):
|
||||||
|
goods = request.goods
|
||||||
|
brand_set = set([brand.strip() for brand in request.brand_list.split(",")]) # 解析逗号分隔的字符串并去空格
|
||||||
|
|
||||||
|
# 构建prompt
|
||||||
|
prompt = (
|
||||||
|
f"商品名称:{goods}\n"
|
||||||
|
"-只需要返回一个品牌的名字,去掉多余的描述\n"
|
||||||
|
f"-请在以下品牌中选择:{brand_set}"
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system",
|
||||||
|
"content": '从商品名称中提取品牌名称,需严格匹配预定义的品牌列表。若未找到匹配品牌,返回 "失败" '},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
]
|
||||||
|
|
||||||
|
# 生成文本
|
||||||
|
text = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=False
|
||||||
|
)
|
||||||
|
|
||||||
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||||
|
generated_ids = model.generate(
|
||||||
|
**model_inputs,
|
||||||
|
max_new_tokens=32768
|
||||||
|
)
|
||||||
|
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
|
||||||
|
|
||||||
|
# 解析输出
|
||||||
|
try:
|
||||||
|
index = len(output_ids) - output_ids[::-1].index(151668) # 151668是tokenizer中的特殊token
|
||||||
|
except ValueError:
|
||||||
|
index = 0
|
||||||
|
|
||||||
|
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
|
||||||
|
|
||||||
|
# 后处理:确保返回的品牌在brand_set中,否则返回"失败"
|
||||||
|
extracted_brand = content.strip()
|
||||||
|
if extracted_brand not in brand_set:
|
||||||
|
# 尝试在goods中直接查找品牌名(简单匹配)
|
||||||
|
for brand in brand_set:
|
||||||
|
if brand in goods:
|
||||||
|
extracted_brand = brand
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
extracted_brand = "失败"
|
||||||
|
|
||||||
|
return {"extracted_brand": extracted_brand}
|
||||||
|
|
||||||
|
|
||||||
|
# 测试用的GET端点(可选)
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
return {"message": "Brand Extraction API is running"}
|
|
@ -0,0 +1,21 @@
|
||||||
|
modelscope
|
||||||
|
filelock
|
||||||
|
transformers
|
||||||
|
torch<=2.3
|
||||||
|
fastapi
|
||||||
|
numpy<=1.26.4
|
||||||
|
starlette
|
||||||
|
torch<=2.3
|
||||||
|
torchaudio
|
||||||
|
uvicorn
|
||||||
|
addict
|
||||||
|
datasets==2.21.0
|
||||||
|
pillow
|
||||||
|
simplejson
|
||||||
|
sortedcontainers
|
||||||
|
loguru
|
||||||
|
accelerate
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
pip install accelerate -i https://pypi.tuna.tsinghua.edu.cn/simple/
|
Loading…
Reference in New Issue