This commit is contained in:
renzhiyuan 2025-08-21 04:35:54 +08:00
parent 7845b82718
commit a1de80892b
2 changed files with 131 additions and 88 deletions

214
rag.py
View File

@ -1,106 +1,132 @@
from langchain_community.document_loaders import TextLoader from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document, BaseRetriever from langchain_community.vectorstores import FAISS
from langchain_community.vectorstores import Chroma, FAISS from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import BM25Retriever from langchain.retrievers import EnsembleRetriever
from typing import List from typing import List
from pydantic_settings import BaseSettings
from pydantic import BaseModel, Field
from modelscope import AutoModel, AutoTokenizer,AutoModelForCausalLM # 改用 AutoModel
import torch import torch
import numpy as np
# 1. 自定义 EnsembleRetriever保持不变
class EnsembleRetriever(BaseRetriever):
def __init__(self, retrievers: List[BaseRetriever], weights: List[float] = None):
self.retrievers = retrievers
self.weights = weights or [1.0 / len(retrievers)] * len(retrievers)
def get_relevant_documents(self, query: str) -> List[Document]:
all_results = []
for retriever, weight in zip(self.retrievers, self.weights):
docs = retriever.get_relevant_documents(query)
all_results.extend(docs)
# 简单去重
seen = set()
unique_docs = []
for doc in all_results:
if doc.page_content not in seen:
seen.add(doc.page_content)
unique_docs.append(doc)
return unique_docs
# 2. 自定义 ModelScopeEmbeddings改用 AutoModel
class ModelScopeEmbeddings: class ModelScopeEmbeddings:
"""ModelScope 模型嵌入生成器"""
def __init__(self, model_name: str, device: str = None): def __init__(self, model_name: str, device: str = None):
from modelscope import AutoModel, AutoTokenizer
self.model_name = model_name self.model_name = model_name
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.device = "cpu" if device is None else device
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
self.model = AutoModel.from_pretrained( self.model = AutoModel.from_pretrained(
model_name, model_name,
torch_dtype=torch.float16, torch_dtype=torch.float16,
device_map=self.device, device_map=self.device,
trust_remote_code=True, trust_remote_code=True,
use_safetensors=True # 强制使用 safetensors use_safetensors=True
) )
def embed_query(self, text: str) -> np.ndarray: def __call__(self, text: str) -> List[float]:
"""支持直接调用 embeddings(text)"""
return self.embed_query(text)
def embed_query(self, text: str) -> List[float]:
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device) inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
with torch.no_grad(): with torch.no_grad():
outputs = self.model(**inputs) outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy() embeddings = outputs.last_hidden_state.mean(dim=1).cpu().float().numpy()
return embeddings.squeeze(0) return embeddings.squeeze(0).tolist()
def embed_documents(self, texts: List[str]) -> np.ndarray: def embed_documents(self, texts: List[str]) -> List[List[float]]:
return np.array([self.embed_query(text) for text in texts]) return [self.embed_query(text) for text in texts]
# 3. 文档加载与分块(保持不变) def main():
loader = TextLoader("./lsxd.txt", encoding="utf-8") # 1. 文档加载与分块
pages = loader.load() print("加载并分块文档...")
text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=200) try:
docs = text_splitter.split_documents(pages) loader = TextLoader("./lsxd.txt", encoding="utf-8")
pages = loader.load()
except FileNotFoundError:
print("错误:未找到文档文件 './lsxd.txt'")
return
except Exception as e:
print(f"加载文档时出错: {str(e)}")
return
# 4. 初始化自定义嵌入模型和向量存储 text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=200)
embeddings = ModelScopeEmbeddings(model_name="BAAI/bge-large-zh-v1.5") # 改用嵌入专用模型 docs = text_splitter.split_documents(pages)
print(f"文档分块完成,共 {len(docs)} 个片段")
# 初始化 Chroma 和 FAISS # 2. 初始化嵌入模型和向量存储
vector_store = Chroma.from_documents(docs, embeddings) # 主存储 print("初始化嵌入模型和向量数据库...")
faiss_store = FAISS.from_documents(docs, embeddings) embeddings = None
faiss_retriever = faiss_store.as_retriever(search_type="similarity", search_kwargs={"k": 5}) try:
embeddings = ModelScopeEmbeddings(model_name="AI-ModelScope/bge-large-zh-v1.5", device="cpu")
except Exception as e:
print(f"初始化嵌入模型时出错: {str(e)}")
return
# 5. 混合检索器BM25Retriever faiss_store = None
bm25_retriever = BM25Retriever.from_documents(docs) try:
ensemble_retriever = EnsembleRetriever( # 使用 FAISS 存储向量
retrievers=[bm25_retriever, faiss_retriever], faiss_store = FAISS.from_documents(docs, embeddings)
weights=[0.3, 0.7] faiss_retriever = faiss_store.as_retriever(search_type="similarity", search_kwargs={"k": 5})
) except Exception as e:
print(f"初始化向量存储时出错: {str(e)}")
return
# 3. BM25 检索器
print("初始化 BM25 检索器...")
bm25_retriever = None
try:
bm25_retriever = BM25Retriever.from_documents(docs)
except Exception as e:
print(f"初始化 BM25 检索器时出错: {str(e)}")
return
# 6. 模型配置(使用 ModelScope 的 Qwen 模型) # 4. 混合检索器
class Config(BaseSettings): print("初始化混合检索器...")
model_name: str = Field("Qwen/Qwen-7B-Chat", description="模型名称") ensemble_retriever = None
device: str = Field("cuda" if torch.cuda.is_available() else "cpu", description="运行设备") try:
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, faiss_retriever],
weights=[0.3, 0.7]
)
except Exception as e:
print(f"初始化混合检索器时出错: {str(e)}")
return
# 5. 加载 Qwen 大模型
print("加载大语言模型...")
model = None
tokenizer = None
device = "auto"
config = Config() try:
tokenizer = AutoTokenizer.from_pretrained(config.model_name, trust_remote_code=True) from modelscope import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained( model_name = "Qwen/Qwen3-4B-AWQ"
config.model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
torch_dtype=torch.float16, model = AutoModelForCausalLM.from_pretrained(
device_map=config.device, model_name,
trust_remote_code=True torch_dtype=torch.float16,
) device_map=device,
trust_remote_code=True
)
model.eval()
except Exception as e:
print(f"加载大模型时出错: {str(e)}")
return
# 6. 查询与生成
def generate_response(query: str) -> str:
if not ensemble_retriever or not model or not tokenizer:
return "系统初始化未完成,无法处理请求"
# 7. 查询与生成(保持不变) try:
def generate_response(query: str) -> str: # 使用混合检索器获取相关文档
results = ensemble_retriever.get_relevant_documents(query) results = ensemble_retriever.get_relevant_documents(query)
context = "\n".join([f"文档片段:{doc.page_content[:500]}..." for doc in results[:3]]) context = "\n".join([f"文档片段:{doc.page_content[:500]}..." for doc in results[:3]])
prompt = f"""你是一个智能助手,请根据以下上下文回答用户问题。若信息不足,请回答“我不知道”。
# 构造 Prompt
prompt = f"""你是一个智能助手,请根据以下上下文回答用户问题。若信息不足,请回答"我不知道"
用户问题{query} 用户问题{query}
@ -108,22 +134,36 @@ def generate_response(query: str) -> str:
{context} {context}
回答""" 回答"""
inputs = tokenizer(prompt, return_tensors="pt").to(config.device) # 生成回答
outputs = model.generate( inputs = tokenizer(prompt, return_tensors="pt").to(device)
inputs.input_ids, with torch.no_grad():
max_new_tokens=512, outputs = model.generate(
temperature=0.3, inputs.input_ids,
repetition_penalty=1.1, max_new_tokens=512,
do_sample=True, temperature=0.3,
top_p=0.9, repetition_penalty=1.1,
eos_token_id=tokenizer.eos_token_id do_sample=True,
) top_p=0.9,
response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True) eos_token_id=tokenizer.eos_token_id
return response.strip() )
response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
return response.strip()
except Exception as e:
return f"生成回答时出错: {str(e)}"
# 示例查询
print("系统准备就绪,可以开始提问!")
while True:
query = input("\n请输入问题(输入 '退出' 结束):")
if query.strip().lower() == "退出":
break
if not query.strip():
print("请输入有效问题!")
continue
answer = generate_response(query)
print("AI回答", answer)
# 示例查询
if __name__ == "__main__": if __name__ == "__main__":
query = "蓝色兄弟是一家怎样的公司?" main()
answer = generate_response(query)
print("AI回答", answer)

View File

@ -27,6 +27,9 @@ edge-tts
langchain-community langchain-community
sentence-transformers sentence-transformers
chromadb chromadb
faiss-cpu
rank_bm25
# conda install -c pytorch faiss-gpu
#pip install chromadb -i https://mirrors.aliyun.com/pypi/simple/ #pip install autoawq -i https://mirrors.aliyun.com/pypi/simple/
#modelscope download --model qwen/Qwen-1_8B --local_dir ./qwen #modelscope download --model qwen/Qwen-1_8B --local_dir ./qwen