summaryrefslogtreecommitdiff
path: root/models/model.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-09 21:54:10 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-09 21:54:10 +0800
commitde911a563fc503114559d7e0e7f710db090cec0d (patch)
treeb0720a7bf4a5de4b2c7d529e8a45f5c50dd023fd /models/model.py
parent62f14a6ef0d902b9ffd4e57427a40663e2e5c2ad (diff)
Add prototype predict function
Diffstat (limited to 'models/model.py')
-rw-r--r--models/model.py59
1 files changed, 54 insertions, 5 deletions
diff --git a/models/model.py b/models/model.py
index 54f3441..ba29ede 100644
--- a/models/model.py
+++ b/models/model.py
@@ -8,6 +8,7 @@ import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.tensorboard import SummaryWriter
+from tqdm import tqdm
from models.rgb_part_net import RGBPartNet
from utils.configuration import DataloaderConfiguration, \
@@ -88,10 +89,7 @@ class Model:
self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9)
self.writer = SummaryWriter(self.log_name)
- if not self.disable_acc:
- if torch.cuda.device_count() > 1:
- self.rgb_pn = nn.DataParallel(self.rgb_pn)
- self.rgb_pn = self.rgb_pn.to(self.device)
+ self._accelerate()
self.rgb_pn.train()
# Init weights at first iter
@@ -143,6 +141,57 @@ class Model:
self.writer.close()
break
+ def _accelerate(self):
+ if not self.disable_acc:
+ if torch.cuda.device_count() > 1:
+ self.rgb_pn = nn.DataParallel(self.rgb_pn)
+ self.rgb_pn = self.rgb_pn.to(self.device)
+
+ def predict(
+ self,
+ iter_: int,
+ dataset_config: DatasetConfiguration,
+ dataloader_config: DataloaderConfiguration,
+ ):
+ self.is_train = False
+ dataset = self._parse_dataset_config(dataset_config)
+ dataloader = self._parse_dataloader_config(dataset, dataloader_config)
+ hp = self.hp.copy()
+ _, _ = hp.pop('lr'), hp.pop('betas')
+ dataset_name = dataset_config.get('name', 'CASIA-B')
+ if dataset_name == 'CASIA-B':
+ self.rgb_pn = RGBPartNet(124 - self.train_size,
+ self.in_channels,
+ **hp)
+ elif dataset_name == 'FVG':
+ # TODO
+ pass
+ else:
+ raise ValueError('Invalid dataset: {0}'.format(dataset_name))
+
+ self._accelerate()
+
+ self.rgb_pn.eval()
+ # Load checkpoint at iter_
+ self.curr_iter = iter_
+ checkpoint = torch.load(self.checkpoint_name)
+ self.rgb_pn.load_state_dict(checkpoint['model_state_dict'])
+
+ labels, conditions, views, features = [], [], [], []
+ for sample in tqdm(dataloader, desc='Transforming', unit='clips'):
+ label, condition, view, clip = sample.values()
+ feature = self.rgb_pn(clip).detach().cpu().numpy()
+ labels.append(label)
+ conditions.append(condition)
+ views.append(view)
+ features.append(feature)
+ labels = np.asarray(labels)
+ conditions = np.asarray(conditions)
+ views = np.asarray(views)
+ features = np.asarray(features)
+
+ # TODO Implement evaluation function here
+
@staticmethod
def init_weights(m):
if isinstance(m, nn.modules.conv._ConvNd):
@@ -167,7 +216,7 @@ class Model:
)
self.log_name = '_'.join((self.log_name, self._dataset_sig))
config: dict = dataset_config.copy()
- name = config.pop('name')
+ name = config.pop('name', 'CASIA-B')
if name == 'CASIA-B':
return CASIAB(**config, is_train=self.is_train)
elif name == 'FVG':