aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-07-14 15:52:18 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-07-14 16:04:04 +0800
commit7246e6b698f12bc70004c0a4c7a1a8641573cd10 (patch)
tree0ddd93a5aa28cc9498054bcfb9f605aec89b4446
parent4e10bb682262d61b0331536805ecdc6b58e7219c (diff)
Code cleaning
-rw-r--r--supervised/baseline.py13
1 files changed, 6 insertions, 7 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index 818841c..a9d9862 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -1,20 +1,19 @@
+import argparse
+import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Callable
-path = str(Path(Path(__file__).parent.absolute()).parent.absolute())
-sys.path.insert(0, path)
-
-import argparse
-import os
-
import torch
import yaml
from torch.utils.data import Dataset
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.transforms import transforms
+path = str(Path(Path(__file__).parent.absolute()).parent.absolute())
+sys.path.insert(0, path)
+
from libs.datautils import Clip
from libs.schedulers import LinearLR
from libs.utils import Trainer, BaseConfig
@@ -253,7 +252,7 @@ class SupBaselineTrainer(Trainer):
model = self.models['model']
model.eval()
with torch.no_grad():
- for batch, (images, targets) in enumerate(self.test_loader):
+ for images, targets in self.test_loader:
images, targets = images.to(device), targets.to(device)
output = model(images)
loss = loss_fn(output, targets)