Skip to content

Commit bad5ee1

Browse files
committed
added resnet block
1 parent 2172f59 commit bad5ee1

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

controlnet/autoencoder.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import torch.nn as nn
2+
3+
4+
def nonlinearity(x):
5+
# swish
6+
return x * torch.sigmoid(x)
7+
8+
9+
def Normalize(in_channels, num_groups=32):
10+
return nn.GroupNorm(
11+
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
12+
)
13+
14+
15+
class ResnetBlock(nn.Module):
16+
def __init__(
17+
self,
18+
*,
19+
in_channels,
20+
out_channels=None,
21+
conv_shortcut=False,
22+
dropout,
23+
temb_channels=512
24+
):
25+
super().__init__()
26+
self.in_channels = in_channels
27+
out_channels = in_channels if out_channels is None else out_channels
28+
self.out_channels = out_channels
29+
self.use_conv_shortcut = conv_shortcut
30+
31+
self.norm1 = Normalize(in_channels)
32+
self.conv1 = nn.Conv2d(
33+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
34+
)
35+
if temb_channels > 0:
36+
self.temb_proj = nn.Linear(temb_channels, out_channels)
37+
38+
self.norm2 = Normalize(out_channels)
39+
self.droput = nn.Dropout(dropout)
40+
self.conv2 = nn.Conv2d(
41+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
42+
)
43+
44+
if self.in_channels != self.out_channels:
45+
if self.use_conv_shortcut:
46+
self.conv_shortcut = nn.Conv2d(
47+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
48+
)
49+
else:
50+
self.nin_shortcut = nn.Conv2d(
51+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
52+
)
53+
54+
def forward(self, x, temb):
55+
h = x
56+
h = self.norm1(h)
57+
h = nonlinearity(h)
58+
h = self.conv1(h)
59+
60+
if temb is not None:
61+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
62+
63+
h = self.norm2(h)
64+
h = nonlinearity(h)
65+
h = self.dropout(h)
66+
h = self.conv2(h)
67+
68+
if self.in_channels != self.out_channels:
69+
if self.use_conv_shortcut:
70+
x = self.conv_shortcut(x)
71+
else:
72+
x = self.nin_shortcut(x)
73+
74+
return x + h

0 commit comments

Comments
 (0)