diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..e4fba21 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/app.py b/app.py index 0614392..adb758c 100644 --- a/app.py +++ b/app.py @@ -1,22 +1,49 @@ -from fastapi import FastAPI +import requests +from fastapi import FastAPI, HTTPException from pydantic import BaseModel -from modelscope import AutoModelForCausalLM, AutoTokenizer +from typing import List, Optional app = FastAPI() -# 初始化模型和tokenizer(全局变量,只加载一次) -model_name = "Qwen/Qwen3-0.6B" -tokenizer = AutoTokenizer.from_pretrained(model_name) -model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype="auto", - device_map="auto", -) +# Ollama配置 +OLLAMA_BASE_URL = "http://localhost:11434" # 默认Ollama地址 +MODEL_NAME = "qwen3-coder:480b-cloud" # 或其他你安装的模型 class BrandExtractionRequest(BaseModel): goods: str - brand_list: str # 改为字符串格式,如 "apple,samsung,huawei" + brand_list: str # 字符串格式,如 "apple,samsung,huawei" + + +class OllamaRequest(BaseModel): + model: str + prompt: str + stream: bool = False + options: Optional[dict] = None + + +def call_ollama(prompt: str, model: str = MODEL_NAME) -> str: + """调用Ollama API生成文本""" + url = f"{OLLAMA_BASE_URL}/api/generate" + + payload = { + "model": model, + "prompt": prompt, + "stream": False, + "options": { + "temperature": 0.1, # 低温度以获得更确定的输出 + "max_tokens": 100, # 限制生成长度 + "stop": ["\n", "。", "!"] # 停止词 + } + } + + try: + response = requests.post(url, json=payload, timeout=30) + response.raise_for_status() + result = response.json() + return result.get("response", "").strip() + except requests.exceptions.RequestException as e: + raise HTTPException(status_code=500, detail=f"Ollama API调用失败: {str(e)}") @app.post("/extract_brand/") @@ -24,58 +51,112 @@ async def extract_brand(request: BrandExtractionRequest): goods = request.goods brand_set = set([brand.strip() for brand in request.brand_list.split(",")]) # 解析逗号分隔的字符串并去空格 - # 构建prompt - prompt = ( - f"商品名称:{goods}\n" - "-只需要返回品牌名字,去掉多余的描述\n" - #f"-请在以下品牌中选择:{brand_set}" - ) + # 构建更清晰的prompt + brands_str = ", ".join(sorted(brand_set)) - messages = [ - {"role": "system", - "content": '从商品名称中提取品牌名称,需严格匹配预定义的品牌列表。若未找到匹配品牌,返回 "失败" '}, - {"role": "user", "content": prompt} - ] + prompt = f""" +请从商品名称中提取品牌名称。 - # 生成文本 - text = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True, - enable_thinking=False - ) +商品名称:{goods} +可选的品牌列表:{brands_str} - model_inputs = tokenizer([text], return_tensors="pt").to(model.device) - generated_ids = model.generate( - **model_inputs, - max_new_tokens=32768 - ) - output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() +要求: +1. 仔细检查商品名称是否包含上述品牌列表中的任何一个 +2. 如果包含,只返回品牌名称本身 +3. 如果不包含任何品牌,返回"失败" +4. 不要添加任何解释或额外文本 - # 解析输出 - try: - index = len(output_ids) - output_ids[::-1].index(151668) # 151668是tokenizer中的特殊token - except ValueError: - index = 0 +示例: +商品名称:苹果手机iPhone 13 → apple +商品名称:三星Galaxy S23 → samsung +商品名称:华为Mate 60 → huawei +商品名称:普通智能手机 → 失败 - content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n") +请分析商品名称:"{goods}" + """ + + # 调用Ollama + extracted_brand = call_ollama(prompt) + + # 清理响应 + extracted_brand = extracted_brand.strip() + + # 移除可能的引号、句号等标点 + extracted_brand = extracted_brand.strip('"').strip("'").strip("。").strip() # 后处理:确保返回的品牌在brand_set中,否则返回"失败" - extracted_brand = content.strip() - print(extracted_brand) - if extracted_brand not in brand_set: - # 尝试在goods中直接查找品牌名(简单匹配) + if extracted_brand.lower() not in [b.lower() for b in brand_set]: + # 尝试在goods中直接查找品牌名(不区分大小写) for brand in brand_set: - if brand in goods: + if brand.lower() in goods.lower(): extracted_brand = brand break else: extracted_brand = "失败" - return {"extracted_brand": extracted_brand} + return { + "extracted_brand": extracted_brand, + "goods": goods, + "available_brands": list(brand_set) + } -# 测试用的GET端点(可选) +# 添加一个测试Ollama连接的端点 +@app.get("/test_ollama") +async def test_ollama(): + """测试Ollama连接和模型可用性""" + try: + # 测试一个简单的请求 + test_prompt = "请回答:你好,世界!" + response = call_ollama(test_prompt) + + # 同时获取Ollama中的模型列表 + models_url = f"{OLLAMA_BASE_URL}/api/tags" + models_response = requests.get(models_url) + models_data = models_response.json() if models_response.status_code == 200 else {} + + return { + "status": "success", + "ollama_response": response[:50] + "..." if len(response) > 50 else response, + "available_models": [model.get("name") for model in models_data.get("models", [])], + "current_model": MODEL_NAME + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"Ollama测试失败: {str(e)}") + + +# 添加模型切换端点(可选) +@app.post("/switch_model/") +async def switch_model(model_name: str): + """动态切换Ollama模型""" + global MODEL_NAME + old_model = MODEL_NAME + + # 验证模型是否存在 + try: + url = f"{OLLAMA_BASE_URL}/api/tags" + response = requests.get(url) + models = [model.get("name") for model in response.json().get("models", [])] + + if model_name not in models: + raise HTTPException(status_code=400, detail=f"模型 '{model_name}' 不存在。可用模型: {models}") + + MODEL_NAME = model_name + return { + "status": "success", + "message": f"模型已从 '{old_model}' 切换到 '{model_name}'" + } + except requests.exceptions.RequestException as e: + raise HTTPException(status_code=500, detail=f"无法获取模型列表: {str(e)}") + + +# 测试用的GET端点 @app.get("/") async def root(): - return {"message": "Brand Extraction API is running"} \ No newline at end of file + return {"message": "Brand Extraction API is running (Ollama version)"} + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=5001) \ No newline at end of file diff --git a/app.py.bak b/app.py.bak new file mode 100644 index 0000000..0614392 --- /dev/null +++ b/app.py.bak @@ -0,0 +1,81 @@ +from fastapi import FastAPI +from pydantic import BaseModel +from modelscope import AutoModelForCausalLM, AutoTokenizer + +app = FastAPI() + +# 初始化模型和tokenizer(全局变量,只加载一次) +model_name = "Qwen/Qwen3-0.6B" +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype="auto", + device_map="auto", +) + + +class BrandExtractionRequest(BaseModel): + goods: str + brand_list: str # 改为字符串格式,如 "apple,samsung,huawei" + + +@app.post("/extract_brand/") +async def extract_brand(request: BrandExtractionRequest): + goods = request.goods + brand_set = set([brand.strip() for brand in request.brand_list.split(",")]) # 解析逗号分隔的字符串并去空格 + + # 构建prompt + prompt = ( + f"商品名称:{goods}\n" + "-只需要返回品牌名字,去掉多余的描述\n" + #f"-请在以下品牌中选择:{brand_set}" + ) + + messages = [ + {"role": "system", + "content": '从商品名称中提取品牌名称,需严格匹配预定义的品牌列表。若未找到匹配品牌,返回 "失败" '}, + {"role": "user", "content": prompt} + ] + + # 生成文本 + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False + ) + + model_inputs = tokenizer([text], return_tensors="pt").to(model.device) + generated_ids = model.generate( + **model_inputs, + max_new_tokens=32768 + ) + output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() + + # 解析输出 + try: + index = len(output_ids) - output_ids[::-1].index(151668) # 151668是tokenizer中的特殊token + except ValueError: + index = 0 + + content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n") + + # 后处理:确保返回的品牌在brand_set中,否则返回"失败" + extracted_brand = content.strip() + print(extracted_brand) + if extracted_brand not in brand_set: + # 尝试在goods中直接查找品牌名(简单匹配) + for brand in brand_set: + if brand in goods: + extracted_brand = brand + break + else: + extracted_brand = "失败" + + return {"extracted_brand": extracted_brand} + + +# 测试用的GET端点(可选) +@app.get("/") +async def root(): + return {"message": "Brand Extraction API is running"} \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..d0f1e89 --- /dev/null +++ b/main.py @@ -0,0 +1,6 @@ +def main(): + print("Hello from brandselect!") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b5919ff --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,7 @@ +[project] +name = "brandselect" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [] diff --git a/uv.lock b/uv.lock new file mode 100644 index 0000000..52c78c1 --- /dev/null +++ b/uv.lock @@ -0,0 +1,8 @@ +version = 1 +revision = 3 +requires-python = ">=3.12" + +[[package]] +name = "brandselect" +version = "0.1.0" +source = { virtual = "." }