summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/model.py142
-rw-r--r--test/model.py4
-rw-r--r--train.py12
3 files changed, 110 insertions, 48 deletions
diff --git a/models/model.py b/models/model.py
index ba29ede..cabfa97 100644
--- a/models/model.py
+++ b/models/model.py
@@ -14,7 +14,7 @@ from models.rgb_part_net import RGBPartNet
from utils.configuration import DataloaderConfiguration, \
HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration, \
SystemConfiguration
-from utils.dataset import CASIAB
+from utils.dataset import CASIAB, ClipConditions, ClipViews, ClipClasses
from utils.sampler import TripletSampler
@@ -57,21 +57,30 @@ class Model:
self._hp_sig: str = self._make_signature(self.hp)
self._dataset_sig: str = 'undefined'
self._log_sig: str = '_'.join((self._model_sig, self._hp_sig))
- self.log_name: str = os.path.join(self.log_dir, self._log_sig)
+ self._log_name: str = os.path.join(self.log_dir, self._log_sig)
self.rgb_pn: Optional[RGBPartNet] = None
self.optimizer: Optional[optim.Adam] = None
self.scheduler: Optional[optim.lr_scheduler.StepLR] = None
self.writer: Optional[SummaryWriter] = None
+ self.CASIAB_GALLERY_SELECTOR = {
+ 'selector': {'conditions': ClipConditions({r'nm-0[1-4]'})}
+ }
+ self.CASIAB_PROBE_SELECTORS = {
+ 'nm': {'selector': {'conditions': ClipConditions({r'nm-0[5-6]'})}},
+ 'bg': {'selector': {'conditions': ClipConditions({r'bg-0[1-2]'})}},
+ 'cl': {'selector': {'conditions': ClipConditions({r'cl-0[1-2]'})}},
+ }
+
@property
- def signature(self) -> str:
+ def _signature(self) -> str:
return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig,
self._dataset_sig, str(self.pr), str(self.k)))
@property
- def checkpoint_name(self) -> str:
- return os.path.join(self.checkpoint_dir, self.signature)
+ def _checkpoint_name(self) -> str:
+ return os.path.join(self.checkpoint_dir, self._signature)
def fit(
self,
@@ -87,8 +96,8 @@ class Model:
self.rgb_pn = RGBPartNet(self.train_size, self.in_channels, **hp)
self.optimizer = optim.Adam(self.rgb_pn.parameters(), lr, betas)
self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9)
- self.writer = SummaryWriter(self.log_name)
-
+ self.writer = SummaryWriter(self._log_name)
+ # Try to accelerate computation using CUDA or others
self._accelerate()
self.rgb_pn.train()
@@ -96,7 +105,7 @@ class Model:
if self.curr_iter == 0:
self.rgb_pn.apply(self.init_weights)
else: # Load saved state dicts
- checkpoint = torch.load(self.checkpoint_name)
+ checkpoint = torch.load(self._checkpoint_name)
iter_, loss = checkpoint['iter'], checkpoint['loss']
print('{0:5d} loss: {1:.3f}'.format(iter_, loss))
self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
@@ -135,7 +144,7 @@ class Model:
'model_state_dict': self.rgb_pn.state_dict(),
'optim_state_dict': self.optimizer.state_dict(),
'loss': loss,
- }, self.checkpoint_name)
+ }, self._checkpoint_name)
if self.curr_iter == self.total_iter:
self.writer.close()
@@ -151,46 +160,98 @@ class Model:
self,
iter_: int,
dataset_config: DatasetConfiguration,
+ dataset_selectors: dict[
+ str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
+ ],
dataloader_config: DataloaderConfiguration,
):
self.is_train = False
- dataset = self._parse_dataset_config(dataset_config)
- dataloader = self._parse_dataloader_config(dataset, dataloader_config)
+ # Split gallery and probe dataset
+ gallery_dataloader, probe_dataloaders = self._split_gallery_probe(
+ dataset_config, dataloader_config
+ )
+ # Get pretrained models at iter_
+ checkpoints = self._load_pretrained(
+ iter_, dataset_config, dataset_selectors
+ )
+ # Init models
hp = self.hp.copy()
- _, _ = hp.pop('lr'), hp.pop('betas')
+ hp.pop('lr'), hp.pop('betas')
+ self.rgb_pn = RGBPartNet(ae_in_channels=self.in_channels, **hp)
+ # Try to accelerate computation using CUDA or others
+ self._accelerate()
+
+ self.rgb_pn.eval()
+ gallery_samples, probe_samples = [], {}
+
+ # Gallery
+ self.rgb_pn.load_state_dict(torch.load(list(checkpoints.values())[0]))
+ for sample in tqdm(gallery_dataloader,
+ desc='Transforming gallery', unit='clips'):
+ clip = sample.pop('clip').to(self.device)
+ feature = self.rgb_pn(clip).detach().cpu()
+ gallery_samples.append({**sample, **{'feature': feature}})
+ gallery_samples = default_collate(gallery_samples)
+
+ # Probe
+ for (name, dataloader) in probe_dataloaders.items():
+ self.rgb_pn.load_state_dict(torch.load(checkpoints[name]))
+ probe_samples[name] = []
+ for sample in tqdm(dataloader,
+ desc=f'Transforming probe {name}', unit='clips'):
+ clip = sample.pop('clip').to(self.device)
+ feature = self.rgb_pn(clip)
+ probe_samples[name].append({**sample, **{'feature': feature}})
+ for (k, v) in probe_samples.items():
+ probe_samples[k] = default_collate(v)
+
+ # TODO Implement evaluation function here
+
+ def _load_pretrained(
+ self,
+ iter_: int,
+ dataset_config: DatasetConfiguration,
+ dataset_selectors: dict[
+ str, dict[str, Union[ClipClasses, ClipConditions, ClipViews]]
+ ]
+ ) -> dict[str, str]:
+ checkpoints = {}
+ self.curr_iter = iter_
+ for (k, v) in dataset_selectors.items():
+ self._dataset_sig = self._make_signature(
+ dict(**dataset_config, **v),
+ popped_keys=['root_dir', 'cache_on']
+ )
+ checkpoints[k] = self._checkpoint_name
+ return checkpoints
+
+ def _split_gallery_probe(
+ self,
+ dataset_config: DatasetConfiguration,
+ dataloader_config: DataloaderConfiguration,
+ ) -> tuple[DataLoader, dict[str: DataLoader]]:
dataset_name = dataset_config.get('name', 'CASIA-B')
if dataset_name == 'CASIA-B':
- self.rgb_pn = RGBPartNet(124 - self.train_size,
- self.in_channels,
- **hp)
+ gallery_dataset = self._parse_dataset_config(
+ dict(**dataset_config, **self.CASIAB_GALLERY_SELECTOR)
+ )
+ gallery_dataloader = self._parse_dataloader_config(
+ gallery_dataset, dataloader_config
+ )
+ probe_datasets = {k: self._parse_dataset_config(
+ dict(**dataset_config, **v)
+ ) for (k, v) in self.CASIAB_PROBE_SELECTORS.items()}
+ probe_dataloaders = {k: self._parse_dataloader_config(
+ v, dataloader_config
+ ) for (k, v) in probe_datasets.items()}
elif dataset_name == 'FVG':
# TODO
- pass
+ gallery_dataloader = None
+ probe_dataloaders = None
else:
raise ValueError('Invalid dataset: {0}'.format(dataset_name))
- self._accelerate()
-
- self.rgb_pn.eval()
- # Load checkpoint at iter_
- self.curr_iter = iter_
- checkpoint = torch.load(self.checkpoint_name)
- self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
-
- labels, conditions, views, features = [], [], [], []
- for sample in tqdm(dataloader, desc='Transforming', unit='clips'):
- label, condition, view, clip = sample.values()
- feature = self.rgb_pn(clip).detach().cpu().numpy()
- labels.append(label)
- conditions.append(condition)
- views.append(view)
- features.append(feature)
- labels = np.asarray(labels)
- conditions = np.asarray(conditions)
- views = np.asarray(views)
- features = np.asarray(features)
-
- # TODO Implement evaluation function here
+ return gallery_dataloader, probe_dataloaders
@staticmethod
def init_weights(m):
@@ -214,7 +275,7 @@ class Model:
dataset_config,
popped_keys=['root_dir', 'cache_on']
)
- self.log_name = '_'.join((self.log_name, self._dataset_sig))
+ self._log_name = '_'.join((self._log_name, self._dataset_sig))
config: dict = dataset_config.copy()
name = config.pop('name', 'CASIA-B')
if name == 'CASIA-B':
@@ -232,7 +293,8 @@ class Model:
config: dict = dataloader_config.copy()
if self.is_train:
(self.pr, self.k) = config.pop('batch_size')
- self.log_name = '_'.join((self.log_name, str(self.pr), str(self.k)))
+ self._log_name = '_'.join(
+ (self._log_name, str(self.pr), str(self.k)))
triplet_sampler = TripletSampler(dataset, (self.pr, self.k))
return DataLoader(dataset,
batch_sampler=triplet_sampler,
diff --git a/test/model.py b/test/model.py
index 472c210..f679908 100644
--- a/test/model.py
+++ b/test/model.py
@@ -11,11 +11,11 @@ def test_default_signature():
model = Model(conf['system'], conf['model'], conf['hyperparameter'])
casiab = model._parse_dataset_config(conf['dataset'])
model._parse_dataloader_config(casiab, conf['dataloader'])
- assert model.log_name == os.path.join(
+ assert model._log_name == os.path.join(
'runs', 'logs', 'RGB-GaitPart_80000_64_128_128_64_1_2_4_True_True_32_5_'
'3_3_3_3_3_2_1_1_1_1_1_0_2_3_4_16_256_0.2_0.0001_0.9_'
'0.999_CASIA-B_74_30_15_3_64_32_8_16')
- assert model.signature == ('RGB-GaitPart_80000_0_64_128_128_64_1_2_4_True_'
+ assert model._signature == ('RGB-GaitPart_80000_0_64_128_128_64_1_2_4_True_'
'True_32_5_3_3_3_3_3_2_1_1_1_1_1_0_2_3_4_16_256_'
'0.2_0.0001_0.9_0.999_CASIA-B_74_30_15_3_64_32_'
'8_16')
diff --git a/train.py b/train.py
index d921839..cdb2fb0 100644
--- a/train.py
+++ b/train.py
@@ -12,12 +12,12 @@ if CUDA_VISIBLE_DEVICES:
model = Model(config['system'], config['model'], config['hyperparameter'])
# 3 models for different conditions
-dataset_selectors = [
- {'conditions': ClipConditions({r'nm-0\d'})},
- {'conditions': ClipConditions({r'nm-0\d', r'bg-0\d'})},
- {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})},
-]
-for selector in dataset_selectors:
+dataset_selectors = {
+ 'nm': {'conditions': ClipConditions({r'nm-0\d'})},
+ 'bg': {'conditions': ClipConditions({r'nm-0\d', r'bg-0\d'})},
+ 'cl': {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})},
+}
+for selector in dataset_selectors.values():
model.fit(
dict(**config['dataset'], **{'selector': selector}),
config['dataloader']