Capstone-6 StyleGAN 코드 구현
by Dongkyun Kim
원본 : https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
결과 이미지 :
pretrained weight를 받아 구현한 코드입니다.
dataset은 FFHQ dataset을 통해 이미지 생성을 했는데, 이 부분을 이모티콘 도메인으로 바꾸어주고 학습 시키는 방향으로 진행할 예정입니다.
랜덤 시드 값만을 바꾸어주어 이미지를 생성한 모습입니다.s
CODE
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
import pickle
import numpy as np
import IPython
class MyLinear(nn.Module):
"""Linear layer with equalized learning rate and custom learning rate multiplier."""
def __init__(self, input_size, output_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True):
super().__init__()
##nn.Module의 __init__ 상속
he_std = gain * input_size**(-0.5) # He init
# Equalized learning rate and custom learning rate multiplier.
if use_wscale:
init_std = 1.0 / lrmul
self.w_mul = he_std * lrmul
else:
init_std = he_std / lrmul
self.w_mul = lrmul
self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std)
#(input, output)모양의 tensor 만들어 주고 표준편차로 나누어줌(초기화)
if bias:W
self.bias = torch.nn.Parameter(torch.zeros(output_size))
self.b_mul = lrmul
else:
self.bias = None
def forward(self, x):
bias = self.bias
if bias is not None:
bias = bias * self.b_mul
#bias에 b_mul(계수 ??) 곱해줌
return F.linear(x, self.weight * self.w_mul, bias)
class MyConv2d(nn.Module):
"""Conv layer with equalized learning rate and custom learning rate multiplier."""
def __init__(self, input_channels, output_channels, kernel_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True,
intermediate=None, upscale=False):
super().__init__()
if upscale:
self.upscale = Upscale2d()
else:
self.upscale = None
he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5) # He init
self.kernel_size = kernel_size
if use_wscale:
init_std = 1.0 / lrmul
self.w_mul = he_std * lrmul
else:
init_std = he_std / lrmul
self.w_mul = lrmul
#init_std = he_std
self.weight = torch.nn.Parameter(torch.randn(output_channels, input_channels, kernel_size, kernel_size) * init_std)
if bias:
self.bias = torch.nn.Parameter(torch.zeros(output_channels))
self.b_mul = lrmul
else:
self.bias = None
self.intermediate = intermediate
def forward(self, x):
bias = self.bias
if bias is not None:
bias = bias * self.b_mul
have_convolution = False
if self.upscale is not None and min(x.shape[2:]) * 2 >= 128:
# this is the fused upscale + conv from StyleGAN, sadly this seems incompatible with the non-fused way
# this really needs to be cleaned up and go into the conv...
w = self.weight * self.w_mul
w = w.permute(1, 0, 2, 3)
#input,output 차원 바꾸어줌
# probably applying a conv on w would be more efficient. also this quadruples the weight (average)?!
w = F.pad(w, (1,1,1,1))
#(batch, channel, height, width) 중 height, width에 (1, 1) 씩 추가
w = w[:, :, 1:, 1:]+ w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]
x = F.conv_transpose2d(x, w, stride=2, padding=(w.size(-1)-1)//2)
have_convolution = True
elif self.upscale is not None:
x = self.upscale(x)
if not have_convolution and self.intermediate is None:
return F.conv2d(x, self.weight * self.w_mul, bias, padding=self.kernel_size//2)
elif not have_convolution:
x = F.conv2d(x, self.weight * self.w_mul, None, padding=self.kernel_size//2)
if self.intermediate is not None:
x = self.intermediate(x)
if bias is not None:
x = x + bias.view(1, -1, 1, 1)
return x
class NoiseLayer(nn.Module):
"""adds noise. noise is per pixel (constant over channels) with per-channel weight"""
def __init__(self, channels):
super().__init__()
self.weight = nn.Parameter(torch.zeros(channels))
self.noise = None
def forward(self, x, noise=None):
if noise is None and self.noise is None:
noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype)
#x의 shape에 맞게 noise를 랜덤하게 생성
elif noise is None:
# here is a little trick: if you get all the noiselayers and set each
# modules .noise attribute, you can have pre-defined noise.
# Very useful for analysis
noise = self.noise
x = x + self.weight.view(1, -1, 1, 1) * noise
return x
class StyleMod(nn.Module):
def __init__(self, latent_size, channels, use_wscale):
super(StyleMod, self).__init__()
self.lin = MyLinear(latent_size,
channels * 2,
gain=1.0, use_wscale=use_wscale)
def forward(self, x, latent):
style = self.lin(latent) # style => [batch_size, n_channels*2]
shape = [-1, 2, x.size(1)] + (x.dim() - 2) * [1]
style = style.view(shape) # [batch_size, 2, n_channels, ...]
x = x * (style[:, 0] + 1.) + style[:, 1]
return x
class PixelNormLayer(nn.Module):
def __init__(self, epsilon=1e-8):
super().__init__()
self.epsilon = epsilon
def forward(self, x):
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)
class BlurLayer(nn.Module):
def __init__(self, kernel=[1, 2, 1], normalize=True, flip=False, stride=1):
super(BlurLayer, self).__init__()
kernel=[1, 2, 1]
kernel = torch.tensor(kernel, dtype=torch.float32)
kernel = kernel[:, None] * kernel[None, :]
kernel = kernel[None, None]
if normalize:
kernel = kernel / kernel.sum()
if flip:
kernel = kernel[:, :, ::-1, ::-1]
self.register_buffer('kernel', kernel)
self.stride = stride
def forward(self, x):
# expand kernel channels
kernel = self.kernel.expand(x.size(1), -1, -1, -1)
x = F.conv2d(
x,
kernel,
stride=self.stride,
padding=int((self.kernel.size(2)-1)/2),
groups=x.size(1)
)
return x
def upscale2d(x, factor=2, gain=1):
assert x.dim() == 4
if gain != 1:
x = x * gain
if factor != 1:
shape = x.shape
x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, factor, -1, factor)
x = x.contiguous().view(shape[0], shape[1], factor * shape[2], factor * shape[3])
return x
class Upscale2d(nn.Module):
def __init__(self, factor=2, gain=1):
super().__init__()
assert isinstance(factor, int) and factor >= 1
self.gain = gain
self.factor = factor
def forward(self, x):
return upscale2d(x, factor=self.factor, gain=self.gain)
class G_mapping(nn.Sequential):
def __init__(self, nonlinearity='lrelu', use_wscale=True):
act, gain = {'relu': (torch.relu, np.sqrt(2)),
'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]
layers = [
('pixel_norm', PixelNormLayer()),
('dense0', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
('dense0_act', act),
('dense1', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
('dense1_act', act),
('dense2', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
('dense2_act', act),
('dense3', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
('dense3_act', act),
('dense4', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
('dense4_act', act),
('dense5', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
('dense5_act', act),
('dense6', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
('dense6_act', act),
('dense7', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
('dense7_act', act)
]
super().__init__(OrderedDict(layers))
def forward(self, x):
x = super().forward(x)
# Broadcast
x = x.unsqueeze(1).expand(-1, 18, -1)
return x
class Truncation(nn.Module):
def __init__(self, avg_latent, max_layer=8, threshold=0.7):
super().__init__()
self.max_layer = max_layer
self.threshold = threshold
self.register_buffer('avg_latent', avg_latent)
def forward(self, x):
assert x.dim() == 3
interp = torch.lerp(self.avg_latent, x, self.threshold)
do_trunc = (torch.arange(x.size(1)) < self.max_layer).view(1, -1, 1)
return torch.where(do_trunc, interp, x)
class LayerEpilogue(nn.Module):
"""Things to do at the end of each layer."""
def __init__(self, channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
super().__init__()
layers = []
if use_noise:
layers.append(('noise', NoiseLayer(channels)))
layers.append(('activation', activation_layer))
if use_pixel_norm:
layers.append(('pixel_norm', PixelNorm()))
if use_instance_norm:
layers.append(('instance_norm', nn.InstanceNorm2d(channels)))
self.top_epi = nn.Sequential(OrderedDict(layers))
if use_styles:
self.style_mod = StyleMod(dlatent_size, channels, use_wscale=use_wscale)
else:
self.style_mod = None
def forward(self, x, dlatents_in_slice=None):
x = self.top_epi(x)
if self.style_mod is not None:
x = self.style_mod(x, dlatents_in_slice)
else:
assert dlatents_in_slice is None
return x
class InputBlock(nn.Module):
def __init__(self, nf, dlatent_size, const_input_layer, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
super().__init__()
self.const_input_layer = const_input_layer
self.nf = nf
if self.const_input_layer:
# called 'const' in tf
self.const = nn.Parameter(torch.ones(1, nf, 4, 4))
self.bias = nn.Parameter(torch.ones(nf))
else:
self.dense = MyLinear(dlatent_size, nf*16, gain=gain/4, use_wscale=use_wscale) # tweak gain to match the official implementation of Progressing GAN
self.epi1 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)
self.conv = MyConv2d(nf, nf, 3, gain=gain, use_wscale=use_wscale)
self.epi2 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)
def forward(self, dlatents_in_range):
batch_size = dlatents_in_range.size(0)
if self.const_input_layer:
x = self.const.expand(batch_size, -1, -1, -1)
x = x + self.bias.view(1, -1, 1, 1)
else:
x = self.dense(dlatents_in_range[:, 0]).view(batch_size, self.nf, 4, 4)
x = self.epi1(x, dlatents_in_range[:, 0])
x = self.conv(x)
x = self.epi2(x, dlatents_in_range[:, 1])
return x
class GSynthesisBlock(nn.Module):
def __init__(self, in_channels, out_channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
# 2**res x 2**res # res = 3..resolution_log2
super().__init__()
if blur_filter:
blur = BlurLayer(blur_filter)
else:
blur = None
self.conv0_up = MyConv2d(in_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale,
intermediate=blur, upscale=True)
self.epi1 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)
self.conv1 = MyConv2d(out_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale)
self.epi2 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)
def forward(self, x, dlatents_in_range):
x = self.conv0_up(x)
x = self.epi1(x, dlatents_in_range[:, 0])
x = self.conv1(x)
x = self.epi2(x, dlatents_in_range[:, 1])
return x
class G_synthesis(nn.Module):
def __init__(self,
dlatent_size = 512, # Disentangled latent (W) dimensionality.
num_channels = 3, # Number of output color channels.
resolution = 1024, # Output resolution.
fmap_base = 8192, # Overall multiplier for the number of feature maps.
fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution.
fmap_max = 512, # Maximum number of feature maps in any layer.
use_styles = True, # Enable style inputs?
const_input_layer = True, # First layer is a learned constant?
use_noise = True, # Enable noise inputs?
randomize_noise = True, # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables.
nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu'
use_wscale = True, # Enable equalized learning rate?
use_pixel_norm = False, # Enable pixelwise feature vector normalization?
use_instance_norm = True, # Enable instance normalization?
dtype = torch.float32, # Data type to use for activations and outputs.
blur_filter = [1,2,1], # Low-pass filter to apply when resampling activations. None = no filtering.
):
super().__init__()
def nf(stage):
return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
self.dlatent_size = dlatent_size
resolution_log2 = int(np.log2(resolution))
assert resolution == 2**resolution_log2 and resolution >= 4
act, gain = {'relu': (torch.relu, np.sqrt(2)),
'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]
num_layers = resolution_log2 * 2 - 2
num_styles = num_layers if use_styles else 1
torgbs = []
blocks = []
for res in range(2, resolution_log2 + 1):
channels = nf(res-1)
name = '{s}x{s}'.format(s=2**res)
if res == 2:
blocks.append((name,
InputBlock(channels, dlatent_size, const_input_layer, gain, use_wscale,
use_noise, use_pixel_norm, use_instance_norm, use_styles, act)))
else:
blocks.append((name,
GSynthesisBlock(last_channels, channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, act)))
last_channels = channels
self.torgb = MyConv2d(channels, num_channels, 1, gain=1, use_wscale=use_wscale)
self.blocks = nn.ModuleDict(OrderedDict(blocks))
def forward(self, dlatents_in):
# Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size].
# lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0), trainable=False), dtype)
batch_size = dlatents_in.size(0)
for i, m in enumerate(self.blocks.values()):
if i == 0:
x = m(dlatents_in[:, 2*i:2*i+2])
else:
x = m(x, dlatents_in[:, 2*i:2*i+2])
rgb = self.torgb(x)
return rgb
g_all = nn.Sequential(OrderedDict([
('g_mapping', G_mapping()),
#('truncation', Truncation(avg_latent)),
('g_synthesis', G_synthesis())
]))
if 0:
# this can be run to get the weights, but you need the reference implementation and weights
import dnnlib, dnnlib.tflib, pickle, torch, collections
dnnlib.tflib.init_tf()
weights = pickle.load(open('./karras2019stylegan-ffhq-1024x1024.pkl','rb'))
weights_pt = [collections.OrderedDict([(k, torch.from_numpy(v.value().eval())) for k,v in w.trainables.items()]) for w in weights]
torch.s]
ave(weights_pt, './karras2019stylegan-ffhq-1024x1024.pt')
if 0:
# then on the PyTorch side run
state_G, state_D, state_Gs = torch.load('./karras2019stylegan-ffhq-1024x1024.pt')
def key_translate(k):
k = k.lower().split('/')
if k[0] == 'g_synthesis':
if not k[1].startswith('torgb'):
k.insert(1, 'blocks')
k = '.'.join(k)
k = (k.replace('const.const','const').replace('const.bias','bias').replace('const.stylemod','epi1.style_mod.lin')
.replace('const.noise.weight','epi1.top_epi.noise.weight')
.replace('conv.noise.weight','epi2.top_epi.noise.weight')
.replace('conv.stylemod','epi2.style_mod.lin')
.replace('conv0_up.noise.weight', 'epi1.top_epi.noise.weight')
.replace('conv0_up.stylemod','epi1.style_mod.lin')
.replace('conv1.noise.weight', 'epi2.top_epi.noise.weight')
.replace('conv1.stylemod','epi2.style_mod.lin')
.replace('torgb_lod0','torgb'))
else:
k = '.'.join(k)
return k
def weight_translate(k, w):
k = key_translate(k)
if k.endswith('.weight'):
if w.dim() == 2:
w = w.t()
elif w.dim() == 1:
pass
else:
assert w.dim() == 4
w = w.permute(3, 2, 0, 1)
return w
# we delete the useless torgb filters
param_dict = {key_translate(k) : weight_translate(k, v) for k,v in state_Gs.items() if 'torgb_lod' not in key_translate(k)}
if 1:
sd_shapes = {k : v.shape for k,v in g_all.state_dict().items()}
param_shapes = {k : v.shape for k,v in param_dict.items() }
for k in list(sd_shapes)+list(param_shapes):
pds = param_shapes.get(k)
sds = sd_shapes.get(k)
if pds is None:
print ("sd only", k, sds)
elif sds is None:
print ("pd only", k, pds)
elif sds != pds:
print ("mismatch!", k, pds, sds)
g_all.load_state_dict(param_dict, strict=False) # needed for the blur kernels
torch.save(g_all.state_dict(), './karras2019stylegan-ffhq-1024x1024.for_g_all.pt')
g_all.load_state_dict(torch.load('./karras2019stylegan-ffhq-1024x1024.for_g_all.pt'))
%matplotlib inline
from matplotlib import pyplot
import torchvision
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
g_all.eval()
g_all.to(device)
torch.manual_seed(-5) #random number를 생성하기 위한 시드 설정(인자값이 시드를 만듦), torch.Generator반환
nb_rows = 2
nb_cols = 5
nb_samples = nb_rows * nb_cols
latents = torch.randn(nb_samples, 512, device=device)
with torch.no_grad():
imgs = g_all(latents)
imgs = (imgs.clamp(-1, 1) + 1) / 2.0 # normalization to 0..1 range
imgs = imgs.cpu()
imgs = torchvision.utils.make_grid(imgs, nrow=nb_cols)
pyplot.figure(figsize=(15, 6))
pyplot.imshow(imgs.permute(1, 2, 0).detach().numpy())
Subscribe via RSS