跳转至

第六章:记忆系统

记忆系统用于维护对话上下文,让 AI 能够记住之前的对话内容。

记忆系统概述

LangChain 0.3+ 中,记忆系统主要通过以下方式实现:

  1. 消息列表:直接管理消息历史
  2. RunnableWithMessageHistory:LCEL 包装器,自动管理历史
  3. LangGraph:复杂场景推荐使用状态管理

直接管理消息历史

基本方式

from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o-mini")

# 消息历史
messages = [
    SystemMessage("你是一个友好的助手")
]

def chat(user_input: str) -> str:
    # 添加用户消息
    messages.append(HumanMessage(user_input))

    # 调用模型
    response = llm.invoke(messages)

    # 添加 AI 回复
    messages.append(response)

    return response.content

# 测试
print(chat("你好"))           # 你好!有什么可以帮助你的?
print(chat("我刚才说了什么?")) # 你刚才说"你好"

使用 ChatPromptTemplate

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser

prompt = ChatPromptTemplate.from_messages([
    ("system", "你是一个友好的助手"),
    MessagesPlaceholder("history"),
    ("human", "{input}")
])

chain = prompt | llm | StrOutputParser()

history = []

def chat(user_input: str) -> str:
    response = chain.invoke({
        "history": history,
        "input": user_input
    })

    # 更新历史
    history.append(HumanMessage(user_input))
    history.append(AIMessage(response))

    return response

RunnableWithMessageHistory

LangChain 提供的自动历史管理包装器:

基本用法

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory

# 基础链
prompt = ChatPromptTemplate.from_messages([
    ("system", "你是一个友好的助手"),
    MessagesPlaceholder("history"),
    ("human", "{input}")
])

llm = ChatOpenAI(model="gpt-4o-mini")
chain = prompt | llm | StrOutputParser()

# 存储会话历史
store = {}

def get_session_history(session_id: str):
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]

# 包装链
chain_with_history = RunnableWithMessageHistory(
    chain,
    get_session_history,
    input_messages_key="input",
    history_messages_key="history"
)

# 使用
response1 = chain_with_history.invoke(
    {"input": "你好,我是张三"},
    config={"configurable": {"session_id": "user-123"}}
)

response2 = chain_with_history.invoke(
    {"input": "我叫什么名字?"},
    config={"configurable": {"session_id": "user-123"}}
)
# "你叫张三"

不同会话隔离

# 用户 A 的会话
response_a = chain_with_history.invoke(
    {"input": "我是用户A"},
    config={"configurable": {"session_id": "user-a"}}
)

# 用户 B 的会话(独立)
response_b = chain_with_history.invoke(
    {"input": "我是用户B"},
    config={"configurable": {"session_id": "user-b"}}
)

# 用户 A 继续对话
response_a2 = chain_with_history.invoke(
    {"input": "我是谁?"},
    config={"configurable": {"session_id": "user-a"}}
)
# "你是用户A"

持久化存储

文件存储

from langchain_community.chat_message_histories import FileChatMessageHistory

def get_session_history(session_id: str):
    return FileChatMessageHistory(f"chat_history/{session_id}.json")

chain_with_history = RunnableWithMessageHistory(
    chain,
    get_session_history,
    input_messages_key="input",
    history_messages_key="history"
)

Redis 存储

from langchain_community.chat_message_histories import RedisChatMessageHistory

REDIS_URL = "redis://localhost:6379"

def get_session_history(session_id: str):
    return RedisChatMessageHistory(
        session_id,
        url=REDIS_URL
    )

chain_with_history = RunnableWithMessageHistory(
    chain,
    get_session_history,
    input_messages_key="input",
    history_messages_key="history"
)

SQLite 存储

from langchain_community.chat_message_histories import SQLChatMessageHistory

def get_session_history(session_id: str):
    return SQLChatMessageHistory(
        session_id,
        connection="sqlite:///chat_history.db"
    )

历史消息管理

获取历史

history = get_session_history("user-123")

# 获取所有消息
messages = history.messages

for msg in messages:
    print(f"{msg.type}: {msg.content}")

添加消息

from langchain_core.messages import HumanMessage, AIMessage

history = get_session_history("user-123")

# 添加用户消息
history.add_user_message("你好")

# 添加 AI 消息
history.add_ai_message("你好!有什么可以帮助你的?")

# 添加任意消息
history.add_message(HumanMessage("继续"))

清除历史

history = get_session_history("user-123")
history.clear()

历史消息截断

限制消息数量

from langchain_core.messages import trim_messages

# 获取历史
history = get_session_history("user-123")

