aboutsummaryrefslogtreecommitdiff
path: root/supervised/baseline.py
diff options
context:
space:
mode:
Diffstat (limited to 'supervised/baseline.py')
-rw-r--r--supervised/baseline.py28
1 files changed, 20 insertions, 8 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py
index e21aeee..dc92408 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -14,7 +14,7 @@ from datautils import color_distortion, Clip, RandomGaussianBlur
from models import CIFARResNet50, ImageNetResNet50
from optimizers import LARS
from schedulers import LinearWarmupAndCosineAnneal, LinearLR
-from utils import training_log
+from utils import training_log, setup_logging, EPOCH_LOGGER, BATCH_LOGGER
def build_parser():
@@ -78,10 +78,21 @@ def build_parser():
train_group.add_argument('--weight_decay', default=1e-6, type=float,
help='Weight decay (l2 regularization) (default: 1e-6)')
+ logging_group = parser.add_argument_group('Logging config')
+ logging_group.add_argument('--log_dir', default='logs', type=str,
+ help="Path to log directory (default: 'logs')")
+ logging_group.add_argument('--tensorboard_dir', default='runs', type=str,
+ help="Path to tensorboard directory (default: 'runs')")
+ logging_group.add_argument('--checkpoint_dir', default='checkpoints', type=str,
+ help='Path to checkpoints directory '
+ "(default: 'checkpoints')")
+
args = parser.parse_args()
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- args.checkpoint_root = os.path.join('checkpoints', args.codename)
- args.tensorboard_root = os.path.join('runs', args.codename)
+ args.batch_log_filename = os.path.join(args.log_dir, f'batch-{args.codename}.csv')
+ args.epoch_log_filename = os.path.join(args.log_dir, f'epoch-{args.codename}.csv')
+ args.tensorboard_root = os.path.join(args.tensorboard_dir, args.codename)
+ args.checkpoint_root = os.path.join(args.checkpoint_dir, args.codename)
return args
@@ -208,7 +219,7 @@ def configure_optimizer(args, model):
return optimizer
-@training_log
+@training_log(EPOCH_LOGGER)
def load_checkpoint(args, model, optimizer):
checkpoint_path = os.path.join(args.checkpoint_root, f'{args.restore_epoch:04d}.pt')
checkpoint = torch.load(checkpoint_path)
@@ -275,7 +286,7 @@ def eval(args, test_loader, model, loss_fn):
yield batch, loss.item(), accuracy.item()
-@training_log
+@training_log(BATCH_LOGGER)
def batch_logger(args, writer, batch, epoch, loss, lr):
global_batch = epoch * args.num_train_batches + batch
writer.add_scalar('Batch loss/train', loss, global_batch + 1)
@@ -292,7 +303,7 @@ def batch_logger(args, writer, batch, epoch, loss, lr):
}
-@training_log
+@training_log(EPOCH_LOGGER)
def epoch_logger(args, writer, epoch, train_loss, test_loss, test_accuracy):
train_loss_mean = train_loss.mean().item()
test_loss_mean = test_loss.mean().item()
@@ -311,8 +322,7 @@ def epoch_logger(args, writer, epoch, train_loss, test_loss, test_accuracy):
def save_checkpoint(args, epoch_log, model, optimizer):
- if not os.path.exists(args.checkpoint_root):
- os.makedirs(args.checkpoint_root)
+ os.makedirs(args.checkpoint_root, exist_ok=True)
torch.save(epoch_log | {
'model_state_dict': model.state_dict(),
@@ -323,6 +333,8 @@ def save_checkpoint(args, epoch_log, model, optimizer):
if __name__ == '__main__':
args = build_parser()
set_seed(args)
+ setup_logging(BATCH_LOGGER, args.batch_log_filename)
+ setup_logging(EPOCH_LOGGER, args.epoch_log_filename)
train_set, test_set = prepare_dataset(args)
train_loader, test_loader = create_dataloader(args, train_set, test_set)