aboutsummaryrefslogtreecommitdiff
path: root/simclr
diff options
context:
space:
mode:
Diffstat (limited to 'simclr')
-rw-r--r--simclr/main.py18
1 files changed, 8 insertions, 10 deletions
diff --git a/simclr/main.py b/simclr/main.py
index 93a2d83..b170355 100644
--- a/simclr/main.py
+++ b/simclr/main.py
@@ -294,12 +294,11 @@ class SimCLRTrainer(Trainer):
train_loader = iter(self.train_loader)
model.train()
for iter_ in range(self.restore_iter, num_iters):
- (input1, input2), _ = next(train_loader)
- input1, input2 = input1.to(device), input2.to(device)
+ input_, _ = next(train_loader)
+ input_ = torch.cat(input_).to(device)
model.zero_grad()
- output1 = model(input1)
- output2 = model(input2)
- train_loss, train_accuracy = loss_fn(output1, output2)
+ output = model(input_)
+ train_loss, train_accuracy = loss_fn(output)
train_loss.backward()
optim.step()
self.log(logger, self.BatchLogRecord(
@@ -327,11 +326,10 @@ class SimCLRTrainer(Trainer):
model = self.models['model']
model.eval()
with torch.no_grad():
- for (input1, input2), _ in self.test_loader:
- input1, input2 = input1.to(device), input2.to(device)
- output1 = model(input1)
- output2 = model(input2)
- loss, accuracy = loss_fn(output1, output2)
+ for input_, _ in self.test_loader:
+ input_ = torch.cat(input_).to(device)
+ output = model(input_)
+ loss, accuracy = loss_fn(output)
yield loss.item(), accuracy.item()