qwen_brand/app.py

162 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import requests
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
app = FastAPI()
# Ollama配置
OLLAMA_BASE_URL = "http://localhost:11434" # 默认Ollama地址
MODEL_NAME = "qwen3:0.6B" # 或其他你安装的模型
class BrandExtractionRequest(BaseModel):
goods: str
brand_list: str # 字符串格式,如 "apple,samsung,huawei"
class OllamaRequest(BaseModel):
model: str
prompt: str
stream: bool = False
options: Optional[dict] = None
def call_ollama(prompt: str, model: str = MODEL_NAME) -> str:
"""调用Ollama API生成文本"""
url = f"{OLLAMA_BASE_URL}/api/generate"
payload = {
"model": model,
"prompt": prompt,
"stream": False,
"options": {
"temperature": 0.1, # 低温度以获得更确定的输出
"max_tokens": 100, # 限制生成长度
"stop": ["\n", "", ""] # 停止词
}
}
try:
response = requests.post(url, json=payload, timeout=30)
response.raise_for_status()
result = response.json()
return result.get("response", "").strip()
except requests.exceptions.RequestException as e:
raise HTTPException(status_code=500, detail=f"Ollama API调用失败: {str(e)}")
@app.post("/extract_brand/")
async def extract_brand(request: BrandExtractionRequest):
goods = request.goods
brand_set = set([brand.strip() for brand in request.brand_list.split(",")]) # 解析逗号分隔的字符串并去空格
# 构建更清晰的prompt
brands_str = ", ".join(sorted(brand_set))
prompt = f"""
请从商品名称中提取品牌名称。
商品名称:{goods}
可选的品牌列表:{brands_str}
要求:
1. 仔细检查商品名称是否包含上述品牌列表中的任何一个
2. 如果包含,只返回品牌名称本身
3. 如果不包含任何品牌,返回"失败"
4. 不要添加任何解释或额外文本
示例:
商品名称苹果手机iPhone 13 → apple
商品名称三星Galaxy S23 → samsung
商品名称华为Mate 60 → huawei
商品名称:普通智能手机 → 失败
请分析商品名称:"{goods}"
"""
# 调用Ollama
extracted_brand = call_ollama(prompt)
# 清理响应
extracted_brand = extracted_brand.strip()
# 移除可能的引号、句号等标点
extracted_brand = extracted_brand.strip('"').strip("'").strip("").strip()
# 后处理确保返回的品牌在brand_set中否则返回"失败"
if extracted_brand.lower() not in [b.lower() for b in brand_set]:
# 尝试在goods中直接查找品牌名不区分大小写
for brand in brand_set:
if brand.lower() in goods.lower():
extracted_brand = brand
break
else:
extracted_brand = "失败"
return {
"extracted_brand": extracted_brand,
"goods": goods,
"available_brands": list(brand_set)
}
# 添加一个测试Ollama连接的端点
@app.get("/test_ollama")
async def test_ollama():
"""测试Ollama连接和模型可用性"""
try:
# 测试一个简单的请求
test_prompt = "请回答:你好,世界!"
response = call_ollama(test_prompt)
# 同时获取Ollama中的模型列表
models_url = f"{OLLAMA_BASE_URL}/api/tags"
models_response = requests.get(models_url)
models_data = models_response.json() if models_response.status_code == 200 else {}
return {
"status": "success",
"ollama_response": response[:50] + "..." if len(response) > 50 else response,
"available_models": [model.get("name") for model in models_data.get("models", [])],
"current_model": MODEL_NAME
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Ollama测试失败: {str(e)}")
# 添加模型切换端点(可选)
@app.post("/switch_model/")
async def switch_model(model_name: str):
"""动态切换Ollama模型"""
global MODEL_NAME
old_model = MODEL_NAME
# 验证模型是否存在
try:
url = f"{OLLAMA_BASE_URL}/api/tags"
response = requests.get(url)
models = [model.get("name") for model in response.json().get("models", [])]
if model_name not in models:
raise HTTPException(status_code=400, detail=f"模型 '{model_name}' 不存在。可用模型: {models}")
MODEL_NAME = model_name
return {
"status": "success",
"message": f"模型已从 '{old_model}' 切换到 '{model_name}'"
}
except requests.exceptions.RequestException as e:
raise HTTPException(status_code=500, detail=f"无法获取模型列表: {str(e)}")
# 测试用的GET端点
@app.get("/")
async def root():
return {"message": "Brand Extraction API is running (Ollama version)"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5001)