LangChain RAG 检索增强生成
RAG (Retrieval-Augmented Generation) 是让 LLM 基于你的私有数据回答问题的技术,是构建知识库问答系统的核心。
什么是 RAG?
RAG 的工作流程:
1. 文档预处理
文档 → 分割 → 向量化 → 存储到向量数据库
2. 查询时
问题 → 向量化 → 检索相关文档 → 结合上下文 → LLM 生成答案
快速开始
最简单的 RAG
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
# 1. 加载文档
loader = TextLoader("./knowledge.txt", encoding="utf-8")
documents = loader.load()
# 2. 分割文档
splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50
)
chunks = splitter.split_documents(documents)
# 3. 创建向量存储
embeddings = OpenAIEmbeddings()
vectorstore = Chroma.from_documents(chunks, embeddings)
# 4. 创建检索器
retriever = vectorstore.as_retriever()
# 5. 创建 RAG Chain
prompt = ChatPromptTemplate.from_template("""
基于以下上下文回答问题。如果上下文中没有相关信息,请说"我不知道"。
上下文:
{context}
问题:{question}
回答:
""")
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| ChatOpenAI()
| StrOutputParser()
)
# 6. 使用
answer = rag_chain.invoke("这个项目是做什么的?")
print(answer)
文档加载
常用加载器
# 文本文件
from langchain_community.document_loaders import TextLoader
loader = TextLoader("./doc.txt")
# PDF
from langchain_community.document_loaders import PyPDFLoader
loader = PyPDFLoader("./doc.pdf")
# Word
from langchain_community.document_loaders import Docx2txtLoader
loader = Docx2txtLoader("./doc.docx")
# Markdown
from langchain_community.document_loaders import UnstructuredMarkdownLoader
loader = UnstructuredMarkdownLoader("./doc.md")
# HTML
from langchain_community.document_loaders import UnstructuredHTMLLoader
loader = UnstructuredHTMLLoader("./page.html")
# 网页
from langchain_community.document_loaders import WebBaseLoader
loader = WebBaseLoader("https://example.com")
# CSV
from langchain_community.document_loaders.csv_loader import CSVLoader
loader = CSVLoader("./data.csv")
# 目录批量加载
from langchain_community.document_loaders import DirectoryLoader
loader = DirectoryLoader("./docs/", glob="**/*.md")
加载文档
documents = loader.load()
# 查看文档内容
for doc in documents:
print(f"内容: {doc.page_content[:100]}...")
print(f"元数据: {doc.metadata}")
文本分割
RecursiveCharacterTextSplitter (推荐)
from langchain_text_splitters import RecursiveCharacterTextSplitter
splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, # 块大小
chunk_overlap=200, # 重叠大小
length_function=len,
separators=["\n\n", "\n", " ", ""] # 分割顺序
)
chunks = splitter.split_documents(documents)
print(f"分割为 {len(chunks)} 个块")
代码分割
from langchain_text_splitters import (
RecursiveCharacterTextSplitter,
Language
)
python_splitter = RecursiveCharacterTextSplitter.from_language(
language=Language.PYTHON,
chunk_size=2000,
chunk_overlap=200
)
# 支持的语言
# Language.PYTHON, Language.JS, Language.TS, Language.JAVA, Language.GO, ...
Markdown 分割
from langchain_text_splitters import MarkdownHeaderTextSplitter
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
]
splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=headers_to_split_on
)
splits = splitter.split_text(markdown_document)
向量存储
Chroma (本地)
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
# 创建
vectorstore = Chroma.from_documents(
documents=chunks,
embedding=embeddings,
persist_directory="./chroma_db" # 持久化
)
# 加载已有
vectorstore = Chroma(
persist_directory="./chroma_db",
embedding_function=embeddings
)
Pinecone (云端)
from langchain_pinecone import PineconeVectorStore
from pinecone import Pinecone
pc = Pinecone(api_key="your-api-key")
vectorstore = PineconeVectorStore.from_documents(
documents=chunks,
embedding=embeddings,
index_name="my-index"
)
FAISS (高性能本地)
from langchain_community.vectorstores import FAISS
vectorstore = FAISS.from_documents(chunks, embeddings)
# 保存
vectorstore.save_local("./faiss_index")
# 加载
vectorstore = FAISS.load_local("./faiss_index", embeddings)
检索器
基础检索
retriever = vectorstore.as_retriever(
search_type="similarity", # 相似度搜索
search_kwargs={"k": 5} # 返回 5 个结果
)
docs = retriever.invoke("问题")
MMR 检索 (多样性)
retriever = vectorstore.as_retriever(
search_type="mmr", # Maximum Marginal Relevance
search_kwargs={
"k": 5,
"fetch_k": 20, # 先检索 20 个,再筛选 5 个
"lambda_mult": 0.5 # 多样性参数
}
)
带阈值的检索
retriever = vectorstore.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"score_threshold": 0.8, # 只返回相似度 > 0.8 的
"k": 5
}
)
混合检索
from langchain.retrievers import BM25Retriever, EnsembleRetriever
# 向量检索
vector_retriever = vectorstore.as_retriever()
# 关键词检索 (BM25)
bm25_retriever = BM25Retriever.from_documents(chunks)
# 混合检索
ensemble_retriever = EnsembleRetriever(
retrievers=[vector_retriever, bm25_retriever],
weights=[0.5, 0.5] # 权重
)
完整 RAG Pipeline
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
class RAGSystem:
def __init__(self, docs_path: str):
self.embeddings = OpenAIEmbeddings()
self.llm = ChatOpenAI(model="gpt-4o")
self.vectorstore = None
self._load_documents(docs_path)
self._create_chain()
def _load_documents(self, path: str):
# 加载文档
loader = DirectoryLoader(path, glob="**/*.txt", loader_cls=TextLoader)
documents = loader.load()
# 分割
splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
chunks = splitter.split_documents(documents)
# 创建向量存储
self.vectorstore = Chroma.from_documents(
chunks,
self.embeddings,
persist_directory="./rag_db"
)
def _create_chain(self):
retriever = self.vectorstore.as_retriever(
search_kwargs={"k": 5}
)
prompt = ChatPromptTemplate.from_template("""
你是一个专业的问答助手。请基于以下上下文回答问题。
上下文:
{context}
问题:{question}
要求:
1. 只基于上下文回答,不要编造
2. 如果上下文不包含答案,请说"我在文档中没有找到相关信息"
3. 回答要简洁明了
回答:
""")
def format_docs(docs):
return "\n\n---\n\n".join([
f"[来源: {doc.metadata.get('source', '未知')}]\n{doc.page_content}"
for doc in docs
])
self.chain = (
RunnableParallel(
context=retriever | format_docs,
question=RunnablePassthrough()
)
| prompt
| self.llm
| StrOutputParser()
)
def query(self, question: str) -> str:
return self.chain.invoke(question)
def add_documents(self, texts: list[str]):
"""添加新文档"""
self.vectorstore.add_texts(texts)
# 使用
rag = RAGSystem("./documents/")
answer = rag.query("这个项目的主要功能是什么?")
print(answer)
优化技巧
1. 优化分块
# 针对不同内容调整分块大小
# 技术文档:较大块 (1000-2000)
# 对话/问答:较小块 (200-500)
# 代码:按函数/类分割
splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200, # 重叠保持上下文
)
2. 添加元数据
# 在分割后添加元数据
for chunk in chunks:
chunk.metadata["source"] = "user_manual"
chunk.metadata["category"] = "tutorial"
chunk.metadata["date"] = "2024-01-01"
# 检索时可以过滤
retriever = vectorstore.as_retriever(
search_kwargs={
"k": 5,
"filter": {"category": "tutorial"}
}
)
3. 重排序
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CohereRerank
# 使用 Cohere 重排序
compressor = CohereRerank()
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=retriever
)
4. 查询转换
from langchain.retrievers.multi_query import MultiQueryRetriever
# 将一个问题转换为多个查询
multi_query_retriever = MultiQueryRetriever.from_llm(
retriever=retriever,
llm=llm
)
下一步
提示:RAG 的效果很大程度取决于分块策略和检索质量,需要根据具体场景调优。