This commit is contained in:
commit
5fc9026222
|
@ -0,0 +1,8 @@
|
|||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="brand" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
|
@ -0,0 +1,14 @@
|
|||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="1">
|
||||
<item index="0" class="java.lang.String" itemvalue="betterproto" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
|
@ -0,0 +1,6 @@
|
|||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
|
@ -0,0 +1,7 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="311" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="brand" project-jdk-type="Python SDK" />
|
||||
</project>
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/brandSelect.iml" filepath="$PROJECT_DIR$/.idea/brandSelect.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
|
@ -0,0 +1,6 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
|
@ -0,0 +1,15 @@
|
|||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 复制 requirements.txt 并优先安装依赖(利用 Docker 层缓存)
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
|
||||
|
||||
# 复制整个项目
|
||||
COPY . .
|
||||
|
||||
EXPOSE 5001
|
||||
|
||||
# 确保模块名和 Flask 实例名正确(默认是 app:app)
|
||||
CMD ["uvicorn", "app:app", "--reload", "--host", "0.0.0.0", "--port", "5001"]
|
Binary file not shown.
|
@ -0,0 +1,82 @@
|
|||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from modelscope import AutoModelForCausalLM, AutoTokenizer
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 初始化模型和tokenizer(全局变量,只加载一次)
|
||||
model_name = "Qwen/Qwen3-0.6B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
|
||||
class BrandExtractionRequest(BaseModel):
|
||||
goods: str
|
||||
brand_list: str # 改为字符串格式,如 "apple,samsung,huawei"
|
||||
|
||||
|
||||
@app.post("/extract_brand/")
|
||||
async def extract_brand(request: BrandExtractionRequest):
|
||||
goods = request.goods
|
||||
brand_set = set([brand.strip() for brand in request.brand_list.split(",")]) # 解析逗号分隔的字符串并去空格
|
||||
|
||||
# 构建prompt
|
||||
prompt = (
|
||||
f"商品名称:{goods}\n"
|
||||
"-只需要返回一个品牌的名字,去掉多余的描述\n"
|
||||
f"-请在以下品牌中选择:{brand_set}"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system",
|
||||
"content": '从商品名称中提取品牌名称,需严格匹配预定义的品牌列表。若未找到匹配品牌,返回 "失败" '},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
# 生成文本
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False
|
||||
)
|
||||
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
generated_ids = model.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=32768
|
||||
)
|
||||
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
|
||||
|
||||
# 解析输出
|
||||
try:
|
||||
index = len(output_ids) - output_ids[::-1].index(151668) # 151668是tokenizer中的特殊token
|
||||
except ValueError:
|
||||
index = 0
|
||||
|
||||
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
|
||||
|
||||
# 后处理:确保返回的品牌在brand_set中,否则返回"失败"
|
||||
extracted_brand = content.strip()
|
||||
if extracted_brand not in brand_set:
|
||||
# 尝试在goods中直接查找品牌名(简单匹配)
|
||||
for brand in brand_set:
|
||||
if brand in goods:
|
||||
extracted_brand = brand
|
||||
break
|
||||
else:
|
||||
extracted_brand = "失败"
|
||||
|
||||
return {"extracted_brand": extracted_brand}
|
||||
|
||||
|
||||
# 测试用的GET端点(可选)
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "Brand Extraction API is running"}
|
|
@ -0,0 +1,21 @@
|
|||
modelscope
|
||||
filelock
|
||||
transformers
|
||||
torch<=2.3
|
||||
fastapi
|
||||
numpy<=1.26.4
|
||||
starlette
|
||||
torch<=2.3
|
||||
torchaudio
|
||||
uvicorn
|
||||
addict
|
||||
datasets==2.21.0
|
||||
pillow
|
||||
simplejson
|
||||
sortedcontainers
|
||||
loguru
|
||||
accelerate
|
||||
|
||||
|
||||
|
||||
pip install accelerate -i https://pypi.tuna.tsinghua.edu.cn/simple/
|
Loading…
Reference in New Issue