核心功能实现
概览
这一节实现 RAG 的核心部分:
- 加载文档
- 切分文本
- 向量化并存储
- 检索相关片段
- 生成答案
文档加载
创建 rag.py,添加文档加载逻辑:
python
from langchain_community.document_loaders import (
PyPDFLoader,
Docx2txtLoader,
TextLoader
)
def load_document(file_path):
"""根据文件类型选择合适的加载器"""
if file_path.endswith('.pdf'):
loader = PyPDFLoader(file_path)
elif file_path.endswith('.docx'):
loader = Docx2txtLoader(file_path)
elif file_path.endswith('.txt'):
loader = TextLoader(file_path, encoding='utf-8')
else:
raise ValueError(f"不支持的文件格式: {file_path}")
documents = loader.load()
return documents测试:
python
# 测试代码
docs = load_document("data/sample.txt")
print(f"加载了 {len(docs)} 个文档片段")
print(f"第一个片段: {docs[0].page_content[:100]}...")文本切分
python
from langchain.text_splitter import RecursiveCharacterTextSplitter
from config import config
def split_documents(documents):
"""将文档切分成小块"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config.CHUNK_SIZE, # 每块 500 字
chunk_overlap=config.CHUNK_OVERLAP, # 重叠 50 字
separators=["\n\n", "\n", "。", "!", "?", " ", ""]
)
splits = text_splitter.split_documents(documents)
return splits测试:
python
splits = split_documents(docs)
print(f"切分成 {len(splits)} 块")
print(f"第一块: {splits[0].page_content}")向量化并存储
python
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_community.embeddings import DashScopeEmbeddings
from config import config
def get_embeddings():
"""获取 embedding 模型"""
if config.LLM_PROVIDER == "openai":
return OpenAIEmbeddings(openai_api_key=config.OPENAI_API_KEY)
elif config.LLM_PROVIDER == "dashscope":
return DashScopeEmbeddings(
dashscope_api_key=config.DASHSCOPE_API_KEY
)
else:
raise ValueError(f"不支持的 LLM Provider: {config.LLM_PROVIDER}")
def create_vector_store(splits):
"""创建向量数据库"""
embeddings = get_embeddings()
vectorstore = Chroma.from_documents(
documents=splits,
embedding=embeddings,
persist_directory=config.CHROMA_PERSIST_DIRECTORY
)
return vectorstore
def load_vector_store():
"""加载已存在的向量数据库"""
embeddings = get_embeddings()
vectorstore = Chroma(
persist_directory=config.CHROMA_PERSIST_DIRECTORY,
embedding_function=embeddings
)
return vectorstore检索相关片段
python
def retrieve_documents(vectorstore, query, k=3):
"""检索与问题最相关的文档片段"""
retriever = vectorstore.as_retriever(
search_kwargs={"k": k}
)
relevant_docs = retriever.get_relevant_documents(query)
return relevant_docs生成答案
python
from langchain.chains import RetrievalQA
from langchain_openai import ChatOpenAI
from langchain_community.chat_models.tongyi import ChatTongyi
from config import config
def get_llm():
"""获取大模型"""
if config.LLM_PROVIDER == "openai":
return ChatOpenAI(
model=config.LLM_MODEL,
api_key=config.OPENAI_API_KEY,
temperature=0
)
elif config.LLM_PROVIDER == "dashscope":
return ChatTongyi(
model=config.LLM_MODEL,
dashscope_api_key=config.DASHSCOPE_API_KEY
)
else:
raise ValueError(f"不支持的 LLM Provider: {config.LLM_PROVIDER}")
def create_qa_chain(vectorstore):
"""创建问答链"""
llm = get_llm()
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vectorstore.as_retriever(
search_kwargs={"k": config.TOP_K}
),
return_source_documents=True # 返回来源文档
)
return qa_chain完整流程
创建一个函数封装整个流程:
python
def answer_question(question, vectorstore):
"""完整的问答流程"""
# 创建问答链
qa_chain = create_qa_chain(vectorstore)
# 获取答案
result = qa_chain({"query": question})
# 提取答案和来源
answer = result["result"]
source_docs = result["source_documents"]
# 格式化来源
sources = []
for doc in source_docs:
source_info = {
"content": doc.page_content[:200] + "...",
"metadata": doc.metadata
}
sources.append(source_info)
return {
"answer": answer,
"sources": sources
}测试整个流程
创建 test_rag.py:
python
from rag import (
load_document,
split_documents,
create_vector_store,
answer_question
)
# 1. 加载文档
print("加载文档...")
docs = load_document("data/sample.pdf")
print(f"加载了 {len(docs)} 个片段")
# 2. 切分
print("切分文档...")
splits = split_documents(docs)
print(f"切分成 {len(splits)} 块")
# 3. 创建向量库
print("创建向量库...")
vectorstore = create_vector_store(splits)
print("向量库创建完成")
# 4. 测试问答
question = "这个文档主要讲了什么?"
print(f"\n问题: {question}")
result = answer_question(question, vectorstore)
print(f"\n答案: {result['answer']}")
print(f"\n来源:")
for i, source in enumerate(result['sources'], 1):
print(f"{i}. {source['content']}")运行:
bash
python test_rag.py性能优化
1. 持久化向量库
避免每次都重新向量化:
python
# 第一次运行后,向量库会保存到磁盘
# 后续直接加载即可
vectorstore = load_vector_store()2. 批量处理
一次处理多个文档:
python
def process_multiple_documents(file_paths):
"""批量处理文档"""
all_splits = []
for file_path in file_paths:
docs = load_document(file_path)
splits = split_documents(docs)
all_splits.extend(splits)
vectorstore = create_vector_store(all_splits)
return vectorstore3. 添加文档而不是重建
python
def add_documents(file_paths, vectorstore):
"""添加新文档到已有向量库"""
for file_path in file_paths:
docs = load_document(file_path)
splits = split_documents(docs)
vectorstore.add_documents(splits)
return vectorstore常见问题
1. 向量化很慢
如果是大文档(100 页以上),向量化可能需要几分钟。可以:
- 减小 chunk_size(比如从 500 改成 300)
- 换更快的 embedding 模型
- 考虑用 Pinecone 这样的托管服务
2. 检索结果不相关
尝试:
- 增加 TOP_K(比如从 3 改成 5)
- 调整 chunk_size
- 使用更高级的检索策略(如 MMR)
python
retriever = vectorstore.as_retriever(
search_type="mmr", # 最大边际相关性
search_kwargs={"k": 5}
)3. 内存不足
如果是超大文档(10MB 以上),考虑:
- 分批处理
- 使用 Chroma 的持久化模式
- 换用更强大的向量库(如 Pinecone)
下一步
核心功能完成了,接下来做一个简单的用户界面。