Esmail-AGumaan commited on
Commit
01342fe
1 Parent(s): f8b823c

Update encoder.py

Browse files
Files changed (1) hide show
  1. encoder.py +55 -55
encoder.py CHANGED
@@ -1,56 +1,56 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import functional as F
4
- from nanograd.models.stable_diffusion.decoder import VAE_AttentionBlock, VAE_ResidualBlock
5
-
6
- class VAE_Encoder(nn.Sequential):
7
- def __init__(self):
8
- super().__init__(
9
- nn.Conv2d(3, 128, kernel_size=3, padding=1),
10
-
11
- VAE_ResidualBlock(128, 128),
12
- VAE_ResidualBlock(128, 128),
13
-
14
- nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
15
-
16
- VAE_ResidualBlock(128, 256),
17
- VAE_ResidualBlock(256, 256),
18
-
19
- nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
20
-
21
- VAE_ResidualBlock(256, 512),
22
- VAE_ResidualBlock(512, 512),
23
-
24
- nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
25
-
26
- VAE_ResidualBlock(512, 512),
27
- VAE_ResidualBlock(512, 512),
28
- VAE_ResidualBlock(512, 512),
29
- VAE_AttentionBlock(512),
30
- VAE_ResidualBlock(512, 512),
31
-
32
- nn.GroupNorm(32, 512),
33
-
34
- nn.SiLU(),
35
-
36
- nn.Conv2d(512, 8, kernel_size=3, padding=1),
37
-
38
- nn.Conv2d(8, 8, kernel_size=1, padding=0),
39
- )
40
-
41
- def forward(self, x, noise):
42
- for module in self:
43
-
44
- if getattr(module, 'stride', None) == (2, 2):
45
- x = F.pad(x, (0, 1, 0, 1))
46
-
47
- x = module(x)
48
- mean, log_variance = torch.chunk(x, 2, dim=1)
49
- log_variance = torch.clamp(log_variance, -30, 20)
50
- variance = log_variance.exp()
51
- stdev = variance.sqrt()
52
- x = mean + stdev * noise
53
- # Constant taken from: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L17C1-L17C1
54
- x *= 0.18215
55
-
56
  return x
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from decoder import VAE_AttentionBlock, VAE_ResidualBlock
5
+
6
+ class VAE_Encoder(nn.Sequential):
7
+ def __init__(self):
8
+ super().__init__(
9
+ nn.Conv2d(3, 128, kernel_size=3, padding=1),
10
+
11
+ VAE_ResidualBlock(128, 128),
12
+ VAE_ResidualBlock(128, 128),
13
+
14
+ nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
15
+
16
+ VAE_ResidualBlock(128, 256),
17
+ VAE_ResidualBlock(256, 256),
18
+
19
+ nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
20
+
21
+ VAE_ResidualBlock(256, 512),
22
+ VAE_ResidualBlock(512, 512),
23
+
24
+ nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
25
+
26
+ VAE_ResidualBlock(512, 512),
27
+ VAE_ResidualBlock(512, 512),
28
+ VAE_ResidualBlock(512, 512),
29
+ VAE_AttentionBlock(512),
30
+ VAE_ResidualBlock(512, 512),
31
+
32
+ nn.GroupNorm(32, 512),
33
+
34
+ nn.SiLU(),
35
+
36
+ nn.Conv2d(512, 8, kernel_size=3, padding=1),
37
+
38
+ nn.Conv2d(8, 8, kernel_size=1, padding=0),
39
+ )
40
+
41
+ def forward(self, x, noise):
42
+ for module in self:
43
+
44
+ if getattr(module, 'stride', None) == (2, 2):
45
+ x = F.pad(x, (0, 1, 0, 1))
46
+
47
+ x = module(x)
48
+ mean, log_variance = torch.chunk(x, 2, dim=1)
49
+ log_variance = torch.clamp(log_variance, -30, 20)
50
+ variance = log_variance.exp()
51
+ stdev = variance.sqrt()
52
+ x = mean + stdev * noise
53
+ # Constant taken from: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L17C1-L17C1
54
+ x *= 0.18215
55
+
56
  return x