一、什么是生成器?
1. 生成器的定义
生成器(Generator) 是一种特殊的迭代器,它使用 yield 关键字而不是 return 来返回值。生成器函数在每次产生值时会暂停执行,下次调用时从暂停处继续执行。
2. 生成器 vs 普通函数 vs 迭代器
# 普通函数:一次性返回所有结果def normal_function(): result = [] for i in range(5): result.append(i) return result# 生成器函数:逐个产生结果def generator_function(): for i in range(5): yield i# 迭代器:需要实现 __iter__ 和 __next__class IteratorClass: def __init__(self): self.n = 0 def __iter__(self): return self def __next__(self): if self.n < 5: value = self.n self.n += 1 return value raise StopIteration# 使用对比print("普通函数:", normal_function())print("生成器:", list(generator_function()))print("迭代器:", list(IteratorClass()))
二、创建生成器
1. 生成器函数(yield)
def simple_generator(): """最简单的生成器""" yield 1 yield 2 yield 3# 使用gen = simple_generator()print(next(gen)) # 1print(next(gen)) # 2print(next(gen)) # 3# print(next(gen)) # StopIteration# for 循环自动处理 StopIterationfor value in simple_generator(): print(value, end=' ') # 1 2 3print()
2. 生成器表达式
# 列表推导式list_comp = [x**2 for x in range(10)]print(f"列表推导式: {list_comp}")print(f"类型: {type(list_comp)}")# 生成器表达式gen_exp = (x**2 for x in range(10))print(f"生成器表达式: {gen_exp}")print(f"类型: {type(gen_exp)}")# 逐个取值for value in gen_exp: print(value, end=' ') if value > 20: breakprint()# 内存对比import syslist_comp = [x for x in range(1000000)]gen_exp = (x for x in range(1000000))print(f"列表内存: {sys.getsizeof(list_comp) / 1024:.2f} KB")print(f"生成器内存: {sys.getsizeof(gen_exp) / 1024:.2f} KB")
3. yield 的工作原理
def generator_with_print(): """演示 yield 的执行流程""" print("生成器开始执行") yield 1 print("继续执行,准备 yield 2") yield 2 print("继续执行,准备 yield 3") yield 3 print("生成器结束")# 逐步执行gen = generator_with_print()print("第一次调用 next:")value = next(gen)print(f"得到: {value}\n")print("第二次调用 next:")value = next(gen)print(f"得到: {value}\n")print("第三次调用 next:")value = next(gen)print(f"得到: {value}\n")try: print("第四次调用 next:") value = next(gen)except StopIteration: print("生成器已耗尽")
三、生成器的高级特性
1. send() - 向生成器发送值
def echo_generator(): """接收并返回发送的值""" print("生成器启动") while True: received = yield print(f"收到: {received}")gen = echo_generator()next(gen) # 启动生成器gen.send("Hello") # 发送值gen.send("World")gen.send("!")# 实际应用:累加器def accumulator(): """累加器生成器""" total = 0 while True: value = yield total if value is not None: total += valueacc = accumulator()next(acc) # 启动print(acc.send(10)) # 10print(acc.send(20)) # 30print(acc.send(30)) # 60
2. throw() - 向生成器抛出异常
def generator_with_exception(): """处理异常的生成器""" try: yield 1 yield 2 yield 3 except ValueError as e: print(f"捕获到异常: {e}") yield "异常处理完成"gen = generator_with_exception()print(next(gen)) # 1print(next(gen)) # 2# 抛出异常result = gen.throw(ValueError("自定义错误"))print(result) # "异常处理完成"
3. close() - 关闭生成器
def closable_generator(): """可关闭的生成器""" try: yield 1 yield 2 yield 3 except GeneratorExit: print("生成器被关闭,执行清理") # 可以在这里做清理工作 # 注意:不能 yield,只能 returngen = closable_generator()print(next(gen)) # 1print(next(gen)) # 2gen.close() # 关闭生成器# print(next(gen)) # StopIteration
四、最佳实践和注意事项
1. 生成器设计模式
class GeneratorPatterns: """生成器设计模式""" @staticmethod def producer(): """生产者模式""" for i in range(5): print(f"生产: {i}") yield i @staticmethod def consumer(generator): """消费者模式""" for item in generator: print(f"消费: {item}") @staticmethod def pipeline(): """管道模式""" def stage1(): for i in range(5): yield i * 2 def stage2(input_gen): for item in input_gen: yield item + 1 # 构建管道 s1 = stage1() s2 = stage2(s1) return s2 @staticmethod def coroutine(): """协程模式""" def coroutine_example(): while True: received = yield print(f"处理: {received}") co = coroutine_example() next(co) # 启动 co.send("数据1") co.send("数据2")
2. 常见陷阱
# 陷阱1:生成器只能遍历一次def trap_once(): gen = (x for x in range(3)) print(list(gen)) # [0, 1, 2] print(list(gen)) # [] - 已经空了# 正确做法:需要多次遍历时使用列表def fix_once(): data = [x for x in range(3)] # 列表 print(data) print(data)# 陷阱2:在生成器内部修改外部变量def trap_modify(): x = 10 gen = (x for _ in range(3)) x = 20 # 不影响生成器 print(list(gen)) # [10, 10, 10]# 陷阱3:生成器表达式的作用域def trap_scope(): gen = (x for x in range(3)) x = 100 # 不影响生成器内的 x print(list(gen)) # [0, 1, 2]# 陷阱4:过早耗尽生成器def trap_exhaust(): def process(items): if items: # 这里会消耗生成器 return sum(items) return 0 gen = (x for x in range(5)) # print(process(gen)) # 错误:gen 被消耗 # print(list(gen)) # [] - 已经空了 # 正确做法 items = list(gen) # 先转换为列表 print(process(items))# 陷阱5:在生成器中使用递归没有 yield fromdef trap_recursion(): def flatten_wrong(nested): for item in nested: if isinstance(item, (list, tuple)): flatten_wrong(item) # 错误:没有 yield else: yield item def flatten_correct(nested): for item in nested: if isinstance(item, (list, tuple)): yield from flatten_correct(item) else: yield item nested = [1, [2, [3, 4], 5]] print(list(flatten_correct(nested))) # [1, 2, 3, 4, 5]
3. 性能优化技巧
class GeneratorOptimization: """生成器优化技巧""" @staticmethod def use_local_variables(): """使用局部变量加速""" def slow_generator(n): for i in range(n): yield i ** 2 def fast_generator(n): square = lambda x: x ** 2 # 局部变量 for i in range(n): yield square(i) # 性能对比 import time n = 1000000 start = time.perf_counter() sum(slow_generator(n)) slow_time = time.perf_counter() - start start = time.perf_counter() sum(fast_generator(n)) fast_time = time.perf_counter() - start print(f"普通: {slow_time:.3f}s") print(f"优化: {fast_time:.3f}s") print(f"提升: {(slow_time/fast_time - 1)*100:.1f}%") @staticmethod def avoid_yield_in_loops(): """避免在循环中yield""" # 不好的做法 def bad(): for i in range(10): yield i for i in range(10, 20): yield i # 好的做法:使用 yield from def good(): yield from range(10) yield from range(10, 20) @staticmethod def use_itertools(): """使用 itertools 优化""" from itertools import islice, chain, count # 手动实现 def take_manual(n, iterable): for i, x in enumerate(iterable): if i >= n: break yield x # 使用 itertools def take_itertools(n, iterable): yield from islice(iterable, n) data = range(1000000) print(list(take_itertools(5, data))) # [0, 1, 2, 3, 4]
总结
1. 生成器的优点
内存高效:一次只产生一个值
惰性计算:需要时才计算
无限序列:可以表示无限数据流
简洁代码:比实现迭代器类更简单
协程支持:可以实现协作式多任务
2. 适用场景
3. 选择指南
# 什么时候使用生成器?def generator_guide(): """生成器使用指南""" # ✅ 适合使用生成器 # 1. 处理大数据集 def process_large_file(): for line in open('large_file.txt'): yield process(line) # 2. 无限序列 def fibonacci(): a, b = 0, 1 while True: yield a a, b = b, a + b # 3. 数据管道 def pipeline(): data = (x for x in range(100)) data = (x**2 for x in data) data = (x for x in data if x % 2 == 0) return data # ❌ 不适合使用生成器 # 1. 需要随机访问 def bad_for_random(): gen = (x for x in range(10)) # 不能 gen[5] # 2. 需要多次遍历 def bad_for_multiple(): gen = (x for x in range(10)) # list(gen) 只能用一次 # 3. 数据量小且需要重复使用 def better_as_list(): data = [1, 2, 3, 4, 5] # 小数据集用列表
4. 最佳实践总结
class GeneratorBestPractices: """生成器最佳实践""" # 1. 使用有意义的名称 def generate_numbers(self): """生成数字""" yield from range(10) # 2. 提供文档字符串 def fibonacci(self): """生成斐波那契数列(无限)""" a, b = 0, 1 while True: yield a a, b = b, a + b # 3. 处理异常 def safe_generator(self): try: yield 1 yield 2 except GeneratorExit: print("清理资源") finally: print("确保清理") # 4. 使用类型提示 from typing import Generator def typed_generator(self, n: int) -> Generator[int, None, None]: """带类型提示的生成器""" for i in range(n): yield i # 5. 考虑使用 yield from def combined(self): yield from self.generate_numbers() yield from self.fibonacci() # 6. 避免副作用 def pure_generator(self, data): """纯生成器,没有副作用""" for item in data: yield item ** 2 # 不修改输入
生成器是 Python 中强大的特性,正确使用可以写出内存高效、可维护的代码。它们是处理大数据流、实现惰性计算和构建数据处理管道的理想工具。