掌握 itertools,告别繁琐的循环,写出 Pythonic 的代码。
引言
在 Python 编程中,我们经常需要处理迭代操作——遍历、组合、筛选、分组……如果这些操作都用 for 循环硬写,代码不仅冗长,而且效率低下。
itertools 模块正是为了解决这些问题而生。它提供了一系列快速、高效的迭代器工具,让你的代码更简洁、更 Pythonic。
本文将全面介绍 itertools 的核心功能,配合实例代码,帮助你彻底掌握这个强大的工具库。
一、无限迭代器
无限迭代器是 itertools 的特色功能,它们可以无限生成数据,直到你主动停止。
1. count() - 计数器
从指定数字开始,无限递增。
from itertools import count
# 从 10 开始,每次增加 2
for i in count(10, 2):
if i > 20:
break
print(i)
# 输出: 10, 12, 14, 16, 18, 20
2. cycle() - 循环迭代
无限循环遍历一个可迭代对象。
from itertools import cycle
colors = ['红', '绿', '蓝']
color_cycle = cycle(colors)
# 取前 7 个
for _ in range(7):
print(next(color_cycle))
# 输出: 红, 绿, 蓝, 红, 绿, 蓝, 红
实际应用:实现轮询负载均衡。
servers = ['server1', 'server2', 'server3']
server_pool = cycle(servers)
# 每次请求选择下一个服务器
for request in range(10):
server = next(server_pool)
print(f"请求 {request+1} -> {server}")
3. repeat() - 重复元素
重复返回同一个元素,可指定次数。
from itertools import repeat
# 重复 '嗨' 5 次
for _ in repeat('嗨', 5):
print(_)
# 输出: 嗨, 嗨, 嗨, 嗨, 嗨
# 无限重复(不指定次数)
counter = 0
for _ in repeat('永远'):
if counter >= 3:
break
print(_)
counter += 1
二、有限迭代器
这类迭代器接受一个或多个可迭代对象,返回处理后的结果,直到输入耗尽。
1. accumulate() - 累积计算
对可迭代对象进行累积操作,默认是求和。
from itertools import accumulate
numbers = [1, 2, 3, 4, 5]
# 默认累加
result = list(accumulate(numbers))
print(result)
# 输出: [1, 3, 6, 10, 15] (1, 1+2, 1+2+3...)
# 自定义操作:累乘
import operator
result = list(accumulate(numbers, operator.mul))
print(result)
# 输出: [1, 2, 6, 24, 120] (阶乘)
# 计算最大值累积
prices = [100, 95, 102, 98, 105, 103]
max_prices = list(accumulate(prices, max))
print(max_prices)
# 输出: [100, 100, 102, 102, 105, 105]
2. chain() - 连接多个可迭代对象
将多个可迭代对象串联成一个。
from itertools import chain
list1 = [1, 2, 3]
list2 = ['a', 'b', 'c']
list3 = [True, False]
# 连接多个列表
result = list(chain(list1, list2, list3))
print(result)
# 输出: [1, 2, 3, 'a', 'b', 'c', True, False]
# 对比普通方法
# result = list1 + list2 + list3 # 效率低,创建新列表
# chain.from_iterable - 用于展平嵌套列表
nested = [[1, 2], [3, 4], [5, 6]]
flat = list(chain.from_iterable(nested))
print(flat)
# 输出: [1, 2, 3, 4, 5, 6]
3. compress() - 按条件筛选
根据选择器对数据列表进行过滤。
from itertools import compress
data = ['A', 'B', 'C', 'D', 'E']
selectors = [True, False, True, False, True]
# 只保留对应 True 的元素
result = list(compress(data, selectors))
print(result)
# 输出: ['A', 'C', 'E']
# 实际应用:选择特定列
columns = ['姓名', '年龄', '城市', '薪资']
use_column = [True, True, False, True] # 不需要城市
active_columns = list(compress(columns, use_column))
print(active_columns)
# 输出: ['姓名', '年龄', '薪资']
4. dropwhile() 和 takewhile()
根据条件截取序列。
from itertools import dropwhile, takewhile
numbers = [1, 3, 5, 7, 4, 6, 8, 2]
# dropwhile: 只要条件为真就丢弃,一旦为假则保留剩余所有
dropped = list(dropwhile(lambda x: x < 5, numbers))
print(dropped)
# 输出: [5, 7, 4, 6, 8, 2] (1,3被丢弃,5不小于5停止)
# takewhile: 只要条件为真就保留,一旦为假停止
taken = list(takewhile(lambda x: x < 5, numbers))
print(taken)
# 输出: [1, 3] (遇到5停止)
# 实际应用:跳过文件头部注释
lines = [
'# 这是注释',
'# 作者:张三',
'',
'import os',
'print("Hello")'
]
# 跳过空行和注释
code = list(dropwhile(lambda line: not line or line.startswith('#'), lines))
print(code)
5. filterfalse()
filter() 的反操作,保留使条件为假的元素。
from itertools import filterfalse
numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# 筛选偶数(保留使条件为假的:x%2==1为奇数,取反)
evens = list(filterfalse(lambda x: x % 2, numbers))
print(evens)
# 输出: [2, 4, 6, 8, 10]
# 对比 filter
odds = list(filter(lambda x: x % 2, numbers))
print(odds)
# 输出: [1, 3, 5, 7, 9]
# 实际应用:过滤掉空值
data = ['hello', '', 'world', None, 'python', '', '']
filtered = list(filterfalse(lambda x: not x, data))
print(filtered)
# 输出: ['hello', 'world', 'python']
6. groupby()
对可迭代对象进行分组(需要先排序)。
from itertools import groupby
# 注意:groupby 要求数据先按分组键排序!
data = [
('A', 'Alice'),
('A', 'Anna'),
('B', 'Bob'),
('B', 'Ben'),
('C', 'Carol')
]
# 按第一个字母分组
for key, group in groupby(data, key=lambda x: x[0]):
print(f"字母 {key}: {list(group)}")
# 输出:
# 字母 A: [('A', 'Alice'), ('A', 'Anna')]
# 字母 B: [('B', 'Bob'), ('B', 'Ben')]
# 字母 C: [('C', 'Carol')]
# 实际应用:统计连续相同的元素
from operator import itemgetter
events = [
('2024-01-01', 'login'),
('2024-01-02', 'login'),
('2024-01-03', 'logout'),
('2024-01-04', 'login'),
]
# 按事件类型分组(需要先排序)
events.sort(key=itemgetter(1))
for event_type, group in groupby(events, key=itemgetter(1)):
count = len(list(group))
print(f"{event_type}: {count} 次")
7. islice() - 切片迭代器
对迭代器进行切片,支持负数索引(会消耗迭代器)。
from itertools import islice
numbers = iter(range(100))
# 取前 5 个
first_five = list(islice(numbers, 5))
print(first_five)
# 输出: [0, 1, 2, 3, 4]
# 取 10-14(从第10个开始,取5个)
middle = list(islice(numbers, 5, 10))
print(middle)
# 输出: [5, 6, 7, 8, 9]
# 每 2 个取一个(步长为2)
every_second = list(islice(numbers, 0, 10, 2))
print(every_second)
# 输出: [10, 12, 14, 16, 18]
# 对比普通切片(只能用于有索引的对象)
numbers_list = list(range(100))
print(numbers_list[5:10]) # 适用于列表
# islice 可以用于任何迭代器(文件、生成器等)
8. starmap()
将参数列表解包后传递给函数(类似 * 解包)。
from itertools import starmap
# 普通 map
args = [(2, 3), (4, 5), (6, 7)]
# 使用 lambda 手动解包
result1 = list(map(lambda x: x[0] ** x[1], args))
print(result1)
# 输出: [8, 1024, 279936]
# 使用 starmap 自动解包
result2 = list(starmap(lambda x, y: x ** y, args))
print(result2)
# 输出: [8, 1024, 279936]
# 实际应用:批量计算距离
import math
coordinates = [
((0, 0), (3, 4)), # 距离 5
((1, 1), (4, 5)), # 距离 5
((0, 0), (1, 1)), # 距离 √2
]
def distance(p1, p2):
return math.sqrt((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)
distances = list(starmap(distance, coordinates))
print(distances)
# 输出: [5.0, 5.0, 1.414...]
9. tee()
将一个迭代器分裂成多个独立的迭代器。
from itertools import tee
numbers = iter([1, 2, 3, 4, 5])
# 分裂成 3 个独立迭代器
iter1, iter2, iter3 = tee(numbers, 3)
print(list(iter1)) # [1, 2, 3, 4, 5]
print(list(iter2)) # [1, 2, 3, 4, 5]
print(list(iter3)) # [1, 2, 3, 4, 5]
# 原迭代器已耗尽,不能再用
# print(list(numbers)) # 空列表
# 实际应用:同时计算总和与平均值
def sum_and_avg(iterable):
it1, it2 = tee(iterable, 2)
total = sum(it1)
count = sum(1 for _ in it2)
return total, total / count if count > 0 else 0
numbers = range(1, 101) # 1-100
total, avg = sum_and_avg(numbers)
print(f"总和: {total}, 平均值: {avg}")
# 输出: 总和: 5050, 平均值: 50.5
10. zip_longest()
zip() 的增强版,可以处理长度不等的序列,用指定值填充。
from itertools import zip_longest
names = ['Alice', 'Bob', 'Carol']
ages = [25, 30]
cities = ['NYC', 'LA', 'Chicago', 'Houston']
# 普通 zip 以最短的为准
print(list(zip(names, ages, cities)))
# 输出: [('Alice', 25, 'NYC'), ('Bob', 30, 'LA')]
# zip_longest 以最长的为准,空缺填充 None
print(list(zip_longest(names, ages, cities)))
# 输出:
# [('Alice', 25, 'NYC'), ('Bob', 30, 'LA'), ('Carol', None, 'Chicago'), (None, None, 'Houston')]
# 自定义填充值
print(list(zip_longest(names, ages, cities, fillvalue='未知')))
# 输出:
# [('Alice', 25, 'NYC'), ('Bob', 30, 'LA'), ('Carol', '未知', 'Chicago'), ('未知', '未知', 'Houston')]
三、组合迭代器
组合迭代器用于生成各种排列组合,在算法、概率统计等领域非常有用。
1. product() - 笛卡尔积
计算多个可迭代对象的笛卡尔积(所有可能的组合)。
from itertools import product
# 两个列表的笛卡尔积
colors = ['红', '绿']
sizes = ['大', '小']
result = list(product(colors, sizes))
print(result)
# 输出: [('红', '大'), ('红', '小'), ('绿', '大'), ('绿', '小')]
# 三个列表
materials = ['棉', '麻']
result = list(product(colors, sizes, materials))
print(f"共有 {len(result)} 种组合")
# 输出: 共有 8 种组合
# 同一个列表的笛卡尔积(允许重复)
dice_faces = [1, 2, 3, 4, 5, 6]
# 掷两次骰子的所有可能
rolls = list(product(dice_faces, repeat=2))
print(f"掷两次骰子共有 {len(rolls)} 种可能")
# 输出: 掷两次骰子共有 36 种可能
# 计算点数和为 7 的概率
sum_to_7 = [roll for roll in rolls if sum(roll) == 7]
print(f"点数和为 7 的情况有 {len(sum_to_7)} 种,概率为 {len(sum_to_7)/len(rolls):.2%}")
# 输出: 点数和为 7 的情况有 6 种,概率为 16.67%
2. permutations() - 排列
生成所有可能的排列(考虑顺序)。
from itertools import permutations
items = ['A', 'B', 'C']
# 全排列
result = list(permutations(items))
print(f"3 个元素的全排列共 {len(result)} 种:")
for p in result:
print(p)
# 输出:
# 3 个元素的全排列共 6 种:
# ('A', 'B', 'C')
# ('A', 'C', 'B')
# ('B', 'A', 'C')
# ('B', 'C', 'A')
# ('C', 'A', 'B')
# ('C', 'B', 'A')
# 部分排列(取 2 个)
result = list(permutations(items, 2))
print(f"\n从 3 个中取 2 个的排列共 {len(result)} 种:")
for p in result:
print(p)
# 输出: 6 种 (A,B), (A,C), (B,A), (B,C), (C,A), (C,B)
# 实际应用:密码组合分析
digits = '123'
# 3 位数字密码的所有可能(数字不重复)
passwords = [''.join(p) for p in permutations(digits, 3)]
print(f"\n可能的 3 位密码(不重复): {passwords}")
# 输出: ['123', '132', '213', '231', '312', '321']
3. combinations() - 组合
生成所有可能的组合(不考虑顺序)。
from itertools import combinations
team = ['Alice', 'Bob', 'Carol', 'David']
# 选 2 人的所有组合
pairs = list(combinations(team, 2))
print(f"从 4 人中选 2 人,共 {len(pairs)} 种组合:")
for p in pairs:
print(f" {p[0]} & {p[1]}")
# 输出: 6 种组合
# 选 3 人的组合
trios = list(combinations(team, 3))
print(f"\n从 4 人中选 3 人,共 {len(trios)} 种组合")
# 输出: 4 种组合
# 实际应用:抽奖组合
import random
participants = ['张三', '李四', '王五', '赵六', '孙七']
# 抽取 3 人组成中奖小组
winning_groups = list(combinations(participants, 3))
print(f"\n共有 {len(winning_groups)} 种可能的 3 人组合")
# 随机选一组
lucky_group = random.choice(winning_groups)
print(f"中奖组合: {lucky_group}")
4. combinations_with_replacement() - 可重复组合
元素可以重复出现的组合。
from itertools import combinations_with_replacement
# 从 ['A', 'B'] 中选 2 个,允许重复
result = list(combinations_with_replacement(['A', 'B'], 2))
print(result)
# 输出: [('A', 'A'), ('A', 'B'), ('B', 'B')]
# 注意:普通 combinations 不会包含 ('A', 'A')
# 实际应用:购买组合
# 有 3 种水果,每种买 0-2 个,总共买 2 个
fruits = ['苹果', '香蕉', '橙子']
purchase_options = list(combinations_with_replacement(fruits, 2))
print(f"\n买 2 个水果的组合(可重复)共 {len(purchase_options)} 种:")
for opt in purchase_options:
print(f" {opt[0]} + {opt[1]}")
# 输出: 6 种组合(包括两个相同水果的情况)
四、实用技巧与最佳实践
1. 内存效率:迭代器 vs 列表
itertools 返回的都是迭代器,具有惰性求值特性,可以处理海量数据而不会耗尽内存。
from itertools import count, islice
# 高效:不存储所有数据
# 生成 0-999999,取前 10 个,不占用内存
first_10 = islice(count(), 10)
print(list(first_10))
# 低效:创建巨大列表
# numbers = list(range(1000000)) # 占用大量内存
2. 链式操作
多个 itertools 函数可以链式组合,实现复杂的数据处理。
from itertools import chain, islice, cycle
# 示例:生成无限循环的序列,取特定窗口
data = [1, 2, 3]
window = islice(cycle(data), 2, 8) # 跳过2个,取6个
print(list(window))
# 输出: [3, 1, 2, 3, 1, 2]
3. 与生成器表达式结合
itertools 与生成器表达式配合,可以编写简洁高效的代码。
from itertools import takewhile
# 生成斐波那契数列
# 使用 itertools 简化斐波那契生成
def fibonacci():
a, b = 0, 1
while True:
yield a
a, b = b, a + b
# 获取小于 1000 的所有斐波那契数
fib_under_1000 = takewhile(lambda x: x < 1000, fibonacci())
print(list(fib_under_1000))
# 输出: [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987]
4. 常见场景速查表
| 场景 |
推荐函数 |
示例 |
| 无限计数 |
count() |
count(10, 2) → 10, 12, 14... |
| 循环轮询 |
cycle() |
cycle(['A', 'B', 'C']) |
| 累加/累乘 |
accumulate() |
accumulate([1,2,3]) → 1, 3, 6 |
| 扁平化 |
chain() |
chain(list1, list2) |
| 展平嵌套 |
chain.from_iterable() |
chain.from_iterable(matrix) |
| 分组 |
groupby() |
groupby(data, key=func) |
| 笛卡尔积 |
product() |
product(A, B) |
| 排列 |
permutations() |
permutations(items, 2) |
| 组合 |
combinations() |
combinations(items, 2) |
五、总结
itertools 是 Python 标准库中最强大、最高效的模块之一。掌握它可以让你:
学习建议
- 从常用函数开始:先掌握
chain()、groupby()、islice() - 结合实际场景:在工作中寻找可以用 itertools 优化的循环
- 阅读源码:itertools 是用 C 实现的,了解实现原理有助于深入理解
掌握 itertools,让你的 Python 代码更加优雅、高效、Pythonic!