import requests from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Optional app = FastAPI() # Ollama配置 OLLAMA_BASE_URL = "http://localhost:11434" # 默认Ollama地址 MODEL_NAME = "qwen:0.5b" # 或其他你安装的模型 class BrandExtractionRequest(BaseModel): goods: str brand_list: str # 字符串格式,如 "apple,samsung,huawei" class OllamaRequest(BaseModel): model: str prompt: str stream: bool = False options: Optional[dict] = None def call_ollama(prompt: str, model: str = MODEL_NAME) -> str: """调用Ollama API生成文本""" url = f"{OLLAMA_BASE_URL}/api/generate" payload = { "model": model, "prompt": prompt, "stream": False, "options": { "temperature": 0.1, # 低温度以获得更确定的输出 "max_tokens": 100, # 限制生成长度 "stop": ["\n", "。", "!"] # 停止词 } } try: response = requests.post(url, json=payload, timeout=30) response.raise_for_status() result = response.json() return result.get("response", "").strip() except requests.exceptions.RequestException as e: raise HTTPException(status_code=500, detail=f"Ollama API调用失败: {str(e)}") @app.post("/extract_brand/") async def extract_brand(request: BrandExtractionRequest): goods = request.goods brand_set = set([brand.strip() for brand in request.brand_list.split(",")]) # 解析逗号分隔的字符串并去空格 # 构建更清晰的prompt brands_str = ", ".join(sorted(brand_set)) prompt = f""" 请从商品名称中提取品牌名称。 商品名称:{goods} 可选的品牌列表:{brands_str} 要求: 1. 仔细检查商品名称是否包含上述品牌列表中的任何一个 2. 如果包含,只返回品牌名称本身 3. 如果不包含任何品牌,返回"失败" 4. 不要添加任何解释或额外文本 示例: 商品名称:苹果手机iPhone 13 → apple 商品名称:三星Galaxy S23 → samsung 商品名称:华为Mate 60 → huawei 商品名称:普通智能手机 → 失败 请分析商品名称:"{goods}" """ # 调用Ollama extracted_brand = call_ollama(prompt) # 清理响应 extracted_brand = extracted_brand.strip() # 移除可能的引号、句号等标点 extracted_brand = extracted_brand.strip('"').strip("'").strip("。").strip() # 后处理:确保返回的品牌在brand_set中,否则返回"失败" if extracted_brand.lower() not in [b.lower() for b in brand_set]: # 尝试在goods中直接查找品牌名(不区分大小写) for brand in brand_set: if brand.lower() in goods.lower(): extracted_brand = brand break else: extracted_brand = "失败" return { "extracted_brand": extracted_brand, "goods": goods, "available_brands": list(brand_set) } # 添加一个测试Ollama连接的端点 @app.get("/test_ollama") async def test_ollama(): """测试Ollama连接和模型可用性""" try: # 测试一个简单的请求 test_prompt = "请回答:你好,世界!" response = call_ollama(test_prompt) # 同时获取Ollama中的模型列表 models_url = f"{OLLAMA_BASE_URL}/api/tags" models_response = requests.get(models_url) models_data = models_response.json() if models_response.status_code == 200 else {} return { "status": "success", "ollama_response": response[:50] + "..." if len(response) > 50 else response, "available_models": [model.get("name") for model in models_data.get("models", [])], "current_model": MODEL_NAME } except Exception as e: raise HTTPException(status_code=500, detail=f"Ollama测试失败: {str(e)}") # 添加模型切换端点(可选) @app.post("/switch_model/") async def switch_model(model_name: str): """动态切换Ollama模型""" global MODEL_NAME old_model = MODEL_NAME # 验证模型是否存在 try: url = f"{OLLAMA_BASE_URL}/api/tags" response = requests.get(url) models = [model.get("name") for model in response.json().get("models", [])] if model_name not in models: raise HTTPException(status_code=400, detail=f"模型 '{model_name}' 不存在。可用模型: {models}") MODEL_NAME = model_name return { "status": "success", "message": f"模型已从 '{old_model}' 切换到 '{model_name}'" } except requests.exceptions.RequestException as e: raise HTTPException(status_code=500, detail=f"无法获取模型列表: {str(e)}") # 测试用的GET端点 @app.get("/") async def root(): return {"message": "Brand Extraction API is running (Ollama version)"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=5001)