from langchain_community.document_loaders import TextLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS from langchain_community.retrievers import BM25Retriever from langchain.retrievers import EnsembleRetriever from typing import List import torch class ModelScopeEmbeddings: """ModelScope 模型嵌入生成器""" def __init__(self, model_name: str, device: str = None): from modelscope import AutoModel, AutoTokenizer self.model_name = model_name self.device = "cpu" if device is None else device self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) self.model = AutoModel.from_pretrained( model_name, torch_dtype=torch.float16, device_map=self.device, trust_remote_code=True, use_safetensors=True ) def __call__(self, text: str) -> List[float]: """支持直接调用 embeddings(text)""" return self.embed_query(text) def embed_query(self, text: str) -> List[float]: inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device) with torch.no_grad(): outputs = self.model(**inputs) embeddings = outputs.last_hidden_state.mean(dim=1).cpu().float().numpy() return embeddings.squeeze(0).tolist() def embed_documents(self, texts: List[str]) -> List[List[float]]: return [self.embed_query(text) for text in texts] def main(): # 1. 文档加载与分块 print("加载并分块文档...") try: loader = TextLoader("./lsxd.txt", encoding="utf-8") pages = loader.load() except FileNotFoundError: print("错误:未找到文档文件 './lsxd.txt'") return except Exception as e: print(f"加载文档时出错: {str(e)}") return text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=200) docs = text_splitter.split_documents(pages) print(f"文档分块完成,共 {len(docs)} 个片段") # 2. 初始化嵌入模型和向量存储 print("初始化嵌入模型和向量数据库...") embeddings = None try: embeddings = ModelScopeEmbeddings(model_name="AI-ModelScope/bge-large-zh-v1.5", device="cpu") except Exception as e: print(f"初始化嵌入模型时出错: {str(e)}") return faiss_store = None try: # 使用 FAISS 存储向量 faiss_store = FAISS.from_documents(docs, embeddings) faiss_retriever = faiss_store.as_retriever(search_type="similarity", search_kwargs={"k": 5}) except Exception as e: print(f"初始化向量存储时出错: {str(e)}") return # 3. BM25 检索器 print("初始化 BM25 检索器...") bm25_retriever = None try: bm25_retriever = BM25Retriever.from_documents(docs) except Exception as e: print(f"初始化 BM25 检索器时出错: {str(e)}") return # 4. 混合检索器 print("初始化混合检索器...") ensemble_retriever = None try: ensemble_retriever = EnsembleRetriever( retrievers=[bm25_retriever, faiss_retriever], weights=[0.3, 0.7] ) except Exception as e: print(f"初始化混合检索器时出错: {str(e)}") return # 5. 加载 Qwen 大模型 print("加载大语言模型...") model = None tokenizer = None device = "auto" try: from modelscope import AutoTokenizer, AutoModelForCausalLM model_name = "Qwen/Qwen3-4B-AWQ" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map=device, trust_remote_code=True ) model.eval() except Exception as e: print(f"加载大模型时出错: {str(e)}") return # 6. 查询与生成 def generate_response(query: str) -> str: if not ensemble_retriever or not model or not tokenizer: return "系统初始化未完成,无法处理请求" try: # 使用混合检索器获取相关文档 results = ensemble_retriever.get_relevant_documents(query) context = "\n".join([f"文档片段:{doc.page_content[:500]}..." for doc in results[:3]]) # 构造 Prompt prompt = f"""你是一个智能助手,请根据以下上下文回答用户问题。若信息不足,请回答"我不知道"。 用户问题:{query} 上下文信息: {context} 回答:""" # 生成回答 inputs = tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): outputs = model.generate( inputs.input_ids, max_new_tokens=512, temperature=0.3, repetition_penalty=1.1, do_sample=True, top_p=0.9, eos_token_id=tokenizer.eos_token_id ) response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True) return response.strip() except Exception as e: return f"生成回答时出错: {str(e)}" # 示例查询 print("系统准备就绪,可以开始提问!") while True: query = input("\n请输入问题(输入 '退出' 结束):") if query.strip().lower() == "退出": break if not query.strip(): print("请输入有效问题!") continue answer = generate_response(query) print("AI回答:", answer) if __name__ == "__main__": main()