🐍 魔术方法 — 让对象像内置类型一样优雅
🕐 预计用时:2-3 小时 | 🎯目标:掌握 __str__/__repr__/__len__/__eq__/__lt__、运算符重载
📖 今日目录
- 比较运算:__eq__/__ne__/__lt__/__gt__/__le__/__ge__
- 算术运算:__add__/__sub__/__mul__/__truediv__
- 容器协议:__getitem__/__setitem__/__contains__
1. 什么是魔术方法?
魔术方法(Magic Methods)是 Python 中以双下划线开头和结尾的特殊方法——它们让自定义对象像内置类型一样工作。
# 魔术方法无处不在
print(len([1, 2, 3])) # __len__
print(1 + 2) # __add__
print([1, 2] == [1, 2]) # __eq__
print("hello" < "world") # __lt__
print(str(42)) # __str__
# 你也可以让自己的类支持这些操作!
| | |
|---|
__str__ | str(obj) | |
__repr__ | repr(obj) | |
__len__ | len(obj) | |
__eq__ | obj1 == obj2 | |
__add__ | obj1 + obj2 | |
__getitem__ | obj[key] | |
__call__ | obj(args) | |
2. 字符串表示:__str__ 与 __repr__
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
def __str__(self):
"""给用户看的:友好、易读"""
return f"({self.x}, {self.y})"
def __repr__(self):
"""给开发者看的:精确、可重建"""
return f"Point({self.x}, {self.y})"
p = Point(3, 4)
# print() 和 str() 调用 __str__
print(p) # (3, 4)
print(str(p)) # (3, 4)
# 交互式环境和 repr() 调用 __repr__
print(repr(p)) # Point(3, 4)
# 在列表中显示的是 __repr__
points = [Point(1, 2), Point(3, 4)]
print(points) # [Point(1, 2), Point(3, 4)]
# 没有 __str__ 时,回退到 __repr__
# 没有 __repr__ 时,显示 <__main__.Point object at 0x...>
💡 黄金法则:
__repr__ 返回的字符串应该能 eval() 重建对象(理想情况)。
__str__ 返回用户友好的显示。
如果只实现一个,优先实现 __repr__。
3. 长度:__len__
class ShoppingCart:
def __init__(self):
self.items = []
def add(self, name, price, quantity=1):
self.items.append({"name": name, "price": price, "quantity": quantity})
def __len__(self):
"""让 len() 支持自定义对象"""
return sum(item["quantity"] for item in self.items)
def __str__(self):
return f"购物车: {len(self)} 件商品, ¥{self.total:.2f}"
@property
def total(self):
return sum(item["price"] * item["quantity"] for item in self.items)
cart = ShoppingCart()
cart.add("苹果", 5.5, 3)
cart.add("牛奶", 12, 2)
cart.add("面包", 8, 1)
print(len(cart)) # 6(3+2+1 件)
print(cart) # 购物车: 6 件商品, ¥47.50
# len() 支持的前提是实现了 __len__
# if cart: ← 这也会用到 __len__(非零为 True)
4. 比较运算
实现比较魔术方法,让对象支持 ==、!=、<、>、<=、>= 运算。
from functools import total_ordering
@total_ordering # 只需实现 __eq__ 和 __lt__,自动生成其他
class Student:
def __init__(self, name, score):
self.name = name
self.score = score
def __eq__(self, other):
"""== 等于"""
if not isinstance(other, Student):
return NotImplemented
return self.score == other.score
def __lt__(self, other):
"""< 小于"""
if not isinstance(other, Student):
return NotImplemented
return self.score < other.score
# @total_ordering 自动生成以下方法:
# __ne__(!=)、__gt__(>)、__le__(<=)、__ge__(>=)
def __repr__(self):
return f"Student('{self.name}', {self.score})"
s1 = Student("张三", 85)
s2 = Student("李四", 92)
s3 = Student("王五", 85)
print(s1 == s3) # True(分数相同)
print(s1 < s2) # True(85 < 92)
print(s2 > s1) # True(92 > 85)
print(s1 <= s3) # True
print(s2 >= s1) # True
# 排序直接可用!
students = [s1, s2, s3]
print(sorted(students)) # [Student('张三',85), Student('王五',85), Student('李四',92)]
💡 @total_ordering 装饰器:
只需实现 __eq__ 和 __lt__,自动生成 __ne__、__gt__、__le__、__ge__。
省时省力,强烈推荐!
5. 算术运算
class Vector:
"""二维向量类"""
def __init__(self, x, y):
self.x = x
self.y = y
def __add__(self, other):
"""+ 加法"""
return Vector(self.x + other.x, self.y + other.y)
def __sub__(self, other):
"""- 减法"""
return Vector(self.x - other.x, self.y - other.y)
def __mul__(self, scalar):
"""* 乘法(标量)"""
return Vector(self.x * scalar, self.y * scalar)
def __rmul__(self, scalar):
"""* 右乘(3 * v 也能用)"""
return self.__mul__(scalar)
def __neg__(self):
"""- 取负"""
return Vector(-self.x, -self.y)
def __abs__(self):
"""abs() 向量长度"""
return (self.x ** 2 + self.y ** 2) ** 0.5
def __repr__(self):
return f"Vector({self.x}, {self.y})"
v1 = Vector(3, 4)
v2 = Vector(1, 2)
print(v1 + v2) # Vector(4, 6)
print(v1 - v2) # Vector(2, 2)
print(v1 * 3) # Vector(9, 12)
print(3 * v1) # Vector(9, 12)(__rmul__)
print(-v1) # Vector(-3, -4)
print(abs(v1)) # 5.0(勾股定理)
📋 常用算术魔术方法
| | |
|---|
__add__ | | a + b |
__sub__ | | a - b |
__mul__ | | a * b |
__truediv__ | | a / b |
__floordiv__ | | a // b |
__mod__ | | a % b |
__pow__ | | a ** b |
__neg__ | | -a |
__abs__ | | abs(a) |
6. 容器协议
class Playlist:
"""播放列表:像列表一样使用"""
def __init__(self, name):
self.name = name
self.songs = []
def add(self, song):
self.songs.append(song)
def __getitem__(self, index):
"""支持索引访问:playlist[0]"""
return self.songs[index]
def __setitem__(self, index, value):
"""支持索引赋值:playlist[0] = '新歌'"""
self.songs[index] = value
def __len__(self):
"""支持 len()"""
return len(self.songs)
def __contains__(self, item):
"""支持 in 运算符"""
return item in self.songs
def __iter__(self):
"""支持 for 循环"""
return iter(self.songs)
def __str__(self):
return f"🎵 {self.name}: {len(self)} 首歌"
playlist = Playlist("我的最爱")
playlist.add("晴天")
playlist.add("七里香")
playlist.add("稻香")
# 索引访问
print(playlist[0]) # 晴天
print(playlist[-1]) # 稻香
# 索引赋值
playlist[0] = "青花瓷"
print(playlist[0]) # 青花瓷
# in 运算符
print("七里香" in playlist) # True
print("双截棍" in playlist) # False
# for 循环
for song in playlist:
print(f" 🎵 {song}")
# len()
print(len(playlist)) # 3
💡 容器协议四件套:
__getitem__ — 索引/键访问
__setitem__ — 索引/键赋值
__contains__ — in 运算符
__iter__ — for 循环遍历
实现这四个,你的对象就像列表/字典一样好用!
7. __call__:让对象像函数一样调用
class Multiplier:
"""可调用的乘法器"""
def __init__(self, factor):
self.factor = factor
def __call__(self, x):
return x * self.factor
double = Multiplier(2)
triple = Multiplier(3)
print(double(5)) # 10
print(triple(5)) # 15
# 检查对象是否可调用
print(callable(double)) # True
print(callable(triple)) # True
# 用途:函数工厂、缓存、装饰器
class Cache:
"""简易缓存"""
def __init__(self):
self.data = {}
def __call__(self, func):
def wrapper(*args):
if args not in self.data:
self.data[args] = func(*args)
return self.data[args]
return wrapper
cache = Cache()
@cache
def expensive_calc(n):
print(f" 计算 {n}...")
return n ** 2
print(expensive_calc(5)) # 计算 5... → 25
print(expensive_calc(5)) # 25(直接返回缓存,不计算)
8. 上下文管理:__enter__ / __exit__
class Timer:
"""计时器上下文管理器"""
def __init__(self, label=""):
self.label = label
def __enter__(self):
import time
self.start = time.time()
print(f"⏱️ 开始: {self.label}")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
import time
self.elapsed = time.time() - self.start
print(f"⏱️ 结束: {self.label} ({self.elapsed:.4f}秒)")
return False # 不抑制异常
# 使用 with 语句
with Timer("排序测试"):
data = sorted(range(100000, 0, -1))
with Timer("求和测试"):
total = sum(range(1000000))
# 也可以用类的实例
class DBConnection:
"""模拟数据库连接"""
def __init__(self, host):
self.host = host
def __enter__(self):
print(f"🔗 连接数据库: {self.host}")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
print(f"🔌 断开数据库: {self.host}")
return False
def query(self, sql):
print(f" 📝 执行: {sql}")
with DBConnection("localhost") as db:
db.query("SELECT * FROM users")
db.query("INSERT INTO logs ...")
# 自动断开连接
9. 实战练习
🎯 练习 1:Matrix 矩阵类(完整运算符重载)
class Matrix:
def __init__(self, data):
self.data = [row[:] for row in data]
self.rows = len(data)
self.cols = len(data[0]) if data else 0
def __repr__(self):
return f"Matrix({self.data})"
def __str__(self):
max_len = max(len(str(x)) for row in self.data for x in row)
lines = []
for row in self.data:
line = " ".join(f"{x:>{max_len}}" for x in row)
lines.append(f"| {line} |")
return "\n".join(lines)
def __eq__(self, other):
return self.data == other.data
def __getitem__(self, pos):
row, col = pos
return self.data[row][col]
def __setitem__(self, pos, value):
row, col = pos
self.data[row][col] = value
def __add__(self, other):
if self.rows != other.rows or self.cols != other.cols:
raise ValueError("矩阵尺寸不匹配")
result = [
[self.data[i][j] + other.data[i][j] for j in range(self.cols)]
for i in range(self.rows)
]
return Matrix(result)
def __sub__(self, other):
if self.rows != other.rows or self.cols != other.cols:
raise ValueError("矩阵尺寸不匹配")
result = [
[self.data[i][j] - other.data[i][j] for j in range(self.cols)]
for i in range(self.rows)
]
return Matrix(result)
def __mul__(self, other):
if isinstance(other, (int, float)):
# 标量乘法
result = [[self.data[i][j] * other for j in range(self.cols)] for i in range(self.rows)]
return Matrix(result)
elif isinstance(other, Matrix):
# 矩阵乘法
if self.cols != other.rows:
raise ValueError(f"无法相乘: {self.rows}x{self.cols} * {other.rows}x{other.cols}")
result = [
[sum(self.data[i][k] * other.data[k][j] for k in range(self.cols)) for j in range(other.cols)]
for i in range(self.rows)
]
return Matrix(result)
return NotImplemented
def __rmul__(self, scalar):
return self.__mul__(scalar)
def __neg__(self):
return self * -1
def __len__(self):
return self.rows * self.cols
@property
def T(self):
"""转置"""
return Matrix([[self.data[j][i] for j in range(self.rows)] for i in range(self.cols)])
# 测试
A = Matrix([[1, 2], [3, 4]])
B = Matrix([[5, 6], [7, 8]])
print("A =")
print(A)
print("\nB =")
print(B)
print("\nA + B =")
print(A + B)
print("\nA * B (矩阵乘法) =")
print(A * B)
print("\nA * 3 (标量乘法) =")
print(A * 3)
print("\nA.T (转置) =")
print(A.T)
print(f"\nA[1][0] = {A[1, 0]}")
print(f"len(A) = {len(A)}")
🎯 练习 2:Money 货币类
from functools import total_ordering
@total_ordering
class Money:
"""货币类:支持运算和比较"""
EXCHANGE_RATES = {
("USD", "CNY"): 7.24,
("CNY", "USD"): 1 / 7.24,
("EUR", "CNY"): 7.89,
("CNY", "EUR"): 1 / 7.89,
("USD", "EUR"): 0.92,
("EUR", "USD"): 1 / 0.92,
}
def __init__(self, amount, currency="CNY"):
self.amount = round(amount, 2)
self.currency = currency
def _convert(self, other):
"""统一货币后比较"""
if self.currency == other.currency:
return self.amount, other.amount
key = (self.currency, other.currency)
if key in self.EXCHANGE_RATES:
return self.amount, round(other.amount * self.EXCHANGE_RATES[key], 2)
raise ValueError(f"不支持的转换: {key}")
def __eq__(self, other):
if not isinstance(other, Money):
return NotImplemented
s, o = self._convert(other)
return s == o
def __lt__(self, other):
if not isinstance(other, Money):
return NotImplemented
s, o = self._convert(other)
return s < o
def __add__(self, other):
if not isinstance(other, Money):
return NotImplemented
if self.currency == other.currency:
return Money(self.amount + other.amount, self.currency)
key = (other.currency, self.currency)
converted = round(other.amount * self.EXCHANGE_RATES.get(key, 0), 2)
return Money(self.amount + converted, self.currency)
def __sub__(self, other):
if not isinstance(other, Money):
return NotImplemented
if self.currency == other.currency:
return Money(self.amount - other.amount, self.currency)
key = (other.currency, self.currency)
converted = round(other.amount * self.EXCHANGE_RATES.get(key, 0), 2)
return Money(self.amount - converted, self.currency)
def __mul__(self, scalar):
return Money(self.amount * scalar, self.currency)
def __rmul__(self, scalar):
return self.__mul__(scalar)
def __neg__(self):
return Money(-self.amount, self.currency)
def __abs__(self):
return Money(abs(self.amount), self.currency)
def __repr__(self):
return f"Money({self.amount}, '{self.currency}')"
def __str__(self):
symbols = {"CNY": "¥", "USD": "$", "EUR": "€"}
symbol = symbols.get(self.currency, self.currency + " ")
return f"{symbol}{self.amount:,.2f}"
# 测试
price1 = Money(100, "CNY")
price2 = Money(15, "USD")
price3 = Money(12, "EUR")
print(f"价格1: {price1}") # ¥100.00
print(f"价格2: {price2}") # $15.00
print(f"价格1 + 价格2: {price1 + price2}") # ¥208.60
print(f"价格1 > 价格2: {price1 > price2}") # True
print(f"3倍价格1: {3 * price1}") # ¥300.00
# 排序
prices = [price1, price2, price3, Money(50, "CNY")]
for p in sorted(prices):
print(f" {p}")
10. 今日小结
| | |
|---|
| __str__ | print() |
| __len__ | len(obj) |
| __eq__ | == |
| __add__ | + |
| __getitem__ / __contains__ / __iter__ | [] |
| __call__ | obj() |
| __enter__ | with obj |
🧠 记忆口诀:
双下划线魔术法,对象秒变内置家。
str 给人看,repr 给码看。
len 算长度,eq 判相等。
add 加 sub 减,mul 乘 truediv 除。
getitem 像列表,call 像函数。
enter exit with 用,total_ordering 省代码。
🔮 预告: Day 25 综合练习 — 🎯 项目 1:学生管理系统(OOP 架构、文件存储、增删改查菜单)。把 Day21-Day24 学的全部用起来!