Skip to content

核心功能实现

概览

这一节实现 RAG 的核心部分:

  1. 加载文档
  2. 切分文本
  3. 向量化并存储
  4. 检索相关片段
  5. 生成答案

文档加载

创建 rag.py,添加文档加载逻辑:

python
from langchain_community.document_loaders import (
    PyPDFLoader,
    Docx2txtLoader,
    TextLoader
)

def load_document(file_path):
    """根据文件类型选择合适的加载器"""
    if file_path.endswith('.pdf'):
        loader = PyPDFLoader(file_path)
    elif file_path.endswith('.docx'):
        loader = Docx2txtLoader(file_path)
    elif file_path.endswith('.txt'):
        loader = TextLoader(file_path, encoding='utf-8')
    else:
        raise ValueError(f"不支持的文件格式: {file_path}")

    documents = loader.load()
    return documents

测试:

python
# 测试代码
docs = load_document("data/sample.txt")
print(f"加载了 {len(docs)} 个文档片段")
print(f"第一个片段: {docs[0].page_content[:100]}...")

文本切分

python
from langchain.text_splitter import RecursiveCharacterTextSplitter
from config import config

def split_documents(documents):
    """将文档切分成小块"""
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=config.CHUNK_SIZE,      # 每块 500 字
        chunk_overlap=config.CHUNK_OVERLAP, # 重叠 50 字
        separators=["\n\n", "\n", "。", "!", "?", " ", ""]
    )

    splits = text_splitter.split_documents(documents)
    return splits

测试:

python
splits = split_documents(docs)
print(f"切分成 {len(splits)} 块")
print(f"第一块: {splits[0].page_content}")

向量化并存储

python
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_community.embeddings import DashScopeEmbeddings
from config import config

def get_embeddings():
    """获取 embedding 模型"""
    if config.LLM_PROVIDER == "openai":
        return OpenAIEmbeddings(openai_api_key=config.OPENAI_API_KEY)
    elif config.LLM_PROVIDER == "dashscope":
        return DashScopeEmbeddings(
            dashscope_api_key=config.DASHSCOPE_API_KEY
        )
    else:
        raise ValueError(f"不支持的 LLM Provider: {config.LLM_PROVIDER}")

def create_vector_store(splits):
    """创建向量数据库"""
    embeddings = get_embeddings()

    vectorstore = Chroma.from_documents(
        documents=splits,
        embedding=embeddings,
        persist_directory=config.CHROMA_PERSIST_DIRECTORY
    )

    return vectorstore

def load_vector_store():
    """加载已存在的向量数据库"""
    embeddings = get_embeddings()

    vectorstore = Chroma(
        persist_directory=config.CHROMA_PERSIST_DIRECTORY,
        embedding_function=embeddings
    )

    return vectorstore

检索相关片段

python
def retrieve_documents(vectorstore, query, k=3):
    """检索与问题最相关的文档片段"""
    retriever = vectorstore.as_retriever(
        search_kwargs={"k": k}
    )

    relevant_docs = retriever.get_relevant_documents(query)
    return relevant_docs

生成答案

python
from langchain.chains import RetrievalQA
from langchain_openai import ChatOpenAI
from langchain_community.chat_models.tongyi import ChatTongyi
from config import config

def get_llm():
    """获取大模型"""
    if config.LLM_PROVIDER == "openai":
        return ChatOpenAI(
            model=config.LLM_MODEL,
            api_key=config.OPENAI_API_KEY,
            temperature=0
        )
    elif config.LLM_PROVIDER == "dashscope":
        return ChatTongyi(
            model=config.LLM_MODEL,
            dashscope_api_key=config.DASHSCOPE_API_KEY
        )
    else:
        raise ValueError(f"不支持的 LLM Provider: {config.LLM_PROVIDER}")

def create_qa_chain(vectorstore):
    """创建问答链"""
    llm = get_llm()

    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=vectorstore.as_retriever(
            search_kwargs={"k": config.TOP_K}
        ),
        return_source_documents=True  # 返回来源文档
    )

    return qa_chain

完整流程

创建一个函数封装整个流程:

python
def answer_question(question, vectorstore):
    """完整的问答流程"""
    # 创建问答链
    qa_chain = create_qa_chain(vectorstore)

    # 获取答案
    result = qa_chain({"query": question})

    # 提取答案和来源
    answer = result["result"]
    source_docs = result["source_documents"]

    # 格式化来源
    sources = []
    for doc in source_docs:
        source_info = {
            "content": doc.page_content[:200] + "...",
            "metadata": doc.metadata
        }
        sources.append(source_info)

    return {
        "answer": answer,
        "sources": sources
    }

测试整个流程

创建 test_rag.py

python
from rag import (
    load_document,
    split_documents,
    create_vector_store,
    answer_question
)

# 1. 加载文档
print("加载文档...")
docs = load_document("data/sample.pdf")
print(f"加载了 {len(docs)} 个片段")

# 2. 切分
print("切分文档...")
splits = split_documents(docs)
print(f"切分成 {len(splits)} 块")

# 3. 创建向量库
print("创建向量库...")
vectorstore = create_vector_store(splits)
print("向量库创建完成")

# 4. 测试问答
question = "这个文档主要讲了什么?"
print(f"\n问题: {question}")

result = answer_question(question, vectorstore)
print(f"\n答案: {result['answer']}")
print(f"\n来源:")
for i, source in enumerate(result['sources'], 1):
    print(f"{i}. {source['content']}")

运行:

bash
python test_rag.py

性能优化

1. 持久化向量库

避免每次都重新向量化:

python
# 第一次运行后,向量库会保存到磁盘
# 后续直接加载即可
vectorstore = load_vector_store()

2. 批量处理

一次处理多个文档:

python
def process_multiple_documents(file_paths):
    """批量处理文档"""
    all_splits = []

    for file_path in file_paths:
        docs = load_document(file_path)
        splits = split_documents(docs)
        all_splits.extend(splits)

    vectorstore = create_vector_store(all_splits)
    return vectorstore

3. 添加文档而不是重建

python
def add_documents(file_paths, vectorstore):
    """添加新文档到已有向量库"""
    for file_path in file_paths:
        docs = load_document(file_path)
        splits = split_documents(docs)
        vectorstore.add_documents(splits)

    return vectorstore

常见问题

1. 向量化很慢

如果是大文档(100 页以上),向量化可能需要几分钟。可以:

  • 减小 chunk_size(比如从 500 改成 300)
  • 换更快的 embedding 模型
  • 考虑用 Pinecone 这样的托管服务

2. 检索结果不相关

尝试:

  • 增加 TOP_K(比如从 3 改成 5)
  • 调整 chunk_size
  • 使用更高级的检索策略(如 MMR)
python
retriever = vectorstore.as_retriever(
    search_type="mmr",  # 最大边际相关性
    search_kwargs={"k": 5}
)

3. 内存不足

如果是超大文档(10MB 以上),考虑:

  • 分批处理
  • 使用 Chroma 的持久化模式
  • 换用更强大的向量库(如 Pinecone)

下一步

核心功能完成了,接下来做一个简单的用户界面。

继续:前端界面 →


← 返回环境搭建 | 返回项目一

最近更新

基于 Apache 2.0 许可发布