summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--eval.py16
-rw-r--r--models/model.py67
-rw-r--r--train.py16
-rw-r--r--utils/misc.py10
4 files changed, 81 insertions, 28 deletions
diff --git a/eval.py b/eval.py
new file mode 100644
index 0000000..fee4ab9
--- /dev/null
+++ b/eval.py
@@ -0,0 +1,16 @@
+from config import config
+from models import Model
+from utils.dataset import ClipConditions
+from utils.misc import set_visible_cuda
+
+set_visible_cuda(config['system'])
+model = Model(config['system'], config['model'], config['hyperparameter'])
+
+dataset_selectors = {
+ 'nm': {'conditions': ClipConditions({r'nm-0\d'})},
+ 'bg': {'conditions': ClipConditions({r'nm-0\d', r'bg-0\d'})},
+ 'cl': {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})},
+}
+
+accuracy = model.predict_all(config['model']['total_iter'], config['dataset'],
+ dataset_selectors, config['dataloader'])
diff --git a/models/model.py b/models/model.py
index 456c2f1..b343f86 100644
--- a/models/model.py
+++ b/models/model.py
@@ -1,4 +1,5 @@
import os
+from datetime import datetime
from typing import Union, Optional
import numpy as np
@@ -86,6 +87,21 @@ class Model:
def _checkpoint_name(self) -> str:
return os.path.join(self.checkpoint_dir, self._signature)
+ def fit_all(
+ self,
+ dataset_config: DatasetConfiguration,
+ dataset_selectors: dict[
+ str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
+ ],
+ dataloader_config: DataloaderConfiguration,
+ ):
+ for (condition, selector) in dataset_selectors.items():
+ print(f'Training model {condition} ...')
+ self.fit(
+ dict(**dataset_config, **{'selector': selector}),
+ dataloader_config
+ )
+
def fit(
self,
dataset_config: DatasetConfiguration,
@@ -115,6 +131,8 @@ class Model:
self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optim_state_dict'])
+ # Training start
+ start_time = datetime.now()
for (batch_c1, batch_c2) in dataloader:
self.curr_iter += 1
# Zero the parameter gradients
@@ -137,7 +155,7 @@ class Model:
], metrics)), self.curr_iter)
if self.curr_iter % 100 == 0:
- print('{0:5d} loss: {1:.3f}'.format(self.curr_iter, loss),
+ print('{0:5d} loss: {1:6.3f}'.format(self.curr_iter, loss),
'(xrecon = {:f}, pose_sim = {:f},'
' cano_cons = {:f}, ba_trip = {:f})'.format(*metrics),
'lr:', self.scheduler.get_last_lr()[0])
@@ -149,8 +167,11 @@ class Model:
'optim_state_dict': self.optimizer.state_dict(),
'loss': loss,
}, self._checkpoint_name)
+ print(datetime.now() - start_time, 'used')
+ start_time = datetime.now()
if self.curr_iter == self.total_iter:
+ self.curr_iter = 0
self.writer.close()
break
@@ -160,7 +181,7 @@ class Model:
self.rgb_pn = nn.DataParallel(self.rgb_pn)
self.rgb_pn = self.rgb_pn.to(self.device)
- def predict(
+ def predict_all(
self,
iter_: int,
dataset_config: DatasetConfiguration,
@@ -189,23 +210,36 @@ class Model:
gallery_samples, probe_samples = [], {}
# Gallery
- self.rgb_pn.load_state_dict(torch.load(list(checkpoints.values())[0]))
+ checkpoint = torch.load(list(checkpoints.values())[0])
+ self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
for sample in tqdm(gallery_dataloader,
desc='Transforming gallery', unit='clips'):
+ label = sample.pop('label').item()
clip = sample.pop('clip').to(self.device)
feature = self.rgb_pn(clip).detach()
- gallery_samples.append({**sample, **{'feature': feature}})
+ gallery_samples.append({
+ **{'label': label},
+ **sample,
+ **{'feature': feature}
+ })
gallery_samples = default_collate(gallery_samples)
# Probe
- for (name, dataloader) in probe_dataloaders.items():
- self.rgb_pn.load_state_dict(torch.load(checkpoints[name]))
- probe_samples[name] = []
+ for (condition, dataloader) in probe_dataloaders.items():
+ checkpoint = torch.load(checkpoints[condition])
+ self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
+ probe_samples[condition] = []
for sample in tqdm(dataloader,
- desc=f'Transforming probe {name}', unit='clips'):
+ desc=f'Transforming probe {condition}',
+ unit='clips'):
+ label = sample.pop('label').item()
clip = sample.pop('clip').to(self.device)
feature = self.rgb_pn(clip).detach()
- probe_samples[name].append({**sample, **{'feature': feature}})
+ probe_samples[condition].append({
+ **{'label': label},
+ **sample,
+ **{'feature': feature}
+ })
for (k, v) in probe_samples.items():
probe_samples[k] = default_collate(v)
@@ -243,11 +277,11 @@ class Model:
f_p = features_p[probe_view_mask]
y_p = labels_p[probe_view_mask]
# Euclidean distance
- f_g_squared_sum = torch.sum(f_g ** 2, dim=1).unsqueeze(1)
- f_p_squared_sum = torch.sum(f_p ** 2, dim=1).unsqueeze(0)
- f_g_times_f_p_sum = f_g @ f_p.T
+ f_p_squared_sum = torch.sum(f_p ** 2, dim=1).unsqueeze(1)
+ f_g_squared_sum = torch.sum(f_g ** 2, dim=1).unsqueeze(0)
+ f_p_times_f_g_sum = f_p @ f_g.T
dist = torch.sqrt(F.relu(
- f_g_squared_sum - 2*f_g_times_f_p_sum + f_p_squared_sum
+ f_p_squared_sum - 2*f_p_times_f_g_sum + f_g_squared_sum
))
# Ranked accuracy
rank_mask = dist.argsort(1)[:, :num_ranks]
@@ -354,8 +388,8 @@ class Model:
dataloader_config: DataloaderConfiguration
) -> DataLoader:
config: dict = dataloader_config.copy()
+ (self.pr, self.k) = config.pop('batch_size')
if self.is_train:
- (self.pr, self.k) = config.pop('batch_size')
self._log_name = '_'.join(
(self._log_name, str(self.pr), str(self.k)))
triplet_sampler = TripletSampler(dataset, (self.pr, self.k))
@@ -364,7 +398,6 @@ class Model:
collate_fn=self._batch_splitter,
**config)
else: # is_test
- config.pop('batch_size')
return DataLoader(dataset, **config)
def _batch_splitter(
@@ -399,8 +432,10 @@ class Model:
for v in values:
if isinstance(v, str):
strings.append(v)
- elif isinstance(v, (tuple, list, set)):
+ elif isinstance(v, (tuple, list)):
strings.append(self._gen_sig(v))
+ elif isinstance(v, set):
+ strings.append(self._gen_sig(sorted(list(v))))
elif isinstance(v, dict):
strings.append(self._gen_sig(list(v.values())))
else:
diff --git a/train.py b/train.py
index cdb2fb0..d91dcd0 100644
--- a/train.py
+++ b/train.py
@@ -1,14 +1,9 @@
-import os
-
from config import config
from models import Model
from utils.dataset import ClipConditions
+from utils.misc import set_visible_cuda
-# Set environment variable CUDA device(s)
-CUDA_VISIBLE_DEVICES = config['system'].get('CUDA_VISIBLE_DEVICES', None)
-if CUDA_VISIBLE_DEVICES:
- os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES
-
+set_visible_cuda(config['system'])
model = Model(config['system'], config['model'], config['hyperparameter'])
# 3 models for different conditions
@@ -17,8 +12,5 @@ dataset_selectors = {
'bg': {'conditions': ClipConditions({r'nm-0\d', r'bg-0\d'})},
'cl': {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})},
}
-for selector in dataset_selectors.values():
- model.fit(
- dict(**config['dataset'], **{'selector': selector}),
- config['dataloader']
- )
+
+model.fit_all(config['dataset'], dataset_selectors, config['dataloader'])
diff --git a/utils/misc.py b/utils/misc.py
new file mode 100644
index 0000000..b850830
--- /dev/null
+++ b/utils/misc.py
@@ -0,0 +1,10 @@
+import os
+
+from utils.configuration import SystemConfiguration
+
+
+def set_visible_cuda(config: SystemConfiguration):
+ """Set environment variable CUDA device(s)"""
+ CUDA_VISIBLE_DEVICES = config.get('CUDA_VISIBLE_DEVICES', None)
+ if CUDA_VISIBLE_DEVICES:
+ os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_VISIBLE_DEVICES