summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/model.py142
1 files changed, 102 insertions, 40 deletions
diff --git a/models/model.py b/models/model.py
index ba29ede..cabfa97 100644
--- a/models/model.py
+++ b/models/model.py
@@ -14,7 +14,7 @@ from models.rgb_part_net import RGBPartNet
from utils.configuration import DataloaderConfiguration, \
HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \
SystemConfiguration
-from utils.dataset import CASIAB
+from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
from utils.sampler import TripletSampler
@@ -57,21 +57,30 @@ class Model:
self._hp_sig: str = self._make_signature(self.hp)
self._dataset_sig: str = 'undefined'
self._log_sig: str = '_'.join((self._model_sig, self._hp_sig))
- self.log_name: str = os.path.join(self.log_dir, self._log_sig)
+ self._log_name: str = os.path.join(self.log_dir, self._log_sig)
self.rgb_pn: Optional[RGBPartNet] = None
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
+ self.CASIAB_GALLERY_SELECTOR = {
+ 'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})}
+ }
+ self.CASIAB_PROBE_SELECTORS = {
+ 'nm': {'selector': {'conditions': ClipConditions({r'nm-0[5-6]'})}},
+ 'bg': {'selector': {'conditions': ClipConditions({r'bg-0[1-2]'})}},
+ 'cl': {'selector': {'conditions': ClipConditions({r'cl-0[1-2]'})}},
+ }
+
@property
- def signature(self) -> str:
+ def _signature(self) -> str:
return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig,
self._dataset_sig, str(self.pr), str(self.k)))
@property
- def checkpoint_name(self) -> str:
- return os.path.join(self.checkpoint_dir, self.signature)
+ def _checkpoint_name(self) -> str:
+ return os.path.join(self.checkpoint_dir, self._signature)
def fit(
self,
@@ -87,8 +96,8 @@ class Model:
self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **hp)
self.optimizer = optim.Adam(self.rgb_pn.parameters(), lr, betas)
self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9)
- self.writer = SummaryWriter(self.log_name)
-
+ self.writer = SummaryWriter(self._log_name)
+ # Try to accelerate computation using CUDA or others
self._accelerate()
self.rgb_pn.train()
@@ -96,7 +105,7 @@ class Model:
if self.curr_iter == 0:
self.rgb_pn.apply(self.init_weights)
else: # Load saved state dicts
- checkpoint = torch.load(self.checkpoint_name)
+ checkpoint = torch.load(self._checkpoint_name)
iter_, loss = checkpoint['iter'], checkpoint['loss']
print('{0:5d} loss: {1:.3f}'.format(iter_, loss))
self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
@@ -135,7 +144,7 @@ class Model:
'model_state_dict': self.rgb_pn.state_dict(),
'optim_state_dict': self.optimizer.state_dict(),
'loss': loss,
- }, self.checkpoint_name)
+ }, self._checkpoint_name)
if self.curr_iter == self.total_iter:
self.writer.close()
@@ -151,46 +160,98 @@ class Model:
self,
iter_: int,
dataset_config: DatasetConfiguration,
+ dataset_selectors: dict[
+ str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
+ ],
dataloader_config: DataloaderConfiguration,
):
self.is_train = False
- dataset = self._parse_dataset_config(dataset_config)
- dataloader = self._parse_dataloader_config(dataset, dataloader_config)
+ # Split gallery and probe dataset
+ gallery_dataloader, probe_dataloaders = self._split_gallery_probe(
+ dataset_config, dataloader_config
+ )
+ # Get pretrained models at iter_
+ checkpoints = self._load_pretrained(
+ iter_, dataset_config, dataset_selectors
+ )
+ # Init models
hp = self.hp.copy()
- _, _ = hp.pop('lr'), hp.pop('betas')
+ hp.pop('lr'), hp.pop('betas')
+ self.rgb_pn = RGBPartNet(ae_in_channels=self.in_channels, **hp)
+ # Try to accelerate computation using CUDA or others
+ self._accelerate()
+
+ self.rgb_pn.eval()
+ gallery_samples, probe_samples = [], {}
+
+ # Gallery
+ self.rgb_pn.load_state_dict(torch.load(list(checkpoints.values())[0]))
+ for sample in tqdm(gallery_dataloader,
+ desc='Transforming gallery', unit='clips'):
+ clip = sample.pop('clip').to(self.device)
+ feature = self.rgb_pn(clip).detach().cpu()
+ gallery_samples.append({**sample, **{'feature': feature}})
+ gallery_samples = default_collate(gallery_samples)
+
+ # Probe
+ for (name, dataloader) in probe_dataloaders.items():
+ self.rgb_pn.load_state_dict(torch.load(checkpoints[name]))
+ probe_samples[name] = []
+ for sample in tqdm(dataloader,
+ desc=f'Transforming probe {name}', unit='clips'):
+ clip = sample.pop('clip').to(self.device)
+ feature = self.rgb_pn(clip)
+ probe_samples[name].append({**sample, **{'feature': feature}})
+ for (k, v) in probe_samples.items():
+ probe_samples[k] = default_collate(v)
+
+ # TODO Implement evaluation function here
+
+ def _load_pretrained(
+ self,
+ iter_: int,
+ dataset_config: DatasetConfiguration,
+ dataset_selectors: dict[
+ str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
+ ]
+ ) -> dict[str, str]:
+ checkpoints = {}
+ self.curr_iter = iter_
+ for (k, v) in dataset_selectors.items():
+ self._dataset_sig = self._make_signature(
+ dict(**dataset_config, **v),
+ popped_keys=['root_dir', 'cache_on']
+ )
+ checkpoints[k] = self._checkpoint_name
+ return checkpoints
+
+ def _split_gallery_probe(
+ self,
+ dataset_config: DatasetConfiguration,
+ dataloader_config: DataloaderConfiguration,
+ ) -> tuple[DataLoader, dict[str: DataLoader]]:
dataset_name = dataset_config.get('name', 'CASIA-B')
if dataset_name == 'CASIA-B':
- self.rgb_pn = RGBPartNet(124 - self.train_size,
- self.in_channels,
- **hp)
+ gallery_dataset = self._parse_dataset_config(
+ dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR)
+ )
+ gallery_dataloader = self._parse_dataloader_config(
+ gallery_dataset, dataloader_config
+ )
+ probe_datasets = {k: self._parse_dataset_config(
+ dict(**dataset_config, **v)
+ ) for (k, v) in self.CASIAB_PROBE_SELECTORS.items()}
+ probe_dataloaders = {k: self._parse_dataloader_config(
+ v, dataloader_config
+ ) for (k, v) in probe_datasets.items()}
elif dataset_name == 'FVG':
# TODO
- pass
+ gallery_dataloader = None
+ probe_dataloaders = None
else:
raise ValueError('Invalid dataset: {0}'.format(dataset_name))
- self._accelerate()
-
- self.rgb_pn.eval()
- # Load checkpoint at iter_
- self.curr_iter = iter_
- checkpoint = torch.load(self.checkpoint_name)
- self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
-
- labels, conditions, views, features = [], [], [], []
- for sample in tqdm(dataloader, desc='Transforming', unit='clips'):
- label, condition, view, clip = sample.values()
- feature = self.rgb_pn(clip).detach().cpu().numpy()
- labels.append(label)
- conditions.append(condition)
- views.append(view)
- features.append(feature)
- labels = np.asarray(labels)
- conditions = np.asarray(conditions)
- views = np.asarray(views)
- features = np.asarray(features)
-
- # TODO Implement evaluation function here
+ return gallery_dataloader, probe_dataloaders
@staticmethod
def init_weights(m):
@@ -214,7 +275,7 @@ class Model:
dataset_config,
popped_keys=['root_dir', 'cache_on']
)
- self.log_name = '_'.join((self.log_name, self._dataset_sig))
+ self._log_name = '_'.join((self._log_name, self._dataset_sig))
config: dict = dataset_config.copy()
name = config.pop('name', 'CASIA-B')
if name == 'CASIA-B':
@@ -232,7 +293,8 @@ class Model:
config: dict = dataloader_config.copy()
if self.is_train:
(self.pr, self.k) = config.pop('batch_size')
- self.log_name = '_'.join((self.log_name, str(self.pr), str(self.k)))
+ self._log_name = '_'.join(
+ (self._log_name, str(self.pr), str(self.k)))
triplet_sampler = TripletSampler(dataset, (self.pr, self.k))
return DataLoader(dataset,
batch_sampler=triplet_sampler,