前言

关于扩散模型的数学基础性的知识,可以参照原始论文[1][2];亦可以参考本文相关的DDPM的博客DDIM的博客以及条件扩散模型。本文主要讲述如何使用Python作为脚本语言,从零开始实现扩散模型的工程问题,参考论文LDM[4]提供的源代码,使用pytorch-lightning框架完成扩散模型的训练和推理。

代码

本文的全部代码在gitee仓库中开源,见于此处。目前正在整理,后续将整理结果发布于github上。

TODO

  • [x] 一个便于复用的DDPM框架。
  • [ ] 一个便于复用的DDIM框架。
  • [ ] 一个便于复用的LDM框架。

扩散模型——从入门到入土

本博客计划使用pytorch-lightning搭建整个扩散模型的训练框架和数据加载器,并使用此模型完成推理。

DDPM基本框架

实现DDPM也将分为训练和推理两个部分的内容单独处理。首先需要将全部使用的参数,如αt,βt,αˉt\alpha_t,\beta_t,\bar{\alpha}_t等等都提前保存下来,在某个模型中,它们全部是定值,不能学习。设self.denoiser是本模型所使用的去噪UNet,该模型设计较为简单,本文不再赘述。

亦可以根据自身需要,将其更换为如TransformerRCNN等其他形式的估计器,用于更加复杂的视觉任务。

注册定值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def register_schedule(
self,
given_betas=None,
beta_schedule="linear",
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
):
if exists(given_betas):
betas = given_betas
else:
betas = make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
)
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])

(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert (
alphas_cumprod.shape[0] == self.num_timesteps
), "alphas have to be defined for each timestep"

to_torch = partial(torch.tensor, dtype=torch.float32)

self.register_buffer("betas", to_torch(betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))

# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer(
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
)

# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (
1.0 - alphas_cumprod_prev
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer("posterior_variance", to_torch(posterior_variance))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer(
"posterior_log_variance_clipped",
to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
)
self.register_buffer(
"posterior_mean_coef1",
to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
)
self.register_buffer(
"posterior_mean_coef2",
to_torch(
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
),
)

if self.parameterization == "eps":
lvlb_weights = self.betas**2 / (
2
* self.posterior_variance
* to_torch(alphas)
* (1 - self.alphas_cumprod)
)
elif self.parameterization == "x0":
lvlb_weights = (
0.5
* np.sqrt(torch.Tensor(alphas_cumprod))
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
)
else:
raise NotImplementedError("mu not supported")
# TODO how to choose this term
lvlb_weights[0] = lvlb_weights[1]
self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
assert not torch.isnan(self.lvlb_weights).all()

扩散过程设计

虽然一步步加噪声是很直观的设计,但是python里面写循环往往意味着效率的下降,因此可以根据αˉt\bar{\alpha}_t的应用,直接一步完成扩散过程,如公式:

xt=αˉtx0+1αˉtϵ~t \mathbf{x}_t=\sqrt{\bar{\alpha}_t}\mathbf{x}_0+\sqrt{1-\bar{\alpha}_t}\tilde{\boldsymbol{\epsilon}}_t

后者可以直接从ϵ~tN(0,I)\tilde{\bm{\epsilon}}_t\sim\mathcal{N}(0,\mathbf{I})中采样取得,剩下的量全部是已知量,可以简单地实现为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
@abstractmethod
def q_sample(self, x_start, t) -> tuple[torch.Tensor, torch.Tensor]:
"""
扩散过程主函数,用于从输入图像获取完成噪声化的带噪图像。\\
use reparametrization trick to sample from q(x_t | x_0)
x_t = sqrt(alpha_t) * x_0 + sqrt(1 - alpha_t) * noise\\
Arguments:
x_start (torch.Tensor) : 输入的数据结构,可以是单张图(最简单的DDPM),也可以是其他的类型。如检测框、分割模板等。
t (int) : 在批维度上与x_start对应的时间步,一般是随机采样取得。
Returns:
x_noisy (torch.Tensor) : 生成的带噪图像
noise (torch.Tensor) : 生成的噪声,充当GT
"""
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)

损失函数设计

其中训练部分的内容甚至比推理更简单,根据DDPM论文[1]得出的结论,实际中使用的损失函数可以不像公式

