|
| 1 | +import torch.nn as nn |
| 2 | + |
| 3 | + |
| 4 | +def nonlinearity(x): |
| 5 | + # swish |
| 6 | + return x * torch.sigmoid(x) |
| 7 | + |
| 8 | + |
| 9 | +def Normalize(in_channels, num_groups=32): |
| 10 | + return nn.GroupNorm( |
| 11 | + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True |
| 12 | + ) |
| 13 | + |
| 14 | + |
| 15 | +class ResnetBlock(nn.Module): |
| 16 | + def __init__( |
| 17 | + self, |
| 18 | + *, |
| 19 | + in_channels, |
| 20 | + out_channels=None, |
| 21 | + conv_shortcut=False, |
| 22 | + dropout, |
| 23 | + temb_channels=512 |
| 24 | + ): |
| 25 | + super().__init__() |
| 26 | + self.in_channels = in_channels |
| 27 | + out_channels = in_channels if out_channels is None else out_channels |
| 28 | + self.out_channels = out_channels |
| 29 | + self.use_conv_shortcut = conv_shortcut |
| 30 | + |
| 31 | + self.norm1 = Normalize(in_channels) |
| 32 | + self.conv1 = nn.Conv2d( |
| 33 | + in_channels, out_channels, kernel_size=3, stride=1, padding=1 |
| 34 | + ) |
| 35 | + if temb_channels > 0: |
| 36 | + self.temb_proj = nn.Linear(temb_channels, out_channels) |
| 37 | + |
| 38 | + self.norm2 = Normalize(out_channels) |
| 39 | + self.droput = nn.Dropout(dropout) |
| 40 | + self.conv2 = nn.Conv2d( |
| 41 | + out_channels, out_channels, kernel_size=3, stride=1, padding=1 |
| 42 | + ) |
| 43 | + |
| 44 | + if self.in_channels != self.out_channels: |
| 45 | + if self.use_conv_shortcut: |
| 46 | + self.conv_shortcut = nn.Conv2d( |
| 47 | + in_channels, out_channels, kernel_size=3, stride=1, padding=1 |
| 48 | + ) |
| 49 | + else: |
| 50 | + self.nin_shortcut = nn.Conv2d( |
| 51 | + in_channels, out_channels, kernel_size=1, stride=1, padding=0 |
| 52 | + ) |
| 53 | + |
| 54 | + def forward(self, x, temb): |
| 55 | + h = x |
| 56 | + h = self.norm1(h) |
| 57 | + h = nonlinearity(h) |
| 58 | + h = self.conv1(h) |
| 59 | + |
| 60 | + if temb is not None: |
| 61 | + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] |
| 62 | + |
| 63 | + h = self.norm2(h) |
| 64 | + h = nonlinearity(h) |
| 65 | + h = self.dropout(h) |
| 66 | + h = self.conv2(h) |
| 67 | + |
| 68 | + if self.in_channels != self.out_channels: |
| 69 | + if self.use_conv_shortcut: |
| 70 | + x = self.conv_shortcut(x) |
| 71 | + else: |
| 72 | + x = self.nin_shortcut(x) |
| 73 | + |
| 74 | + return x + h |
0 commit comments