162 lines
4.9 KiB
Python
162 lines
4.9 KiB
Python
import requests
|
||
from fastapi import FastAPI, HTTPException
|
||
from pydantic import BaseModel
|
||
from typing import List, Optional
|
||
|
||
app = FastAPI()
|
||
|
||
# Ollama配置
|
||
OLLAMA_BASE_URL = "http://localhost:11434" # 默认Ollama地址
|
||
MODEL_NAME = "qwen3:0.6B" # 或其他你安装的模型
|
||
|
||
|
||
class BrandExtractionRequest(BaseModel):
|
||
goods: str
|
||
brand_list: str # 字符串格式,如 "apple,samsung,huawei"
|
||
|
||
|
||
class OllamaRequest(BaseModel):
|
||
model: str
|
||
prompt: str
|
||
stream: bool = False
|
||
options: Optional[dict] = None
|
||
|
||
|
||
def call_ollama(prompt: str, model: str = MODEL_NAME) -> str:
|
||
"""调用Ollama API生成文本"""
|
||
url = f"{OLLAMA_BASE_URL}/api/generate"
|
||
|
||
payload = {
|
||
"model": model,
|
||
"prompt": prompt,
|
||
"stream": False,
|
||
"options": {
|
||
"temperature": 0.1, # 低温度以获得更确定的输出
|
||
"max_tokens": 100, # 限制生成长度
|
||
"stop": ["\n", "。", "!"] # 停止词
|
||
}
|
||
}
|
||
|
||
try:
|
||
response = requests.post(url, json=payload, timeout=30)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
return result.get("response", "").strip()
|
||
except requests.exceptions.RequestException as e:
|
||
raise HTTPException(status_code=500, detail=f"Ollama API调用失败: {str(e)}")
|
||
|
||
|
||
@app.post("/extract_brand/")
|
||
async def extract_brand(request: BrandExtractionRequest):
|
||
goods = request.goods
|
||
brand_set = set([brand.strip() for brand in request.brand_list.split(",")]) # 解析逗号分隔的字符串并去空格
|
||
|
||
# 构建更清晰的prompt
|
||
brands_str = ", ".join(sorted(brand_set))
|
||
|
||
prompt = f"""
|
||
请从商品名称中提取品牌名称。
|
||
|
||
商品名称:{goods}
|
||
可选的品牌列表:{brands_str}
|
||
|
||
要求:
|
||
1. 仔细检查商品名称是否包含上述品牌列表中的任何一个
|
||
2. 如果包含,只返回品牌名称本身
|
||
3. 如果不包含任何品牌,返回"失败"
|
||
4. 不要添加任何解释或额外文本
|
||
|
||
示例:
|
||
商品名称:苹果手机iPhone 13 → apple
|
||
商品名称:三星Galaxy S23 → samsung
|
||
商品名称:华为Mate 60 → huawei
|
||
商品名称:普通智能手机 → 失败
|
||
|
||
请分析商品名称:"{goods}"
|
||
"""
|
||
|
||
# 调用Ollama
|
||
extracted_brand = call_ollama(prompt)
|
||
|
||
# 清理响应
|
||
extracted_brand = extracted_brand.strip()
|
||
|
||
# 移除可能的引号、句号等标点
|
||
extracted_brand = extracted_brand.strip('"').strip("'").strip("。").strip()
|
||
|
||
# 后处理:确保返回的品牌在brand_set中,否则返回"失败"
|
||
if extracted_brand.lower() not in [b.lower() for b in brand_set]:
|
||
# 尝试在goods中直接查找品牌名(不区分大小写)
|
||
for brand in brand_set:
|
||
if brand.lower() in goods.lower():
|
||
extracted_brand = brand
|
||
break
|
||
else:
|
||
extracted_brand = "失败"
|
||
|
||
return {
|
||
"extracted_brand": extracted_brand,
|
||
"goods": goods,
|
||
"available_brands": list(brand_set)
|
||
}
|
||
|
||
|
||
# 添加一个测试Ollama连接的端点
|
||
@app.get("/test_ollama")
|
||
async def test_ollama():
|
||
"""测试Ollama连接和模型可用性"""
|
||
try:
|
||
# 测试一个简单的请求
|
||
test_prompt = "请回答:你好,世界!"
|
||
response = call_ollama(test_prompt)
|
||
|
||
# 同时获取Ollama中的模型列表
|
||
models_url = f"{OLLAMA_BASE_URL}/api/tags"
|
||
models_response = requests.get(models_url)
|
||
models_data = models_response.json() if models_response.status_code == 200 else {}
|
||
|
||
return {
|
||
"status": "success",
|
||
"ollama_response": response[:50] + "..." if len(response) > 50 else response,
|
||
"available_models": [model.get("name") for model in models_data.get("models", [])],
|
||
"current_model": MODEL_NAME
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"Ollama测试失败: {str(e)}")
|
||
|
||
|
||
# 添加模型切换端点(可选)
|
||
@app.post("/switch_model/")
|
||
async def switch_model(model_name: str):
|
||
"""动态切换Ollama模型"""
|
||
global MODEL_NAME
|
||
old_model = MODEL_NAME
|
||
|
||
# 验证模型是否存在
|
||
try:
|
||
url = f"{OLLAMA_BASE_URL}/api/tags"
|
||
response = requests.get(url)
|
||
models = [model.get("name") for model in response.json().get("models", [])]
|
||
|
||
if model_name not in models:
|
||
raise HTTPException(status_code=400, detail=f"模型 '{model_name}' 不存在。可用模型: {models}")
|
||
|
||
MODEL_NAME = model_name
|
||
return {
|
||
"status": "success",
|
||
"message": f"模型已从 '{old_model}' 切换到 '{model_name}'"
|
||
}
|
||
except requests.exceptions.RequestException as e:
|
||
raise HTTPException(status_code=500, detail=f"无法获取模型列表: {str(e)}")
|
||
|
||
|
||
# 测试用的GET端点
|
||
@app.get("/")
|
||
async def root():
|
||
return {"message": "Brand Extraction API is running (Ollama version)"}
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
|
||
uvicorn.run(app, host="0.0.0.0", port=5001) |