Eq{DKL[q(xTx0)p(xT)]LT+t=2TDKL[q(xt1x0,xt)pθ(xt1xt)]Lt1logpθ(x0x1)L0}\mathbb{E}_q\left\{\underbrace{D_\mathrm{KL}\left[q(\mathbf{x}_T|\mathbf{x}_0)||p(\mathbf{x}_T)\right]}_{L_T} +\sum_{t=2}^T\underbrace{D_\mathrm{KL}\left[q(\mathbf{x}_{t-1}|\mathbf{x}_0,\mathbf{x}_t)||p_{\boldsymbol{\theta}}(\mathbf{x}_{t-1}|\mathbf{x}_t)\right]}_{L_{t-1}}\underbrace{-\log p_{\boldsymbol{\theta}}(\mathbf{x}_0|\mathbf{x}_1)}_{L_0}\right\}

这般复杂,只需要把噪声估计器当作是一个普通的图像恢复网络即可,即使用损失函数:

Lsimple(θ)=Et,x0,ϵ[ϵϵθ(αˉtx0+1αˉtϵ,t)2]L_\mathrm{simple}(\bm{\theta})=\mathbb{E}_{t,\mathbf{x}_0,\bm{\epsilon}}\left[\|\bm{\epsilon}-\bm{\epsilon_\theta}(\sqrt{\bar{\alpha}_t}\bm{x}_0+\sqrt{1-\bar{\alpha}_t}\bm{\epsilon},t)\|^2\right]

就够了,可以不考虑诸如变分下界损失和权重学习等因素,这样可以写出训练的损失函数为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def get_loss(self, pred, target, mean=True):
"""
用于从噪声估计器的输出值和某个输入值计算损失,这里的输出值和输入值都可以不是图像格式,需要注意的是,返回值一定是形状等于batch数的列向量,例如:
>>> def get_loss(self, pred, target, mean=True):
>>> if self.loss_type == 'l1':
>>> loss = (target - pred).abs().mean(dim=[1, 2, 3])
>>> if mean:
>>> loss = loss.mean()
>>> return loss
"""
if self.loss_type == "l1":
loss = (target - pred).abs().mean(dim=[1, 2, 3])
if mean:
loss = loss.mean()
elif self.loss_type == "l2":
if mean:
loss = torch.nn.functional.mse_loss(target, pred)
else:
loss = torch.nn.functional.mse_loss(
target, pred, reduction="none"
).mean(dim=[1, 2, 3])
else:
raise NotImplementedError(f"unknown loss type '{self.loss_type}'")

return loss

其中,本步损失函数还可以进一步设计,根据自己的需求,修改得更加复杂。牢牢记住损失函数是给噪声估计器设计的。然后,加入逐时间步的损失加权和变分下界损失,如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def p_losses(self, x_start, t):
"""
本步骤是主要的训练步骤
Arguments:
x_start (torch.Tensor) : 输入的数据结构,可以是单张图(最简单的DDPM),也可以是其他的类型。如检测框、分割模板等。
t (int) : 在批维度上与x_start对应的时间步,一般是随机采样取得。
noise (torch.Tensor) : 输入的噪声,形状上应当与x_start相同(但是不一定每个通道都有噪声)。
"""
x_noisy, noise = self.q_sample(x_start=x_start, t=t)
pred_noise = self.denoiser(x_noisy, t)

loss_dict = {}
# 这一步决定是否采用残差学习还是直接恢复学习
if self.parameterization == "eps":
target = noise
elif self.parameterization == "x0":
target = x_start
else:
raise NotImplementedError(
f"Paramterization {self.parameterization} not yet supported"
)

loss = self.get_loss(pred_noise, target, mean=False)

log_prefix = "train" if self.training else "val"

loss_dict.update({f"{log_prefix}/loss_simple": loss.mean()})
loss_simple = loss.mean() * self.l_simple_weight
# 这一步求解出不同时间步长的损失加权
loss_vlb = (self.lvlb_weights[t] * loss).mean()
loss_dict.update({f"{log_prefix}/loss_vlb": loss_vlb})

loss = loss_simple + self.original_elbo_weight * loss_vlb

loss_dict.update({f"{log_prefix}/loss": loss})

return loss, loss_dict

训练步骤设计

