diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-07-14 15:52:18 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-07-14 16:04:04 +0800 |
commit | 7246e6b698f12bc70004c0a4c7a1a8641573cd10 (patch) | |
tree | 0ddd93a5aa28cc9498054bcfb9f605aec89b4446 /supervised/baseline.py | |
parent | 4e10bb682262d61b0331536805ecdc6b58e7219c (diff) |
Code cleaning
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r-- | supervised/baseline.py | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index 818841c..a9d9862 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -1,20 +1,19 @@ +import argparse +import os import sys from dataclasses import dataclass from pathlib import Path from typing import Iterable, Callable -path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) -sys.path.insert(0, path) - -import argparse -import os - import torch import yaml from torch.utils.data import Dataset from torchvision.datasets import CIFAR10, CIFAR100 from torchvision.transforms import transforms +path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) +sys.path.insert(0, path) + from libs.datautils import Clip from libs.schedulers import LinearLR from libs.utils import Trainer, BaseConfig @@ -253,7 +252,7 @@ class SupBaselineTrainer(Trainer): model = self.models['model'] model.eval() with torch.no_grad(): - for batch, (images, targets) in enumerate(self.test_loader): + for images, targets in self.test_loader: images, targets = images.to(device), targets.to(device) output = model(images) loss = loss_fn(output, targets) |