summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-02-15 11:08:52 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-02-15 11:08:52 +0800
commit24b5968bfc5799e44c9bbbc00e3a9be00f4509ac (patch)
tree425e22421acec244b34352603737f3a9741c4ecf /models
parent34d2f9017e77a7bdef761ab3d92cd0340c5154c3 (diff)
Revert "Memory usage improvement"
This reverts commit be508061
Diffstat (limited to 'models')
-rw-r--r--models/auto_encoder.py70
-rw-r--r--models/model.py21
-rw-r--r--models/rgb_part_net.py117
3 files changed, 92 insertions, 116 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 918a95c..7b9b29f 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -119,47 +119,32 @@ class AutoEncoder(nn.Module):
embedding_dims: Tuple[int, int, int] = (128, 128, 64)
):
super().__init__()
- self.f_c_c1_t2_ = None
- self.f_p_c1_t2_ = None
- self.f_c_c1_t1_ = None
self.encoder = Encoder(channels, feature_channels, embedding_dims)
self.decoder = Decoder(embedding_dims, feature_channels, channels)
- def forward(self, x_t2, is_c1=True):
- n, t, c, h, w = x_t2.size()
- if is_c1: # condition 1
- # x_c1_t2 is the frame for later module
- x_c1_t2_ = x_t2.view(n * t, c, h, w)
- (f_a_c1_t2_, self.f_c_c1_t2_, self.f_p_c1_t2_) \
- = self.encoder(x_c1_t2_)
-
- if self.training:
- # t1 is random time step
- x_c1_t1 = x_t2[:, torch.randperm(t), :, :, :]
- x_c1_t1_ = x_c1_t1.view(n * t, c, h, w)
- (f_a_c1_t1_, self.f_c_c1_t1_, _) = self.encoder(x_c1_t1_)
-
- x_c1_t2_pred_ = self.decoder(
- f_a_c1_t1_, self.f_c_c1_t1_, self.f_p_c1_t2_
- )
- x_c1_t2_pred = x_c1_t2_pred_.view(n, t, c, h, w)
-
- xrecon_loss = torch.stack([
- F.mse_loss(x_t2[:, i, :, :, :], x_c1_t2_pred[:, i, :, :, :])
- for i in range(t)
- ]).sum()
-
- return ((f_a_c1_t2_, self.f_c_c1_t2_, self.f_p_c1_t2_),
- xrecon_loss)
- else: # evaluating
- return self.f_c_c1_t2_, self.f_p_c1_t2_
- else: # condition 2
- # c2 is another condition
- x_c2_t2_ = x_t2.view(n * t, c, h, w)
- (_, f_c_c2_t2_, f_p_c2_t2_) = self.encoder(x_c2_t2_)
-
- f_c_c1_t1 = self.f_c_c1_t1_.view(n, t, -1)
- f_c_c1_t2 = self.f_c_c1_t2_.view(n, t, -1)
+ def forward(self, x_c1_t2, x_c1_t1=None, x_c2_t2=None):
+ n, t, c, h, w = x_c1_t2.size()
+ # x_c1_t2 is the frame for later module
+ x_c1_t2_ = x_c1_t2.view(n * t, c, h, w)
+ (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_) = self.encoder(x_c1_t2_)
+
+ if self.training:
+ # t1 is random time step, c2 is another condition
+ x_c1_t1 = x_c1_t1.view(n * t, c, h, w)
+ (f_a_c1_t1_, f_c_c1_t1_, _) = self.encoder(x_c1_t1)
+ x_c2_t2 = x_c2_t2.view(n * t, c, h, w)
+ (_, f_c_c2_t2_, f_p_c2_t2_) = self.encoder(x_c2_t2)
+
+ x_c1_t2_pred_ = self.decoder(f_a_c1_t1_, f_c_c1_t1_, f_p_c1_t2_)
+ x_c1_t2_pred = x_c1_t2_pred_.view(n, t, c, h, w)
+
+ xrecon_loss = torch.stack([
+ F.mse_loss(x_c1_t2[:, i, :, :, :], x_c1_t2_pred[:, i, :, :, :])
+ for i in range(t)
+ ]).sum()
+
+ f_c_c1_t1 = f_c_c1_t1_.view(n, t, -1)
+ f_c_c1_t2 = f_c_c1_t2_.view(n, t, -1)
f_c_c2_t2 = f_c_c2_t2_.view(n, t, -1)
cano_cons_loss = torch.stack([
F.mse_loss(f_c_c1_t1[:, i, :], f_c_c1_t2[:, i, :])
@@ -167,8 +152,13 @@ class AutoEncoder(nn.Module):
for i in range(t)
]).mean()
- f_p_c1_t2 = self.f_p_c1_t2_.view(n, t, -1)
+ f_p_c1_t2 = f_p_c1_t2_.view(n, t, -1)
f_p_c2_t2 = f_p_c2_t2_.view(n, t, -1)
pose_sim_loss = F.mse_loss(f_p_c1_t2.mean(1), f_p_c2_t2.mean(1))
- return cano_cons_loss, pose_sim_loss * 10
+ return (
+ (f_a_c1_t2_, f_c_c1_t2_, f_p_c1_t2_),
+ (xrecon_loss, cano_cons_loss, pose_sim_loss * 10)
+ )
+ else: # evaluating
+ return f_c_c1_t2_, f_p_c1_t2_
diff --git a/models/model.py b/models/model.py
index 3aeb754..9748e46 100644
--- a/models/model.py
+++ b/models/model.py
@@ -182,7 +182,7 @@ class Model:
# Training start
start_time = datetime.now()
running_loss = torch.zeros(5, device=self.device)
- print(f"{'Time':^8} {'Iter':^5} {'Loss':^5}",
+ print(f"{'Time':^8} {'Iter':^5} {'Loss':^6}",
f"{'Xrecon':^8} {'CanoCons':^8} {'PoseSim':^8}",
f"{'BATripH':^8} {'BATripP':^8} {'LRs':^19}")
for (batch_c1, batch_c2) in dataloader:
@@ -190,21 +190,12 @@ class Model:
# Zero the parameter gradients
self.optimizer.zero_grad()
# forward + backward + optimize
- # Feed data twice in order to reduce memory usage
x_c1 = batch_c1['clip'].to(self.device)
+ x_c2 = batch_c2['clip'].to(self.device)
y = batch_c1['label'].to(self.device)
# Duplicate labels for each part
y = y.unsqueeze(1).repeat(1, self.rgb_pn.num_total_parts)
- # Feed condition 1 clips first
- losses, images = self.rgb_pn(x_c1, y)
- (xrecon_loss, hpm_ba_trip, pn_ba_trip) = losses
- x_c2 = batch_c2['clip'].to(self.device)
- # Then feed condition 2 clips
- cano_cons_loss, pose_sim_loss = self.rgb_pn(x_c2, is_c1=False)
- losses = torch.stack((
- xrecon_loss, cano_cons_loss, pose_sim_loss,
- hpm_ba_trip, pn_ba_trip
- ))
+ losses, images = self.rgb_pn(x_c1, x_c2, y)
loss = losses.sum()
loss.backward()
self.optimizer.step()
@@ -234,9 +225,7 @@ class Model:
self.writer.add_images(
'Canonical image', i_c, self.curr_iter
)
- for (i, (o, a, p)) in enumerate(zip(
- batch_c1['clip'], i_a, i_p
- )):
+ for (i, (o, a, p)) in enumerate(zip(x_c1, i_a, i_p)):
self.writer.add_images(
f'Original image/batch {i}', o, self.curr_iter
)
@@ -250,7 +239,7 @@ class Model:
remaining_minute, second = divmod(time_used.seconds, 60)
hour, minute = divmod(remaining_minute, 60)
print(f'{hour:02}:{minute:02}:{second:02}',
- f'{self.curr_iter:5d} {running_loss.sum() / 100:5.3f}',
+ f'{self.curr_iter:5d} {running_loss.sum() / 100:6.3f}',
'{:f} {:f} {:f} {:f} {:f}'.format(*running_loss / 100),
'{:.3e} {:.3e}'.format(lrs[0], lrs[1]))
running_loss.zero_()
diff --git a/models/rgb_part_net.py b/models/rgb_part_net.py
index c489ec6..260eabd 100644
--- a/models/rgb_part_net.py
+++ b/models/rgb_part_net.py
@@ -58,67 +58,64 @@ class RGBPartNet(nn.Module):
def fc(self, x):
return x @ self.fc_mat
- def forward(self, x, y=None, is_c1=True):
- # Step 1a: Disentangle condition 1 clips
- if is_c1:
- # n, t, c, h, w
- ((x_c, x_p), xrecon_loss, images) = self._disentangle(x, is_c1)
-
- # Step 2.a: Static Gait Feature Aggregation & HPM
- # n, c, h, w
- x_c = self.hpm(x_c)
- # p, n, c
-
- # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
- # n, t, c, h, w
- x_p = self.pn(x_p)
- # p, n, c
-
- # Step 3: Cat feature map together and fc
- x = torch.cat((x_c, x_p))
- x = self.fc(x)
-
- if self.training:
- y = y.T
- hpm_ba_trip = self.hpm_ba_trip(
- x[:self.hpm_num_parts], y[:self.hpm_num_parts]
- )
- pn_ba_trip = self.pn_ba_trip(
- x[self.hpm_num_parts:], y[self.hpm_num_parts:]
- )
- return (xrecon_loss, hpm_ba_trip, pn_ba_trip), images
- else: # evaluating
- return x.unsqueeze(1).view(-1)
- else: # Step 1b: Disentangle condition 2 clips
- return self._disentangle(x, is_c1)
-
- def _disentangle(self, x_t2, is_c1=True):
- if is_c1: # condition 1
- n, t, *_ = x_size = x_t2.size()
- device = x_t2.device
- if self.training:
- (f_a_, f_c_, f_p_), xrecon_loss = self.ae(x_t2, is_c1)
- # Decode features
- with torch.no_grad():
- x_c = self._decode_cano_feature(f_c_, n, t, device)
- x_p = self._decode_pose_feature(f_p_, *x_size, device)
-
- i_a, i_c, i_p = None, None, None
- if self.image_log_on:
- i_a = self._decode_appr_feature(f_a_, *x_size, device)
- # Continue decoding canonical features
- i_c = self.ae.decoder.trans_conv3(x_c)
- i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c))
- i_p = x_p
-
- return (x_c, x_p), xrecon_loss, (i_a, i_c, i_p)
- else: # evaluating
- f_c_, f_p_ = self.ae(x_t2)
+ def forward(self, x_c1, x_c2=None, y=None):
+ # Step 1: Disentanglement
+ # n, t, c, h, w
+ ((x_c, x_p), losses, images) = self._disentangle(x_c1, x_c2)
+
+ # Step 2.a: Static Gait Feature Aggregation & HPM
+ # n, c, h, w
+ x_c = self.hpm(x_c)
+ # p, n, c
+
+ # Step 2.b: FPFE & TFA (Dynamic Gait Feature Aggregation)
+ # n, t, c, h, w
+ x_p = self.pn(x_p)
+ # p, n, c
+
+ # Step 3: Cat feature map together and fc
+ x = torch.cat((x_c, x_p))
+ x = self.fc(x)
+
+ if self.training:
+ y = y.T
+ hpm_ba_trip = self.hpm_ba_trip(
+ x[:self.hpm_num_parts], y[:self.hpm_num_parts]
+ )
+ pn_ba_trip = self.pn_ba_trip(
+ x[self.hpm_num_parts:], y[self.hpm_num_parts:]
+ )
+ losses = torch.stack((*losses, hpm_ba_trip, pn_ba_trip))
+ return losses, images
+ else:
+ return x.unsqueeze(1).view(-1)
+
+ def _disentangle(self, x_c1_t2, x_c2_t2=None):
+ n, t, c, h, w = x_c1_t2.size()
+ device = x_c1_t2.device
+ x_c1_t1 = x_c1_t2[:, torch.randperm(t), :, :, :]
+ if self.training:
+ ((f_a_, f_c_, f_p_), losses) = self.ae(x_c1_t2, x_c1_t1, x_c2_t2)
+ # Decode features
+ with torch.no_grad():
x_c = self._decode_cano_feature(f_c_, n, t, device)
- x_p = self._decode_pose_feature(f_p_, *x_size, device)
- return (x_c, x_p), None, None
- else: # condition 2
- return self.ae(x_t2, is_c1)
+ x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
+
+ i_a, i_c, i_p = None, None, None
+ if self.image_log_on:
+ i_a = self._decode_appr_feature(f_a_, n, t, c, h, w, device)
+ # Continue decoding canonical features
+ i_c = self.ae.decoder.trans_conv3(x_c)
+ i_c = torch.sigmoid(self.ae.decoder.trans_conv4(i_c))
+ i_p = x_p
+
+ return (x_c, x_p), losses, (i_a, i_c, i_p)
+
+ else: # evaluating
+ f_c_, f_p_ = self.ae(x_c1_t2)
+ x_c = self._decode_cano_feature(f_c_, n, t, device)
+ x_p = self._decode_pose_feature(f_p_, n, t, c, h, w, device)
+ return (x_c, x_p), None, None
def _decode_appr_feature(self, f_a_, n, t, c, h, w, device):
# Decode appearance features