1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
|
from typing import Union, Optional
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from utils.configuration import DataloaderConfiguration, \
HyperparameterConfiguration, DatasetConfiguration, ModelConfiguration
from utils.dataset import CASIAB
from utils.sampler import TripletSampler
class Model:
def __init__(
self,
model_config: ModelConfiguration,
hyperparameter_config: HyperparameterConfiguration
):
self.meta = model_config
self.hp = hyperparameter_config
self.curr_iter = self.meta['restore_iter']
self.is_train: bool = True
self.dataset_metadata: Optional[DatasetConfiguration] = None
self.pr: Optional[int] = None
self.k: Optional[int] = None
self._model_sig: str = self._make_signature(self.meta, ['restore_iter'])
self._hp_sig: str = self._make_signature(self.hp)
self._dataset_sig: str = 'undefined'
@property
def signature(self) -> str:
return '_'.join((self._model_sig, str(self.curr_iter), self._hp_sig,
self._dataset_sig, str(self.batch_size)))
@property
def batch_size(self) -> int:
if self.is_train:
if self.pr and self.k:
return self.pr * self.k
raise AttributeError('No dataset loaded')
else:
return 1
def _batch_splitter(
self,
batch: list[dict[str, Union[str, torch.Tensor]]]
) -> list[tuple[dict[str, list[Union[str, torch.Tensor]]],
dict[str, list[Union[str, torch.Tensor]]]]]:
"""
Disentanglement cannot be processed on different subjects at the
same time, we need to load `pr` subjects one by one. The batch
splitter will return a pr-length list of tuples (with 2 dicts
containing k-length lists of labels, conditions, view and
k-length tensor of clip data, representing condition 1 and
condition 2 respectively).
"""
_batch = []
for i in range(0, self.pr * self.k * 2, self.k * 2):
_batch.append((default_collate(batch[i:i + self.k]),
default_collate(batch[i + self.k:i + self.k * 2])))
return _batch
def fit(
self,
dataset_config: DatasetConfiguration,
dataloader_config: DataloaderConfiguration,
):
self.is_train = True
dataset = self._parse_dataset_config(dataset_config)
dataloader = self._parse_dataloader_config(dataset, dataloader_config)
for iter_i, samples_batched in enumerate(dataloader):
for sub_i, (subject_c1, subject_c2) in enumerate(samples_batched):
pass
if sub_i == 0:
break
if iter_i == 0:
break
def _parse_dataset_config(
self,
dataset_config: DatasetConfiguration
) -> Union[CASIAB]:
self._dataset_sig = self._make_signature(
dataset_config,
popped_keys=['root_dir', 'cache_on']
)
config: dict = dataset_config.copy()
name = config.pop('name')
if name == 'CASIA-B':
return CASIAB(**config, is_train=self.is_train)
elif name == 'FVG':
# TODO
pass
raise ValueError('Invalid dataset: {0}'.format(name))
def _parse_dataloader_config(
self,
dataset: Union[CASIAB],
dataloader_config: DataloaderConfiguration
) -> DataLoader:
config: dict = dataloader_config.copy()
if self.is_train:
(self.pr, self.k) = config.pop('batch_size')
triplet_sampler = TripletSampler(dataset, (self.pr, self.k))
return DataLoader(dataset,
batch_sampler=triplet_sampler,
collate_fn=self._batch_splitter,
**config)
else: # is_test
config.pop('batch_size')
return DataLoader(dataset, **config)
@staticmethod
def _make_signature(config: dict,
popped_keys: Optional[list] = None) -> str:
_config = config.copy()
for (key, value) in config.items():
if popped_keys and key in popped_keys:
_config.pop(key)
continue
if isinstance(value, str):
pass
elif isinstance(value, (tuple, list)):
_config[key] = '_'.join([str(v) for v in value])
else:
_config[key] = str(value)
return '_'.join(_config.values())
|