前言

这是啥?从未见过,甚至一来还屡出Bug,我差点就想把这个模块完全给删了。

真是不用不知道,一用吓一跳,好吧,原来这个东西竟然是基于torch构建的一个顶层/分布式/工作流/工具集的究极糅合框架,极大方便了用户的使用,可以说是赛博炼丹师的必学框架。

你最好是真的方便了使用

本框架的核心工作流是从yaml文件不经过显式的模型构建,直接将配置文件实例化得到待训练网络,分离网络构建和数据预处理,然后使用Trainer.fit(model, data)完成隐式的网络训练主流程,对新手不是很友好,因为直接通过Type checker往往不能定位到函数的地址。

使用体验

  • 初次上手这个玩意,我感觉极难用。
  • 再一次接触它时,我似乎多了一些感觉,它确实有存在的意义。

一些说明

本框架作为对Pytorch的上层封装,确实在一定程度上方便了用户的使用,在2021年以前主流的版本是pytorch-lightning 1.4版本,包括latent-diffusion在内的大模型框架都使用的是1.4版本,而在2021年8月后,pytorch-lightning迎来了一波大的更新,直接将版本号从1.4更新至2.0,此次更新幅度较大,大量的旧版本方法、属性都被删减(向下兼容这一点做得没有torch好),同时将较多数量的类迁移至不同的目录下,直接割裂了新旧版本之间的兼容性,而当前的torch版本已经更新到2.2,如果仍然要在旧的pytorch-lightning框架下使用torch,则势必使用较旧的、Bug较多的1.7.0版本, 实际上无益于模型的训练与迭代。

因此还是从2.0版本以上开始学习兼容使用吧

把旧版本的一些方法迁移至新版本下重新写一遍,主要问题集中在pytorch-lightning新定义了一个LightningArgumentParser的命令行参数解析器,用该派生类取代了在1.4版本中使用的标准库解析器argparser,在调试过程中遇到的主要问题出现在这里。

一些经验之谈

基于pytorch-lightning搭建自定义的网络训练框架,当且仅当有较大的显存和资源占用需求时,才推荐使用它(因为它方便管理分布式训练资源和数据集)。而训练小模型时,杀鸡焉用牛刀,其实是没有必要的,而且它不太方便调试和复现学习,因为所有的模型说明和数据集定义都放在config的yaml文件中,不会看那个文件,就很难读懂整个工程项目的逻辑结构。

收回我刚才说的话,谁说pytorch-lightning不方便调试来着,只是我之前不会使用罢了。

pytorch-lightning的API参考文档,所有的API都有详细的解释,是本博客的学习来源。此外,该框架还配置了完善的上手教程演示示例,拥有完备的社区,常见的问题和解决方案都可以通过社区与ChatGPT获取。

越用越觉得这玩意不错,设备管理真是省心。

请注意,要详细地学习pytorch-lightning,请务必参考原始文档,本博客仅记录个人学习感受,视角非常片面!

主体架构

pytorch-lightning用于辅助构建深度学习的网络框架,可以用下图表示其与pytorch、CUDA等的拓扑关系:
pytorch-lightning的拓扑结构
其可以在软件层级应用PyTorch的接口,由PyTorch完成不同设备间的分布式使用,同时集成了包括命令行参数工具、可视化报表工具等,极大方便了深度学习训练过程中的监视与测试。

LightningArgumentParser类

命令行参数,继承自python标准库的argparser,以更好地兼容pytorch-lightning的Trainer等派生类的参数,便于直接实例化。实际上使用起来并不是特别方便,除非模型就完全从命令行中构建。该类对yaml文件的解码结果(一般用dict存储)兼容性并不是很好,在latent-diffusion脚本中,手工使用了OmegaConf库管理yaml解码的结果,并没有使用到此类。

Trainer类

本框架的一个核心类,有一个叫Trainer.fit(model, data)的方法,是将整个训练过程全部封装起来,另有testpredict方法,可实现测试和推理。Trainer类的另一个关键是各种各样的Callbacks,正是这些回调函数实现了对训练模型的监视、管理、乃至加载和初始化等一切功能,支持非常灵活的自定义,主要用到的将是它的日志生成器、学习率管理器、CUDA设备管理器、梯度累积器、模型保存检查器。

与直接从torch.nn.module继承并定义网络类型一样,训练器模型全部定义自pytorch-lightning.module继承的对象中,该对象必须重写以下函数:train_stepconfig_optimizers__init__(甚至可以不用重写forward)。由于pytorch-lightning.module实际上是继承自torch.nn.module的对象,因此torch.nn.module中拥有的forward方法也是可以使用的,调用该方法与torch.nn.module一致,也是直接使用self(inputs)即可。

