From de911a563fc503114559d7e0e7f710db090cec0d Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Sat, 9 Jan 2021 21:54:10 +0800 Subject: Add prototype predict function --- models/model.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 54 insertions(+), 5 deletions(-) (limited to 'models/model.py') diff --git a/models/model.py b/models/model.py index 54f3441..ba29ede 100644 --- a/models/model.py +++ b/models/model.py @@ -8,6 +8,7 @@ import torch.optim as optim from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm from models.rgb_part_net import RGBPartNet from utils.configuration import DataloaderConfiguration, \ @@ -88,10 +89,7 @@ class Model: self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 500, 0.9) self.writer = SummaryWriter(self.log_name) - if not self.disable_acc: - if torch.cuda.device_count() > 1: - self.rgb_pn = nn.DataParallel(self.rgb_pn) - self.rgb_pn = self.rgb_pn.to(self.device) + self._accelerate() self.rgb_pn.train() # Init weights at first iter @@ -143,6 +141,57 @@ class Model: self.writer.close() break + def _accelerate(self): + if not self.disable_acc: + if torch.cuda.device_count() > 1: + self.rgb_pn = nn.DataParallel(self.rgb_pn) + self.rgb_pn = self.rgb_pn.to(self.device) + + def predict( + self, + iter_: int, + dataset_config: DatasetConfiguration, + dataloader_config: DataloaderConfiguration, + ): + self.is_train = False + dataset = self._parse_dataset_config(dataset_config) + dataloader = self._parse_dataloader_config(dataset, dataloader_config) + hp = self.hp.copy() + _, _ = hp.pop('lr'), hp.pop('betas') + dataset_name = dataset_config.get('name', 'CASIA-B') + if dataset_name == 'CASIA-B': + self.rgb_pn = RGBPartNet(124 - self.train_size, + self.in_channels, + **hp) + elif dataset_name == 'FVG': + # TODO + pass + else: + raise ValueError('Invalid dataset: {0}'.format(dataset_name)) + + self._accelerate() + + self.rgb_pn.eval() + # Load checkpoint at iter_ + self.curr_iter = iter_ + checkpoint = torch.load(self.checkpoint_name) + self.rgb_pn.load_state_dict(checkpoint['model_state_dict']) + + labels, conditions, views, features = [], [], [], [] + for sample in tqdm(dataloader, desc='Transforming', unit='clips'): + label, condition, view, clip = sample.values() + feature = self.rgb_pn(clip).detach().cpu().numpy() + labels.append(label) + conditions.append(condition) + views.append(view) + features.append(feature) + labels = np.asarray(labels) + conditions = np.asarray(conditions) + views = np.asarray(views) + features = np.asarray(features) + + # TODO Implement evaluation function here + @staticmethod def init_weights(m): if isinstance(m, nn.modules.conv._ConvNd): @@ -167,7 +216,7 @@ class Model: ) self.log_name = '_'.join((self.log_name, self._dataset_sig)) config: dict = dataset_config.copy() - name = config.pop('name') + name = config.pop('name', 'CASIA-B') if name == 'CASIA-B': return CASIAB(**config, is_train=self.is_train) elif name == 'FVG': -- cgit v1.2.3