summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-04-04 17:43:19 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-04-04 17:43:19 +0800
commit85627d4cfb495453a7c28b3f131b84b1038af674 (patch)
treeffea2f7947e58666e36736370c405fe44aad1641 /models/model.py
parentcb05de36f5ffd8584d78c6776dbe90e21abff25a (diff)
Add cross entropy lossdisentangling_only
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py24
1 files changed, 19 insertions, 5 deletions
diff --git a/models/model.py b/models/model.py
index 0829f33..d976f5a 100644
--- a/models/model.py
+++ b/models/model.py
@@ -54,6 +54,7 @@ class Model:
self.total_iters = self.meta.get('total_iters', (self.total_iter,))
self.is_train: bool = True
+ self.num_class: Optional[int] = None
self.in_channels: int = 3
self.in_size: tuple[int, int] = (64, 48)
self.batch_size: Optional[int] = None
@@ -160,8 +161,13 @@ class Model:
optim_hp: dict = self.hp.get('optimizer', {}).copy()
sched_hp = self.hp.get('scheduler', {})
- self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp,
- image_log_on=self.image_log_on)
+ self.rgb_pn = RGBPartNet(
+ self.num_class,
+ self.in_channels,
+ self.in_size,
+ **model_hp,
+ image_log_on=self.image_log_on
+ )
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
@@ -202,7 +208,8 @@ class Model:
# forward + backward + optimize
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
- losses, features, images = self.rgb_pn(x_c1, x_c2)
+ y = batch_c1['label'].to(self.device)
+ losses, features, images = self.rgb_pn(x_c1, x_c2, y)
loss = losses.sum()
loss.backward()
self.optimizer.step()
@@ -254,8 +261,9 @@ class Model:
batch_c1, batch_c2 = next(val_dataloader)
x_c1 = batch_c1['clip'].to(self.device)
x_c2 = batch_c2['clip'].to(self.device)
+ y = batch_c1['label'].to(self.device)
with torch.no_grad():
- losses, _, _ = self.rgb_pn(x_c1, x_c2)
+ losses, _, _ = self.rgb_pn(x_c1, x_c2, y)
loss = losses.sum()
self._write_stat('Val', loss, losses)
@@ -302,7 +310,12 @@ class Model:
# Init models
model_hp: dict = self.hp.get('model', {}).copy()
- self.rgb_pn = RGBPartNet(self.in_channels, self.in_size, **model_hp)
+ self.rgb_pn = RGBPartNet(
+ self.num_class,
+ self.in_channels,
+ self.in_size,
+ **model_hp
+ )
# Try to accelerate computation using CUDA or others
self.rgb_pn = self.rgb_pn.to(self.device)
self.rgb_pn.eval()
@@ -419,6 +432,7 @@ class Model:
self,
dataset_config: DatasetConfiguration
) -> Union[CASIAB]:
+ self.num_class = dataset_config.get('train_size', 74)
self.in_channels = dataset_config.get('num_input_channels', 3)
self.in_size = dataset_config.get('frame_size', (64, 48))
self._dataset_sig = self._make_signature(