最近研究在自编码器,放一个复现的代码,移除了工程相关的代码,只保留了核心,有多卡accelerate就设置为True,没有就关了。
Decode 和 Encode 参考了stable diffusion的设计,Decode最后一层改成了方差和均值(也就是纯血VAE)特征图通过采样产生,再使用VQ量化特征图。图片最后还是有些胡,感觉是因为有些图像被压缩过,插值成256*256,或者jpeg格式的有损压缩导致了数据有噪声被学会了。
数据源:
Konachan动漫头像数据集_数据集-飞桨AI Studio星河社区
效果图
epoch 0 step 100
epoch 6 step 10000
epoch 50 step 85000epoch 100 176700
模型代码
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, groups=1):
super(ConvBlock, self).__init__()
self.conv_block = nn.Sequential(
nn.GroupNorm(groups, in_channels),
nn.SiLU(),
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
)
def forward(self, x):
return self.conv_block(x)
class ResnetBlock(nn.Module):
def __init__(self, in_channels, out_channels, groups=32):
super(ResnetBlock, self).__init__()
self.conv_block = nn.Sequential(
ConvBlock(in_channels, out_channels, groups=groups),
ConvBlock(out_channels, out_channels, groups=groups),
)
if in_channels != out_channels:
self.skip_conn = ConvBlock(in_channels, out_channels, kernel_size=1, padding=0, groups=groups)
else:
self.skip_conn = nn.Identity()
def forward(self, x):
return self.conv_block(x) + self.skip_conn(x)
class AttentionBlock(nn.Module):
def __init__(self, in_channels, out_channels, groups=32):
super(AttentionBlock, self).__init__()
self.q_conv = ConvBlock(in_channels, out_channels, kernel_size=1, padding=0, groups=groups)
self.k_conv = ConvBlock(in_channels, out_channels, kernel_size=1, padding=0, groups=groups)
self.v_conv = ConvBlock(in_channels, out_channels, kernel_size=1, padding=0, groups=groups)
self.out_conv = ConvBlock(out_channels, out_channels, kernel_size=1, padding=0, groups=groups)
if in_channels != out_channels:
self.skip_conn = ConvBlock(in_channels, out_channels, kernel_size=1, padding=0, groups=groups)
else:
self.skip_conn = nn.Identity()
def forward(self, x):
q = self.q_conv(x)
k = self.k_conv(x)
v = self.v_conv(x)
attention = torch.einsum('bchw,bcHW->bhwHW', q, k)
attention = attention / math.sqrt(q.shape[-1])
attention = attention.softmax(dim=-1)
out = torch.einsum('bhwHW,bcHW->bchw', attention, v)
out = self.out_conv(out)
return out + self.skip_conn(x)
class MiddleBlock(nn.Module):
def __init__(self, in_channels, out_channels, groups=32):
super(MiddleBlock, self).__init__()
self.conv_block = nn.Sequential(
ResnetBlock(in_channels, out_channels, groups=groups),
AttentionBlock(out_channels, out_channels, groups=groups),
ResnetBlock(out_channels, out_channels, groups=groups),
)
def forward(self, x):
return self.conv_block(x)
class UpSample(nn.Module):
def __init__(self, in_channels, out_channels):
super(UpSample, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
def forward(self, x):
x = nn.functional.interpolate(x, scale_factor=2)
x = self.conv(x)
return x
class DownSample(nn.Module):
def __init__(self, in_channels, out_channels):
super(DownSample, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, 2, 0)
def forward(self, x):
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode='constant', value=0)
x = self.conv(x)
return x
class DownBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(DownBlock, self).__init__()
self.down_block = nn.Sequential(
ResnetBlock(in_channels, out_channels),
ResnetBlock(out_channels, out_channels),
)
def forward(self, x):
return self.down_block(x)
class UpBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(UpBlock, self).__init__()
self.up_block = nn.Sequential(
ResnetBlock(in_channels, out_channels),
ResnetBlock(out_channels, out_channels),
)
def forward(self, x):
return self.up_block(x)
class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, groups=32):
super(Encoder, self).__init__()
self.conv = nn.Conv2d(in_channels, 128, 3, 1, 1)
self.res_block = self.create_resnet_block(128, 128, 2, groups=groups)
self.res_block2 = self.create_resnet_block(128, 256, 2, groups=groups)
self.res_block3 = self.create_resnet_block(256, 512, 2, groups=groups)
self.down_block = DownBlock(512, 512)
self.middle_block = MiddleBlock(512, 512, groups=groups)
self.conv_block = ConvBlock(512, z_channels * 2, groups=groups)
@staticmethod
def create_resnet_block(in_channels, out_channels, num_blocks, groups=32):
res_blocks = []
for _ in range(num_blocks):
res_blocks.append(ResnetBlock(in_channels, in_channels, groups=groups))
res_blocks.append(DownSample(in_channels, out_channels))
return nn.Sequential(*res_blocks)
def forward(self, x):
x = self.conv(x)
x = self.res_block(x)
x = self.res_block2(x)
x = self.res_block3(x)
x = self.down_block(x)
x = self.middle_block(x)
x = self.conv_block(x)
return x
class Decoder(nn.Module):
def __init__(self, in_channels, groups=32):
super(Decoder, self).__init__()
self.conv = nn.Conv2d(in_channels, 512, 3, 1, 1)
self.middle_block = MiddleBlock(512, 512, groups=groups)
self.resnet_block = self.create_resnet_block(512, 512, 3, groups=groups)
self.resnet_block2 = self.create_resnet_block(512, 256, 3, groups=groups)
self.resnet_block3 = self.create_resnet_block(256, 128, 3, groups=groups)
self.up_block = UpBlock(128, 128)
self.conv_block = ConvBlock(128, 3, groups=groups)
@staticmethod
def create_resnet_block(in_channels, out_channels, num_blocks, groups=32):
res_blocks = []
for _ in range(num_blocks):
res_blocks.append(ResnetBlock(in_channels, in_channels, groups=groups))
res_blocks.append(UpSample(in_channels, out_channels))
return nn.Sequential(*res_blocks)
def forward(self, x):
x = self.conv(x)
x = self.middle_block(x)
x = self.resnet_block(x)
x = self.resnet_block2(x)
x = self.resnet_block3(x)
x = self.up_block(x)
x = self.conv_block(x)
return x
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=None):
if dims is None:
dims = [1, 2, 3]
if self.deterministic:
return torch.Tensor([0.])
log_two_pi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
log_two_pi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
class VectorQuantizer(nn.Module):
"""带EMA更新的向量量化层"""
def __init__(self, num_embeddings, embedding_dim, beta=0.25, decay=0.99, epsilon=1e-5, ema=False):
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.beta = beta
self.decay = decay
self.epsilon = epsilon
self.ema = ema
# 码本初始化
self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
self.embedding.weight.data.normal_()
# self.embedding.requires_grad_(False)
# EMA统计量
self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
self.register_buffer('_ema_w', self.embedding.weight.data.clone())
def forward(self, z):
# 形状变换
z = z.permute(0, 2, 3, 1) # [B, D, H, W] -> [B, H, W, D]
z_flattened = z.reshape(-1, self.embedding_dim)
# 计算码本距离
distances = torch.cdist(z_flattened, self.embedding.weight, p=2.0) ** 2
# 获取最近邻编码
encoding_indices = torch.argmin(distances, dim=1)
quantized = self.embedding(encoding_indices).view(z.shape)
quantized = quantized.permute(0, 3, 1, 2)
vq_loss = self.beta * F.mse_loss(quantized.detach(), z.permute(0, 3, 1, 2))
# EMA 更新
if self.training and self.ema:
with torch.no_grad():
# 更新 EMA 统计量
encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=z.device)
encodings.scatter_(1, encoding_indices.view(-1, 1), 1)
updated_ema_cluster_size = self._ema_cluster_size * self.decay + (1 - self.decay) * torch.sum(encodings,
0)
# Laplace平滑
n = torch.sum(updated_ema_cluster_size)
updated_ema_cluster_size = ((updated_ema_cluster_size + self.epsilon)
/ (n + self.num_embeddings * self.epsilon) * n)
dw = torch.matmul(encodings.t(), z_flattened)
updated_ema_w = self._ema_w * self.decay + (1 - self.decay) * dw
# 更新码本
self._ema_cluster_size.data.copy_(updated_ema_cluster_size)
self.embedding.weight.data.copy_(updated_ema_w / updated_ema_cluster_size.unsqueeze(1))
else:
codebook_loss = F.mse_loss(quantized, z.permute(0, 3, 1, 2).detach())
vq_loss = vq_loss + codebook_loss
# 直通估计
quantized = z.permute(0, 3, 1, 2) + (quantized - z.permute(0, 3, 1, 2)).detach()
return quantized, encoding_indices, vq_loss
class VAE(nn.Module):
def __init__(self, in_channels, groups=32, z_channels=4, embedding_dim=4):
super(VAE, self).__init__()
self.scale_factor = 0.18215
self.encoder = Encoder(in_channels, z_channels, groups=groups)
self.decoder = Decoder(z_channels, groups=groups)
self.quant_conv = nn.Conv2d(z_channels * 2, embedding_dim * 2, 1, 1, 0)
self.post_quant_conv = nn.Conv2d(embedding_dim, z_channels, 1, 1, 0)
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
out = posterior.sample()
out = self.scale_factor * out
return out
def decode(self, z):
z = 1. / self.scale_factor * z
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, x):
z = self.encode(x)
dec = self.decode(z)
return dec
def generate(self, x):
x = self.decoder(x)
return x
class VQVAE(VAE):
def __init__(self, in_channels=3, groups=8, z_channels=4, embedding_dim=4, num_embeddings=8196, beta=0.25,
decay=0.99, epsilon=1e-5):
super(VQVAE, self).__init__(in_channels, groups, z_channels, embedding_dim)
self.quantize = VectorQuantizer(num_embeddings,
embedding_dim,
ema=True,
beta=beta,
decay=decay,
epsilon=epsilon)
def forward(self, x):
z = self.encode(x)
quantized, _, vq_loss = self.quantize(z)
x_recon = self.decode(quantized)
return x_recon, vq_loss
def calculate_balance_facter(self, perceptual_loss, gan_loss):
last_layer = self.decoder.conv_block.conv_block[-1]
last_layer_weight = last_layer.weight
perceptual_loss_grads = torch.autograd.grad(perceptual_loss, last_layer_weight, retain_graph=True)[0]
gan_loss_grads = torch.autograd.grad(gan_loss, last_layer_weight, retain_graph=True)[0]
alpha = torch.norm(perceptual_loss_grads) / (torch.norm(gan_loss_grads) + 1e-4)
alpha = torch.clamp(alpha, 0, 1e4).detach()
return 0.8 * alpha
训练脚本
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from accelerate import Accelerator, DistributedDataParallelKwargs
from lpips import LPIPS
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.models import VGG19_Weights
from tqdm import tqdm
from vae import VQVAE
# --------------------------
# 对抗组件
# --------------------------
class Discriminator(nn.Module):
"""多尺度判别器"""
def __init__(self, in_channels=3, base_channels=4, num_layers=3):
super().__init__()
layers = [nn.Conv2d(in_channels, base_channels, 4, 2, 1), nn.LeakyReLU(0.2)]
channels = base_channels
for _ in range(1, num_layers):
layers += [
nn.Conv2d(channels, channels * 2, 4, 2, 1),
nn.InstanceNorm2d(channels * 2),
nn.LeakyReLU(0.2)
]
channels *= 2
layers += [
nn.Conv2d(channels, channels, 4, 1, 0),
nn.InstanceNorm2d(channels),
nn.LeakyReLU(0.2),
nn.Conv2d(channels, 1, 1)
]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class PerceptualLoss(nn.Module):
def __init__(self, layers=None):
super(PerceptualLoss, self).__init__()
if layers is None:
layers = ['1', '2', '4', '7']
self.layers = layers
self.vgg = torchvision.models.vgg19(weights=VGG19_Weights.DEFAULT).features.eval()
self.vgg.requires_grad_(False)
for name, module in self.vgg.named_modules():
if name in layers:
module.register_forward_hook(self.forward_hook)
self.features = []
def forward_hook(self, module, input, output):
self.features.append(output)
def forward(self, x, x_recon):
x_and_x_recon = torch.cat((x, x_recon), dim=0)
self.features = []
self.vgg(x_and_x_recon)
x_and_x_recon_features = self.features
loss = torch.tensor(0.0, device=x.device)
for i, layer in enumerate(self.layers):
x_feature = x_and_x_recon_features[i][:x.shape[0]]
x_norm_factor = torch.sqrt(torch.mean(x_feature ** 2, dim=1, keepdim=True))
x_feature = x_feature / x_norm_factor
x_recon_feature = x_and_x_recon_features[i][x.shape[0]:]
x_recon_norm_factor = torch.sqrt(torch.mean(x_recon_feature ** 2, dim=1, keepdim=True))
x_recon_feature = x_recon_feature / x_recon_norm_factor
loss += F.l1_loss(x_feature, x_recon_feature)
return loss
# --------------------------
# 训练循环
# --------------------------
def train_vqgan(dataloader, epochs=100, mixed_precision=False, accelerate=False, disc_start=10000, rec_factor=1,
perceptual_factor=1, learning_rate=4.5e-6, in_channels=3, groups=8, z_channels=4, embedding_dim=4,
num_embeddings=8196, beta=0.25, decay=0.99, epsilon=1e-5):
os.makedirs('results', exist_ok=True)
# 初始化模型
model = VQVAE(in_channels, groups, z_channels, embedding_dim, num_embeddings, beta, decay, epsilon)
discriminator = Discriminator()
# perceptual_loss_fn = PerceptualLoss()
perceptual_loss_fn = LPIPS().eval()
# 优化器
opt_ae = Adam(list(model.encoder.parameters()) + list(model.decoder.parameters())
+ list(model.quantize.parameters()), lr=learning_rate, betas=(0.5, 0.9))
opt_disc = Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.9))
gradient_accumulation_steps = 4
step = 0
start_epoch = 0
if os.path.exists("vqgan.pth"):
state_dict = torch.load("vqgan.pth")
step = state_dict.get("step", 0)
start_epoch = state_dict.get("epoch", 0)
model.load_state_dict(state_dict.get("model", {}))
discriminator.load_state_dict(state_dict.get("discriminator", {}))
opt_ae.load_state_dict(state_dict.get("opt_ae", {}))
opt_disc.load_state_dict(state_dict.get("opt_disc", {}))
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
if accelerate:
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps,
mixed_precision='fp16' if mixed_precision else 'no',
kwargs_handlers=[ddp_kwargs])
# 加速器
model, discriminator, perceptual_loss_fn, opt_ae, opt_disc, dataloader = accelerator.prepare(
model, discriminator, perceptual_loss_fn, opt_ae, opt_disc, dataloader)
device = accelerator.device
else:
accelerator = None
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
discriminator = discriminator.to(device)
perceptual_loss_fn = perceptual_loss_fn.to(device)
for epoch in range(start_epoch, epochs):
with tqdm(range(len(dataloader))) as pbar:
for _, batch in zip(pbar, dataloader):
x, _ = batch
x = x.to(device)
if accelerator is not None:
# 生成器更新
with accelerator.autocast():
disc_loss, g_loss, perceptual_loss, rec_loss, total_loss, vq_loss, x_recon = train_step(
accelerator,
disc_start,
discriminator,
model,
perceptual_factor,
perceptual_loss_fn,
rec_factor,
step,
x)
opt_ae.zero_grad()
accelerator.backward(total_loss, retain_graph=True)
opt_disc.zero_grad()
accelerator.backward(disc_loss)
opt_ae.step()
opt_disc.step()
else:
# 生成器更新
with torch.amp.autocast(device, enabled=mixed_precision):
disc_loss, g_loss, perceptual_loss, rec_loss, total_loss, vq_loss, x_recon = train_step(
accelerator,
disc_start,
discriminator,
model,
perceptual_factor,
perceptual_loss_fn,
rec_factor,
step,
x)
opt_ae.zero_grad()
total_loss.backward(retain_graph=True)
opt_disc.zero_grad()
disc_loss.backward()
opt_ae.step()
opt_disc.step()
pbar.set_postfix(
TotalLoss=np.round(total_loss.cpu().detach().numpy().item(), 5),
DiscLoss=np.round(disc_loss.cpu().detach().numpy().item(), 3),
PerceptualLoss=np.round(perceptual_loss.cpu().detach().numpy().item(), 5),
RecLoss=np.round(rec_loss.cpu().detach().numpy().item(), 5),
GenLoss=np.round(g_loss.cpu().detach().numpy().item(), 5),
VqLoss=np.round(vq_loss.cpu().detach().numpy().item(), 5)
)
pbar.update(0)
# 日志记录
if step % 100 == 0:
if accelerator:
if accelerator.is_main_process:
with torch.no_grad():
fake_image = x_recon[:4].permute(0, 2, 3, 1).contiguous()
means = torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 1, 3).to(fake_image.device)
stds = torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 1, 3).to(fake_image.device)
fake_image = fake_image * stds + means
fake_image.clamp_(0, 1)
fake_image = fake_image.permute(0, 3, 1, 2).contiguous()
real_image = x[:4].permute(0, 2, 3, 1).contiguous()
real_image = real_image * stds + means
real_image.clamp_(0, 1)
real_image = real_image.permute(0, 3, 1, 2).contiguous()
real_fake_images = torch.cat((real_image, fake_image))
torchvision.utils.save_image(real_fake_images,
os.path.join("results", f"{epoch}_{step}.jpg"),
nrow=4)
else:
with torch.no_grad():
real_fake_images = torch.cat((x[:4], x_recon.add(1).mul(0.5)[:4]))
torchvision.utils.save_image(real_fake_images,
os.path.join("results", f"{epoch}_{step}.jpg"),
nrow=4)
step += 1
if accelerate:
if accelerate and accelerator.is_main_process:
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_discriminator = accelerator.unwrap_model(discriminator)
# 保存模型
state_dict = {
"model": unwrapped_model.state_dict(),
"discriminator": unwrapped_discriminator.state_dict(),
"opt_ae": opt_ae.state_dict(),
"opt_disc": opt_disc.state_dict(),
"step": step,
"epoch": epoch
}
torch.save(state_dict, "vqgan.pth")
else:
# 保存模型
state_dict = {
"model": model.state_dict(),
"discriminator": discriminator.state_dict(),
"opt_ae": opt_ae.state_dict(),
"opt_disc": opt_disc.state_dict(),
"step": step,
"epoch": epoch
}
torch.save(state_dict, "vqgan.pth")
return model, discriminator, opt_ae, opt_disc
def train_step(accelerator, disc_start, discriminator, model, perceptual_factor, perceptual_loss_fn, rec_factor, step,
x):
x_recon, vq_loss = model(x)
disc_real = discriminator(x)
disc_faker = discriminator(x_recon)
disc_factor = 0 if disc_start > step else 1
perceptual_loss = perceptual_loss_fn(x, x_recon).mean()
rec_loss = F.l1_loss(x_recon, x)
perceptual_rec_loss = perceptual_factor * perceptual_loss + rec_factor * rec_loss
perceptual_rec_loss = perceptual_rec_loss.mean()
g_loss = -torch.mean(disc_faker)
if accelerator:
balance_facter = model.module.calculate_balance_facter(perceptual_rec_loss, g_loss)
else:
balance_facter = model.calculate_balance_facter(perceptual_rec_loss, g_loss)
total_loss = perceptual_rec_loss + vq_loss + disc_factor * balance_facter * g_loss
d_real_loss = F.binary_cross_entropy_with_logits(
disc_real, torch.ones_like(disc_real))
d_fake_loss = F.binary_cross_entropy_with_logits(
disc_faker, torch.zeros_like(disc_faker))
disc_loss = disc_factor * 0.5 * (d_real_loss + d_fake_loss)
return disc_loss, g_loss, perceptual_loss, rec_loss, total_loss, vq_loss, x_recon
def get_imagenet_dataloader(batch_size=32, data_path="datasets/faces"):
# 数据加载
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = ImageFolder(data_path, transform=transform)
return DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)
# --------------------------
# 使用示例
# --------------------------
if __name__ == "__main__":
# 数据加载(示例)
train_loader = get_imagenet_dataloader(batch_size=12, data_path="faces")
# 开始训练
train_vqgan(train_loader, epochs=100, mixed_precision=True, accelerate=True, disc_start=10000, rec_factor=1,
perceptual_factor=1, learning_rate=4.5e-6, in_channels=3, groups=8, z_channels=4, embedding_dim=4,
num_embeddings=8196, beta=0.25, decay=0.99, epsilon=1e-5)