This commit is contained in:
renzhiyuan 2026-01-07 17:24:30 +08:00
parent 1e3aa9857a
commit 1ff2ef5762
7 changed files with 233 additions and 49 deletions

1
.python-version Normal file
View File

@ -0,0 +1 @@
3.12

0
README.md Normal file
View File

179
app.py
View File

@ -1,22 +1,49 @@
from fastapi import FastAPI
import requests
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from modelscope import AutoModelForCausalLM, AutoTokenizer
from typing import List, Optional
app = FastAPI()
# 初始化模型和tokenizer全局变量只加载一次
model_name = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
)
# Ollama配置
OLLAMA_BASE_URL = "http://localhost:11434" # 默认Ollama地址
MODEL_NAME = "qwen3-coder:480b-cloud" # 或其他你安装的模型
class BrandExtractionRequest(BaseModel):
goods: str
brand_list: str # 改为字符串格式,如 "apple,samsung,huawei"
brand_list: str # 字符串格式,如 "apple,samsung,huawei"
class OllamaRequest(BaseModel):
model: str
prompt: str
stream: bool = False
options: Optional[dict] = None
def call_ollama(prompt: str, model: str = MODEL_NAME) -> str:
"""调用Ollama API生成文本"""
url = f"{OLLAMA_BASE_URL}/api/generate"
payload = {
"model": model,
"prompt": prompt,
"stream": False,
"options": {
"temperature": 0.1, # 低温度以获得更确定的输出
"max_tokens": 100, # 限制生成长度
"stop": ["\n", "", ""] # 停止词
}
}
try:
response = requests.post(url, json=payload, timeout=30)
response.raise_for_status()
result = response.json()
return result.get("response", "").strip()
except requests.exceptions.RequestException as e:
raise HTTPException(status_code=500, detail=f"Ollama API调用失败: {str(e)}")
@app.post("/extract_brand/")
@ -24,58 +51,112 @@ async def extract_brand(request: BrandExtractionRequest):
goods = request.goods
brand_set = set([brand.strip() for brand in request.brand_list.split(",")]) # 解析逗号分隔的字符串并去空格
# 构建prompt
prompt = (
f"商品名称:{goods}\n"
"-只需要返回品牌名字,去掉多余的描述\n"
#f"-请在以下品牌中选择:{brand_set}"
)
# 构建更清晰的prompt
brands_str = ", ".join(sorted(brand_set))
messages = [
{"role": "system",
"content": '从商品名称中提取品牌名称,需严格匹配预定义的品牌列表。若未找到匹配品牌,返回 "失败" '},
{"role": "user", "content": prompt}
]
prompt = f"""
请从商品名称中提取品牌名称
# 生成文本
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
商品名称{goods}
可选的品牌列表{brands_str}
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=32768
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
要求
1. 仔细检查商品名称是否包含上述品牌列表中的任何一个
2. 如果包含只返回品牌名称本身
3. 如果不包含任何品牌返回"失败"
4. 不要添加任何解释或额外文本
# 解析输出
try:
index = len(output_ids) - output_ids[::-1].index(151668) # 151668是tokenizer中的特殊token
except ValueError:
index = 0
示例
商品名称苹果手机iPhone 13 apple
商品名称三星Galaxy S23 samsung
商品名称华为Mate 60 huawei
商品名称普通智能手机 失败
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
请分析商品名称"{goods}"
"""
# 调用Ollama
extracted_brand = call_ollama(prompt)
# 清理响应
extracted_brand = extracted_brand.strip()
# 移除可能的引号、句号等标点
extracted_brand = extracted_brand.strip('"').strip("'").strip("").strip()
# 后处理确保返回的品牌在brand_set中否则返回"失败"
extracted_brand = content.strip()
print(extracted_brand)
if extracted_brand not in brand_set:
# 尝试在goods中直接查找品牌名简单匹配
if extracted_brand.lower() not in [b.lower() for b in brand_set]:
# 尝试在goods中直接查找品牌名不区分大小写
for brand in brand_set:
if brand in goods:
if brand.lower() in goods.lower():
extracted_brand = brand
break
else:
extracted_brand = "失败"
return {"extracted_brand": extracted_brand}
return {
"extracted_brand": extracted_brand,
"goods": goods,
"available_brands": list(brand_set)
}
# 测试用的GET端点可选
# 添加一个测试Ollama连接的端点
@app.get("/test_ollama")
async def test_ollama():
"""测试Ollama连接和模型可用性"""
try:
# 测试一个简单的请求
test_prompt = "请回答:你好,世界!"
response = call_ollama(test_prompt)
# 同时获取Ollama中的模型列表
models_url = f"{OLLAMA_BASE_URL}/api/tags"
models_response = requests.get(models_url)
models_data = models_response.json() if models_response.status_code == 200 else {}
return {
"status": "success",
"ollama_response": response[:50] + "..." if len(response) > 50 else response,
"available_models": [model.get("name") for model in models_data.get("models", [])],
"current_model": MODEL_NAME
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Ollama测试失败: {str(e)}")
# 添加模型切换端点(可选)
@app.post("/switch_model/")
async def switch_model(model_name: str):
"""动态切换Ollama模型"""
global MODEL_NAME
old_model = MODEL_NAME
# 验证模型是否存在
try:
url = f"{OLLAMA_BASE_URL}/api/tags"
response = requests.get(url)
models = [model.get("name") for model in response.json().get("models", [])]
if model_name not in models:
raise HTTPException(status_code=400, detail=f"模型 '{model_name}' 不存在。可用模型: {models}")
MODEL_NAME = model_name
return {
"status": "success",
"message": f"模型已从 '{old_model}' 切换到 '{model_name}'"
}
except requests.exceptions.RequestException as e:
raise HTTPException(status_code=500, detail=f"无法获取模型列表: {str(e)}")
# 测试用的GET端点
@app.get("/")
async def root():
return {"message": "Brand Extraction API is running"}
return {"message": "Brand Extraction API is running (Ollama version)"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5001)

81
app.py.bak Normal file
View File

@ -0,0 +1,81 @@
from fastapi import FastAPI
from pydantic import BaseModel
from modelscope import AutoModelForCausalLM, AutoTokenizer
app = FastAPI()
# 初始化模型和tokenizer全局变量只加载一次
model_name = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
)
class BrandExtractionRequest(BaseModel):
goods: str
brand_list: str # 改为字符串格式,如 "apple,samsung,huawei"
@app.post("/extract_brand/")
async def extract_brand(request: BrandExtractionRequest):
goods = request.goods
brand_set = set([brand.strip() for brand in request.brand_list.split(",")]) # 解析逗号分隔的字符串并去空格
# 构建prompt
prompt = (
f"商品名称:{goods}\n"
"-只需要返回品牌名字,去掉多余的描述\n"
#f"-请在以下品牌中选择:{brand_set}"
)
messages = [
{"role": "system",
"content": '从商品名称中提取品牌名称,需严格匹配预定义的品牌列表。若未找到匹配品牌,返回 "失败" '},
{"role": "user", "content": prompt}
]
# 生成文本
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=32768
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
# 解析输出
try:
index = len(output_ids) - output_ids[::-1].index(151668) # 151668是tokenizer中的特殊token
except ValueError:
index = 0
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
# 后处理确保返回的品牌在brand_set中否则返回"失败"
extracted_brand = content.strip()
print(extracted_brand)
if extracted_brand not in brand_set:
# 尝试在goods中直接查找品牌名简单匹配
for brand in brand_set:
if brand in goods:
extracted_brand = brand
break
else:
extracted_brand = "失败"
return {"extracted_brand": extracted_brand}
# 测试用的GET端点可选
@app.get("/")
async def root():
return {"message": "Brand Extraction API is running"}

6
main.py Normal file
View File

@ -0,0 +1,6 @@
def main():
print("Hello from brandselect!")
if __name__ == "__main__":
main()

7
pyproject.toml Normal file
View File

@ -0,0 +1,7 @@
[project]
name = "brandselect"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = []

8
uv.lock Normal file
View File

@ -0,0 +1,8 @@
version = 1
revision = 3
requires-python = ">=3.12"
[[package]]
name = "brandselect"
version = "0.1.0"
source = { virtual = "." }