diff options
-rw-r--r-- | supervised/baseline.py | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 818841c..a9d9862 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -1,20 +1,19 @@ +import argparse +import os import sys from dataclasses import dataclass from pathlib import Path from typing import Iterable, Callable -path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) -sys.path.insert(0, path) - -import argparse -import os - import torch import yaml from torch.utils.data import Dataset from torchvision.datasets import CIFAR10, CIFAR100 from torchvision.transforms import transforms +path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) +sys.path.insert(0, path) + from libs.datautils import Clip from libs.schedulers import LinearLR from libs.utils import Trainer, BaseConfig @@ -253,7 +252,7 @@ class SupBaselineTrainer(Trainer): model = self.models['model'] model.eval() with torch.no_grad(): - for batch, (images, targets) in enumerate(self.test_loader): + for images, targets in self.test_loader: images, targets = images.to(device), targets.to(device) output = model(images) loss = loss_fn(output, targets) |