1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
|
import torch
from models.auto_encoder import Encoder, Decoder, AutoEncoder
N, C, H, W = 128, 3, 64, 32
def test_default_encoder():
encoder = Encoder()
x = torch.rand(N, C, H, W)
f_a, f_c, f_p = encoder(x)
assert tuple(f_a.size()) == (N, 128)
assert tuple(f_c.size()) == (N, 128)
assert tuple(f_p.size()) == (N, 64)
def test_custom_encoder():
output_dims = (64, 64, 32)
encoder = Encoder(in_channels=1,
feature_channels=32,
output_dims=output_dims)
x = torch.rand(N, 1, H, W)
f_a, f_c, f_p = encoder(x)
assert tuple(f_a.size()) == (N, output_dims[0])
assert tuple(f_c.size()) == (N, output_dims[1])
assert tuple(f_p.size()) == (N, output_dims[2])
def test_default_decoder():
decoder = Decoder()
f_a, f_c, f_p = torch.rand(N, 128), torch.rand(N, 128), torch.rand(N, 64)
x_trans_conv = decoder(f_a, f_c, f_p)
assert tuple(x_trans_conv.size()) == (N, C, H, W)
x_no_trans_conv = decoder(f_a, f_c, f_p, no_trans_conv=True)
assert tuple(x_no_trans_conv.size()) == (N, 64 * 8, 4, 2)
def test_custom_decoder():
embedding_dims = (64, 64, 32)
feature_channels = 32
decoder = Decoder(input_dims=embedding_dims,
feature_channels=feature_channels,
out_channels=1)
f_a, f_c, f_p = (torch.rand(N, embedding_dims[0]),
torch.rand(N, embedding_dims[1]),
torch.rand(N, embedding_dims[2]))
x_trans_conv = decoder(f_a, f_c, f_p)
assert tuple(x_trans_conv.size()) == (N, 1, H, W)
x_no_trans_conv = decoder(f_a, f_c, f_p, no_trans_conv=True)
assert tuple(x_no_trans_conv.size()) == (N, feature_channels * 8, 4, 2)
def test_default_auto_encoder():
ae = AutoEncoder()
x = torch.rand(N, C, H, W)
y = torch.randint(74, (N,))
ae.train()
((x_c, x_p), (f_p_c1, f_p_c2), (xrecon, cano)) = ae(x, x, x, y)
assert tuple(x_c.size()) == (N, 64 * 8, 4, 2)
assert tuple(x_p.size()) == (N, C, H, W)
assert tuple(f_p_c1.size()) == tuple(f_p_c2.size()) == (N, 64)
assert tuple(xrecon.size()) == tuple(cano.size()) == ()
ae.eval()
(x_c, x_p) = ae(x, x, x)
assert tuple(x_c.size()) == (N, 64 * 8, 4, 2)
assert tuple(x_p.size()) == (N, C, H, W)
def test_custom_auto_encoder():
num_class = 10
channels = 1
embedding_dims = (64, 64, 32)
feature_channels = 32
ae = AutoEncoder(num_class=num_class,
channels=channels,
feature_channels=feature_channels,
embedding_dims=embedding_dims)
x = torch.rand(N, 1, H, W)
y = torch.randint(num_class, (N,))
ae.train()
((x_c, x_p), (f_p_c1, f_p_c2), (xrecon, cano)) = ae(x, x, x, y)
assert tuple(x_c.size()) == (N, feature_channels * 8, 4, 2)
assert tuple(x_p.size()) == (N, 1, H, W)
assert tuple(f_p_c1.size()) \
== tuple(f_p_c2.size()) \
== (N, embedding_dims[2])
assert tuple(xrecon.size()) == tuple(cano.size()) == ()
ae.eval()
(x_c, x_p) = ae(x, x, x)
assert tuple(x_c.size()) == (N, feature_channels * 8, 4, 2)
assert tuple(x_p.size()) == (N, 1, H, W)
|