# 截断:保留最后 10 条消息
trimmed = trim_messages(
    history.messages,
    max_tokens=4000,       # 最大 token 数
    strategy="last",       # 保留最后
    token_counter=len,     # 简单计数器
    include_system=True,   # 保留系统消息
    allow_partial=False
)

使用 token 计数器

from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o-mini")

trimmed = trim_messages(
    history.messages,
    max_tokens=4000,
    strategy="last",
    token_counter=llm.get_num_tokens,  # 使用模型的 token 计数
    include_system=True
)

自定义记忆管理

滑动窗口记忆

class SlidingWindowMemory:
    def __init__(self, window_size: int = 10):
        self.window_size = window_size
        self.store = {}

    def get_history(self, session_id: str):
        if session_id not in self.store:
            self.store[session_id] = []
        return self.store[session_id]

    def add_message(self, session_id: str, message):
        history = self.get_history(session_id)
        history.append(message)

        # 保持窗口大小
        if len(history) > self.window_size:
            self.store[session_id] = history[-self.window_size:]

memory = SlidingWindowMemory(window_size=5)

摘要记忆

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate

class SummaryMemory:
    def __init__(self, llm, max_messages: int = 10):
        self.llm = llm
        self.max_messages = max_messages
        self.summaries = {}
        self.recent = {}

    def get_context(self, session_id: str):
        summary = self.summaries.get(session_id, "")
        recent = self.recent.get(session_id, [])
        return summary, recent

    def add_message(self, session_id: str, message):
        if session_id not in self.recent:
            self.recent[session_id] = []

        self.recent[session_id].append(message)

        # 超过限制时生成摘要
        if len(self.recent[session_id]) > self.max_messages:
            self._summarize(session_id)

    def _summarize(self, session_id: str):
        old_summary = self.summaries.get(session_id, "")
        messages = self.recent[session_id]

        prompt = ChatPromptTemplate.from_template(
            "请总结以下对话历史:\n旧摘要:{old_summary}\n新对话:{messages}"
        )

        chain = prompt | self.llm
        response = chain.invoke({
            "old_summary": old_summary,
            "messages": "\n".join(str(m) for m in messages)
        })

        self.summaries[session_id] = response.content
        self.recent[session_id] = []

LangGraph 状态管理(推荐)

对于复杂场景,推荐使用 LangGraph 的状态管理:

from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph, MessagesState
from langchain_openai import ChatOpenAI

# 定义状态
class State(MessagesState):
    pass

# 定义节点
llm = ChatOpenAI(model="gpt-4o-mini")

def chatbot(state: State):
    return {"messages": [llm.invoke(state["messages"])]}

# 构建图
graph = StateGraph(State)
graph.add_node("chatbot", chatbot)
graph.set_entry_point("chatbot")
graph.set_finish_point("chatbot")

# 添加记忆
checkpointer = MemorySaver()
app = graph.compile(checkpointer=checkpointer)

# 使用
config = {"configurable": {"thread_id": "user-123"}}

response1 = app.invoke(
    {"messages": [("user", "你好,我是张三")]},
    config
)

response2 = app.invoke(
    {"messages": [("user", "我叫什么名字?")]},
    config
)
# "你叫张三"

完整示例:带记忆的聊天机器人

import os
from dotenv import load_dotenv
load_dotenv()

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory

# 创建模型
llm = ChatOpenAI(model="gpt-4o-mini")

# 创建提示词
prompt = ChatPromptTemplate.from_messages([
    ("system", "你是一个友好的助手,回答简洁明了。"),
    MessagesPlaceholder("history"),
    ("human", "{input}")
])

# 创建链
chain = prompt | llm | StrOutputParser()

# 会话存储
store = {}

def get_session_history(session_id: str):
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]

# 包装链
chain_with_history = RunnableWithMessageHistory(
    chain,
    get_session_history,
    input_messages_key="input",
    history_messages_key="history"
)

# 聊天函数
def chat(user_input: str, session_id: str = "default") -> str:
    response = chain_with_history.invoke(
        {"input": user_input},
        config={"configurable": {"session_id": session_id}}
    )
    return response

# 测试
if __name__ == "__main__":
    print("聊天机器人(输入 'quit' 退出)")
    while True:
        user_input = input("你: ")
        if user_input.lower() == "quit":
            break
        response = chat(user_input)
        print(f"AI: {response}")

小结

本章学习了:

  • ✅ 消息历史基本管理
  • ✅ RunnableWithMessageHistory 包装器
  • ✅ 持久化存储(文件、Redis、SQLite)
  • ✅ 历史消息截断
  • ✅ 自定义记忆管理
  • ✅ LangGraph 状态管理

下一章

第七章:Agent 与工具 - 学习如何让 LLM 使用工具完成任务。