train_step

本方法内定义每个迭代step对应的全部流程,包括数据加载、模型推理计算、损失计算、优化器的调用与参数更新、其他功能(如监视变量、输出保存结果等),是每个迭代步计算一次(并不是对全部训练数据集跑一次)。该方法必须有以下的输入参数类型:

1
2
3
def train_step(self, batch, batch_idx, *arg, **kwarg):
inputs, targets = batch
...

其中batch是从输入数据集中产生的批加载数据,可以是字典类型(每个键值对中存储一批数据),也可以是元组等。batch_idx是当前批次的索引,该索引在每一个epoch内是不相同的,但是在不同的epoch内是重复出现的,需要与self.global_step结合使用,以实现对训练过程的监视。

推荐在train_step函数中使用手动优化self.automatic_optimizition=False

手工设计优化器、学习率管理器等迭代step(),自动优化虽然好,但是不灵活,不适用于多优化器、多学习率管理器等一般情况。更为重要的一点是:

自动优化器管理不利于调试

因为自动优化将参数更新的过程封装在函数内部,直接调试断点打不到参数更新前后,因此并不能得知参数是否更新,容易影响后续工作进行(比如我)。因此这里非常建议关闭自动优化(甚至是屏蔽它最好),手工在train_step()方法中添加如下的优化器更新代码:

1
2
3
4
5
6
7
x, targets = self.get_inputs(batch) # 定义数据准备函数
inputs = self(x) # 模型推理计算
opt = self.optimizers()
opt.zero_grad()
loss = self.loss_fn(inputs, targets) # 模型计算损失
loss.backward()
opt.step()

如有多个优化器,在不使用梯度保留retain_graph=True的情况下,可以引用batch_idx输入量,根据频率调用不同的优化器,如下:

1
2
3
4
5
6
7
8
9
10
11
opt1, opt2 = self.optimizers()
if batch_idx % 2 == 0:
opt1.zero_grad()
loss = self.loss_fn1(inputs, targets)
loss.backward()
opt1.step()
elif batch_idx % 2 == 1:
opt2.zero_grad()
loss = self.loss_fn2(inputs, targets)
loss.backward()
opt2.step()

同理,对于学习率管理器的调用也是类似的,根据学习率更新的频率判定其步进即可,如下:

1
2
3
lr = self.lr_schedulers()
if batch_idx % 200 == 0:
lr.step()

train_step方法整个迭代step作用前后的两个时间节点处,还提供了on_train_epoch_endon_train_epoch_start两个方法,这两个方法可以在训练迭代每一步的前后执行某些特别的工作(如梯度累积、新建迭代对象等),例如,在训练BYOL型自对比网络时,需要对分支作动量更新,直接在train_step内定义参数更新达不到梯度累积后动量更新的效果,因此可在on_train_epoch_end中定义如下的步骤:

1
2
3
4
5
6
for p1, p2 in zip(self.branch1.parameters(), self.branch2.parameters()):
p2.data = (1 - self.momentum) * p1.data + self.momentum * p2.data
p1.data = p2.data
schedular = self.lr_schedulers()
if schedular is not None:
schedular.step()

最后需要提醒的一点是,train_step的返回值只能是torch.TensordictNone三种类型,分别代表损失函数返回值、某些需要返回的值(可包括损失函数)和无返回值,如果没有特别的操作,推荐设定为无返回值

config_optimizers

本方法控制优化器和学习率管理器的配置,操作非常简单,与传统训练框架中定义完全一致,唯一需要注意的是返回值,对于只定义了优化器的情况,返回值应当是优化器对象列表或元组;如果同时定义了优化器和学习率管理器,返回对象需要是优化器列表、学习率管理器列表(实际上也可以是字典),如下所示:

1
2
3
4
5
6
def config_optimizers(self):
opt1 = torch.optim.Adam(self.model1.parameters(), lr=1e-3)
opt2 = torch.optim.SGD(self.model2.parameters(), lr=1e-3) # 可以使用不同类型的优化器
lr1 = torch.optim.lr_scheduler.StepLR(opt1, step_size=100, gamma=0.1)
lr2 = torch.optim.lr_scheduler.StepLR(opt2, step_size=100, gamma=0.1) # 也可以使用不同类型的学习率管理器
return [opt1, opt2], [lr1, lr2]

然后在train_stepon_train_epoch_end等方法中可直接调用self.optimizers()self.lr_schedulers()这两个对象,进行参数更新和学习率调整。