一、Pickle 序列化概述
1. 什么是 Pickle?
Pickle 是 Python 内置的序列化模块,可以将 Python 对象转换为字节流,并能从字节流重建对象。
2. Pickle 协议方法
| | |
|---|
__getstate__ | | |
__setstate__ | | |
__reduce__ | | |
__reduce_ex__ | | |
3. 基本使用
import pickleclass SimpleClass: def __init__(self, name, value): self.name = name self.value = value def __repr__(self): return f"SimpleClass(name={self.name}, value={self.value})"# 序列化obj = SimpleClass("test", 42)data = pickle.dumps(obj)print(f"序列化数据: {data[:50]}...")# 反序列化obj2 = pickle.loads(data)print(f"恢复对象: {obj2}")
二、__getstate__ 和 __setstate__ 详解
1. 基本用法
import pickleclass Person: """人员类,演示基本的序列化控制""" def __init__(self, name, age, password): self.name = name self.age = age self.password = password # 敏感信息 self._cache = {} # 不需要序列化的缓存 def __getstate__(self): """返回要序列化的状态""" print("调用 __getstate__") # 复制当前状态 state = self.__dict__.copy() # 移除敏感信息 if 'password' in state: del state['password'] # 移除缓存 if '_cache' in state: del state['_cache'] # 可以添加额外信息 state['_version'] = 1 return state def __setstate__(self, state): """从状态恢复对象""" print("调用 __setstate__") # 处理版本兼容 if '_version' not in state: # 旧版本数据 state['_version'] = 0 state['_cache'] = {} # 恢复状态 self.__dict__.update(state) # 如果没有密码,设置默认值 if not hasattr(self, 'password'): self.password = None# 使用person = Person("Alice", 30, "secret123")person._cache['temp'] = "缓存数据"print("原始对象:")print(f" name: {person.name}")print(f" age: {person.age}")print(f" password: {person.password}")print(f" cache: {person._cache}")# 序列化data = pickle.dumps(person)print(f"\n序列化数据大小: {len(data)} 字节")# 反序列化person2 = pickle.loads(data)print("\n恢复的对象:")print(f" name: {person2.name}")print(f" age: {person2.age}")print(f" password: {person2.password}")print(f" cache: {getattr(person2, '_cache', '不存在')}")
2. 处理不可序列化对象
import pickleimport threadingimport socketclass NetworkConnection: """网络连接类,包含不可序列化的资源""" def __init__(self, host, port): self.host = host self.port = port self.socket = None self.lock = threading.Lock() self.connected = False self._connect() def _connect(self): """建立连接""" print(f"连接到 {self.host}:{self.port}") self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # 实际连接代码... self.connected = True def _disconnect(self): """断开连接""" if self.socket: print("断开连接") self.socket.close() self.socket = None self.connected = False def __getstate__(self): """序列化时保存连接信息""" print("序列化网络连接") # 断开连接 self._disconnect() # 返回可序列化的状态 return { 'host': self.host, 'port': self.port, 'connected': False # 连接状态不保存 } def __setstate__(self, state): """反序列化时恢复连接""" print("反序列化网络连接") # 恢复状态 self.__dict__.update(state) # 重新初始化不可序列化的属性 self.lock = threading.Lock() self.socket = None # 重新连接 self._connect() def send(self, data): """发送数据""" if not self.connected: raise RuntimeError("未连接") print(f"发送: {data}") def __repr__(self): return f"NetworkConnection({self.host}:{self.port}, connected={self.connected})"# 使用conn = NetworkConnection("localhost", 8080)print(f"原始连接: {conn}")# 序列化data = pickle.dumps(conn)print(f"\n序列化数据大小: {len(data)} 字节")# 反序列化conn2 = pickle.loads(data)print(f"\n恢复的连接: {conn2}")conn2.send("Hello")
3. 处理循环引用
import pickleclass Node: """树节点,演示循环引用处理""" def __init__(self, name): self.name = name self.parent = None self.children = [] def add_child(self, child): child.parent = self self.children.append(child) def __getstate__(self): """序列化时避免循环引用""" print(f"序列化节点: {self.name}") # 保存基本信息 state = { 'name': self.name, 'children': self.children # 子节点也会递归序列化 } # 不直接保存 parent,而是保存 parent 的名称(如果有) if self.parent: state['parent_name'] = self.parent.name return state def __setstate__(self, state): """反序列化时重建引用""" print(f"反序列化节点: {state['name']}") self.name = state['name'] self.children = state.get('children', []) # parent 会在之后通过名称查找设置 self.parent = None self._parent_name = state.get('parent_name') def restore_parents(self, node_map): """恢复父节点引用""" if hasattr(self, '_parent_name'): self.parent = node_map.get(self._parent_name) delattr(self, '_parent_name') for child in self.children: child.restore_parents(node_map) def __repr__(self): parent_name = self.parent.name if self.parent else "None" return f"Node({self.name}, parent={parent_name})"# 创建树结构root = Node("root")child1 = Node("child1")child2 = Node("child2")grandchild = Node("grandchild")root.add_child(child1)root.add_child(child2)child1.add_child(grandchild)print("原始树:")print(f" root: {root}")print(f" child1: {child1}")print(f" child2: {child2}")print(f" grandchild: {grandchild}")# 序列化data = pickle.dumps(root)# 反序列化root2 = pickle.loads(data)# 重建引用关系node_map = {}def collect_nodes(node): node_map[node.name] = node for child in node.children: collect_nodes(child)collect_nodes(root2)root2.restore_parents(node_map)print("\n恢复的树:")print(f" root: {root2}")print(f" child1: {root2.children[0]}")print(f" child2: {root2.children[1]}")print(f" grandchild: {root2.children[0].children[0]}")
三、__reduce__ 和 __reduce_ex__ 详解
1. 基本用法
import pickleclass CustomReduce: """自定义 reduce 方法的类""" def __init__(self, value): self.value = value self.processed = False def __reduce__(self): """ 返回一个元组 (callable, arguments, state, ...) 必须的: (callable, arguments) 可选的: (callable, arguments, state) (callable, arguments, state, iterator) (callable, arguments, state, iterator, list) """ print("调用 __reduce__") # 返回重建函数和参数 return (self._reconstructor, (self.value * 2,), self.__dict__) @staticmethod def _reconstructor(value): """重建对象的函数""" obj = CustomReduce(value) obj.processed = True return obj def __repr__(self): return f"CustomReduce(value={self.value}, processed={self.processed})"class ReduceWithState: """带状态的 reduce""" def __init__(self, name, data): self.name = name self.data = data self.temp = None def __reduce__(self): print("调用 __reduce__") # 返回重建函数、参数和状态 return (self.__class__, (self.name,), {'data': self.data}) def __setstate__(self, state): print("调用 __setstate__") self.__dict__.update(state) self.temp = "从状态恢复"# 使用obj = CustomReduce(10)print(f"原始: {obj}")data = pickle.dumps(obj)obj2 = pickle.loads(data)print(f"恢复: {obj2}")print("\n=== Reduce with State ===")obj3 = ReduceWithState("test", [1, 2, 3])data2 = pickle.dumps(obj3)obj4 = pickle.loads(data2)print(f"恢复: {obj4.name}, {obj4.data}, {obj4.temp}")
2. __reduce_ex__ 和协议版本
import pickleimport sysclass VersionedClass: """支持多协议版本的类""" def __init__(self, name, version=1): self.name = name self.version = version self.data = {} def __reduce_ex__(self, protocol): """ protocol: pickle 协议版本 (0, 1, 2, 3, 4, 5) 返回格式同 __reduce__ """ print(f"调用 __reduce_ex__ with protocol={protocol}") if protocol >= 4: # 新版协议,可以使用更高效的方法 return (self._reconstruct_v4, (self.name, self.version), self.__dict__) else: # 旧版协议,保持兼容 return (self._reconstruct_v3, (self.name,), self.__dict__) @staticmethod def _reconstruct_v4(name, version): """新版重建""" obj = VersionedClass(name, version) obj.migrated = True return obj @staticmethod def _reconstruct_v3(name): """旧版重建""" obj = VersionedClass(name, 1) obj.migrated = False return obj def __repr__(self): return f"VersionedClass(name={self.name}, version={self.version})"# 测试不同协议版本obj = VersionedClass("test", 2)for protocol in range(pickle.HIGHEST_PROTOCOL + 1): print(f"\n协议 {protocol}:") try: data = pickle.dumps(obj, protocol=protocol) obj2 = pickle.loads(data) print(f" 恢复: {obj2}") print(f" 数据大小: {len(data)} 字节") except Exception as e: print(f" 错误: {e}")
3. 实现自定义序列化
import pickleimport zlibimport base64class CompressedObject: """压缩序列化的对象""" def __init__(self, data): self.data = data self.compressed = False def __reduce__(self): """自定义序列化,压缩数据""" print("压缩数据...") # 压缩数据 compressed = zlib.compress(repr(self.data).encode()) encoded = base64.b64encode(compressed) # 返回重建函数和压缩后的数据 return (self._decompress, (encoded,)) @staticmethod def _decompress(encoded): """解压数据重建对象""" print("解压数据...") compressed = base64.b64decode(encoded) data_str = zlib.decompress(compressed).decode() return CompressedObject(eval(data_str)) def __repr__(self): return f"CompressedObject(data={self.data})"# 测试压缩效果big_data = list(range(10000))obj = CompressedObject(big_data)# 普通序列化normal_data = pickle.dumps(big_data)print(f"普通序列化大小: {len(normal_data)} 字节")# 压缩序列化compressed_data = pickle.dumps(obj)print(f"压缩序列化大小: {len(compressed_data)} 字节")print(f"压缩比: {len(compressed_data)/len(normal_data)*100:.1f}%")# 恢复obj2 = pickle.loads(compressed_data)print(f"恢复的数据长度: {len(obj2.data)}")
四、最佳实践和注意事项
1. 序列化方法选择指南
| | |
|---|
| | |
| __getstate__ | |
| __getstate__ | |
| __reduce__ | |
| __reduce_ex__ | |
2. 安全注意事项
class SecurityBestPractices: """安全最佳实践""" @staticmethod def never_untrusted_pickle(): """永远不要反序列化不可信数据""" # 危险 # data = receive_from_network() # obj = pickle.loads(data) # 可能执行恶意代码 @staticmethod def use_restricted_unpickler(): """使用受限的反序列化器""" class SafeUnpickler(pickle.Unpickler): def find_class(self, module, name): # 白名单检查 allowed = { ('builtins', 'dict'), ('builtins', 'list'), ('builtins', 'str'), } if (module, name) not in allowed: raise pickle.UnpicklingError(f"禁止加载 {module}.{name}") return super().find_class(module, name) @staticmethod def sign_data(): """签名数据防止篡改""" import hmac import hashlib key = b'secret-key' data = pickle.dumps(obj) signature = hmac.new(key, data, hashlib.sha256).digest() return signature + dataclass PerformanceConsiderations: """性能考虑""" @staticmethod def use_protocol(): """使用高协议版本""" # 协议4/5更快,更紧凑 data = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) @staticmethod def avoid_circular(): """避免循环引用""" # 循环引用会导致序列化变慢 # 使用弱引用或ID引用 @staticmethod def compress_large(): """压缩大对象""" import zlib data = pickle.dumps(obj) compressed = zlib.compress(data)
3. 常见陷阱
# 陷阱1:忘记处理不可序列化对象class BadClass: def __init__(self): self.socket = socket.socket() # 不可序列化 # 没有 __getstate__,序列化会失败# 陷阱2:__reduce__ 返回不可调用对象class BadReduce: def __reduce__(self): return (None, ()) # 第一个元素必须是可调用对象# 陷阱3:修改对象后序列化class MutableObject: def __init__(self): self.data = [] def add(self, item): self.data.append(item)obj = MutableObject()obj.add(1)data = pickle.dumps(obj)obj.add(2) # 修改原对象obj2 = pickle.loads(data) # obj2 没有 2# 陷阱4:版本兼容问题class V1: def __init__(self): self.field1 = 1class V2: def __init__(self): self.field1 = 1 self.field2 = 2 # 新增字段 # 反序列化 V1 数据时会缺少 field2# 陷阱5:忘记调用父类方法class ChildClass(ParentClass): def __getstate__(self): state = self.__dict__.copy() # 应该调用父类的 __getstate__ # state.update(super().__getstate__()) return state
4. 总结要点
class PickleBestPractices: """Pickle 最佳实践总结""" # 1. 总是实现 __getstate__ 和 __setstate__ 处理不可序列化属性 # 2. 使用 __slots__ 减少内存和序列化大小 # 3. 处理版本兼容性 # 4. 永远不要反序列化不可信数据 # 5. 使用高协议版本提高性能 # 6. 考虑压缩大对象 # 7. 避免循环引用 # 8. 实现 __reduce__ 时确保返回可调用对象 class Example: """完整示例""" __slots__ = ['id', 'name', '_cache'] def __init__(self, id, name): self.id = id self.name = name self._cache = {} def __getstate__(self): state = { 'id': self.id, 'name': self.name, '_version': 1 } return state def __setstate__(self, state): self.id = state['id'] self.name = state['name'] self._cache = {} # 版本兼容 if state.get('_version', 1) == 1: # 迁移代码 pass def __reduce__(self): """自定义重建(如果需要)""" return (self.__class__, (self.id, self.name), self.__getstate__())
Pickle 是 Python 强大的序列化工具,正确使用需要理解其协议和工作原理。通过实现 __getstate__、__setstate__ 和 __reduce__ 方法,可以精确控制对象的序列化行为,处理各种复杂场景。