目标:让 AI Agent 记住对话历史,服务重启也不丢失
BaseCheckpointSaver,用 Redis 存储 | |
thread_id区分不同用户的对话 | |
Checkpointer(检查点器) 作用:保存 Agent 的运行状态,下次运行时恢复
类比:游戏存档
thread_id | |
checkpoint | |
TTL | |
BaseCheckpointSaver |
架构图:
用户请求 ↓LangGraph Agent(StateGraph) ├─ 自动调用 checkpointer.get() 恢复历史 ├─ 执行 agent 节点(LLM 调用) └─ 自动调用 checkpointer.put() 保存状态 ↓Redis Checkpointer ├─ Redis Hash 存储(Key = langgraph:checkpoint:{thread_id}) ├─ pickle 序列化状态快照 ├─ TTL = 7 天自动过期 └─ 最多保留 50 个快照
调用流程图:

classAgentState(TypedDict): messages: Annotated[list[BaseMessage], add_messages]解释:定义 Agent 的状态结构,messages是消息列表,add_messages是累加器(每次新消息追加到列表)。
classRedisCheckpointSaver(BaseCheckpointSaver):def__init__(self, redis_url, ttl_seconds=604800): self.client = redis_lib.Redis.from_url(redis_url, decode_responses=False) self.ttl = ttl_seconds # 默认 7 天过期解释:初始化 Redis 连接,decode_responses=False 表示返回原始字节(因为用 pickle 序列化)。
def_make_key(self, thread_id: str) -> str:returnf"langgraph:checkpoint:{thread_id}"解释:每个用户的对话存在独立的 Redis key 下,实现数据隔离。
get()defget(self, config: dict) -> Optional[CheckpointTuple]: thread_id = config.get("configurable", {}).get("thread_id", "default") key = self._make_key(thread_id) all_fields = self.client.hgetall(key) # 获取该用户所有快照ifnot all_fields:returnNone# 按 checkpoint_id 降序,取最新 latest_id = sorted(all_fields.keys(), reverse=True)[0] raw = all_fields[latest_id] data = self._deserialize(raw) # pickle 反序列化return CheckpointTuple(config=..., checkpoint=..., metadata=...)解释:
thread_id找到 Redis keyHGETALLpickle.loadsCheckpointTupleput()defput(self, config, checkpoint, metadata, new_versions) -> dict: thread_id = config.get("configurable", {}).get("thread_id", "default") key = self._make_key(thread_id) checkpoint_id = checkpoint.get("id") data = {"channel_values": checkpoint.get("channel_values", {}),"metadata": metadata, ... } self.client.hset(key, checkpoint_id, self._serialize(data)) # pickle 序列化后存储 self.client.expire(key, self.ttl) # 刷新 TTL# 只保留最近 50 个快照iflen(all_fields) > 50: oldest_fields = sorted(all_fields)[:-50] self.client.hdel(key, *oldest_fields)解释:
pickle.dumpsHSET存储list()deflist(self, config: dict) -> Iterator[CheckpointTuple]: thread_id = config.get("configurable", {}).get("thread_id", "default") key = self._make_key(thread_id) all_fields = self.client.hgetall(key)for field_id insorted(all_fields.keys(), reverse=True):yield CheckpointTuple(...) # 用 yield 返回迭代器解释:返回所有历史快照的迭代器(而非一次性加载全部),节省内存。
adelete()asyncdefadelete(self, config: dict) -> None: thread_id = config.get("configurable", {}).get("thread_id", "default") key = self._make_key(thread_id) result = self.client.delete(key) # 删除整个 key解释:删除用户的所有对话数据,用于账号注销或测试清理。
defbuild_agent_with_checkpointer(checkpointer):defcall_model(state: AgentState) -> dict: messages = state["messages"] response = llm.invoke(messages)return {"messages": [response]} graph = StateGraph(AgentState) graph.add_node("agent", call_model) graph.add_edge(START, "agent") graph.add_edge("agent", END) compiled = graph.compile(checkpointer=checkpointer) # ← 关键:注入 checkpointerreturn compiled解释:
compile(checkpointer=...)config = {"configurable": {"thread_id": "user_001"}}# 第一次对话result = agent.invoke( {"messages": [HumanMessage(content="你好")]}, config=config)# LangGraph 自动调用 checkpointer.put() 保存状态# 第二次对话(自动带上历史)result = agent.invoke( {"messages": [HumanMessage(content="还记得我吗?")]}, config=config)# LangGraph 自动调用 checkpointer.get() 恢复历史一句话总结 Checkpointer 让 Agent 记住对话,Redis 实现持久化,thread_id 实现多用户隔离。
| 为什么需要 Checkpointer | |
| 为什么用 Redis | |
| thread_id 的作用 | |
| pickle vs JSON | |
| TTL 过期 | |
| 快照数量限制 |
Q: LangGraph 状态存在哪?服务重启怎么办?
A: 用 Checkpointer 接口,开发用 MemorySaver(重启丢),生产用 RedisCheckpointSaver。Redis 方案:序列化 State → pickle → 存 Redis Hash,key =
checkpoint:{thread_id}:{checkpoint_id},设置 TTL 自动过期。
"""实现==============================目标:自定义 LangGraph BaseCheckpointSaver,用 Redis 做后端存储特性:TTL 自动过期 / 序列化快照 / 多 checkpoint 版本管理依赖:pip install redis langgraph langchain-openai"""import sysimport ioimport osimport jsonimport pickleimport asynciofrom typing import Any, Iterator, Optional, Tuplefrom datetime import datetimefrom langgraph.checkpoint.base import BaseCheckpointSaver, Checkpoint, CheckpointMetadata, CheckpointTuplesys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')# ─── 尝试导入 Redis ──────────────────────────────────────────────────────────try:import redis as redis_libREDIS_AVAILABLE = Trueexcept ImportError:REDIS_AVAILABLE = Falseprint("[WARN] redis 未安装,将使用 MemorySaver 演示")from typing import Annotatedfrom typing_extensions import TypedDictfrom langgraph.graph import StateGraph, START, ENDfrom langgraph.graph.message import add_messagesfrom langgraph.checkpoint.memory import MemorySaverfrom langchain_core.messages import HumanMessage, AIMessage, BaseMessagefrom dotenv import load_dotenvload_dotenv()# ─── 1. 自定义 Redis Checkpointer ────────────────────────────────────────────class RedisCheckpointSaver(BaseCheckpointSaver):"""基于 Redis 的 LangGraph Checkpointer存储结构:redis hash key: checkpoint:{thread_id}hash field: {checkpoint_id}hash value: pickle({"messages": [...], "metadata": {...}})TTL:默认 7 天(604800 秒)"""def __init__(self, redis_url: str = "redis://localhost:6379", ttl_seconds: int = 604800):if not REDIS_AVAILABLE:raise ImportError("请先 pip install redis")self.client = redis_lib.Redis.from_url(redis_url, decode_responses=False)self.ttl = ttl_secondsprint(f"[Redis] 已连接:{redis_url}")def _make_key(self, thread_id: str) -> str:return f"langgraph:checkpoint:{thread_id}"def _serialize(self, data: Any) -> bytes:return pickle.dumps(data)def _deserialize(self, raw: bytes) -> Any:return pickle.loads(raw)# ─── 核心接口:get(读取最新快照)def get(self, config: dict) -> Optional[CheckpointTuple]:"""读取最新 checkpoint,返回 CheckpointTuple(state, metadata, config)"""thread_id = config.get("configurable", {}).get("thread_id", "default")checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns", "")key = self._make_key(thread_id)# 获取所有 checkpoint,取最新(按 checkpoint_id 排序)all_fields = self.client.hgetall(key)if not all_fields:return None# 按 checkpoint_id 降序,取最新latest_id = sorted(all_fields.keys(), reverse=True)[0]raw = all_fields[latest_id]data = self._deserialize(raw)# checkpoint_id 应该是字符串 UUID 格式checkpoint_id_str = latest_id.decode() if isinstance(latest_id, bytes) else latest_id# 构建 Checkpoint(TypedDict 格式)checkpoint = {"v": data.get("v", 2),"id": checkpoint_id_str,"ts": data.get("ts", datetime.now().isoformat()),"channel_values": data.get("channel_values", {}),"channel_versions": data.get("channel_versions", {}),"versions_seen": data.get("versions_seen", {}),"updated_channels": data.get("updated_channels", {}),}# 构建 metadatametadata_dict = data.get("metadata", {})metadata = CheckpointMetadata(**metadata_dict) if isinstance(metadata_dict, dict) else CheckpointMetadata()new_config = {"configurable": {"thread_id": thread_id,"checkpoint_ns": checkpoint_ns,"checkpoint_id": checkpoint_id_str,}}return CheckpointTuple(config=new_config,checkpoint=checkpoint,metadata=metadata)def get_tuple(self, config: dict) -> Optional[CheckpointTuple]:"""get_tuple 接口(LangGraph 调用)"""return self.get(config)# ─── 核心接口:put(写入快照)def put(self,config: dict,checkpoint: Checkpoint,metadata: CheckpointMetadata,new_versions: dict) -> dict:"""写入 checkpoint,返回新的 config"""thread_id = config.get("configurable", {}).get("thread_id", "default")checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns", "")key = self._make_key(thread_id)# checkpoint_id 是字符串 UUID 格式checkpoint_id = checkpoint.get("id") or ""# 保存完整的 Checkpoint 字段data = {"v": checkpoint.get("v", 2),"id": checkpoint_id,"ts": checkpoint.get("ts", datetime.now().isoformat()),"channel_values": checkpoint.get("channel_values", {}),"channel_versions": checkpoint.get("channel_versions", {}),"versions_seen": checkpoint.get("versions_seen", {}),"updated_channels": checkpoint.get("updated_channels", {}),"metadata": metadata if hasattr(metadata, '__dict__') else dict(metadata or {}),}self.client.hset(key, checkpoint_id, self._serialize(data))self.client.expire(key, self.ttl) # 刷新 TTL# 只保留最近 50 个 checkpoint,防止无限增长all_fields = self.client.hkeys(key)if len(all_fields) > 50:oldest_fields = sorted(all_fields)[:-50]if oldest_fields:self.client.hdel(key, *oldest_fields)return {"configurable": {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns, "checkpoint_id": checkpoint_id}}def put_writes(self,config: dict,writes: list,task_id: str,task_path: str = "") -> None:"""写入 pending writes(LangGraph 内部使用)"""# 对于简单的实现,这里可以忽略# 复杂实现需要将这些写入存储起来pass# ─── 列出所有历史快照def list(self, config: dict) -> Iterator[CheckpointTuple]:"""列出 thread 的所有 checkpoint,返回迭代器"""thread_id = config.get("configurable", {}).get("thread_id", "default")checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns", "")key = self._make_key(thread_id)all_fields = self.client.hgetall(key)for field_id in sorted(all_fields.keys(), reverse=True):# 转换为字符串field_id_str = field_id.decode() if isinstance(field_id, bytes) else field_idraw = all_fields[field_id]data = self._deserialize(raw)# 构建 Checkpoint(TypedDict 格式)checkpoint = {"v": data.get("v", 2),"id": field_id_str,"ts": data.get("ts", datetime.now().isoformat()),"channel_values": data.get("channel_values", {}),"channel_versions": data.get("channel_versions", {}),"versions_seen": data.get("versions_seen", {}),"updated_channels": data.get("updated_channels", {}),}# 构建 metadatametadata_dict = data.get("metadata", {})metadata = CheckpointMetadata(**metadata_dict) if isinstance(metadata_dict, dict) else CheckpointMetadata()new_config = {"configurable": {"thread_id": thread_id,"checkpoint_ns": checkpoint_ns,"checkpoint_id": field_id_str}}yield CheckpointTuple(config=new_config,checkpoint=checkpoint,metadata=metadata)# ─── 删除用户所有数据async def adelete(self, config: dict) -> None:"""删除用户的所有 checkpoint(用于账号注销 / 测试清理)"""thread_id = config.get("configurable", {}).get("thread_id", "default")key = self._make_key(thread_id)result = self.client.delete(key)print(f"[Redis] 已删除 thread_id={thread_id},删除条数={result}")# ─── 调试:打印用户状态摘要def inspect_thread(self, thread_id: str) -> None:"""打印用户的对话历史摘要"""config = {"configurable": {"thread_id": thread_id}}checkpoints = list(self.list(config))print(f"\n[Redis 状态检查] thread_id={thread_id}")print(f" checkpoint 数量:{len(checkpoints)}")if checkpoints:latest = checkpoints[0]state = latest.checkpoint.get("channel_values", {}) if hasattr(latest, 'checkpoint') else latest.get("state", {})messages = state.get("messages", []) if isinstance(state, dict) else []print(f" 最新消息数:{len(messages)}")for msg in messages[-3:]: # 只看最后 3 条if isinstance(msg, dict):role = msg.get("type", "?")content = str(msg.get("content", ""))[:60]else:role = type(msg).__name__content = str(getattr(msg, "content", ""))[:60]print(f" [{role}] {content}")# ─── 2. 简单 Agent(带 Redis Checkpointer)────────────────────────────────────class AgentState(TypedDict):messages: Annotated[list[BaseMessage], add_messages]def build_agent_with_checkpointer(checkpointer):"""构建带 checkpointer 的 LangGraph Agent(支持多轮对话记忆)"""# 尝试使用 llama.cpp,失败则回退到 Mockllm = Nonetry:from langchain_openai import ChatOpenAIllm = ChatOpenAI(base_url="http://localhost:8080/v1", api_key="not-needed", model="qwen3.5-4b-q4_k_m", temperature=0.7)print("[Agent] 使用 llama.cpp qwen3.5-4b-q4_k_m")except Exception as e:print(f"[Agent] llama.cpp 不可用({e}),使用 Mock LLM")def call_model(state: AgentState) -> dict:messages = state["messages"]if llm:import reresponse = llm.invoke(messages)# 去除 qwen3 的 <tool_call> 标签clean = re.sub(r"<tool_call>.*?<tool_call>", "", response.content, flags=re.DOTALL).strip()response.content = cleanreturn {"messages": [response]}else:# Mock 回复last_msg = messages[-1].content if messages else "..."return {"messages": [AIMessage(content=f"[Mock 回复] 我收到了:{last_msg[:30]}")]}graph = StateGraph(AgentState)graph.add_node("agent", call_model)graph.add_edge(START, "agent")graph.add_edge("agent", END)compiled = graph.compile(checkpointer=checkpointer)return compiled# ─── 3. 演示:Redis 多轮对话 ──────────────────────────────────────────────────def demo_redis_chat(agent, checkpointer, thread_id: str, messages: list[str]) -> None:"""模拟一个用户的多轮对话"""config = {"configurable": {"thread_id": thread_id}}print(f"\n{'='*50}")print(f"用户 {thread_id} 开始对话")print(f"{'='*50}")for msg in messages:print(f"\n[用户] {msg}")result = agent.invoke({"messages": [HumanMessage(content=msg)]},config=config)last_ai = result["messages"][-1].contentprint(f"[AI] {last_ai[:100]}")# 检查持久化状态if hasattr(checkpointer, "inspect_thread"):checkpointer.inspect_thread(thread_id)def demo_persistence(agent, checkpointer, thread_id: str) -> None:"""验证重启后状态恢复"""config = {"configurable": {"thread_id": thread_id}}print(f"\n--- 验证持久化:从 Redis 恢复 thread_id={thread_id} ---")saved_state = checkpointer.get(config)if saved_state:# CheckpointTuple.checkpoint['channel_values'] 包含实际状态channel_values = saved_state.checkpoint.get("channel_values", {})messages = channel_values.get("messages", [])print(f" 成功恢复 {len(messages)} 条消息记录")else:print(" 未找到持久化数据(可能是 Redis 未连接)")# ─── Main ────────────────────────────────────────────────────────────────────def main():print("=" * 60)print("T13-01: Redis Checkpointer 演示")print("=" * 60)# ── 选择 Checkpointerif REDIS_AVAILABLE:try:checkpointer = RedisCheckpointSaver("redis://localhost:6379", ttl_seconds=86400)# 测试连接checkpointer.client.ping()print("[OK] Redis 连接成功")except Exception as e:print(f"[WARN] Redis 连接失败({e}),降级使用 MemorySaver")checkpointer = MemorySaver()else:print("[INFO] 使用 MemorySaver(无 Redis)")checkpointer = MemorySaver()# ── 构建 Agentagent = build_agent_with_checkpointer(checkpointer)# ── 用户 A 对话demo_redis_chat(agent, checkpointer, "user_alice_001", ["你好,我叫 Alice,是一名前端工程师","我最感兴趣的技术是 React 和 TypeScript","你还记得我的名字和职业吗?",])# ── 用户 B 对话demo_redis_chat(agent, checkpointer, "user_bob_002", ["我是 Bob,后端工程师,Python 是我的主语言","我最近在学 FastAPI 和 LangGraph",])# ── 验证持久化demo_persistence(agent, checkpointer, "user_alice_001")# ── 如果是 Redis,展示存储统计if isinstance(checkpointer, RedisCheckpointSaver):alice_cps = checkpointer.list({"configurable": {"thread_id": "user_alice_001"}})print(f"\n[统计] Alice 共有 {len(alice_cps)} 个 checkpoint 快照存在 Redis 中")print("[演示完成] 实际部署时 Redis 数据将在重启后保留")if __name__ == "__main__":main()
