From ebb2f93ac01f40d00968daaf9a2ad96c24ce7ab3 Mon Sep 17 00:00:00 2001 From: Jordan Gong Date: Mon, 8 Aug 2022 19:32:51 +0800 Subject: Optimize batching --- simclr/main.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) (limited to 'simclr/main.py') diff --git a/simclr/main.py b/simclr/main.py index 93a2d83..b170355 100644 --- a/simclr/main.py +++ b/simclr/main.py @@ -294,12 +294,11 @@ class SimCLRTrainer(Trainer): train_loader = iter(self.train_loader) model.train() for iter_ in range(self.restore_iter, num_iters): - (input1, input2), _ = next(train_loader) - input1, input2 = input1.to(device), input2.to(device) + input_, _ = next(train_loader) + input_ = torch.cat(input_).to(device) model.zero_grad() - output1 = model(input1) - output2 = model(input2) - train_loss, train_accuracy = loss_fn(output1, output2) + output = model(input_) + train_loss, train_accuracy = loss_fn(output) train_loss.backward() optim.step() self.log(logger, self.BatchLogRecord( @@ -327,11 +326,10 @@ class SimCLRTrainer(Trainer): model = self.models['model'] model.eval() with torch.no_grad(): - for (input1, input2), _ in self.test_loader: - input1, input2 = input1.to(device), input2.to(device) - output1 = model(input1) - output2 = model(input2) - loss, accuracy = loss_fn(output1, output2) + for input_, _ in self.test_loader: + input_ = torch.cat(input_).to(device) + output = model(input_) + loss, accuracy = loss_fn(output) yield loss.item(), accuracy.item() -- cgit v1.2.3