Skip to content
Snippets Groups Projects
Commit 8df9672d authored by 柏灌's avatar 柏灌
Browse files

add discriminator

parent a67dc2d8
No related branches found
No related tags found
No related merge requests found
......@@ -676,4 +676,61 @@ class FullGenerator(nn.Module):
outs = self.generator([outs], return_latents, inject_index, truncation, truncation_latent, input_is_latent, noise=noise[1:])
return outs
class Discriminator(nn.Module):
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], narrow=1):
super().__init__()
channels = {
4: int(512 * narrow),
8: int(512 * narrow),
16: int(512 * narrow),
32: int(512 * narrow),
64: int(256 * channel_multiplier * narrow),
128: int(128 * channel_multiplier * narrow),
256: int(64 * channel_multiplier * narrow),
512: int(32 * channel_multiplier * narrow),
1024: int(16 * channel_multiplier * narrow)
}
convs = [ConvLayer(3, channels[size], 1)]
log_size = int(math.log(size, 2))
in_channel = channels[size]
for i in range(log_size, 2, -1):
out_channel = channels[2 ** (i - 1)]
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
in_channel = out_channel
self.convs = nn.Sequential(*convs)
self.stddev_group = 4
self.stddev_feat = 1
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
self.final_linear = nn.Sequential(
EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
EqualLinear(channels[4], 1),
)
def forward(self, input):
out = self.convs(input)
batch, channel, height, width = out.shape
group = min(batch, self.stddev_group)
stddev = out.view(
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
)
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
stddev = stddev.repeat(group, 1, height, width)
out = torch.cat([out, stddev], 1)
out = self.final_conv(out)
out = out.view(batch, -1)
out = self.final_linear(out)
return out
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment