Diffusion 源码解析(Pytorch)
该 Blog 是对 DDPM 对图片数据和一维数据代码的解读比较。
代码仓库:
https://github.com/lucidrains/denoising-diffusion-pytorch/tree/main
参考文献/文章:
- https://arxiv.org/abs/2006.11239
- https://lilianweng.github.io/posts/2021-07-11-diffusion-models/
- https://huggingface.co/blog/annotated-diffusion
运行方法¶
一维¶
from denoising_diffusion_pytorch import Unet1D, GaussianDiffusion1D, Trainer1D, Dataset1D
model = Unet1D(
dim = 64,
dim_mults = (1, 2, 4, 8),
channels = 32
)
diffusion = GaussianDiffusion1D(
model,
seq_length = 128,
timesteps = 1000,
objective = 'pred_v'
)
training_seq = torch.rand(64, 32, 128) # features are normalized from 0 to 1
dataset = Dataset1D(training_seq) # this is just an example, but you can formulate your own Dataset and pass it into the `Trainer1D` below
loss = diffusion(training_seq)
loss.backward()
# Or using trainer
trainer = Trainer1D(
diffusion,
dataset = dataset,
train_batch_size = 32,
train_lr = 8e-5,
train_num_steps = 700000, # total training steps
gradient_accumulate_every = 2, # gradient accumulation steps
ema_decay = 0.995, # exponential moving average decay
amp = True, # turn on mixed precision
)
trainer.train()
# after a lot of training
sampled_seq = diffusion.sample(batch_size = 4)
sampled_seq.shape # (4, 32, 128)
多维¶
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
model = Unet(
dim = 64,
dim_mults = (1, 2, 4, 8),
flash_attn = True
)
diffusion = GaussianDiffusion(
model,
image_size = 128,
timesteps = 1000, # number of steps
sampling_timesteps = 250 # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
)
trainer = Trainer(
diffusion,
'path/to/your/images',
train_batch_size = 32,
train_lr = 8e-5,
train_num_steps = 700000, # total training steps
gradient_accumulate_every = 2, # gradient accumulation steps
ema_decay = 0.995, # exponential moving average decay
amp = True, # turn on mixed precision
calculate_fid = True # whether to calculate fid during training
)
trainer.train()
1、导入相关库¶
import math
from pathlib import Path
from random import random
from tqdm.auto import tqdm
from ema_pytorch import EMA
from functools import partial
from accelerate import Accelerator
from collections import namedtuple
from multiprocessing import cpu_count
import torch
from torch import nn, einsum, Tensor
import torch.nn.functional as F
from torch.cuda.amp import autocast
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
导入库的功能如下所示:
math
: 用于数学运算。pathlib.Path
: 用于处理文件和目录路径的对象。random.random
: 用于生成0到1之间的随机数。-
tqdm
: 用于在循环中显示进度条。 -
ema_pytorch
: 用于指数移动平均(EMA)的PyTorch扩展。 functools.partial
: 用于创建偏函数。accelerate
: 加速训练,用于自动混合精度和分布式训练。collections.namedtuple
: 用于创建命名元组。-
multiprocessing.cpu_count
: 用于获取计算机的CPU核心数。 -
torch
: PyTorch的根模块。 torch.nn
: 神经网络模块,包含了各种神经网络层和函数。torch.cuda.amp.autocast
: 自动混合精度训练模块。torch.optim.Adam
: Adam优化器。torch.utils.data.Dataset
: 数据集类,用于自定义数据集。-
torch.utils.data.DataLoader
: 数据加载器,用于批量加载数据。 -
einops
:提供了简洁的方式对PyTorch张量进行重组和操作。rearrange
函数用于重新排列张量的维度;reduce
函数用于沿着指定的维度进行缩减操作(如求和、求平均等);Rearrange
类是einops
提供的一个PyTorch层,用于在PyTorch模型中重新排列张量的维度。
2、 Dataset¶
一维¶
class Dataset1D(Dataset):
def __init__(self, tensor: Tensor):
super().__init__()
self.tensor = tensor.clone()
def __len__(self):
return len(self.tensor)
def __getitem__(self, idx):
return self.tensor[idx].clone()
该代码定义了一个一维的数据集类,用于处理一维数据。具体来说:
-
__init__
:这是类的初始化函数,它接收一个Tensor作为参数,并将其克隆保存在self.tensor中。这样做的目的是为了避免在原始数据上进行修改。 -
__len__
:这个方法返回数据集的大小,也就是一维数据的长度。在PyTorch的数据加载器(DataLoader)中,这个方法被用来确定每个epoch的迭代次数。 -
__getitem__
:这个方法接收一个索引,返回对应索引的数据。在PyTorch的数据加载器(DataLoader)中,这个方法被用来按需加载数据,这样可以节省内存,提高数据加载的效率。
多维¶
class Dataset(Dataset):
def __init__(
self,
folder,
image_size,
exts = ['jpg', 'jpeg', 'png', 'tiff'],
augment_horizontal_flip = False,
convert_image_to = None
):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else nn.Identity()
self.transform = T.Compose([
T.Lambda(maybe_convert_fn),
T.Resize(image_size),
T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
T.CenterCrop(image_size),
T.ToTensor()
])
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
相对于处理一维数据的 Dataset,该数据集用于处理图片。具体来说,这个类的输入是一个图片文件夹的路径,而不是一个Tensor。它会搜索这个文件夹中的所有图片文件,并将它们的路径保存在self.paths中;同时,该类在 __init__
中定义了一个数据预处理流程,这个流程包括图片格式转换、图片大小调整、随机水平翻转、中心裁剪和转换为Tensor。这个预处理流程会在__getitem__
方法中被应用到每一张图片上。
__getitem__
方法最后返回的Tensor的形状应该是(C, H, W)
,其中:
C
是通道数,对于彩色图片,通道数通常是3(RGB);对于灰度图片,通道数是1。H
是图片的高度,这个值由image_size
参数决定,这个参数在初始化数据集类时被传入。W
是图片的宽度,这个值也由image_size
参数决定。
3、Trainer¶
Part 1¶
一维¶
class Trainer1D(object):
def __init__(
self,
diffusion_model: GaussianDiffusion1D,
dataset: Dataset,
*,
train_batch_size = 16,
gradient_accumulate_every = 1,
train_lr = 1e-4,
train_num_steps = 100000,
ema_update_every = 10,
ema_decay = 0.995,
adam_betas = (0.9, 0.99),
save_and_sample_every = 1000,
num_samples = 25,
results_folder = './results',
amp = False,
mixed_precision_type = 'fp16',
split_batches = True,
):
super().__init__()
# accelerator
self.accelerator = Accelerator(
split_batches = split_batches,
mixed_precision = mixed_precision_type if amp else 'no'
)
# model
self.model = diffusion_model
self.channels = diffusion_model.channels
# sampling and training hyperparameters
assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
self.num_samples = num_samples
self.save_and_sample_every = save_and_sample_every
self.batch_size = train_batch_size
self.gradient_accumulate_every = gradient_accumulate_every
self.train_num_steps = train_num_steps
# dataset and dataloader
dl = DataLoader(dataset, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())
dl = self.accelerator.prepare(dl)
self.dl = cycle(dl)
# optimizer
self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas)
# for logging results in a folder periodically
if self.accelerator.is_main_process:
self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every)
self.ema.to(self.device)
self.results_folder = Path(results_folder)
self.results_folder.mkdir(exist_ok = True)
# step counter state
self.step = 0
# prepare model, dataloader, optimizer with accelerator
self.model, self.opt = self.accelerator.prepare(self.model, self.opt)
这个初始化函数接受了很多参数,下面是每个参数的含义:
diffusion_model
: GaussianDiffusion1D模型对象,它是训练过程中需要优化的模型。
dataset
: 数据集对象,它提供了训练模型所需的数据。
train_batch_size
: 整数,表示每个训练批次的大小。
gradient_accumulate_every
: 整数,表示每隔多少步进行一次梯度累积。
train_lr
: 浮点数,表示训练的学习率。
train_num_steps
: 整数,表示训练的总步数。
ema_update_every
: 整数,表示每隔多少步更新一次指数移动平均(EMA)。
ema_decay
: 浮点数,表示EMA的衰减率。
adam_betas
: 元组,表示Adam优化器的beta参数。
save_and_sample_every
: 整数,表示每隔多少步保存一次模型并生成样本。
num_samples
: 整数,表示每次生成样本的数量。
results_folder
: 字符串,表示保存结果的文件夹路径。
amp
: 布尔值,表示是否使用自动混合精度(AMP)训练。
mixed_precision_type
: 字符串,表示混合精度的类型,可能的值是'fp16'或'fp32'。
split_batches
: 布尔值,表示是否将批次分割到多个设备上。
在初始化函数中,首先初始化了一个Accelerator
对象,这个对象用于管理模型的设备分配和混合精度训练。然后,初始化了模型、训练参数、数据加载器和优化器。如果当前进程是主进程,还会初始化一个EMA对象,用于跟踪模型的移动平均。最后,准备了模型、数据加载器和优化器,以适应Accelerator
的设置。
多维¶
class Trainer(object):
def __init__(
self,
diffusion_model,
folder,
*,
train_batch_size = 16,
gradient_accumulate_every = 1,
augment_horizontal_flip = True,
train_lr = 1e-4,
train_num_steps = 100000,
ema_update_every = 10,
ema_decay = 0.995,
adam_betas = (0.9, 0.99),
save_and_sample_every = 1000,
num_samples = 25,
results_folder = './results',
amp = False,
mixed_precision_type = 'fp16',
split_batches = True,
convert_image_to = None,
calculate_fid = True,
inception_block_idx = 2048,
max_grad_norm = 1.,
num_fid_samples = 50000,
save_best_and_latest_only = False
):
super().__init__()
# accelerator
self.accelerator = Accelerator(
split_batches = split_batches,
mixed_precision = mixed_precision_type if amp else 'no'
)
# model
self.model = diffusion_model
self.channels = diffusion_model.channels
# sampling and training hyperparameters
assert has_int_squareroot(num_samples), 'number of samples must have an integer square root'
self.num_samples = num_samples
self.save_and_sample_every = save_and_sample_every
self.batch_size = train_batch_size
self.gradient_accumulate_every = gradient_accumulate_every
assert (train_batch_size * gradient_accumulate_every) >= 16, f'your effective batch size (train_batch_size x gradient_accumulate_every) should be at least 16 or above'
self.train_num_steps = train_num_steps
self.image_size = diffusion_model.image_size
self.max_grad_norm = max_grad_norm
# dataset and dataloader
self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip, convert_image_to = convert_image_to)
assert len(self.ds) >= 100, 'you should have at least 100 images in your folder. at least 10k images recommended'
dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())
dl = self.accelerator.prepare(dl)
self.dl = cycle(dl)
# optimizer
self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas)
# for logging results in a folder periodically
if self.accelerator.is_main_process:
self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every)
self.ema.to(self.device)
self.results_folder = Path(results_folder)
self.results_folder.mkdir(exist_ok = True)
# step counter state
self.step = 0
# prepare model, dataloader, optimizer with accelerator
self.model, self.opt = self.accelerator.prepare(self.model, self.opt)
# FID-score computation
self.calculate_fid = calculate_fid and self.accelerator.is_main_process
if self.calculate_fid:
if not self.model.is_ddim_sampling:
self.accelerator.print(
"WARNING: Robust FID computation requires a lot of generated samples and can therefore be very time consuming."\
"Consider using DDIM sampling to save time."
)
self.fid_scorer = FIDEvaluation(
batch_size=self.batch_size,
dl=self.dl,
sampler=self.ema.ema_model,
channels=self.channels,
accelerator=self.accelerator,
stats_dir=results_folder,
device=self.device,
num_fid_samples=num_fid_samples,
inception_block_idx=inception_block_idx
)
if save_best_and_latest_only:
assert calculate_fid, "`calculate_fid` must be True to provide a means for model evaluation for `save_best_and_latest_only`."
self.best_fid = 1e10 # infinite
self.save_best_and_latest_only = save_best_and_latest_only
与上一个__init__
函数相比,这个函数多了一些参数,主要是关于数据增强、图像格式转换和FID评分的参数。这些参数使得这个训练器更适合处理图像数据。不同/增加的参数如下:
folder
: 字符串,表示数据集的文件夹路径。
augment_horizontal_flip
: 布尔值,表示是否对图像进行水平翻转的数据增强。
convert_image_to
: 字符串,表示将图像转换为何种格式,可能的值是'RGB'或'YCbCr'。
calculate_fid
: 布尔值,表示是否计算Frechet Inception Distance(FID)评分。
inception_block_idx
: 整数,表示用于计算FID评分的Inception模型的哪个块的输出。
max_grad_norm
: 浮点数,表示梯度裁剪的阈值。
num_fid_samples
: 整数,表示用于计算FID评分的样本数量。
save_best_and_latest_only
: 布尔值,表示是否只保存最好和最新的模型。
Part 2¶
@property
def device(self):
return self.accelerator.device
def save(self, milestone):
if not self.accelerator.is_local_main_process:
return
data = {
'step': self.step,
'model': self.accelerator.get_state_dict(self.model),
'opt': self.opt.state_dict(),
'ema': self.ema.state_dict(),
'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None,
'version': __version__
}
torch.save(data, str(self.results_folder / f'model-{milestone}.pt'))
def load(self, milestone):
accelerator = self.accelerator
device = accelerator.device
data = torch.load(str(self.results_folder / f'model-{milestone}.pt'), map_location=device)
model = self.accelerator.unwrap_model(self.model)
model.load_state_dict(data['model'])
self.step = data['step']
self.opt.load_state_dict(data['opt'])
if self.accelerator.is_main_process:
self.ema.load_state_dict(data["ema"])
if 'version' in data:
print(f"loading from version {data['version']}")
if exists(self.accelerator.scaler) and exists(data['scaler']):
self.accelerator.scaler.load_state_dict(data['scaler'])
device
:返回训练器正在使用的设备,CPU或者GPU。
save
:保存训练器的状态,包括模型的参数、优化器的状态、训练步数等。函数在每个训练周期结束时被调用,用于保存训练的进度,以便在需要时恢复训练。
load
:加载保存的训练器状态,包括模型的参数、优化器的状态、训练步数等。函数在训练开始时被调用,用于从保存的状态恢复训练。
(一维与多维基本一致)
Part 3¶
def train(self):
accelerator = self.accelerator
device = accelerator.device
with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar:
while self.step < self.train_num_steps:
total_loss = 0.
for _ in range(self.gradient_accumulate_every):
data = next(self.dl).to(device)
with self.accelerator.autocast():
loss = self.model(data)
loss = loss / self.gradient_accumulate_every
total_loss += loss.item()
self.accelerator.backward(loss)
accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
pbar.set_description(f'loss: {total_loss:.4f}')
accelerator.wait_for_everyone()
self.opt.step()
self.opt.zero_grad()
accelerator.wait_for_everyone()
self.step += 1
if accelerator.is_main_process:
self.ema.update()
# 判断是否达到了保存和采样的步数。如果达到了,就生成样本并保存模型
if self.step != 0 and self.step % self.save_and_sample_every == 0:
self.ema.ema_model.eval()
with torch.no_grad():
milestone = self.step // self.save_and_sample_every
batches = num_to_groups(self.num_samples, self.batch_size)
all_samples_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))
all_samples = torch.cat(all_samples_list, dim = 0)
torch.save(all_samples, str(self.results_folder / f'sample-{milestone}.png'))
self.save(milestone)
pbar.update(1)
accelerator.print('training complete')
train
:训练器的主要函数,实现了模型的训练过程。函数在每个训练周期中被调用,用于执行模型的前向阶段和后向阶段,以及参数的更新。在函数中,模型的前向阶段是在loss = self.model(data)
这行代码中进行的,模型的后向阶段是在self.accelerator.backward(loss)
这行代码中进行的。最后,参数的更新是在self.opt.step()
这行代码中进行的。在train
函数的最后,当达到保存和采样的步数时,会调用self.ema.ema_model.sample(batch_size=n)
生成样本,并保存到文件中,这部分对应于生成保存的生成文件。
def train(self):
accelerator = self.accelerator
device = accelerator.device
with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar:
while self.step < self.train_num_steps:
total_loss = 0.
for _ in range(self.gradient_accumulate_every):
data = next(self.dl).to(device)
with self.accelerator.autocast():
loss = self.model(data)
loss = loss / self.gradient_accumulate_every
total_loss += loss.item()
self.accelerator.backward(loss)
accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
pbar.set_description(f'loss: {total_loss:.4f}')
accelerator.wait_for_everyone()
self.opt.step()
self.opt.zero_grad()
accelerator.wait_for_everyone()
self.step += 1
if accelerator.is_main_process:
self.ema.update()
if self.step != 0 and divisible_by(self.step, self.save_and_sample_every):
self.ema.ema_model.eval()
with torch.inference_mode():
milestone = self.step // self.save_and_sample_every
batches = num_to_groups(self.num_samples, self.batch_size)
all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))
all_images = torch.cat(all_images_list, dim = 0)
utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = int(math.sqrt(self.num_samples)))
# whether to calculate fid
if self.calculate_fid:
fid_score = self.fid_scorer.fid_score()
accelerator.print(f'fid_score: {fid_score}')
if self.save_best_and_latest_only:
if self.best_fid > fid_score:
self.best_fid = fid_score
self.save("best")
self.save("latest")
else:
self.save(milestone)
pbar.update(1)
accelerator.print('training complete')
与一维逻辑基本相同,但在保存模型部分保存最优分数模型和最新模型。(主要是因为一维模型缺少度量机制)
4、GaussianDiffusion¶
代码大致运行流程:
-
__init__
:初始化函数,用于设置模型的参数,计算beta和alpha值,初始化模型、优化器和数据加载器,根据objective
参数计算损失权重,根据auto_normalize
参数选择是否自动归一化。这个函数在创建类的实例时运行。 -
forward
:前向传播函数,用于计算损失。这个函数在训练模型时运行,它会调用p_losses
函数来计算损失。 -
p_losses
:计算损失函数,它会生成噪声,然后从起始图像和噪声中采样,得到一个新的图像。然后,根据self_condition
参数决定是否进行自我条件化,然后调用model_predictions
函数计算模型预测。最后,根据objective
参数计算损失。这个函数在forward
函数中被调用。 -
model_predictions
:计算模型预测函数,它会根据objective
参数调用相应的函数(predict_start_from_noise
、predict_noise_from_start
或predict_v
)计算模型预测。这个函数在p_losses
函数中被调用。 -
p_sample
:采样函数,用于生成新的图像。这个函数在生成新的图像时运行,它会被p_sample_loop
、ddim_sample
和sample
函数调用。 -
p_sample_loop
:循环采样函数,它会在一个循环中多次调用p_sample
函数进行采样。这个函数在sample
函数中被调用。 -
ddim_sample
:DDIM采样函数,它会在一个循环中多次调用p_sample
函数进行DDIM采样。这个函数在sample
函数中被调用。 -
sample
:采样函数,它会根据is_ddim_sampling
参数选择采样函数(p_sample_loop
或ddim_sample
)。这个函数在生成新的图像时运行。 -
interpolate
:插值函数,它会在一个循环中多次调用p_sample
函数进行插值。这个函数在生成插值图像时运行。 -
q_sample
:采样函数,它会从起始图像和噪声中采样,得到一个新的图像。这个函数在p_losses
和interpolate
函数中被调用。
具体来说,在 GaussianDiffusion
类中,forward
函数是模型的主要入口,它负责在前向过程中添加噪声,并在后向过程中计算损失。
在前向过程中,forward
函数首先将输入图像正则化,然后调用 p_losses
函数。在 p_losses
函数中,首先生成一个与输入图像形状相同的随机噪声,然后调用 q_sample
函数,该函数将噪声添加到输入图像中,生成一个新的图像。这就是在前向过程中添加噪声的部分。
在后向过程中,p_losses
函数继续执行,它调用模型的 model_predictions
函数来预测噪声,然后计算预测噪声和实际噪声之间的均方误差损失。这个损失被乘以一个损失权重,然后返回给 forward
函数,forward
函数将这个损失返回给调用者。这就是在后向过程中计算损失的部分。
一维¶
class GaussianDiffusion1D(nn.Module):
def __init__(
self,
model,
*,
seq_length,
timesteps = 1000,
sampling_timesteps = None,
objective = 'pred_noise',
beta_schedule = 'cosine',
ddim_sampling_eta = 0.,
auto_normalize = True
):
super().__init__()
self.model = model
self.channels = self.model.channels
self.self_condition = self.model.self_condition
self.seq_length = seq_length
self.objective = objective
assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
if beta_schedule == 'linear':
betas = linear_beta_schedule(timesteps)
elif beta_schedule == 'cosine':
betas = cosine_beta_schedule(timesteps)
else:
raise ValueError(f'unknown beta schedule {beta_schedule}')
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
# sampling related parameters
self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
assert self.sampling_timesteps <= timesteps
self.is_ddim_sampling = self.sampling_timesteps < timesteps
self.ddim_sampling_eta = ddim_sampling_eta
# helper function to register buffer from float64 to float32
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
register_buffer('betas', betas)
register_buffer('alphas_cumprod', alphas_cumprod)
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
# calculate loss weight
snr = alphas_cumprod / (1 - alphas_cumprod)
if objective == 'pred_noise':
loss_weight = torch.ones_like(snr)
elif objective == 'pred_x0':
loss_weight = snr
elif objective == 'pred_v':
loss_weight = snr / (snr + 1)
register_buffer('loss_weight', loss_weight)
# whether to autonormalize
self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity
self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def predict_noise_from_start(self, x_t, t, x0):
return (
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
)
def predict_v(self, x_start, t, noise):
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
)
def predict_start_from_v(self, x_t, t, v):
return (
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False):
model_output = self.model(x, t, x_self_cond)
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
if self.objective == 'pred_noise':
pred_noise = model_output
x_start = self.predict_start_from_noise(x, t, pred_noise)
x_start = maybe_clip(x_start)
if clip_x_start and rederive_pred_noise:
pred_noise = self.predict_noise_from_start(x, t, x_start)
elif self.objective == 'pred_x0':
x_start = model_output
x_start = maybe_clip(x_start)
pred_noise = self.predict_noise_from_start(x, t, x_start)
elif self.objective == 'pred_v':
v = model_output
x_start = self.predict_start_from_v(x, t, v)
x_start = maybe_clip(x_start)
pred_noise = self.predict_noise_from_start(x, t, x_start)
return ModelPrediction(pred_noise, x_start)
def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
preds = self.model_predictions(x, t, x_self_cond)
x_start = preds.pred_x_start
if clip_denoised:
x_start.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad()
def p_sample(self, x, t: int, x_self_cond = None, clip_denoised = True):
b, *_, device = *x.shape, x.device
batched_times = torch.full((b,), t, device = x.device, dtype = torch.long)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = clip_denoised)
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
return pred_img, x_start
@torch.no_grad()
def p_sample_loop(self, shape):
batch, device = shape[0], self.betas.device
img = torch.randn(shape, device=device)
x_start = None
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
self_cond = x_start if self.self_condition else None
img, x_start = self.p_sample(img, t, self_cond)
img = self.unnormalize(img)
return img
@torch.no_grad()
def ddim_sample(self, shape, clip_denoised = True):
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
img = torch.randn(shape, device = device)
x_start = None
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
time_cond = torch.full((batch,), time, device=device, dtype=torch.long)
self_cond = x_start if self.self_condition else None
pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = clip_denoised)
if time_next < 0:
img = x_start
continue
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c = (1 - alpha_next - sigma ** 2).sqrt()
noise = torch.randn_like(img)
img = x_start * alpha_next.sqrt() + \
c * pred_noise + \
sigma * noise
img = self.unnormalize(img)
return img
@torch.no_grad()
def sample(self, batch_size = 16):
seq_length, channels = self.seq_length, self.channels
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
return sample_fn((batch_size, channels, seq_length))
@torch.no_grad()
def interpolate(self, x1, x2, t = None, lam = 0.5):
b, *_, device = *x1.shape, x1.device
t = default(t, self.num_timesteps - 1)
assert x1.shape == x2.shape
t_batched = torch.full((b,), t, device = device)
xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2))
img = (1 - lam) * xt1 + lam * xt2
x_start = None
for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t):
self_cond = x_start if self.self_condition else None
img, x_start = self.p_sample(img, i, self_cond)
return img
@autocast(enabled = False)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, x_start, t, noise = None):
b, c, n = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))
# noise sample
x = self.q_sample(x_start = x_start, t = t, noise = noise)
# if doing self-conditioning, 50% of the time, predict x_start from current set of times
# and condition with unet with that
# this technique will slow down training by 25%, but seems to lower FID significantly
x_self_cond = None
if self.self_condition and random() < 0.5:
with torch.no_grad():
x_self_cond = self.model_predictions(x, t).pred_x_start
x_self_cond.detach_()
# predict and take gradient step
model_out = self.model(x, t, x_self_cond)
if self.objective == 'pred_noise':
target = noise
elif self.objective == 'pred_x0':
target = x_start
elif self.objective == 'pred_v':
v = self.predict_v(x_start, t, noise)
target = v
else:
raise ValueError(f'unknown objective {self.objective}')
loss = F.mse_loss(model_out, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b (...)', 'mean')
loss = loss * extract(self.loss_weight, t, loss.shape)
return loss.mean()
def forward(self, img, *args, **kwargs):
b, c, n, device, seq_length, = *img.shape, img.device, self.seq_length
assert n == seq_length, f'seq length must be {seq_length}'
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
img = self.normalize(img)
return self.p_losses(img, t, *args, **kwargs)
多维¶
class GaussianDiffusion(nn.Module):
def __init__(
self,
model,
*,
image_size,
timesteps = 1000,
sampling_timesteps = None,
objective = 'pred_v',
beta_schedule = 'sigmoid',
schedule_fn_kwargs = dict(),
ddim_sampling_eta = 0.,
auto_normalize = True,
offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
min_snr_gamma = 5
):
super().__init__()
assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)
assert not model.random_or_learned_sinusoidal_cond
self.model = model
self.channels = self.model.channels
self.self_condition = self.model.self_condition
self.image_size = image_size
self.objective = objective
assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
if beta_schedule == 'linear':
beta_schedule_fn = linear_beta_schedule
elif beta_schedule == 'cosine':
beta_schedule_fn = cosine_beta_schedule
elif beta_schedule == 'sigmoid':
beta_schedule_fn = sigmoid_beta_schedule
else:
raise ValueError(f'unknown beta schedule {beta_schedule}')
betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
# sampling related parameters
self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
assert self.sampling_timesteps <= timesteps
self.is_ddim_sampling = self.sampling_timesteps < timesteps
self.ddim_sampling_eta = ddim_sampling_eta
# helper function to register buffer from float64 to float32
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
register_buffer('betas', betas)
register_buffer('alphas_cumprod', alphas_cumprod)
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
# offset noise strength - in blogpost, they claimed 0.1 was ideal
self.offset_noise_strength = offset_noise_strength
# derive loss weight
# snr - signal noise ratio
snr = alphas_cumprod / (1 - alphas_cumprod)
# https://arxiv.org/abs/2303.09556
maybe_clipped_snr = snr.clone()
if min_snr_loss_weight:
maybe_clipped_snr.clamp_(max = min_snr_gamma)
if objective == 'pred_noise':
register_buffer('loss_weight', maybe_clipped_snr / snr)
elif objective == 'pred_x0':
register_buffer('loss_weight', maybe_clipped_snr)
elif objective == 'pred_v':
register_buffer('loss_weight', maybe_clipped_snr / (snr + 1))
# auto-normalization of data [0, 1] -> [-1, 1] - can turn off by setting it to be False
self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity
self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity
@property
def device(self):
return self.betas.device
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def predict_noise_from_start(self, x_t, t, x0):
return (
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
)
def predict_v(self, x_start, t, noise):
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
)
def predict_start_from_v(self, x_t, t, v):
return (
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False):
model_output = self.model(x, t, x_self_cond)
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
if self.objective == 'pred_noise':
pred_noise = model_output
x_start = self.predict_start_from_noise(x, t, pred_noise)
x_start = maybe_clip(x_start)
if clip_x_start and rederive_pred_noise:
pred_noise = self.predict_noise_from_start(x, t, x_start)
elif self.objective == 'pred_x0':
x_start = model_output
x_start = maybe_clip(x_start)
pred_noise = self.predict_noise_from_start(x, t, x_start)
elif self.objective == 'pred_v':
v = model_output
x_start = self.predict_start_from_v(x, t, v)
x_start = maybe_clip(x_start)
pred_noise = self.predict_noise_from_start(x, t, x_start)
return ModelPrediction(pred_noise, x_start)
def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
preds = self.model_predictions(x, t, x_self_cond)
x_start = preds.pred_x_start
if clip_denoised:
x_start.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.inference_mode()
def p_sample(self, x, t: int, x_self_cond = None):
b, *_, device = *x.shape, self.device
batched_times = torch.full((b,), t, device = device, dtype = torch.long)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True)
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
return pred_img, x_start
@torch.inference_mode()
def p_sample_loop(self, shape, return_all_timesteps = False):
batch, device = shape[0], self.device
img = torch.randn(shape, device = device)
imgs = [img]
x_start = None
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
self_cond = x_start if self.self_condition else None
img, x_start = self.p_sample(img, t, self_cond)
imgs.append(img)
ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
ret = self.unnormalize(ret)
return ret
@torch.inference_mode()
def ddim_sample(self, shape, return_all_timesteps = False):
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
img = torch.randn(shape, device = device)
imgs = [img]
x_start = None
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
self_cond = x_start if self.self_condition else None
pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True)
if time_next < 0:
img = x_start
imgs.append(img)
continue
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c = (1 - alpha_next - sigma ** 2).sqrt()
noise = torch.randn_like(img)
img = x_start * alpha_next.sqrt() + \
c * pred_noise + \
sigma * noise
imgs.append(img)
ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
ret = self.unnormalize(ret)
return ret
@torch.inference_mode()
def sample(self, batch_size = 16, return_all_timesteps = False):
image_size, channels = self.image_size, self.channels
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps)
@torch.inference_mode()
def interpolate(self, x1, x2, t = None, lam = 0.5):
b, *_, device = *x1.shape, x1.device
t = default(t, self.num_timesteps - 1)
assert x1.shape == x2.shape
t_batched = torch.full((b,), t, device = device)
xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2))
img = (1 - lam) * xt1 + lam * xt2
x_start = None
for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t):
self_cond = x_start if self.self_condition else None
img, x_start = self.p_sample(img, i, self_cond)
return img
@autocast(enabled = False)
def q_sample(self, x_start, t, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, x_start, t, noise = None, offset_noise_strength = None):
b, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))
# offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)
if offset_noise_strength > 0.:
offset_noise = torch.randn(x_start.shape[:2], device = self.device)
noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
# noise sample
x = self.q_sample(x_start = x_start, t = t, noise = noise)
# if doing self-conditioning, 50% of the time, predict x_start from current set of times
# and condition with unet with that
# this technique will slow down training by 25%, but seems to lower FID significantly
x_self_cond = None
if self.self_condition and random() < 0.5:
with torch.inference_mode():
x_self_cond = self.model_predictions(x, t).pred_x_start
x_self_cond.detach_()
# predict and take gradient step
model_out = self.model(x, t, x_self_cond)
if self.objective == 'pred_noise':
target = noise
elif self.objective == 'pred_x0':
target = x_start
elif self.objective == 'pred_v':
v = self.predict_v(x_start, t, noise)
target = v
else:
raise ValueError(f'unknown objective {self.objective}')
loss = F.mse_loss(model_out, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b (...)', 'mean')
loss = loss * extract(self.loss_weight, t, loss.shape)
return loss.mean()
def forward(self, img, *args, **kwargs):
b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
img = self.normalize(img)
return self.p_losses(img, t, *args, **kwargs)
GaussianDiffusion
和GaussianDiffusion1D
两个类的主要差异在于它们处理的数据维度不同。GaussianDiffusion
类可以处理更高维度的数据,例如二维的图像数据,而GaussianDiffusion1D
类只处理一维的数据。
以下是两个类中一些主要函数的差异和简化:
-
__init__
函数:在GaussianDiffusion
类中,这个函数需要处理的参数更多,包括图像的宽度和高度,而在GaussianDiffusion1D
类中,由于只处理一维数据,所以这个函数的参数更少,只需要处理数据的长度。 -
p_sample
函数:在GaussianDiffusion
类中,这个函数需要处理多个时间步,每个时间步都需要生成一个新的图像。而在GaussianDiffusion1D
类中,这个函数只需要处理一个时间步,生成一个新的一维数据。 -
p_sample_loop
和ddim_sample
函数:在GaussianDiffusion
类中,这两个函数需要在一个循环中多次调用p_sample
函数进行采样,每次采样都需要生成一个新的图像。而在GaussianDiffusion1D
类中,这两个函数的实现更简单,只需要在一个循环中多次调用p_sample
函数进行采样,每次采样都生成一个新的一维数据。 -
sample
函数:在GaussianDiffusion
类中,这个函数需要根据is_ddim_sampling
参数选择采样函数(p_sample_loop
或ddim_sample
),然后在一个循环中多次调用选定的采样函数进行采样,每次采样都需要生成一个新的图像。而在GaussianDiffusion1D
类中,这个函数的实现更简单,只需要根据is_ddim_sampling
参数选择采样函数(p_sample_loop
或ddim_sample
),然后在一个循环中多次调用选定的采样函数进行采样,每次采样都生成一个新的一维数据。
5、Model¶
在DDPM(Denoising Diffusion Probabilistic Models)模型中,model
的目标可以是pred_noise
、pred_x0
或pred_v
,这些选项代表了不同的预测目标,具体如下:
-
pred_noise
:在这种情况下,model
试图预测在每个时间步长应该添加的噪声。这是最直接的方法,因为在前向过程中,我们实际上是在每个步骤中添加噪声。然后在后向过程中,我们试图预测并去除这些噪声。 -
pred_x0
:在这种情况下,model
试图预测原始图像(即噪声图像在完全去噪后的状态)。这是一个更具挑战性的任务,因为model
需要在每个时间步长预测原始图像的全貌,即使在早期步骤中,噪声图像可能与原始图像差距很大。 -
pred_v
:这是一个更复杂的预测目标,它来自于DDPM的一种变体,称为Score-Based Generative Models。在这种情况下,model
预测的是一个"score",它是原始图像的梯度方向。这个"score"可以被看作是一个向量场,指向原始图像的方向。在后向过程中,我们可以沿着这个向量场的方向,逐步从噪声图像"移动"向原始图像。
这三种预测目标提供了不同的方式来解决DDPM模型的重建任务。选择哪种预测目标取决于具体的应用需求和模型性能。
Construct¶
一维¶
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
def Upsample(dim, dim_out = None):
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv1d(dim, default(dim_out, dim), 3, padding = 1)
)
def Downsample(dim, dim_out = None):
return nn.Conv1d(dim, default(dim_out, dim), 4, 2, 1)
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.g = nn.Parameter(torch.ones(1, dim, 1))
def forward(self, x):
return F.normalize(x, dim = 1) * self.g * (x.shape[1] ** 0.5)
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = RMSNorm(dim)
def forward(self, x):
x = self.norm(x)
return self.fn(x)
# sinusoidal positional embeds
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class RandomOrLearnedSinusoidalPosEmb(nn.Module):
""" following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
def __init__(self, dim, is_random = False):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)
def forward(self, x):
x = rearrange(x, 'b -> b 1')
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
fouriered = torch.cat((x, fouriered), dim = -1)
return fouriered
# building block modules
class Block(nn.Module):
def __init__(self, dim, dim_out, groups = 8):
super().__init__()
self.proj = nn.Conv1d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift = None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
super().__init__()
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb = None):
scale_shift = None
if exists(self.mlp) and exists(time_emb):
time_emb = self.mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1')
scale_shift = time_emb.chunk(2, dim = 1)
h = self.block1(x, scale_shift = scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
class LinearAttention(nn.Module):
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv1d(hidden_dim, dim, 1),
RMSNorm(dim)
)
def forward(self, x):
b, c, n = x.shape
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) n -> b h c n', h = self.heads), qkv)
q = q.softmax(dim = -2)
k = k.softmax(dim = -1)
q = q * self.scale
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
out = rearrange(out, 'b h c n -> b (h c) n', h = self.heads)
return self.to_out(out)
class Attention(nn.Module):
def __init__(self, dim, heads = 4, dim_head = 32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv1d(hidden_dim, dim, 1)
def forward(self, x):
b, c, n = x.shape
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) n -> b h c n', h = self.heads), qkv)
q = q * self.scale
sim = einsum('b h d i, b h d j -> b h i j', q, k)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h d j -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b (h d) n')
return self.to_out(out)
代码定义了一些用于构建深度学习模型的基础模块。下面是每个函数和类的作用:
-
Residual
:残差模块,它将输入x通过一个函数fn处理后,再加上原始的输入x,形成了一个残差连接。 -
Upsample
和Downsample
:两个函数分别用于上采样和下采样。上采样是将输入的特征图放大,下采样则是将输入的特征图缩小。 -
RMSNorm
:归一化层,它使用RMSNorm方法进行归一化。归一化可以帮助模型更好地学习和理解数据。 -
PreNorm
:预归一化模块,它先对输入进行归一化,然后再通过一个函数进行处理。 -
SinusoidalPosEmb
和RandomOrLearnedSinusoidalPosEmb
:两个类用于生成正弦位置嵌入。位置嵌入是用于处理序列数据的一种技术,它可以帮助模型理解序列中的元素的位置关系。 -
Block
:基础的卷积块,它包含一个卷积层,一个归一化层,和一个激活函数。 -
ResnetBlock
:ResNet风格的卷积块,它包含两个基础卷积块和一个残差连接。 -
LinearAttention
和Attention
:两个类都是用于实现注意力机制的模块。注意力机制是一种让模型在处理数据时能够关注到重要部分的技术。
多维¶
def Upsample(dim, dim_out = None):
return nn.Sequential(
nn.Upsample(scale_factor = 2, mode = 'nearest'),
nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1)
)
def Downsample(dim, dim_out = None):
return nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = 2, p2 = 2),
nn.Conv2d(dim * 4, default(dim_out, dim), 1)
)
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
def forward(self, x):
return F.normalize(x, dim = 1) * self.g * (x.shape[1] ** 0.5)
# sinusoidal positional embeds
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class RandomOrLearnedSinusoidalPosEmb(nn.Module):
""" following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
def __init__(self, dim, is_random = False):
super().__init__()
assert divisible_by(dim, 2)
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)
def forward(self, x):
x = rearrange(x, 'b -> b 1')
freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
fouriered = torch.cat((x, fouriered), dim = -1)
return fouriered
# building block modules
class Block(nn.Module):
def __init__(self, dim, dim_out, groups = 8):
super().__init__()
self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift = None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
super().__init__()
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb = None):
scale_shift = None
if exists(self.mlp) and exists(time_emb):
time_emb = self.mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1')
scale_shift = time_emb.chunk(2, dim = 1)
h = self.block1(x, scale_shift = scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
class LinearAttention(nn.Module):
def __init__(
self,
dim,
heads = 4,
dim_head = 32
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.norm = RMSNorm(dim)
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Sequential(
nn.Conv2d(hidden_dim, dim, 1),
RMSNorm(dim)
)
def forward(self, x):
b, c, h, w = x.shape
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
q = q.softmax(dim = -2)
k = k.softmax(dim = -1)
q = q * self.scale
context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
return self.to_out(out)
class Attention(nn.Module):
def __init__(
self,
dim,
heads = 4,
dim_head = 32,
flash = False
):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.norm = RMSNorm(dim)
self.attend = Attend(flash = flash)
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = 1)
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h (x y) c', h = self.heads), qkv)
out = self.attend(q, k, v)
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
return self.to_out(out)
该段代码与一维代码实现的功能大致相同,但因为主要是针对二维数据设计的,因此在部分方法上有着区别:
-
Upsample
和Downsample
:与一维代码段中的函数相比,使用了二维卷积和二维重排列。 -
RMSNorm
:与一维代码段中的函数相比,增加了一个维度。 -
SinusoidalPosEmb
和RandomOrLearnedSinusoidalPosEmb
:这两个类用于生成正弦位置嵌入。位置嵌入是用于处理序列数据的一种技术,它可以帮助模型理解序列中的元素的位置关系。 -
Block
:与一维代码段中的函数相比,使用了二维卷积。 -
ResnetBlock
:与一维代码段中的函数相比,使用了二维卷积。 -
LinearAttention
和Attention
:与一维代码段中的函数相比,使用了二维卷积,并且在处理数据时考虑了数据的二维结构。
U-Net¶
U-Net是一种常用于图像分割的深度学习模型,它的特点是在编码器和解码器之间有很多的跨层连接。在这个模型中,首先通过一系列的下采样层将数据的维度逐渐减小,然后在中间层进行处理,最后通过一系列的上采样层将数据的维度逐渐恢复,同时在每一层都会有残差连接和注意力机制。
一维¶
class Unet1D(nn.Module):
def __init__(
self,
dim,
init_dim = None,
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
self_condition = False,
resnet_block_groups = 8,
learned_variance = False,
learned_sinusoidal_cond = False,
random_fourier_features = False,
learned_sinusoidal_dim = 16,
attn_dim_head = 32,
attn_heads = 4
):
super().__init__()
# determine dimensions
self.channels = channels
self.self_condition = self_condition
input_channels = channels * (2 if self_condition else 1)
init_dim = default(init_dim, dim)
self.init_conv = nn.Conv1d(input_channels, init_dim, 7, padding = 3)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
block_klass = partial(ResnetBlock, groups = resnet_block_groups)
# time embeddings
time_dim = dim * 4
self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features
if self.random_or_learned_sinusoidal_cond:
sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
fourier_dim = learned_sinusoidal_dim + 1
else:
sinu_pos_emb = SinusoidalPosEmb(dim)
fourier_dim = dim
self.time_mlp = nn.Sequential(
sinu_pos_emb,
nn.Linear(fourier_dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Downsample(dim_in, dim_out) if not is_last else nn.Conv1d(dim_in, dim_out, 3, padding = 1)
]))
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim, dim_head = attn_dim_head, heads = attn_heads)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
is_last = ind == (len(in_out) - 1)
self.ups.append(nn.ModuleList([
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Upsample(dim_out, dim_in) if not is_last else nn.Conv1d(dim_out, dim_in, 3, padding = 1)
]))
default_out_dim = channels * (1 if not learned_variance else 2)
self.out_dim = default(out_dim, default_out_dim)
self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
self.final_conv = nn.Conv1d(dim, self.out_dim, 1)
def forward(self, x, time, x_self_cond = None):
if self.self_condition:
x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x_self_cond, x), dim = 1)
x = self.init_conv(x)
r = x.clone()
t = self.time_mlp(time)
h = []
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
h.append(x)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim = 1)
x = block1(x, t)
x = torch.cat((x, h.pop()), dim = 1)
x = block2(x, t)
x = attn(x)
x = upsample(x)
x = torch.cat((x, r), dim = 1)
x = self.final_res_block(x, t)
return self.final_conv(x)
上述代码构建了一个一维的U-Net模型,主要用于处理一维序列数据,如音频或时间序列数据。因此,它的输入数据通常是形状为[batch_size, channels, length]的三维张量。在U-Net1D模型中,卷积、上采样和下采样等操作都是一维的。
在Unet1D
的forward
函数中,首先检查是否设置了自我条件模式,如果设置了,那么会将输入数据和自我条件数据进行拼接。然后通过初始卷积层进行处理,接着将处理后的数据通过一系列的下采样层、中间层和上采样层进行处理,最后通过最终的残差块和卷积层得到输出。在这个过程中,还会使用时间嵌入层对输入的时间进行处理,并将处理后的时间嵌入作为额外的输入传递给各个卷积块和残差块。
多维¶
class Unet(nn.Module):
def __init__(
self,
dim,
init_dim = None,
out_dim = None,
dim_mults = (1, 2, 4, 8),
channels = 3,
self_condition = False,
resnet_block_groups = 8,
learned_variance = False,
learned_sinusoidal_cond = False,
random_fourier_features = False,
learned_sinusoidal_dim = 16,
attn_dim_head = 32,
attn_heads = 4,
full_attn = (False, False, False, True),
flash_attn = False
):
super().__init__()
# determine dimensions
self.channels = channels
self.self_condition = self_condition
input_channels = channels * (2 if self_condition else 1)
init_dim = default(init_dim, dim)
self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
block_klass = partial(ResnetBlock, groups = resnet_block_groups)
# time embeddings
time_dim = dim * 4
self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features
if self.random_or_learned_sinusoidal_cond:
sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
fourier_dim = learned_sinusoidal_dim + 1
else:
sinu_pos_emb = SinusoidalPosEmb(dim)
fourier_dim = dim
self.time_mlp = nn.Sequential(
sinu_pos_emb,
nn.Linear(fourier_dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
# attention
num_stages = len(dim_mults)
full_attn = cast_tuple(full_attn, num_stages)
attn_heads = cast_tuple(attn_heads, num_stages)
attn_dim_head = cast_tuple(attn_dim_head, num_stages)
assert len(full_attn) == len(dim_mults)
FullAttention = partial(Attention, flash = flash_attn)
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(in_out, full_attn, attn_heads, attn_dim_head)):
is_last = ind >= (num_resolutions - 1)
attn_klass = FullAttention if layer_full_attn else LinearAttention
self.downs.append(nn.ModuleList([
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
attn_klass(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads),
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
]))
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
self.mid_attn = FullAttention(mid_dim, heads = attn_heads[-1], dim_head = attn_dim_head[-1])
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
for ind, ((dim_in, dim_out), layer_full_attn, layer_attn_heads, layer_attn_dim_head) in enumerate(zip(*map(reversed, (in_out, full_attn, attn_heads, attn_dim_head)))):
is_last = ind == (len(in_out) - 1)
attn_klass = FullAttention if layer_full_attn else LinearAttention
self.ups.append(nn.ModuleList([
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
attn_klass(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads),
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
]))
default_out_dim = channels * (1 if not learned_variance else 2)
self.out_dim = default(out_dim, default_out_dim)
self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
self.final_conv = nn.Conv2d(dim, self.out_dim, 1)
@property
def downsample_factor(self):
return 2 ** (len(self.downs) - 1)
def forward(self, x, time, x_self_cond = None):
assert all([divisible_by(d, self.downsample_factor) for d in x.shape[-2:]]), f'your input dimensions {x.shape[-2:]} need to be divisible by {self.downsample_factor}, given the unet'
if self.self_condition:
x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x_self_cond, x), dim = 1)
x = self.init_conv(x)
r = x.clone()
t = self.time_mlp(time)
h = []
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
h.append(x)
x = block2(x, t)
x = attn(x) + x
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x) + x
x = self.mid_block2(x, t)
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim = 1)
x = block1(x, t)
x = torch.cat((x, h.pop()), dim = 1)
x = block2(x, t)
x = attn(x) + x
x = upsample(x)
x = torch.cat((x, r), dim = 1)
x = self.final_res_block(x, t)
return self.final_conv(x)
与修改后的 U-Net1D 相比,原始的U-Net模型是用于处理二维图像数据的,因此它的输入数据通常是形状为[batch_size, channels, height, width]的四维张量。在U-Net模型中,卷积、上采样和下采样等操作都是二维的。
对比¶
由于上述 U-Net 模型针对处理的数据维度有区别,导致其使用的操作也存在不同。其主要区别如下:
-
数据维度:Unet类处理的是二维数据,因此它使用的是二维卷积(nn.Conv2d),二维上采样和二维下采样。而Unet1D类处理的是一维数据,因此它使用的是一维卷积(nn.Conv1d),一维上采样和一维下采样。
-
模型结构:Unet和Unet1D类的模型结构基本相同,都包含了下采样(downsampling)、上采样(upsampling)和跳跃连接(skip connection)。他们都使用了残差块(ResnetBlock)和注意力机制(Attention)。但是在实现细节上,由于处理的数据维度不同,所以他们使用的操作也不同。
-
注意力机制:在Unet类中,使用了全注意力(FullAttention)和线性注意力(LinearAttention)。而在Unet1D类中,只使用了线性注意力(LinearAttention)。
-
时间嵌入:Unet和Unet1D类都使用了时间嵌入(time embeddings),这是一种处理序列数据中的位置信息的方法。他们都使用了正弦位置嵌入(SinusoidalPosEmb)或者随机或学习的正弦位置嵌入(RandomOrLearnedSinusoidalPosEmb)。
-
输入数据的形状:Unet类的输入数据的形状通常是[batch_size, channels, height, width],而Unet1D类的输入数据的形状通常是[batch_size, channels, length]。
-
自我条件(self-condition):Unet和Unet1D类都支持自我条件(self-condition),这是一种让模型能够处理自我监督学习任务的方法。如果启用了自我条件,那么模型的输入数据会包含两部分:一部分是原始的输入数据,另一部分是自我条件的数据。
-
输出数据的形状:Unet类的输出数据的形状是[batch_size, out_dim, height, width],而Unet1D类的输出数据的形状是[batch_size, out_dim, length]。
需要注意的是,U-Net1D模型还包含了一些特殊的设计,如使用了位置嵌入(Position Embedding)来处理序列数据中的位置信息,使用了注意力机制(Attention)来让模型能够关注到序列中的重要部分,以及使用了残差连接(Residual Connection)和归一化(Normalization)等技术来提升模型的性能。
Created: 2023-07-29