设计完毕损失函数后,训练步骤就变得简单了,也就是从pytorch-lightning的输入batch中获取需要的数据结构,经扩散过程产生带噪图像后加入至噪声估计器中,然后由噪声估计器返回估计噪声,与加入的噪声直接计算损失即可,不需要迭代计算。以上逻辑可以使用如下代码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
@abstractmethod
def get_input(self, batch: dict, k: list):
"""
从数据加载器的batch中获取输入数据,需要注意是,加载的数据张量的形状,通道量被放在了最后一个维度,而如果要使用`torch`自带的卷积等模板,应当将通道数放在第二个维度。
Arguments:
batch (dict) : 数据加载器的一个batch
k (list) : batch中的键值,用于获取输入数据
"""
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = rearrange(x, "b h w c -> b c h w")
x = x.to(memory_format=torch.contiguous_format).float()
return x

def shared_step(self, batch):
x = self.get_input(batch, self.first_stage_key)
t = torch.randint(
0, self.num_timesteps, (x.shape[0],), device=self.device
).long()
loss, loss_dict = self.p_losses(x, t)
return loss, loss_dict

def training_step(self, batch, batch_idx):
loss, loss_dict = self.shared_step(batch)

self.log_dict(
loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True
)

self.log(
"global_step",
self.global_step,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=False,
)
opt = self.optimizers()
opt.zero_grad()
loss.backward()
opt.step()

由以上方法可以完成对DDPM的训练,尚属于比较容易的。真正让人觉得头大的是扩散过程的推理步骤。其内容较多,下文将其单列为一章。

推理步骤设计

推理步骤,可以简单地认为是迭代地从噪声输入中不断去除噪声,直到图像显现为止。因此一定会用到训练步骤中得到的self.denoiser,除此以外,该过程得到的估计噪声也会与当前图像叠加再次输入至噪声估计器完成下一步估计,因此必然有方法专门负责处理输入的带噪图像;也必然有方法专门处理从噪声中产生估计图像。

噪声估计

根据公式:

μ~t(xt,x^0)=1αt[xtβt1αˉtϵθ(xt,t)]\tilde{\bm{\mu}}_t(\mathbf{x}_t,\mathbf{\hat{x}}_0)=\frac{1}{\sqrt{\alpha_t}}\left[\mathbf{x}_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\bm{\epsilon_\theta}(\mathbf{x}_t,t)\right]

噪声估计器的输入中应当包括时间量,以及当前步骤的带噪图像xtx_t,根据每个batch采样的时间不一定相同,逐个实现为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
@torch.no_grad()
@abstractmethod
def p_sample(self, x, t, clip_denoised=True) -> torch.Tensor:
"""
用于从带噪图像中完成去噪得到图像,即生成过程。
Arguments:
x (torch.Tensor) : 输入的带噪图像,应当是完整的输入数据结构
t (torch.Tensor) : 在批维度上与x对应的时间步,一般是随机采样取得
Returns:
x_noisy (torch.Tensor) : 生成的去噪图像

Examples:
>>> b, *_ = x.shape
>>> pred_noise = self.denoiser(x, t)
>>> x_noisy = x[:, -1:]
>>> model_mean, _, model_log_variance = self.p_mean_variance(x_noisy, pred_noise, t=t, clip_denoised=clip_denoised)
>>> noise = torch.randn_like(x)
>>> # no noise when t == 0
>>> nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_noisy.shape) - 1)))
>>> output = x.detach().clone()
>>> output[:, -1:] = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
>>> return output
"""
b, *_ = x.shape
pred_noise = self.denoiser(x, t)
x_noisy = x[:, -1:]
model_mean, _, model_log_variance = self.p_mean_variance(
x_noisy, pred_noise, t=t, clip_denoised=clip_denoised
)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(
b, *((1,) * (len(x_noisy.shape) - 1))
)
output = x.detach().clone()
output[:, -1:] = (
model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
)
return output

此步骤中用到了函数self.p_mean_variance,该函数实现的实质上是以下公式:

