一、异步迭代概述
1. 什么是异步迭代?
异步迭代允许在迭代过程中执行异步操作,每次获取下一个值时可以等待 I/O 操作完成。
2. 异步迭代协议
| | | |
|---|
__aiter__ | | AsyncIterator | __iter__ |
__anext__ | | awaitable | __next__ |
3. 异步可迭代对象 vs 异步迭代器
# 异步可迭代对象:实现了 __aiter__,返回异步迭代器class AsyncIterable: async def __aiter__(self): return AsyncIterator()# 异步迭代器:实现了 __aiter__ 和 __anext__class AsyncIterator: async def __aiter__(self): return self async def __anext__(self): # 返回下一个值或抛出 StopAsyncIteration pass
二、基础异步迭代实现
1. 最简单的异步迭代器
import asyncioclass SimpleAsyncIterator: """最简单的异步迭代器""" def __init__(self, start, end): self.current = start self.end = end def __aiter__(self): return self async def __anext__(self): if self.current >= self.end: raise StopAsyncIteration value = self.current self.current += 1 await asyncio.sleep(0.1) # 模拟异步操作 return valueasync def main(): async for num in SimpleAsyncIterator(1, 5): print(f"得到数字: {num}")asyncio.run(main())
2. 异步计数器
import asyncioimport timeclass AsyncCounter: """异步计数器""" def __init__(self, start=0, step=1, delay=0.5): self.value = start self.step = step self.delay = delay self.count = 0 def __aiter__(self): return self async def __anext__(self): """每次迭代返回下一个计数值""" self.count += 1 current = self.value self.value += self.step # 模拟异步操作 await asyncio.sleep(self.delay) # 限制迭代次数 if self.count > 5: raise StopAsyncIteration return currentasync def main(): print("开始异步计数:") async for num in AsyncCounter(start=10, step=2, delay=0.3): print(f" 计数: {num}") # 可以在循环中执行其他异步操作 await asyncio.sleep(0.1)asyncio.run(main())
三、异步生成器
1. 使用 async for 和 yield
import asyncioasync def async_range(start, end, delay=0.1): """异步生成器""" for i in range(start, end): await asyncio.sleep(delay) yield iasync def main(): print("异步范围:") async for num in async_range(1, 6, 0.2): print(f" {num}")asyncio.run(main())
2. 异步文件读取器
import asyncioimport aiofilesclass AsyncFileReader: """异步文件读取器""" def __init__(self, filename, chunk_size=1024): self.filename = filename self.chunk_size = chunk_size self.file = None async def __aenter__(self): self.file = await aiofiles.open(self.filename, 'rb') return self async def __aexit__(self, exc_type, exc_val, exc_tb): if self.file: await self.file.close() def __aiter__(self): return self async def __anext__(self): if not self.file: raise StopAsyncIteration chunk = await self.file.read(self.chunk_size) if not chunk: raise StopAsyncIteration return chunkasync def read_file_async(filename): """异步读取文件""" async with AsyncFileReader(filename, 64) as reader: chunk_count = 0 async for chunk in reader: chunk_count += 1 print(f"块 {chunk_count}: {len(chunk)} 字节") # 处理数据块 await asyncio.sleep(0.1)# 创建测试文件async def create_test_file(): async with aiofiles.open('test.txt', 'w') as f: await f.write('A' * 1000)async def main(): await create_test_file() await read_file_async('test.txt')asyncio.run(main())
四、实际应用场景
1. 异步数据流处理
import asyncioimport randomfrom typing import List, Anyclass AsyncDataStream: """异步数据流""" def __init__(self, source, batch_size=5): self.source = source self.batch_size = batch_size self.buffer = [] def __aiter__(self): return self async def __anext__(self): """返回一批数据""" if not self.buffer: # 加载下一批数据 self.buffer = await self._fetch_batch() if not self.buffer: raise StopAsyncIteration return self.buffer.pop(0) async def _fetch_batch(self): """模拟从数据源获取一批数据""" await asyncio.sleep(0.3) # 模拟网络延迟 batch = [f"数据{i}" for i in range(self.batch_size)] print(f"获取到一批数据: {batch}") return batchclass AsyncDataProcessor: """异步数据处理器""" def __init__(self, stream: AsyncDataStream): self.stream = stream self.processed_count = 0 self.errors = [] async def process(self): """处理数据流""" async for item in self.stream: try: result = await self._process_item(item) self.processed_count += 1 print(f"处理: {item} -> {result}") except Exception as e: self.errors.append((item, str(e))) async def _process_item(self, item): """处理单个数据项""" await asyncio.sleep(random.uniform(0.1, 0.3)) if random.random() < 0.1: # 10% 概率失败 raise ValueError(f"处理失败: {item}") return f"processed_{item}"async def main(): stream = AsyncDataStream("api.example.com/data", batch_size=3) processor = AsyncDataProcessor(stream) await processor.process() print(f"\n统计:") print(f" 成功处理: {processor.processed_count}") print(f" 失败数量: {len(processor.errors)}") if processor.errors: print(f" 错误详情: {processor.errors}")asyncio.run(main())
2. 异步分页 API
import asynciofrom typing import Optional, Dict, Anyclass AsyncPaginatedAPI: """异步分页 API 客户端""" def __init__(self, base_url: str, page_size: int = 10): self.base_url = base_url self.page_size = page_size self.current_page = 0 self.has_more = True def __aiter__(self): return self async def __anext__(self): """返回下一页的数据""" if not self.has_more: raise StopAsyncIteration self.current_page += 1 page_data = await self._fetch_page(self.current_page) if not page_data or len(page_data) < self.page_size: self.has_more = False return page_data async def _fetch_page(self, page: int) -> list: """模拟获取分页数据""" print(f"获取第 {page} 页...") await asyncio.sleep(0.5) # 模拟网络延迟 # 模拟数据 start = (page - 1) * self.page_size if start >= 50: # 最多50条数据 return [] end = min(start + self.page_size, 50) return [{"id": i, "name": f"Item {i}"} for i in range(start, end)]class AsyncDataFetcher: """异步数据获取器""" def __init__(self, api: AsyncPaginatedAPI): self.api = api self.items = [] async def fetch_all(self): """获取所有数据""" async for page in self.api: print(f"收到页面: {len(page)} 条数据") self.items.extend(page) # 可以提前处理 await self._process_page(page) return self.items async def _process_page(self, page): """处理单页数据""" tasks = [self._process_item(item) for item in page] await asyncio.gather(*tasks) async def _process_item(self, item): """处理单个数据项""" await asyncio.sleep(0.1) # 模拟处理 item['processed'] = Trueasync def main(): api = AsyncPaginatedAPI("https://api.example.com/items", page_size=8) fetcher = AsyncDataFetcher(api) items = await fetcher.fetch_all() print(f"\n总共获取: {len(items)} 条数据") print(f"前3条: {items[:3]}")asyncio.run(main())
3. 异步 WebSocket 消息流
import asyncioimport randomfrom typing import Optionalclass AsyncWebSocketStream: """异步 WebSocket 消息流""" def __init__(self, url: str): self.url = url self.connected = False self.message_queue = asyncio.Queue() self.running = False async def connect(self): """模拟 WebSocket 连接""" print(f"连接到 {self.url}...") await asyncio.sleep(0.5) self.connected = True self.running = True asyncio.create_task(self._message_generator()) print("连接成功") async def disconnect(self): """断开连接""" print("断开连接...") self.running = False self.connected = False def __aiter__(self): return self async def __anext__(self): """获取下一条消息""" if not self.connected: raise StopAsyncIteration try: # 等待消息,设置超时 message = await asyncio.wait_for( self.message_queue.get(), timeout=5.0 ) return message except asyncio.TimeoutError: # 超时结束迭代 raise StopAsyncIteration async def _message_generator(self): """模拟接收消息""" message_types = ['text', 'json', 'binary'] while self.running: await asyncio.sleep(random.uniform(0.5, 1.5)) msg_type = random.choice(message_types) if msg_type == 'text': message = f"文本消息 {random.randint(1, 100)}" elif msg_type == 'json': message = {"type": "data", "value": random.random()} else: message = bytes([random.randint(0, 255) for _ in range(10)]) await self.message_queue.put(message) print(f"收到消息: {message}")async def main(): ws = AsyncWebSocketStream("wss://echo.websocket.org") await ws.connect() try: message_count = 0 async for message in ws: message_count += 1 print(f"处理消息 {message_count}: {message}") if message_count >= 5: print("收到足够消息,停止") break finally: await ws.disconnect()asyncio.run(main())
五、高级异步迭代模式
1. 异步迭代器组合器
import asynciofrom typing import AsyncIterator, TypeVar, Callable, AwaitableT = TypeVar('T')U = TypeVar('U')class AsyncIteratorCombinator: """异步迭代器组合器""" @staticmethod async def map(iterable: AsyncIterator[T], func: Callable[[T], Awaitable[U]]) -> AsyncIterator[U]: """映射转换""" async for item in iterable: yield await func(item) @staticmethod async def filter(iterable: AsyncIterator[T], predicate: Callable[[T], Awaitable[bool]]) -> AsyncIterator[T]: """过滤""" async for item in iterable: if await predicate(item): yield item @staticmethod async def take(iterable: AsyncIterator[T], n: int) -> AsyncIterator[T]: """取前n个""" count = 0 async for item in iterable: if count >= n: break yield item count += 1 @staticmethod async def batch(iterable: AsyncIterator[T], size: int) -> AsyncIterator[list]: """批处理""" batch_items = [] async for item in iterable: batch_items.append(item) if len(batch_items) >= size: yield batch_items batch_items = [] if batch_items: yield batch_itemsasync def number_generator(): """数字生成器""" for i in range(20): await asyncio.sleep(0.1) yield iasync def main(): gen = number_generator() print("=== map: 平方 ===") async for num in AsyncIteratorCombinator.map(gen, lambda x: x * x): print(num, end=' ') if num >= 25: break print() print("\n=== filter: 偶数 ===") gen = number_generator() async for num in AsyncIteratorCombinator.filter(gen, lambda x: x % 2 == 0): print(num, end=' ') if num >= 10: break print() print("\n=== take: 前5个 ===") gen = number_generator() async for num in AsyncIteratorCombinator.take(gen, 5): print(num, end=' ') print() print("\n=== batch: 每批3个 ===") gen = number_generator() async for batch in AsyncIteratorCombinator.batch(gen, 3): print(f"批: {batch}")asyncio.run(main())
2. 异步迭代器管道
import asynciofrom typing import AsyncIterator, Anyclass AsyncPipeline: """异步处理管道""" def __init__(self): self.stages = [] def add_stage(self, processor): """添加处理阶段""" self.stages.append(processor) return self async def process(self, input_stream: AsyncIterator) -> AsyncIterator: """处理数据流""" current_stream = input_stream for stage in self.stages: current_stream = stage(current_stream) return current_streamclass Stage: """处理阶段基类""" def __init__(self, name): self.name = name async def __call__(self, stream: AsyncIterator) -> AsyncIterator: async for item in stream: result = await self.process(item) if result is not None: yield result async def process(self, item): """处理单个项目,子类重写""" return itemclass MapStage(Stage): """映射阶段""" def __init__(self, func, name="map"): super().__init__(name) self.func = func async def process(self, item): return await self.func(item)class FilterStage(Stage): """过滤阶段""" def __init__(self, predicate, name="filter"): super().__init__(name) self.predicate = predicate async def process(self, item): if await self.predicate(item): return item return Noneclass BatchStage(Stage): """批处理阶段""" def __init__(self, size, name="batch"): super().__init__(name) self.size = size self.buffer = [] async def __call__(self, stream: AsyncIterator) -> AsyncIterator: async for item in stream: self.buffer.append(item) if len(self.buffer) >= self.size: yield self.buffer self.buffer = [] if self.buffer: yield self.bufferasync def data_source(): """数据源""" for i in range(20): await asyncio.sleep(0.1) yield iasync def main(): # 创建处理管道 pipeline = AsyncPipeline() # 添加处理阶段 pipeline.add_stage( MapStage(lambda x: x * 2) ).add_stage( FilterStage(lambda x: x % 3 == 0) ).add_stage( BatchStage(3) ) # 处理数据 source = data_source() processed = await pipeline.process(source) print("管道处理结果:") async for batch in processed: print(f" 批: {batch}")asyncio.run(main())
3. 异步迭代器与上下文管理器结合
import asynciofrom typing import Optionalclass AsyncResourceIterator: """带资源管理的异步迭代器""" def __init__(self, resource_name: str): self.resource_name = resource_name self.resource = None self.position = 0 async def __aenter__(self): """进入上下文,获取资源""" print(f"获取资源: {self.resource_name}") await asyncio.sleep(0.2) # 模拟获取资源 self.resource = [f"{self.resource_name}_{i}" for i in range(5)] return self async def __aexit__(self, exc_type, exc_val, exc_tb): """退出上下文,释放资源""" print(f"释放资源: {self.resource_name}") await asyncio.sleep(0.1) # 模拟释放资源 self.resource = None def __aiter__(self): return self async def __anext__(self): if not self.resource or self.position >= len(self.resource): raise StopAsyncIteration value = self.resource[self.position] self.position += 1 await asyncio.sleep(0.1) # 模拟读取延迟 return valueasync def process_resources(): """处理多个资源""" async with AsyncResourceIterator("database") as iter1, \ AsyncResourceIterator("cache") as iter2: print("开始迭代资源1:") async for item in iter1: print(f" DB: {item}") print("\n开始迭代资源2:") async for item in iter2: print(f" Cache: {item}")async def main(): await process_resources()asyncio.run(main())
六、错误处理和边界情况
1. 异常处理
import asyncioimport randomclass AsyncIteratorWithError: """带错误处理的异步迭代器""" def __init__(self, items, error_probability=0.2): self.items = items self.error_prob = error_probability self.index = 0 self.errors_handled = 0 def __aiter__(self): return self async def __anext__(self): if self.index >= len(self.items): raise StopAsyncIteration await asyncio.sleep(0.1) # 模拟随机错误 if random.random() < self.error_prob: self.errors_handled += 1 raise ValueError(f"处理项目 {self.items[self.index]} 时出错") value = self.items[self.index] self.index += 1 return value async def safe_iterate(self): """安全迭代,处理错误""" results = [] while True: try: value = await self.__anext__() results.append(value) print(f"成功: {value}") except StopAsyncIteration: break except ValueError as e: print(f"捕获错误: {e}") # 跳过错误项目 self.index += 1 continue print(f"完成,处理了 {self.errors_handled} 个错误") return resultsasync def main(): iterator = AsyncIteratorWithError( list(range(10)), error_probability=0.3 ) results = await iterator.safe_iterate() print(f"结果: {results}")asyncio.run(main())
2. 超时和取消
import asyncioclass AsyncIteratorWithTimeout: """带超时的异步迭代器""" def __init__(self, delay=0.5): self.delay = delay self.count = 0 def __aiter__(self): return self async def __anext__(self): if self.count >= 5: raise StopAsyncIteration self.count += 1 await asyncio.sleep(self.delay) return self.countasync def iterate_with_timeout(iterator, timeout=2.0): """带超时的迭代""" try: async for item in iterator: print(f"得到: {item}") # 可以在这里处理超时 except asyncio.CancelledError: print("迭代被取消") raise except Exception as e: print(f"迭代错误: {e}")async def main(): # 正常迭代 iterator1 = AsyncIteratorWithTimeout(delay=0.3) print("正常迭代:") await iterate_with_timeout(iterator1) # 超时示例 print("\n带超时的迭代:") iterator2 = AsyncIteratorWithTimeout(delay=1.0) try: await asyncio.wait_for( iterate_with_timeout(iterator2), timeout=2.5 ) except asyncio.TimeoutError: print("迭代超时")asyncio.run(main())
七、总结
1. 异步迭代方法速查表
| | | |
|---|
__aiter__ | | AsyncIterator | |
__anext__ | | awaitable | StopAsyncIteration |
async for | | | |
2. 应用场景
流式数据处理:处理连续的数据流
分页 API:异步获取分页数据
WebSocket:处理实时消息
文件处理:异步读写大文件
数据库查询:异步获取大量结果
网络请求:并发处理多个请求
3. 设计原则
非阻塞:所有操作应该是非阻塞的
可取消:支持任务取消
背压处理:控制数据生产速度
错误恢复:优雅处理异常
资源管理:正确释放资源
4. 常见陷阱
# 陷阱1:在 __anext__ 中执行阻塞操作class BadAsyncIterator: async def __anext__(self): time.sleep(1) # 阻塞!应该用 await asyncio.sleep() return value# 陷阱2:忘记抛出 StopAsyncIterationclass NoStopAsyncIterator: async def __anext__(self): if self.index >= len(self.items): return None # 应该抛出 StopAsyncIteration# 陷阱3:__aiter__ 返回 self 但不实现 __anext__class IncompleteAsyncIterator: def __aiter__(self): return self # 缺少 __anext__# 陷阱4:在异步迭代器中修改共享状态class SharedStateIterator: def __init__(self, shared_list): self.shared_list = shared_list # 多个迭代器共享 async def __anext__(self): # 并发修改可能导致问题 return self.shared_list.pop()