summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/model.py67
1 files changed, 51 insertions, 16 deletions
diff --git a/models/model.py b/models/model.py
index 456c2f1..b343f86 100644
--- a/models/model.py
+++ b/models/model.py
@@ -1,4 +1,5 @@
import os
+from datetime import datetime
from typing import Union, Optional
import numpy as np
@@ -86,6 +87,21 @@ class Model:
def _checkpoint_name(self) -> str:
return os.path.join(self.checkpoint_dir, self._signature)
+ def fit_all(
+ self,
+ dataset_config: DatasetConfiguration,
+ dataset_selectors: dict[
+ str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
+ ],
+ dataloader_config: DataloaderConfiguration,
+ ):
+ for (condition, selector) in dataset_selectors.items():
+ print(f'Training model {condition} ...')
+ self.fit(
+ dict(**dataset_config, **{'selector': selector}),
+ dataloader_config
+ )
+
def fit(
self,
dataset_config: DatasetConfiguration,
@@ -115,6 +131,8 @@ class Model:
self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optim_state_dict'])
+ # Training start
+ start_time = datetime.now()
for (batch_c1, batch_c2) in dataloader:
self.curr_iter += 1
# Zero the parameter gradients
@@ -137,7 +155,7 @@ class Model:
], metrics)), self.curr_iter)
if self.curr_iter % 100 == 0:
- print('{0:5d} loss: {1:.3f}'.format(self.curr_iter, loss),
+ print('{0:5d} loss: {1:6.3f}'.format(self.curr_iter, loss),
'(xrecon = {:f}, pose_sim = {:f},'
' cano_cons = {:f}, ba_trip = {:f})'.format(*metrics),
'lr:', self.scheduler.get_last_lr()[0])
@@ -149,8 +167,11 @@ class Model:
'optim_state_dict': self.optimizer.state_dict(),
'loss': loss,
}, self._checkpoint_name)
+ print(datetime.now() - start_time, 'used')
+ start_time = datetime.now()
if self.curr_iter == self.total_iter:
+ self.curr_iter = 0
self.writer.close()
break
@@ -160,7 +181,7 @@ class Model:
self.rgb_pn = nn.DataParallel(self.rgb_pn)
self.rgb_pn = self.rgb_pn.to(self.device)
- def predict(
+ def predict_all(
self,
iter_: int,
dataset_config: DatasetConfiguration,
@@ -189,23 +210,36 @@ class Model:
gallery_samples, probe_samples = [], {}
# Gallery
- self.rgb_pn.load_state_dict(torch.load(list(checkpoints.values())[0]))
+ checkpoint = torch.load(list(checkpoints.values())[0])
+ self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
for sample in tqdm(gallery_dataloader,
desc='Transforming gallery', unit='clips'):
+ label = sample.pop('label').item()
clip = sample.pop('clip').to(self.device)
feature = self.rgb_pn(clip).detach()
- gallery_samples.append({**sample, **{'feature': feature}})
+ gallery_samples.append({
+ **{'label': label},
+ **sample,
+ **{'feature': feature}
+ })
gallery_samples = default_collate(gallery_samples)
# Probe
- for (name, dataloader) in probe_dataloaders.items():
- self.rgb_pn.load_state_dict(torch.load(checkpoints[name]))
- probe_samples[name] = []
+ for (condition, dataloader) in probe_dataloaders.items():
+ checkpoint = torch.load(checkpoints[condition])
+ self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
+ probe_samples[condition] = []
for sample in tqdm(dataloader,
- desc=f'Transforming probe {name}', unit='clips'):
+ desc=f'Transforming probe {condition}',
+ unit='clips'):
+ label = sample.pop('label').item()
clip = sample.pop('clip').to(self.device)
feature = self.rgb_pn(clip).detach()
- probe_samples[name].append({**sample, **{'feature': feature}})
+ probe_samples[condition].append({
+ **{'label': label},
+ **sample,
+ **{'feature': feature}
+ })
for (k, v) in probe_samples.items():
probe_samples[k] = default_collate(v)
@@ -243,11 +277,11 @@ class Model:
f_p = features_p[probe_view_mask]
y_p = labels_p[probe_view_mask]
# Euclidean distance
- f_g_squared_sum = torch.sum(f_g ** 2, dim=1).unsqueeze(1)
- f_p_squared_sum = torch.sum(f_p ** 2, dim=1).unsqueeze(0)
- f_g_times_f_p_sum = f_g @ f_p.T
+ f_p_squared_sum = torch.sum(f_p ** 2, dim=1).unsqueeze(1)
+ f_g_squared_sum = torch.sum(f_g ** 2, dim=1).unsqueeze(0)
+ f_p_times_f_g_sum = f_p @ f_g.T
dist = torch.sqrt(F.relu(
- f_g_squared_sum - 2*f_g_times_f_p_sum + f_p_squared_sum
+ f_p_squared_sum - 2*f_p_times_f_g_sum + f_g_squared_sum
))
# Ranked accuracy
rank_mask = dist.argsort(1)[:, :num_ranks]
@@ -354,8 +388,8 @@ class Model:
dataloader_config: DataloaderConfiguration
) -> DataLoader:
config: dict = dataloader_config.copy()
+ (self.pr, self.k) = config.pop('batch_size')
if self.is_train:
- (self.pr, self.k) = config.pop('batch_size')
self._log_name = '_'.join(
(self._log_name, str(self.pr), str(self.k)))
triplet_sampler = TripletSampler(dataset, (self.pr, self.k))
@@ -364,7 +398,6 @@ class Model:
collate_fn=self._batch_splitter,
**config)
else: # is_test
- config.pop('batch_size')
return DataLoader(dataset, **config)
def _batch_splitter(
@@ -399,8 +432,10 @@ class Model:
for v in values:
if isinstance(v, str):
strings.append(v)
- elif isinstance(v, (tuple, list, set)):
+ elif isinstance(v, (tuple, list)):
strings.append(self._gen_sig(v))
+ elif isinstance(v, set):
+ strings.append(self._gen_sig(sorted(list(v))))
elif isinstance(v, dict):
strings.append(self._gen_sig(list(v.values())))
else: