ai_test/rag.py

169 lines
5.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from typing import List
import torch
class ModelScopeEmbeddings:
"""ModelScope 模型嵌入生成器"""
def __init__(self, model_name: str, device: str = None):
from modelscope import AutoModel, AutoTokenizer
self.model_name = model_name
self.device = "cpu" if device is None else device
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
self.model = AutoModel.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map=self.device,
trust_remote_code=True,
use_safetensors=True
)
def __call__(self, text: str) -> List[float]:
"""支持直接调用 embeddings(text)"""
return self.embed_query(text)
def embed_query(self, text: str) -> List[float]:
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1).cpu().float().numpy()
return embeddings.squeeze(0).tolist()
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self.embed_query(text) for text in texts]
def main():
# 1. 文档加载与分块
print("加载并分块文档...")
try:
loader = TextLoader("./lsxd.txt", encoding="utf-8")
pages = loader.load()
except FileNotFoundError:
print("错误:未找到文档文件 './lsxd.txt'")
return
except Exception as e:
print(f"加载文档时出错: {str(e)}")
return
text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=200)
docs = text_splitter.split_documents(pages)
print(f"文档分块完成,共 {len(docs)} 个片段")
# 2. 初始化嵌入模型和向量存储
print("初始化嵌入模型和向量数据库...")
embeddings = None
try:
embeddings = ModelScopeEmbeddings(model_name="AI-ModelScope/bge-large-zh-v1.5", device="cpu")
except Exception as e:
print(f"初始化嵌入模型时出错: {str(e)}")
return
faiss_store = None
try:
# 使用 FAISS 存储向量
faiss_store = FAISS.from_documents(docs, embeddings)
faiss_retriever = faiss_store.as_retriever(search_type="similarity", search_kwargs={"k": 5})
except Exception as e:
print(f"初始化向量存储时出错: {str(e)}")
return
# 3. BM25 检索器
print("初始化 BM25 检索器...")
bm25_retriever = None
try:
bm25_retriever = BM25Retriever.from_documents(docs)
except Exception as e:
print(f"初始化 BM25 检索器时出错: {str(e)}")
return
# 4. 混合检索器
print("初始化混合检索器...")
ensemble_retriever = None
try:
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, faiss_retriever],
weights=[0.3, 0.7]
)
except Exception as e:
print(f"初始化混合检索器时出错: {str(e)}")
return
# 5. 加载 Qwen 大模型
print("加载大语言模型...")
model = None
tokenizer = None
device = "auto"
try:
from modelscope import AutoTokenizer, AutoModelForCausalLM
model_name = "Qwen/Qwen3-4B-AWQ"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map=device,
trust_remote_code=True
)
model.eval()
except Exception as e:
print(f"加载大模型时出错: {str(e)}")
return
# 6. 查询与生成
def generate_response(query: str) -> str:
if not ensemble_retriever or not model or not tokenizer:
return "系统初始化未完成,无法处理请求"
try:
# 使用混合检索器获取相关文档
results = ensemble_retriever.get_relevant_documents(query)
context = "\n".join([f"文档片段:{doc.page_content[:500]}..." for doc in results[:3]])
# 构造 Prompt
prompt = f"""你是一个智能助手,请根据以下上下文回答用户问题。若信息不足,请回答"我不知道"
用户问题:{query}
上下文信息:
{context}
回答:"""
# 生成回答
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
max_new_tokens=512,
temperature=0.3,
repetition_penalty=1.1,
do_sample=True,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
return response.strip()
except Exception as e:
return f"生成回答时出错: {str(e)}"
# 示例查询
print("系统准备就绪,可以开始提问!")
while True:
query = input("\n请输入问题(输入 '退出' 结束):")
if query.strip().lower() == "退出":
break
if not query.strip():
print("请输入有效问题!")
continue
answer = generate_response(query)
print("AI回答", answer)
if __name__ == "__main__":
main()