This commit is contained in:
parent
7845b82718
commit
a1de80892b
160
rag.py
160
rag.py
|
@ -1,106 +1,132 @@
|
||||||
from langchain_community.document_loaders import TextLoader
|
from langchain_community.document_loaders import TextLoader
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from langchain.schema import Document, BaseRetriever
|
from langchain_community.vectorstores import FAISS
|
||||||
from langchain_community.vectorstores import Chroma, FAISS
|
from langchain_community.retrievers import BM25Retriever
|
||||||
from langchain.retrievers import BM25Retriever
|
from langchain.retrievers import EnsembleRetriever
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pydantic_settings import BaseSettings
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from modelscope import AutoModel, AutoTokenizer,AutoModelForCausalLM # 改用 AutoModel
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
# 1. 自定义 EnsembleRetriever(保持不变)
|
|
||||||
class EnsembleRetriever(BaseRetriever):
|
|
||||||
def __init__(self, retrievers: List[BaseRetriever], weights: List[float] = None):
|
|
||||||
self.retrievers = retrievers
|
|
||||||
self.weights = weights or [1.0 / len(retrievers)] * len(retrievers)
|
|
||||||
|
|
||||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
|
||||||
all_results = []
|
|
||||||
for retriever, weight in zip(self.retrievers, self.weights):
|
|
||||||
docs = retriever.get_relevant_documents(query)
|
|
||||||
all_results.extend(docs)
|
|
||||||
# 简单去重
|
|
||||||
seen = set()
|
|
||||||
unique_docs = []
|
|
||||||
for doc in all_results:
|
|
||||||
if doc.page_content not in seen:
|
|
||||||
seen.add(doc.page_content)
|
|
||||||
unique_docs.append(doc)
|
|
||||||
return unique_docs
|
|
||||||
|
|
||||||
|
|
||||||
# 2. 自定义 ModelScopeEmbeddings(改用 AutoModel)
|
|
||||||
class ModelScopeEmbeddings:
|
class ModelScopeEmbeddings:
|
||||||
|
"""ModelScope 模型嵌入生成器"""
|
||||||
|
|
||||||
def __init__(self, model_name: str, device: str = None):
|
def __init__(self, model_name: str, device: str = None):
|
||||||
|
from modelscope import AutoModel, AutoTokenizer
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = "cpu" if device is None else device
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||||
self.model = AutoModel.from_pretrained(
|
self.model = AutoModel.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
device_map=self.device,
|
device_map=self.device,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
use_safetensors=True # 强制使用 safetensors
|
use_safetensors=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def embed_query(self, text: str) -> np.ndarray:
|
def __call__(self, text: str) -> List[float]:
|
||||||
|
"""支持直接调用 embeddings(text)"""
|
||||||
|
return self.embed_query(text)
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
|
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(self.device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = self.model(**inputs)
|
outputs = self.model(**inputs)
|
||||||
embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
|
embeddings = outputs.last_hidden_state.mean(dim=1).cpu().float().numpy()
|
||||||
return embeddings.squeeze(0)
|
return embeddings.squeeze(0).tolist()
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> np.ndarray:
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
return np.array([self.embed_query(text) for text in texts])
|
return [self.embed_query(text) for text in texts]
|
||||||
|
|
||||||
|
|
||||||
# 3. 文档加载与分块(保持不变)
|
def main():
|
||||||
|
# 1. 文档加载与分块
|
||||||
|
print("加载并分块文档...")
|
||||||
|
try:
|
||||||
loader = TextLoader("./lsxd.txt", encoding="utf-8")
|
loader = TextLoader("./lsxd.txt", encoding="utf-8")
|
||||||
pages = loader.load()
|
pages = loader.load()
|
||||||
|
except FileNotFoundError:
|
||||||
|
print("错误:未找到文档文件 './lsxd.txt'")
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载文档时出错: {str(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=200)
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=200)
|
||||||
docs = text_splitter.split_documents(pages)
|
docs = text_splitter.split_documents(pages)
|
||||||
|
print(f"文档分块完成,共 {len(docs)} 个片段")
|
||||||
|
|
||||||
# 4. 初始化自定义嵌入模型和向量存储
|
# 2. 初始化嵌入模型和向量存储
|
||||||
embeddings = ModelScopeEmbeddings(model_name="BAAI/bge-large-zh-v1.5") # 改用嵌入专用模型
|
print("初始化嵌入模型和向量数据库...")
|
||||||
|
embeddings = None
|
||||||
|
try:
|
||||||
|
embeddings = ModelScopeEmbeddings(model_name="AI-ModelScope/bge-large-zh-v1.5", device="cpu")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"初始化嵌入模型时出错: {str(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
# 初始化 Chroma 和 FAISS
|
faiss_store = None
|
||||||
vector_store = Chroma.from_documents(docs, embeddings) # 主存储
|
try:
|
||||||
|
# 使用 FAISS 存储向量
|
||||||
faiss_store = FAISS.from_documents(docs, embeddings)
|
faiss_store = FAISS.from_documents(docs, embeddings)
|
||||||
faiss_retriever = faiss_store.as_retriever(search_type="similarity", search_kwargs={"k": 5})
|
faiss_retriever = faiss_store.as_retriever(search_type="similarity", search_kwargs={"k": 5})
|
||||||
|
except Exception as e:
|
||||||
|
print(f"初始化向量存储时出错: {str(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
# 5. 混合检索器(BM25Retriever)
|
# 3. BM25 检索器
|
||||||
|
print("初始化 BM25 检索器...")
|
||||||
|
bm25_retriever = None
|
||||||
|
try:
|
||||||
bm25_retriever = BM25Retriever.from_documents(docs)
|
bm25_retriever = BM25Retriever.from_documents(docs)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"初始化 BM25 检索器时出错: {str(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 4. 混合检索器
|
||||||
|
print("初始化混合检索器...")
|
||||||
|
ensemble_retriever = None
|
||||||
|
try:
|
||||||
ensemble_retriever = EnsembleRetriever(
|
ensemble_retriever = EnsembleRetriever(
|
||||||
retrievers=[bm25_retriever, faiss_retriever],
|
retrievers=[bm25_retriever, faiss_retriever],
|
||||||
weights=[0.3, 0.7]
|
weights=[0.3, 0.7]
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"初始化混合检索器时出错: {str(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 5. 加载 Qwen 大模型
|
||||||
|
print("加载大语言模型...")
|
||||||
|
model = None
|
||||||
|
tokenizer = None
|
||||||
|
device = "auto"
|
||||||
|
|
||||||
# 6. 模型配置(使用 ModelScope 的 Qwen 模型)
|
try:
|
||||||
class Config(BaseSettings):
|
from modelscope import AutoTokenizer, AutoModelForCausalLM
|
||||||
model_name: str = Field("Qwen/Qwen-7B-Chat", description="模型名称")
|
model_name = "Qwen/Qwen3-4B-AWQ"
|
||||||
device: str = Field("cuda" if torch.cuda.is_available() else "cpu", description="运行设备")
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
|
||||||
|
|
||||||
config = Config()
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name, trust_remote_code=True)
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
config.model_name,
|
model_name,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
device_map=config.device,
|
device_map=device,
|
||||||
trust_remote_code=True
|
trust_remote_code=True
|
||||||
)
|
)
|
||||||
|
model.eval()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载大模型时出错: {str(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 6. 查询与生成
|
||||||
# 7. 查询与生成(保持不变)
|
|
||||||
def generate_response(query: str) -> str:
|
def generate_response(query: str) -> str:
|
||||||
|
if not ensemble_retriever or not model or not tokenizer:
|
||||||
|
return "系统初始化未完成,无法处理请求"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用混合检索器获取相关文档
|
||||||
results = ensemble_retriever.get_relevant_documents(query)
|
results = ensemble_retriever.get_relevant_documents(query)
|
||||||
context = "\n".join([f"文档片段:{doc.page_content[:500]}..." for doc in results[:3]])
|
context = "\n".join([f"文档片段:{doc.page_content[:500]}..." for doc in results[:3]])
|
||||||
prompt = f"""你是一个智能助手,请根据以下上下文回答用户问题。若信息不足,请回答“我不知道”。
|
|
||||||
|
# 构造 Prompt
|
||||||
|
prompt = f"""你是一个智能助手,请根据以下上下文回答用户问题。若信息不足,请回答"我不知道"。
|
||||||
|
|
||||||
用户问题:{query}
|
用户问题:{query}
|
||||||
|
|
||||||
|
@ -108,7 +134,9 @@ def generate_response(query: str) -> str:
|
||||||
{context}
|
{context}
|
||||||
|
|
||||||
回答:"""
|
回答:"""
|
||||||
inputs = tokenizer(prompt, return_tensors="pt").to(config.device)
|
# 生成回答
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
||||||
|
with torch.no_grad():
|
||||||
outputs = model.generate(
|
outputs = model.generate(
|
||||||
inputs.input_ids,
|
inputs.input_ids,
|
||||||
max_new_tokens=512,
|
max_new_tokens=512,
|
||||||
|
@ -120,10 +148,22 @@ def generate_response(query: str) -> str:
|
||||||
)
|
)
|
||||||
response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
|
response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
|
||||||
return response.strip()
|
return response.strip()
|
||||||
|
except Exception as e:
|
||||||
|
return f"生成回答时出错: {str(e)}"
|
||||||
|
|
||||||
# 示例查询
|
# 示例查询
|
||||||
if __name__ == "__main__":
|
print("系统准备就绪,可以开始提问!")
|
||||||
query = "蓝色兄弟是一家怎样的公司?"
|
while True:
|
||||||
|
query = input("\n请输入问题(输入 '退出' 结束):")
|
||||||
|
if query.strip().lower() == "退出":
|
||||||
|
break
|
||||||
|
if not query.strip():
|
||||||
|
print("请输入有效问题!")
|
||||||
|
continue
|
||||||
|
|
||||||
answer = generate_response(query)
|
answer = generate_response(query)
|
||||||
print("AI回答:", answer)
|
print("AI回答:", answer)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -27,6 +27,9 @@ edge-tts
|
||||||
langchain-community
|
langchain-community
|
||||||
sentence-transformers
|
sentence-transformers
|
||||||
chromadb
|
chromadb
|
||||||
|
faiss-cpu
|
||||||
|
rank_bm25
|
||||||
|
# conda install -c pytorch faiss-gpu
|
||||||
|
|
||||||
#pip install chromadb -i https://mirrors.aliyun.com/pypi/simple/
|
#pip install autoawq -i https://mirrors.aliyun.com/pypi/simple/
|
||||||
#modelscope download --model qwen/Qwen-1_8B --local_dir ./qwen
|
#modelscope download --model qwen/Qwen-1_8B --local_dir ./qwen
|
Loading…
Reference in New Issue