This commit is contained in:
parent
1e3aa9857a
commit
1ff2ef5762
|
|
@ -0,0 +1 @@
|
|||
3.12
|
||||
179
app.py
179
app.py
|
|
@ -1,22 +1,49 @@
|
|||
from fastapi import FastAPI
|
||||
import requests
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from modelscope import AutoModelForCausalLM, AutoTokenizer
|
||||
from typing import List, Optional
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 初始化模型和tokenizer(全局变量,只加载一次)
|
||||
model_name = "Qwen/Qwen3-0.6B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
)
|
||||
# Ollama配置
|
||||
OLLAMA_BASE_URL = "http://localhost:11434" # 默认Ollama地址
|
||||
MODEL_NAME = "qwen3-coder:480b-cloud" # 或其他你安装的模型
|
||||
|
||||
|
||||
class BrandExtractionRequest(BaseModel):
|
||||
goods: str
|
||||
brand_list: str # 改为字符串格式,如 "apple,samsung,huawei"
|
||||
brand_list: str # 字符串格式,如 "apple,samsung,huawei"
|
||||
|
||||
|
||||
class OllamaRequest(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
stream: bool = False
|
||||
options: Optional[dict] = None
|
||||
|
||||
|
||||
def call_ollama(prompt: str, model: str = MODEL_NAME) -> str:
|
||||
"""调用Ollama API生成文本"""
|
||||
url = f"{OLLAMA_BASE_URL}/api/generate"
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": 0.1, # 低温度以获得更确定的输出
|
||||
"max_tokens": 100, # 限制生成长度
|
||||
"stop": ["\n", "。", "!"] # 停止词
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=payload, timeout=30)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return result.get("response", "").strip()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise HTTPException(status_code=500, detail=f"Ollama API调用失败: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/extract_brand/")
|
||||
|
|
@ -24,58 +51,112 @@ async def extract_brand(request: BrandExtractionRequest):
|
|||
goods = request.goods
|
||||
brand_set = set([brand.strip() for brand in request.brand_list.split(",")]) # 解析逗号分隔的字符串并去空格
|
||||
|
||||
# 构建prompt
|
||||
prompt = (
|
||||
f"商品名称:{goods}\n"
|
||||
"-只需要返回品牌名字,去掉多余的描述\n"
|
||||
#f"-请在以下品牌中选择:{brand_set}"
|
||||
)
|
||||
# 构建更清晰的prompt
|
||||
brands_str = ", ".join(sorted(brand_set))
|
||||
|
||||
messages = [
|
||||
{"role": "system",
|
||||
"content": '从商品名称中提取品牌名称,需严格匹配预定义的品牌列表。若未找到匹配品牌,返回 "失败" '},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
prompt = f"""
|
||||
请从商品名称中提取品牌名称。
|
||||
|
||||
# 生成文本
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False
|
||||
)
|
||||
商品名称:{goods}
|
||||
可选的品牌列表:{brands_str}
|
||||
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
generated_ids = model.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=32768
|
||||
)
|
||||
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
|
||||
要求:
|
||||
1. 仔细检查商品名称是否包含上述品牌列表中的任何一个
|
||||
2. 如果包含,只返回品牌名称本身
|
||||
3. 如果不包含任何品牌,返回"失败"
|
||||
4. 不要添加任何解释或额外文本
|
||||
|
||||
# 解析输出
|
||||
try:
|
||||
index = len(output_ids) - output_ids[::-1].index(151668) # 151668是tokenizer中的特殊token
|
||||
except ValueError:
|
||||
index = 0
|
||||
示例:
|
||||
商品名称:苹果手机iPhone 13 → apple
|
||||
商品名称:三星Galaxy S23 → samsung
|
||||
商品名称:华为Mate 60 → huawei
|
||||
商品名称:普通智能手机 → 失败
|
||||
|
||||
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
|
||||
请分析商品名称:"{goods}"
|
||||
"""
|
||||
|
||||
# 调用Ollama
|
||||
extracted_brand = call_ollama(prompt)
|
||||
|
||||
# 清理响应
|
||||
extracted_brand = extracted_brand.strip()
|
||||
|
||||
# 移除可能的引号、句号等标点
|
||||
extracted_brand = extracted_brand.strip('"').strip("'").strip("。").strip()
|
||||
|
||||
# 后处理:确保返回的品牌在brand_set中,否则返回"失败"
|
||||
extracted_brand = content.strip()
|
||||
print(extracted_brand)
|
||||
if extracted_brand not in brand_set:
|
||||
# 尝试在goods中直接查找品牌名(简单匹配)
|
||||
if extracted_brand.lower() not in [b.lower() for b in brand_set]:
|
||||
# 尝试在goods中直接查找品牌名(不区分大小写)
|
||||
for brand in brand_set:
|
||||
if brand in goods:
|
||||
if brand.lower() in goods.lower():
|
||||
extracted_brand = brand
|
||||
break
|
||||
else:
|
||||
extracted_brand = "失败"
|
||||
|
||||
return {"extracted_brand": extracted_brand}
|
||||
return {
|
||||
"extracted_brand": extracted_brand,
|
||||
"goods": goods,
|
||||
"available_brands": list(brand_set)
|
||||
}
|
||||
|
||||
|
||||
# 测试用的GET端点(可选)
|
||||
# 添加一个测试Ollama连接的端点
|
||||
@app.get("/test_ollama")
|
||||
async def test_ollama():
|
||||
"""测试Ollama连接和模型可用性"""
|
||||
try:
|
||||
# 测试一个简单的请求
|
||||
test_prompt = "请回答:你好,世界!"
|
||||
response = call_ollama(test_prompt)
|
||||
|
||||
# 同时获取Ollama中的模型列表
|
||||
models_url = f"{OLLAMA_BASE_URL}/api/tags"
|
||||
models_response = requests.get(models_url)
|
||||
models_data = models_response.json() if models_response.status_code == 200 else {}
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"ollama_response": response[:50] + "..." if len(response) > 50 else response,
|
||||
"available_models": [model.get("name") for model in models_data.get("models", [])],
|
||||
"current_model": MODEL_NAME
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Ollama测试失败: {str(e)}")
|
||||
|
||||
|
||||
# 添加模型切换端点(可选)
|
||||
@app.post("/switch_model/")
|
||||
async def switch_model(model_name: str):
|
||||
"""动态切换Ollama模型"""
|
||||
global MODEL_NAME
|
||||
old_model = MODEL_NAME
|
||||
|
||||
# 验证模型是否存在
|
||||
try:
|
||||
url = f"{OLLAMA_BASE_URL}/api/tags"
|
||||
response = requests.get(url)
|
||||
models = [model.get("name") for model in response.json().get("models", [])]
|
||||
|
||||
if model_name not in models:
|
||||
raise HTTPException(status_code=400, detail=f"模型 '{model_name}' 不存在。可用模型: {models}")
|
||||
|
||||
MODEL_NAME = model_name
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"模型已从 '{old_model}' 切换到 '{model_name}'"
|
||||
}
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise HTTPException(status_code=500, detail=f"无法获取模型列表: {str(e)}")
|
||||
|
||||
|
||||
# 测试用的GET端点
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "Brand Extraction API is running"}
|
||||
return {"message": "Brand Extraction API is running (Ollama version)"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=5001)
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from modelscope import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 初始化模型和tokenizer(全局变量,只加载一次)
|
||||
model_name = "Qwen/Qwen3-0.6B"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
|
||||
class BrandExtractionRequest(BaseModel):
|
||||
goods: str
|
||||
brand_list: str # 改为字符串格式,如 "apple,samsung,huawei"
|
||||
|
||||
|
||||
@app.post("/extract_brand/")
|
||||
async def extract_brand(request: BrandExtractionRequest):
|
||||
goods = request.goods
|
||||
brand_set = set([brand.strip() for brand in request.brand_list.split(",")]) # 解析逗号分隔的字符串并去空格
|
||||
|
||||
# 构建prompt
|
||||
prompt = (
|
||||
f"商品名称:{goods}\n"
|
||||
"-只需要返回品牌名字,去掉多余的描述\n"
|
||||
#f"-请在以下品牌中选择:{brand_set}"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system",
|
||||
"content": '从商品名称中提取品牌名称,需严格匹配预定义的品牌列表。若未找到匹配品牌,返回 "失败" '},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
# 生成文本
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False
|
||||
)
|
||||
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
generated_ids = model.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=32768
|
||||
)
|
||||
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
|
||||
|
||||
# 解析输出
|
||||
try:
|
||||
index = len(output_ids) - output_ids[::-1].index(151668) # 151668是tokenizer中的特殊token
|
||||
except ValueError:
|
||||
index = 0
|
||||
|
||||
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
|
||||
|
||||
# 后处理:确保返回的品牌在brand_set中,否则返回"失败"
|
||||
extracted_brand = content.strip()
|
||||
print(extracted_brand)
|
||||
if extracted_brand not in brand_set:
|
||||
# 尝试在goods中直接查找品牌名(简单匹配)
|
||||
for brand in brand_set:
|
||||
if brand in goods:
|
||||
extracted_brand = brand
|
||||
break
|
||||
else:
|
||||
extracted_brand = "失败"
|
||||
|
||||
return {"extracted_brand": extracted_brand}
|
||||
|
||||
|
||||
# 测试用的GET端点(可选)
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "Brand Extraction API is running"}
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
def main():
|
||||
print("Hello from brandselect!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
[project]
|
||||
name = "brandselect"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = []
|
||||
Loading…
Reference in New Issue