{μ~t(xt,x0)=αˉt1βt1αˉtx0+α(1αˉt1)1αˉtxtσt2=1αˉt11αˉtβt\left\{\begin{aligned} \bm{\tilde{\mu}_t}(x_t,x_0)&=\frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}\mathbf{x}_0+\frac{\alpha(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}\mathbf{x}_t\\\\ \sigma^2_t&=\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\beta_t \end{aligned}\right.

由此把两个项的实现分别用self.predict_start_from_noiseself.q_posterior实现,如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def predict_start_from_noise(self, x_t, t, noise):
return (
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
* noise
)

def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract_into_tensor(
self.posterior_log_variance_clipped, t, x_t.shape
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped

图像生成

图像生成过程就是反复执行self.p_sample的过程,并将最后一步执行结果作为最终输出即可,如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
@torch.no_grad()
def sample(self, x_noisy, return_intermediates=False):
device = self.betas.device
b = x_noisy.shape[0]
intermediates = [self.log_samples(x_noisy)]
for i in tqdm(
reversed(range(0, self.num_timesteps)),
desc="Sampling t",
total=self.num_timesteps,
):
t = torch.full((b,), i, device=device, dtype=torch.long)
x_noisy = self.p_sample(x_noisy, t, self.clip_denoised)
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
intermediates.append(self.log_samples(x_noisy))
if return_intermediates:
return x_noisy, intermediates
return x_noisy

由此,推理过程也实现了。

自此DDPM的训练和推理都可以实现。

但是这只是一步开始,DDPM的性能不及LDM,推理速度不及DDIM,甚至也不能保证输入一幅噪声图像,输出一幅清晰图像(DDRM才能做到)。因此,以上步骤只能完成最基础内容的搭建,在后面的章节中,将介绍加入条件引导量的实现细节。

条件模型实现细节

条件分两类:条件引导项condition和模板(上下文context),条件引导指代从不是图像的引导信息转换为图像再将它与带噪图像融合;而模板指代本身就是图像,只将它融合进扩散模型的带噪图中的过程。条件引导项在DDPM生成过程的采样中体现为xtN(μ+sσtIxtlogpϕ(yxt),σtI)\mathbf{x}_t\sim\mathcal{N}(\bm{\mu}+s\sigma_t\mathbf{I}\nabla_{\mathbf{x}_t}\log p_{\bm{\phi}}(\mathbf{y|x}_t),\sigma_t\mathbf{I})中的导数项,其中条件概率pxt(yxt)p_{\mathbf{x}_t}(\mathbf{y|x}_t)既可以是分类器条件概率,也可以是诸如矩形框位置、回归坐标等其他形式的条件概率,通常使用字典实现。

字典

字典是NLP领域的术语[3],这里的字典指的是事先存储了大量的健值对(key-value)的数据结构(最简单的形式就是Hash字典,如python内置数据结构dict,但是深度学习中的键值较多,使用Hash字典内存占用过高,常用如线性表、矩阵)。字典的使用是给定一个简单的输入值(称为查询,query),由字典自动比对字典库中全部的键,按以下情况分类给出输出:

  1. 如果字典是可数且稀疏的,当且仅当有一个键与query完全匹配时,输出对应的值,否则报错;
  2. 如果字典是可数,但是不稀疏,则字典根据键与查询的相似度,输出距离查询最相似的键对应的值;
  3. 如果字典是不可数的,例如输入的查询是一个向量,字典的键也是向量,此时字典将根据相似度的定义给出最相似的键,输出对应的值。

著名如自注意力机制,也是可数键值对情况下的可学习字典。正是字典查询能够匹配整个键空间,自注意力机制才能获得了对整个输入序列的全局信息把握能力。(这里偏题了)

pytorch提供了nn.Embedding类,可以实现类似字典的功能,即输入类别的索引,可以给出该类别的编码向量,而不需要求解整个分类器的梯度。

不得不说是一种很聪明的工程技巧,避免了对网络求导的复杂计算量。

如果查询是实空间向量,直接使用nn.Embedding就不再可行了,因为它的输入量必须是可数的。此时一般可采用变换的方法,将该实空间向量经过量化(如码表量化,例如著名的VQGAN就采用的一步隐空间量化)、压缩等方式变为可数空间的向量,再使用nn.Embedding完成查询输出。

Embedding

参考自博客nn.Embedding可以看作是输入为one-hot码的无偏线性层,其权重的更新方式类似于线性层的反向传播,最多根据查询索引出现的频率,对反向传播的梯度作一次缩放,本质上仍然是一种高效的线性层。

参考文献


[1] Ho J. , Jain A. , Abbeel P. .Denoising Diffusion Probabilistic Models[M/OL].arXiv,2020


[2] Song J. , Meng C. , Ermon S. .Denoising Diffusion Implicit Models[M/OL].arXiv,2022


[3] He K. , Fan H. , Wu Y. , et al.Momentum Contrast for Unsupervised Visual Representation Learning[A].2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)[C/OL].Seattle, WA, USA:IEEE,2020:9726-9735


[4] Rombach R. , Blattmann A. , Lorenz D. , et al.High-resolution image synthesis with latent diffusion Models[A].2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)[C/OL].New Orleans, LA, USA:IEEE,2022:10674-10685