diff options
Diffstat (limited to 'simclr')
-rw-r--r-- | simclr/main.py | 18 |
1 files changed, 8 insertions, 10 deletions
diff --git a/simclr/main.py b/simclr/main.py index 93a2d83..b170355 100644 --- a/simclr/main.py +++ b/simclr/main.py @@ -294,12 +294,11 @@ class SimCLRTrainer(Trainer): train_loader = iter(self.train_loader) model.train() for iter_ in range(self.restore_iter, num_iters): - (input1, input2), _ = next(train_loader) - input1, input2 = input1.to(device), input2.to(device) + input_, _ = next(train_loader) + input_ = torch.cat(input_).to(device) model.zero_grad() - output1 = model(input1) - output2 = model(input2) - train_loss, train_accuracy = loss_fn(output1, output2) + output = model(input_) + train_loss, train_accuracy = loss_fn(output) train_loss.backward() optim.step() self.log(logger, self.BatchLogRecord( @@ -327,11 +326,10 @@ class SimCLRTrainer(Trainer): model = self.models['model'] model.eval() with torch.no_grad(): - for (input1, input2), _ in self.test_loader: - input1, input2 = input1.to(device), input2.to(device) - output1 = model(input1) - output2 = model(input2) - loss, accuracy = loss_fn(output1, output2) + for input_, _ in self.test_loader: + input_ = torch.cat(input_).to(device) + output = model(input_) + loss, accuracy = loss_fn(output) yield loss.item(), accuracy.item() |