一、ml_collections是什么
ml_collections 是一个python库,初衷设计是专为机器学习配置。而我在实际使用过程中,把它当做来一个原生字典的替代项,感觉很好用。本文主要介绍基础的使用和锁机制。
地址:https://pypi.org/project/ml-collections/
核心特性:
- 可以让字典使用.的方式来访问和修改,而原生方式是[""]=xx;
- 提供了锁机制,当写完字典后,在一定程度上防止了误修改;
- 自动带了数据类型检查;
二、基础用法
下文代码所用的版本:1.1.0
1. 创建方式1:ConfigDict()
from ml_collections import ConfigDict# ========== 方式1:创建空的ConfigDict,最基础cfg = ConfigDict()# ========== 方式2:初始化时传入原生字典,一键创建cfg = ConfigDict({ "lr": 0.001, "batch_size": 32, "epochs": 100})# ========== 方式3:空配置逐步赋值创建cfg = ConfigDict()# 基础单层配置赋值cfg.lr = 0.001cfg.batch_size = 32cfg.epochs = 200# 配置支持所有Python基础数据类型cfg.device = "cuda"cfg.use_amp = Truecfg.weight_decay = 1e-5
2. 创建方式2:create(),可以一次性把配置都写完
from ml_collections import config_dict# 用 create 一次性写完所有嵌套配置cfg = config_dict.create( # 数据集配置 data=config_dict.create( dataset="cifar10", data_dir="./data", batch_size=32, num_workers=4 ), # 模型配置 model=config_dict.create( backbone="resnet18", num_classes=10, dropout=0.2, pretrained=True ), # 训练配置 + 嵌套的优化器子配置 train=config_dict.create( epochs=200, lr=0.001, optimizer=config_dict.create( name="adam", weight_decay=1e-5 ) ))# 打印配置(可选)print("实验配置:", cfg)
3. 嵌套
字典的值也可以是ConfigDict
from ml_collections import ConfigDict# 创建根配置cfg = ConfigDict()# --------------------------# 第一层:数据集相关配置(嵌套1层)# --------------------------cfg.data = ConfigDict() # 先给data赋值为一个空的ConfigDictcfg.data.dataset_name = "cifar10" # 数据集名称cfg.data.data_path = "./dataset/cifar10" # 数据存放路径cfg.data.batch_size = 32 # 批次大小cfg.data.num_workers = 4 # 加载数据的线程数# --------------------------# 第一层:模型相关配置(嵌套1层)# --------------------------cfg.model = ConfigDict()cfg.model.backbone = "resnet18" # 模型骨干网络cfg.model.num_classes = 10 # 分类任务的类别数cfg.model.dropout_rate = 0.2 # dropout正则化概率cfg.model.use_pretrain = True # 是否使用预训练权重# --------------------------# 第一层:训练相关配置(嵌套2层)# --------------------------cfg.train = ConfigDict()cfg.train.epochs = 200 # 训练总轮数cfg.train.base_lr = 0.001 # 基础学习率# 训练配置里再嵌套:优化器的子配置(嵌套第二层)cfg.train.optimizer = ConfigDict()cfg.train.optimizer.name = "adam" # 优化器名称cfg.train.optimizer.weight_decay = 0.00005 # 权重衰减# --------------------------# 取值:链式点语法,一层一层访问# --------------------------print("数据集名称:", cfg.data.dataset_name)print("模型骨干网络:", cfg.model.backbone)print("优化器名称:", cfg.train.optimizer.name)print("权重衰减系数:", cfg.train.optimizer.weight_decay)
4. 读取方式
同时支持.式和[]中括号方式
from ml_collections import ConfigDictcfg = ConfigDict({ "lr":0.001, "model": ConfigDict({"backbone":"resnet18"})})# ========== 方式1:点语法 读取print(cfg.lr) # 输出:0.001print(cfg.model.backbone) # 输出:resnet18# ========== 方式2:字典中括号语法 读取print(cfg["lr"]) # 输出:0.001print(cfg["model"]["backbone"]) # 输出:resnet18
5. 删除
from ml_collections import ConfigDictcfg = ConfigDict({"lr":0.001, "batch_size":32, "epochs":200})del cfg.epochs # 删除基础字段del cfg["batch_size"] # 用字典语法删除,等价print("删除后配置:", cfg) # 输出:lr: 0.001
6. 自动类型校验
int可以赋值给int,int可以赋值给float,int不能赋值给string,string也不能赋值给int
from ml_collections import config_dictcfg = config_dict.ConfigDict()cfg.float_field = 12.6 # float类型cfg.integer_field = 123 # int类型cfg.another_integer_field = 234 # int类型cfg.nested = config_dict.ConfigDict()cfg.nested.string_field = 'tom' # str类型print(cfg.integer_field) # Prints 123.print(cfg['integer_field']) # Prints 123 as well.try: cfg.integer_field = 'tom' # 将str赋值给int类型,报错except TypeError as e: print(e)cfg.float_field = 12 # int可以给float赋值,自动类型转换cfg.nested.string_field = u'bob' # str赋值print(cfg)
三、锁机制
1. 方式1: lock()
- lock后,不可以新增字段,不可以删除字段
- lock后,可以修改已有的字段,包括嵌套里的字段
from ml_collections import ConfigDict# 1. 创建配置并初始化【已有字段】cfg = ConfigDict()cfg.lr = 0.001cfg.batch_size = 32# 嵌套配置cfg.model = ConfigDict()cfg.model.dropout = 0.2# 2. 核心:执行上锁操作cfg.lock()# 上锁后 - 修改【已有字段】:正常生效,无报错 cfg.lr = 0.0005 # 修改根层级已有字段 ✔️cfg.batch_size = 64 # 修改根层级已有字段 ✔️cfg.model.dropout = 0.1 # 修改嵌套层级已有字段 ✔️print("修改后的值:", cfg.lr, cfg.batch_size, cfg.model.dropout)# 特性2:上锁后 - 新增【任何字段】:直接报错try: cfg.epochs = 200 # 新增根层级字段 except AttributeError as e: print("❌ 新增根字段报错:", e)try: cfg.model.backbone = "resnet18" # 新增嵌套层级字段 ❌except AttributeError as e: print("❌ 新增嵌套字段报错:", e)# 特性3:上锁后 - 删除【任何字段】:直接报错 try: del cfg.lr # 删除根层级已有字段 ❌except AttributeError as e: print("❌ 删除根字段报错:", e)try: del cfg.model.dropout # 删除嵌套层级已有字段 ❌except AttributeError as e: print("❌ 删除嵌套字段报错:", e)
- 判断是否上锁
print(cfg.is_locked) # 上锁后返回 True,解锁后返回 False
- 彻底解锁
- 临时解锁
from ml_collections import ConfigDict# 1. 创建配置 + 赋值已有字段cfg = ConfigDict()cfg.lr = 0.001cfg.batch_size = 32cfg.model = ConfigDict()# 2. 上锁(核心前提)cfg.lock()print(f"初始上锁状态: {cfg.is_locked}") # True# 临时解锁 with cfg.unlocked()# 特性:代码块内 临时解锁,可自由 新增/修改/删除 任何字段;代码块结束 自动重新上锁,无需手动操作with cfg.unlocked(): # ✅ 临时修改已有字段 cfg.lr = 0.0005 cfg.model.dropout = 0.1 # ✅ 临时新增字段(根+嵌套都可以) cfg.epochs = 200 cfg.model.backbone = "resnet18" # ✅ 临时删除字段 del cfg.batch_size# 验证:代码块结束后自动上锁 + 操作全部生效print(f"临时解锁后状态: {cfg.is_locked}") # True 自动上锁print(f"修改后: {cfg.lr}, {cfg.model.dropout}")print(f"新增后: {cfg.epochs}, {cfg.model.backbone}")print(f"删除后batch_size是否存在: {'batch_size'in cfg}") # False# 上锁状态下 依旧禁止新增/删除(安全兜底)try: cfg.device = "cuda"except Exception as e: print(f"\n上锁禁止新增: {e.args[0]}")
2. 方式2: FrozenConfigDict()
- 创建,特性:禁止修改、删除、新增
# 正确导入from ml_collections import ConfigDictfrom ml_collections.config_dict import FrozenConfigDict# ✔️ 方式1:ConfigDict 转 FrozenConfigDictcfg = ConfigDict()cfg.lr = 0.001cfg.batch_size = 32# 嵌套ConfigDict 正确写法cfg.model = ConfigDict()cfg.model.dropout = 0.2cfg.model.backbone = "resnet18"f_cfg = FrozenConfigDict(cfg)# ✔️ 方式2:直接创建 FrozenConfigDict (传字典)f_cfg2 = FrozenConfigDict({"lr":0.001, "batch_size":32})print("="*20)print("FrozenConfigDict只读配置:", f_cfg)print("是否为冰封配置:", isinstance(f_cfg, FrozenConfigDict)) # True# 所有【写操作】全部禁止# 1. ❌ 禁止修改已有字段try: f_cfg.lr = 0.0005except AttributeError as e: print(f"\n❌ 修改报错: {e}")# 2. ❌ 禁止新增字段try: f_cfg.epochs = 200except AttributeError as e: print(f"❌ 新增报错: {e}")# 3. ❌ 禁止删除字段try: del f_cfg.batch_sizeexcept AttributeError as e: print(f"❌ 删除报错: {e}")# 4. ❌ 无lock/unlock方法,天生只读,无需上锁try: f_cfg.lock()except AttributeError as e: print(f"❌ 无lock方法: {e}")# 唯一支持的操作:【读取】 print("\n" + "="*20)print("✅ 读取基础字段:", f_cfg.lr)print("✅ 读取嵌套字段:", f_cfg.model.dropout)print("✅ 字典方式读取:", f_cfg["batch_size"])
- FrozenConfigDict与ConfigDict转换
# 正确导入from ml_collections import ConfigDictfrom ml_collections.config_dict import FrozenConfigDict# ✔️ 转换1: ConfigDict → FrozenConfigDict (冻结,可变→不可变)cfg = ConfigDict({"lr":0.001, "batch_size":32})frozen_cfg = FrozenConfigDict(cfg)# ✔️ 转换2: FrozenConfigDict → ConfigDict (解冻,不可变→可变,恢复所有权限)unfrozen_cfg = ConfigDict(frozen_cfg)# 解冻后 ✅ 恢复全部操作权限:增/删/改 都可以unfrozen_cfg.lr = 1e-4 # 修改 ✔️unfrozen_cfg.device = "cuda" # 新增 ✔️del unfrozen_cfg.batch_size # 删除 ✔️print("✅ 解冻后ConfigDict:", unfrozen_cfg)