aboutsummaryrefslogtreecommitdiff
path: root/supervised
diff options
context:
space:
mode:
Diffstat (limited to 'supervised')
-rw-r--r--supervised/baseline.py13
1 files changed, 6 insertions, 7 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 818841c..a9d9862 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -1,20 +1,19 @@
+import argparse
+import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Callable
-path = str(Path(Path(__file__).parent.absolute()).parent.absolute())
-sys.path.insert(0, path)
-
-import argparse
-import os
-
import torch
import yaml
from torch.utils.data import Dataset
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.transforms import transforms
+path = str(Path(Path(__file__).parent.absolute()).parent.absolute())
+sys.path.insert(0, path)
+
from libs.datautils import Clip
from libs.schedulers import LinearLR
from libs.utils import Trainer, BaseConfig
@@ -253,7 +252,7 @@ class SupBaselineTrainer(Trainer):
model = self.models['model']
model.eval()
with torch.no_grad():
- for batch, (images, targets) in enumerate(self.test_loader):
+ for images, targets in self.test_loader:
images, targets = images.to(device), targets.to(device)
output = model(images)
loss = loss_fn(output, targets)