This commit is contained in:
parent
1e3aa9857a
commit
1ff2ef5762
|
|
@ -0,0 +1 @@
|
||||||
|
3.12
|
||||||
179
app.py
179
app.py
|
|
@ -1,22 +1,49 @@
|
||||||
from fastapi import FastAPI
|
import requests
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from modelscope import AutoModelForCausalLM, AutoTokenizer
|
from typing import List, Optional
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
# 初始化模型和tokenizer(全局变量,只加载一次)
|
# Ollama配置
|
||||||
model_name = "Qwen/Qwen3-0.6B"
|
OLLAMA_BASE_URL = "http://localhost:11434" # 默认Ollama地址
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
MODEL_NAME = "qwen3-coder:480b-cloud" # 或其他你安装的模型
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_name,
|
|
||||||
torch_dtype="auto",
|
|
||||||
device_map="auto",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BrandExtractionRequest(BaseModel):
|
class BrandExtractionRequest(BaseModel):
|
||||||
goods: str
|
goods: str
|
||||||
brand_list: str # 改为字符串格式,如 "apple,samsung,huawei"
|
brand_list: str # 字符串格式,如 "apple,samsung,huawei"
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
prompt: str
|
||||||
|
stream: bool = False
|
||||||
|
options: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
def call_ollama(prompt: str, model: str = MODEL_NAME) -> str:
|
||||||
|
"""调用Ollama API生成文本"""
|
||||||
|
url = f"{OLLAMA_BASE_URL}/api/generate"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"stream": False,
|
||||||
|
"options": {
|
||||||
|
"temperature": 0.1, # 低温度以获得更确定的输出
|
||||||
|
"max_tokens": 100, # 限制生成长度
|
||||||
|
"stop": ["\n", "。", "!"] # 停止词
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(url, json=payload, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
return result.get("response", "").strip()
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Ollama API调用失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@app.post("/extract_brand/")
|
@app.post("/extract_brand/")
|
||||||
|
|
@ -24,58 +51,112 @@ async def extract_brand(request: BrandExtractionRequest):
|
||||||
goods = request.goods
|
goods = request.goods
|
||||||
brand_set = set([brand.strip() for brand in request.brand_list.split(",")]) # 解析逗号分隔的字符串并去空格
|
brand_set = set([brand.strip() for brand in request.brand_list.split(",")]) # 解析逗号分隔的字符串并去空格
|
||||||
|
|
||||||
# 构建prompt
|
# 构建更清晰的prompt
|
||||||
prompt = (
|
brands_str = ", ".join(sorted(brand_set))
|
||||||
f"商品名称:{goods}\n"
|
|
||||||
"-只需要返回品牌名字,去掉多余的描述\n"
|
|
||||||
#f"-请在以下品牌中选择:{brand_set}"
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
prompt = f"""
|
||||||
{"role": "system",
|
请从商品名称中提取品牌名称。
|
||||||
"content": '从商品名称中提取品牌名称,需严格匹配预定义的品牌列表。若未找到匹配品牌,返回 "失败" '},
|
|
||||||
{"role": "user", "content": prompt}
|
|
||||||
]
|
|
||||||
|
|
||||||
# 生成文本
|
商品名称:{goods}
|
||||||
text = tokenizer.apply_chat_template(
|
可选的品牌列表:{brands_str}
|
||||||
messages,
|
|
||||||
tokenize=False,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
enable_thinking=False
|
|
||||||
)
|
|
||||||
|
|
||||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
要求:
|
||||||
generated_ids = model.generate(
|
1. 仔细检查商品名称是否包含上述品牌列表中的任何一个
|
||||||
**model_inputs,
|
2. 如果包含,只返回品牌名称本身
|
||||||
max_new_tokens=32768
|
3. 如果不包含任何品牌,返回"失败"
|
||||||
)
|
4. 不要添加任何解释或额外文本
|
||||||
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
|
|
||||||
|
|
||||||
# 解析输出
|
示例:
|
||||||
try:
|
商品名称:苹果手机iPhone 13 → apple
|
||||||
index = len(output_ids) - output_ids[::-1].index(151668) # 151668是tokenizer中的特殊token
|
商品名称:三星Galaxy S23 → samsung
|
||||||
except ValueError:
|
商品名称:华为Mate 60 → huawei
|
||||||
index = 0
|
商品名称:普通智能手机 → 失败
|
||||||
|
|
||||||
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
|
请分析商品名称:"{goods}"
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 调用Ollama
|
||||||
|
extracted_brand = call_ollama(prompt)
|
||||||
|
|
||||||
|
# 清理响应
|
||||||
|
extracted_brand = extracted_brand.strip()
|
||||||
|
|
||||||
|
# 移除可能的引号、句号等标点
|
||||||
|
extracted_brand = extracted_brand.strip('"').strip("'").strip("。").strip()
|
||||||
|
|
||||||
# 后处理:确保返回的品牌在brand_set中,否则返回"失败"
|
# 后处理:确保返回的品牌在brand_set中,否则返回"失败"
|
||||||
extracted_brand = content.strip()
|
if extracted_brand.lower() not in [b.lower() for b in brand_set]:
|
||||||
print(extracted_brand)
|
# 尝试在goods中直接查找品牌名(不区分大小写)
|
||||||
if extracted_brand not in brand_set:
|
|
||||||
# 尝试在goods中直接查找品牌名(简单匹配)
|
|
||||||
for brand in brand_set:
|
for brand in brand_set:
|
||||||
if brand in goods:
|
if brand.lower() in goods.lower():
|
||||||
extracted_brand = brand
|
extracted_brand = brand
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
extracted_brand = "失败"
|
extracted_brand = "失败"
|
||||||
|
|
||||||
return {"extracted_brand": extracted_brand}
|
return {
|
||||||
|
"extracted_brand": extracted_brand,
|
||||||
|
"goods": goods,
|
||||||
|
"available_brands": list(brand_set)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# 测试用的GET端点(可选)
|
# 添加一个测试Ollama连接的端点
|
||||||
|
@app.get("/test_ollama")
|
||||||
|
async def test_ollama():
|
||||||
|
"""测试Ollama连接和模型可用性"""
|
||||||
|
try:
|
||||||
|
# 测试一个简单的请求
|
||||||
|
test_prompt = "请回答:你好,世界!"
|
||||||
|
response = call_ollama(test_prompt)
|
||||||
|
|
||||||
|
# 同时获取Ollama中的模型列表
|
||||||
|
models_url = f"{OLLAMA_BASE_URL}/api/tags"
|
||||||
|
models_response = requests.get(models_url)
|
||||||
|
models_data = models_response.json() if models_response.status_code == 200 else {}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"ollama_response": response[:50] + "..." if len(response) > 50 else response,
|
||||||
|
"available_models": [model.get("name") for model in models_data.get("models", [])],
|
||||||
|
"current_model": MODEL_NAME
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Ollama测试失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# 添加模型切换端点(可选)
|
||||||
|
@app.post("/switch_model/")
|
||||||
|
async def switch_model(model_name: str):
|
||||||
|
"""动态切换Ollama模型"""
|
||||||
|
global MODEL_NAME
|
||||||
|
old_model = MODEL_NAME
|
||||||
|
|
||||||
|
# 验证模型是否存在
|
||||||
|
try:
|
||||||
|
url = f"{OLLAMA_BASE_URL}/api/tags"
|
||||||
|
response = requests.get(url)
|
||||||
|
models = [model.get("name") for model in response.json().get("models", [])]
|
||||||
|
|
||||||
|
if model_name not in models:
|
||||||
|
raise HTTPException(status_code=400, detail=f"模型 '{model_name}' 不存在。可用模型: {models}")
|
||||||
|
|
||||||
|
MODEL_NAME = model_name
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"message": f"模型已从 '{old_model}' 切换到 '{model_name}'"
|
||||||
|
}
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"无法获取模型列表: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# 测试用的GET端点
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def root():
|
async def root():
|
||||||
return {"message": "Brand Extraction API is running"}
|
return {"message": "Brand Extraction API is running (Ollama version)"}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=5001)
|
||||||
|
|
@ -0,0 +1,81 @@
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from modelscope import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# 初始化模型和tokenizer(全局变量,只加载一次)
|
||||||
|
model_name = "Qwen/Qwen3-0.6B"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype="auto",
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BrandExtractionRequest(BaseModel):
|
||||||
|
goods: str
|
||||||
|
brand_list: str # 改为字符串格式,如 "apple,samsung,huawei"
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/extract_brand/")
|
||||||
|
async def extract_brand(request: BrandExtractionRequest):
|
||||||
|
goods = request.goods
|
||||||
|
brand_set = set([brand.strip() for brand in request.brand_list.split(",")]) # 解析逗号分隔的字符串并去空格
|
||||||
|
|
||||||
|
# 构建prompt
|
||||||
|
prompt = (
|
||||||
|
f"商品名称:{goods}\n"
|
||||||
|
"-只需要返回品牌名字,去掉多余的描述\n"
|
||||||
|
#f"-请在以下品牌中选择:{brand_set}"
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system",
|
||||||
|
"content": '从商品名称中提取品牌名称,需严格匹配预定义的品牌列表。若未找到匹配品牌,返回 "失败" '},
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
]
|
||||||
|
|
||||||
|
# 生成文本
|
||||||
|
text = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=False
|
||||||
|
)
|
||||||
|
|
||||||
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||||
|
generated_ids = model.generate(
|
||||||
|
**model_inputs,
|
||||||
|
max_new_tokens=32768
|
||||||
|
)
|
||||||
|
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
|
||||||
|
|
||||||
|
# 解析输出
|
||||||
|
try:
|
||||||
|
index = len(output_ids) - output_ids[::-1].index(151668) # 151668是tokenizer中的特殊token
|
||||||
|
except ValueError:
|
||||||
|
index = 0
|
||||||
|
|
||||||
|
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
|
||||||
|
|
||||||
|
# 后处理:确保返回的品牌在brand_set中,否则返回"失败"
|
||||||
|
extracted_brand = content.strip()
|
||||||
|
print(extracted_brand)
|
||||||
|
if extracted_brand not in brand_set:
|
||||||
|
# 尝试在goods中直接查找品牌名(简单匹配)
|
||||||
|
for brand in brand_set:
|
||||||
|
if brand in goods:
|
||||||
|
extracted_brand = brand
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
extracted_brand = "失败"
|
||||||
|
|
||||||
|
return {"extracted_brand": extracted_brand}
|
||||||
|
|
||||||
|
|
||||||
|
# 测试用的GET端点(可选)
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
return {"message": "Brand Extraction API is running"}
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
def main():
|
||||||
|
print("Hello from brandselect!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
[project]
|
||||||
|
name = "brandselect"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = []
|
||||||
Loading…
Reference in New Issue