ai_test/rag.py

129 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document, BaseRetriever
from langchain_community.vectorstores import Chroma, FAISS
from langchain.retrievers import BM25Retriever
from typing import List
from pydantic_settings import BaseSettings
from pydantic import BaseModel, Field
from modelscope import AutoModel, AutoTokenizer,AutoModelForCausalLM # 改用 AutoModel
import torch
import numpy as np
# 1. 自定义 EnsembleRetriever保持不变
class EnsembleRetriever(BaseRetriever):
def __init__(self, retrievers: List[BaseRetriever], weights: List[float] = None):
self.retrievers = retrievers
self.weights = weights or [1.0 / len(retrievers)] * len(retrievers)
def get_relevant_documents(self, query: str) -> List[Document]:
all_results = []
for retriever, weight in zip(self.retrievers, self.weights):
docs = retriever.get_relevant_documents(query)
all_results.extend(docs)
# 简单去重
seen = set()
unique_docs = []
for doc in all_results:
if doc.page_content not in seen:
seen.add(doc.page_content)
unique_docs.append(doc)
return unique_docs
# 2. 自定义 ModelScopeEmbeddings改用 AutoModel
class ModelScopeEmbeddings:
def __init__(self, model_name: str, device: str = None):
self.model_name = model_name
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
self.model = AutoModel.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map=self.device,
trust_remote_code=True,
use_safetensors=True # 强制使用 safetensors
)
def embed_query(self, text: str) -> np.ndarray:
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
return embeddings.squeeze(0)
def embed_documents(self, texts: List[str]) -> np.ndarray:
return np.array([self.embed_query(text) for text in texts])
# 3. 文档加载与分块(保持不变)
loader = TextLoader("./lsxd.txt", encoding="utf-8")
pages = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=200)
docs = text_splitter.split_documents(pages)
# 4. 初始化自定义嵌入模型和向量存储
embeddings = ModelScopeEmbeddings(model_name="BAAI/bge-large-zh-v1.5") # 改用嵌入专用模型
# 初始化 Chroma 和 FAISS
vector_store = Chroma.from_documents(docs, embeddings) # 主存储
faiss_store = FAISS.from_documents(docs, embeddings)
faiss_retriever = faiss_store.as_retriever(search_type="similarity", search_kwargs={"k": 5})
# 5. 混合检索器BM25Retriever
bm25_retriever = BM25Retriever.from_documents(docs)
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, faiss_retriever],
weights=[0.3, 0.7]
)
# 6. 模型配置(使用 ModelScope 的 Qwen 模型)
class Config(BaseSettings):
model_name: str = Field("Qwen/Qwen-7B-Chat", description="模型名称")
device: str = Field("cuda" if torch.cuda.is_available() else "cpu", description="运行设备")
config = Config()
tokenizer = AutoTokenizer.from_pretrained(config.model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
config.model_name,
torch_dtype=torch.float16,
device_map=config.device,
trust_remote_code=True
)
# 7. 查询与生成(保持不变)
def generate_response(query: str) -> str:
results = ensemble_retriever.get_relevant_documents(query)
context = "\n".join([f"文档片段:{doc.page_content[:500]}..." for doc in results[:3]])
prompt = f"""你是一个智能助手,请根据以下上下文回答用户问题。若信息不足,请回答“我不知道”。
用户问题:{query}
上下文信息:
{context}
回答:"""
inputs = tokenizer(prompt, return_tensors="pt").to(config.device)
outputs = model.generate(
inputs.input_ids,
max_new_tokens=512,
temperature=0.3,
repetition_penalty=1.1,
do_sample=True,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
return response.strip()
# 示例查询
if __name__ == "__main__":
query = "蓝色兄弟是一家怎样的公司?"
answer = generate_response(query)
print("AI回答", answer)