From a1de80892b8f2e7d32b453cb13cee93972b59f80 Mon Sep 17 00:00:00 2001 From: renzhiyuan <465386466@qq.com> Date: Thu, 21 Aug 2025 04:35:54 +0800 Subject: [PATCH] 1 --- rag.py | 214 ++++++++++++++++++++++++++++------------------- requirements.txt | 5 +- 2 files changed, 131 insertions(+), 88 deletions(-) diff --git a/rag.py b/rag.py index fb43746..7185892 100644 --- a/rag.py +++ b/rag.py @@ -1,106 +1,132 @@ from langchain_community.document_loaders import TextLoader from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain.schema import Document, BaseRetriever -from langchain_community.vectorstores import Chroma, FAISS -from langchain.retrievers import BM25Retriever +from langchain_community.vectorstores import FAISS +from langchain_community.retrievers import BM25Retriever +from langchain.retrievers import EnsembleRetriever from typing import List - -from pydantic_settings import BaseSettings -from pydantic import BaseModel, Field -from modelscope import AutoModel, AutoTokenizer,AutoModelForCausalLM # 改用 AutoModel import torch -import numpy as np -# 1. 自定义 EnsembleRetriever(保持不变) -class EnsembleRetriever(BaseRetriever): - def __init__(self, retrievers: List[BaseRetriever], weights: List[float] = None): - self.retrievers = retrievers - self.weights = weights or [1.0 / len(retrievers)] * len(retrievers) - - def get_relevant_documents(self, query: str) -> List[Document]: - all_results = [] - for retriever, weight in zip(self.retrievers, self.weights): - docs = retriever.get_relevant_documents(query) - all_results.extend(docs) - # 简单去重 - seen = set() - unique_docs = [] - for doc in all_results: - if doc.page_content not in seen: - seen.add(doc.page_content) - unique_docs.append(doc) - return unique_docs - - -# 2. 自定义 ModelScopeEmbeddings(改用 AutoModel) class ModelScopeEmbeddings: + """ModelScope 模型嵌入生成器""" + def __init__(self, model_name: str, device: str = None): + from modelscope import AutoModel, AutoTokenizer self.model_name = model_name - self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + self.device = "cpu" if device is None else device self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) self.model = AutoModel.from_pretrained( model_name, torch_dtype=torch.float16, device_map=self.device, trust_remote_code=True, - use_safetensors=True # 强制使用 safetensors + use_safetensors=True ) - def embed_query(self, text: str) -> np.ndarray: + def __call__(self, text: str) -> List[float]: + """支持直接调用 embeddings(text)""" + return self.embed_query(text) + + def embed_query(self, text: str) -> List[float]: inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) - embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy() - return embeddings.squeeze(0) + embeddings = outputs.last_hidden_state.mean(dim=1).cpu().float().numpy() + return embeddings.squeeze(0).tolist() - def embed_documents(self, texts: List[str]) -> np.ndarray: - return np.array([self.embed_query(text) for text in texts]) + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return [self.embed_query(text) for text in texts] -# 3. 文档加载与分块(保持不变) -loader = TextLoader("./lsxd.txt", encoding="utf-8") -pages = loader.load() -text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=200) -docs = text_splitter.split_documents(pages) +def main(): + # 1. 文档加载与分块 + print("加载并分块文档...") + try: + loader = TextLoader("./lsxd.txt", encoding="utf-8") + pages = loader.load() + except FileNotFoundError: + print("错误:未找到文档文件 './lsxd.txt'") + return + except Exception as e: + print(f"加载文档时出错: {str(e)}") + return -# 4. 初始化自定义嵌入模型和向量存储 -embeddings = ModelScopeEmbeddings(model_name="BAAI/bge-large-zh-v1.5") # 改用嵌入专用模型 + text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=200) + docs = text_splitter.split_documents(pages) + print(f"文档分块完成,共 {len(docs)} 个片段") -# 初始化 Chroma 和 FAISS -vector_store = Chroma.from_documents(docs, embeddings) # 主存储 -faiss_store = FAISS.from_documents(docs, embeddings) -faiss_retriever = faiss_store.as_retriever(search_type="similarity", search_kwargs={"k": 5}) + # 2. 初始化嵌入模型和向量存储 + print("初始化嵌入模型和向量数据库...") + embeddings = None + try: + embeddings = ModelScopeEmbeddings(model_name="AI-ModelScope/bge-large-zh-v1.5", device="cpu") + except Exception as e: + print(f"初始化嵌入模型时出错: {str(e)}") + return -# 5. 混合检索器(BM25Retriever) -bm25_retriever = BM25Retriever.from_documents(docs) -ensemble_retriever = EnsembleRetriever( - retrievers=[bm25_retriever, faiss_retriever], - weights=[0.3, 0.7] -) + faiss_store = None + try: + # 使用 FAISS 存储向量 + faiss_store = FAISS.from_documents(docs, embeddings) + faiss_retriever = faiss_store.as_retriever(search_type="similarity", search_kwargs={"k": 5}) + except Exception as e: + print(f"初始化向量存储时出错: {str(e)}") + return + # 3. BM25 检索器 + print("初始化 BM25 检索器...") + bm25_retriever = None + try: + bm25_retriever = BM25Retriever.from_documents(docs) + except Exception as e: + print(f"初始化 BM25 检索器时出错: {str(e)}") + return -# 6. 模型配置(使用 ModelScope 的 Qwen 模型) -class Config(BaseSettings): - model_name: str = Field("Qwen/Qwen-7B-Chat", description="模型名称") - device: str = Field("cuda" if torch.cuda.is_available() else "cpu", description="运行设备") + # 4. 混合检索器 + print("初始化混合检索器...") + ensemble_retriever = None + try: + ensemble_retriever = EnsembleRetriever( + retrievers=[bm25_retriever, faiss_retriever], + weights=[0.3, 0.7] + ) + except Exception as e: + print(f"初始化混合检索器时出错: {str(e)}") + return + # 5. 加载 Qwen 大模型 + print("加载大语言模型...") + model = None + tokenizer = None + device = "auto" -config = Config() -tokenizer = AutoTokenizer.from_pretrained(config.model_name, trust_remote_code=True) -model = AutoModelForCausalLM.from_pretrained( - config.model_name, - torch_dtype=torch.float16, - device_map=config.device, - trust_remote_code=True -) + try: + from modelscope import AutoTokenizer, AutoModelForCausalLM + model_name = "Qwen/Qwen3-4B-AWQ" + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.float16, + device_map=device, + trust_remote_code=True + ) + model.eval() + except Exception as e: + print(f"加载大模型时出错: {str(e)}") + return + # 6. 查询与生成 + def generate_response(query: str) -> str: + if not ensemble_retriever or not model or not tokenizer: + return "系统初始化未完成,无法处理请求" -# 7. 查询与生成(保持不变) -def generate_response(query: str) -> str: - results = ensemble_retriever.get_relevant_documents(query) - context = "\n".join([f"文档片段:{doc.page_content[:500]}..." for doc in results[:3]]) - prompt = f"""你是一个智能助手,请根据以下上下文回答用户问题。若信息不足,请回答“我不知道”。 + try: + # 使用混合检索器获取相关文档 + results = ensemble_retriever.get_relevant_documents(query) + context = "\n".join([f"文档片段:{doc.page_content[:500]}..." for doc in results[:3]]) + + # 构造 Prompt + prompt = f"""你是一个智能助手,请根据以下上下文回答用户问题。若信息不足,请回答"我不知道"。 用户问题:{query} @@ -108,22 +134,36 @@ def generate_response(query: str) -> str: {context} 回答:""" - inputs = tokenizer(prompt, return_tensors="pt").to(config.device) - outputs = model.generate( - inputs.input_ids, - max_new_tokens=512, - temperature=0.3, - repetition_penalty=1.1, - do_sample=True, - top_p=0.9, - eos_token_id=tokenizer.eos_token_id - ) - response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True) - return response.strip() + # 生成回答 + inputs = tokenizer(prompt, return_tensors="pt").to(device) + with torch.no_grad(): + outputs = model.generate( + inputs.input_ids, + max_new_tokens=512, + temperature=0.3, + repetition_penalty=1.1, + do_sample=True, + top_p=0.9, + eos_token_id=tokenizer.eos_token_id + ) + response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True) + return response.strip() + except Exception as e: + return f"生成回答时出错: {str(e)}" + + # 示例查询 + print("系统准备就绪,可以开始提问!") + while True: + query = input("\n请输入问题(输入 '退出' 结束):") + if query.strip().lower() == "退出": + break + if not query.strip(): + print("请输入有效问题!") + continue + + answer = generate_response(query) + print("AI回答:", answer) -# 示例查询 if __name__ == "__main__": - query = "蓝色兄弟是一家怎样的公司?" - answer = generate_response(query) - print("AI回答:", answer) \ No newline at end of file + main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 28a31dd..ec5cff9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,6 +27,9 @@ edge-tts langchain-community sentence-transformers chromadb +faiss-cpu +rank_bm25 +# conda install -c pytorch faiss-gpu -#pip install chromadb -i https://mirrors.aliyun.com/pypi/simple/ +#pip install autoawq -i https://mirrors.aliyun.com/pypi/simple/ #modelscope download --model qwen/Qwen-1_8B --local_dir ./qwen \ No newline at end of file