diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-17 17:05:10 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-03-17 17:05:10 +0800 |
commit | 9723bf4b92b8c3fb779da6f4990ac702a68c7cc6 (patch) | |
tree | 0f1feba180f807143eb52fb29a701b3e1b67a5f9 /supervised | |
parent | d3cef6cf4d9a8c9afd2875bf76f072826a050f9b (diff) |
Add CSV logger
Diffstat (limited to 'supervised')
-rw-r--r-- | supervised/baseline.py | 28 | ||||
-rw-r--r-- | supervised/utils.py | 46 |
2 files changed, 60 insertions, 14 deletions
diff --git a/supervised/baseline.py b/supervised/baseline.py index e21aeee..dc92408 100644 --- a/supervised/baseline.py +++ b/supervised/baseline.py @@ -14,7 +14,7 @@ from datautils import color_distortion, Clip, RandomGaussianBlur from models import CIFARResNet50, ImageNetResNet50 from optimizers import LARS from schedulers import LinearWarmupAndCosineAnneal, LinearLR -from utils import training_log +from utils import training_log, setup_logging, EPOCH_LOGGER, BATCH_LOGGER def build_parser(): @@ -78,10 +78,21 @@ def build_parser(): train_group.add_argument('--weight_decay', default=1e-6, type=float, help='Weight decay (l2 regularization) (default: 1e-6)') + logging_group = parser.add_argument_group('Logging config') + logging_group.add_argument('--log_dir', default='logs', type=str, + help="Path to log directory (default: 'logs')") + logging_group.add_argument('--tensorboard_dir', default='runs', type=str, + help="Path to tensorboard directory (default: 'runs')") + logging_group.add_argument('--checkpoint_dir', default='checkpoints', type=str, + help='Path to checkpoints directory ' + "(default: 'checkpoints')") + args = parser.parse_args() args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - args.checkpoint_root = os.path.join('checkpoints', args.codename) - args.tensorboard_root = os.path.join('runs', args.codename) + args.batch_log_filename = os.path.join(args.log_dir, f'batch-{args.codename}.csv') + args.epoch_log_filename = os.path.join(args.log_dir, f'epoch-{args.codename}.csv') + args.tensorboard_root = os.path.join(args.tensorboard_dir, args.codename) + args.checkpoint_root = os.path.join(args.checkpoint_dir, args.codename) return args @@ -208,7 +219,7 @@ def configure_optimizer(args, model): return optimizer -@training_log +@training_log(EPOCH_LOGGER) def load_checkpoint(args, model, optimizer): checkpoint_path = os.path.join(args.checkpoint_root, f'{args.restore_epoch:04d}.pt') checkpoint = torch.load(checkpoint_path) @@ -275,7 +286,7 @@ def eval(args, test_loader, model, loss_fn): yield batch, loss.item(), accuracy.item() -@training_log +@training_log(BATCH_LOGGER) def batch_logger(args, writer, batch, epoch, loss, lr): global_batch = epoch * args.num_train_batches + batch writer.add_scalar('Batch loss/train', loss, global_batch + 1) @@ -292,7 +303,7 @@ def batch_logger(args, writer, batch, epoch, loss, lr): } -@training_log +@training_log(EPOCH_LOGGER) def epoch_logger(args, writer, epoch, train_loss, test_loss, test_accuracy): train_loss_mean = train_loss.mean().item() test_loss_mean = test_loss.mean().item() @@ -311,8 +322,7 @@ def epoch_logger(args, writer, epoch, train_loss, test_loss, test_accuracy): def save_checkpoint(args, epoch_log, model, optimizer): - if not os.path.exists(args.checkpoint_root): - os.makedirs(args.checkpoint_root) + os.makedirs(args.checkpoint_root, exist_ok=True) torch.save(epoch_log | { 'model_state_dict': model.state_dict(), @@ -323,6 +333,8 @@ def save_checkpoint(args, epoch_log, model, optimizer): if __name__ == '__main__': args = build_parser() set_seed(args) + setup_logging(BATCH_LOGGER, args.batch_log_filename) + setup_logging(EPOCH_LOGGER, args.epoch_log_filename) train_set, test_set = prepare_dataset(args) train_loader, test_loader = create_dataloader(args, train_set, test_set) diff --git a/supervised/utils.py b/supervised/utils.py index c477544..95d865a 100644 --- a/supervised/utils.py +++ b/supervised/utils.py @@ -1,7 +1,41 @@ -def training_log(func): - def wrapper(*args, **kwargs): - result = func(*args, **kwargs) - print(result) - return result +import logging +import os - return wrapper +EPOCH_LOGGER = 'epoch_logger' +BATCH_LOGGER = 'batch_logger' + + +def setup_logging(name="log", + filename=None, + stream_log_level="INFO", + file_log_level="INFO"): + logger = logging.getLogger(name) + logger.setLevel("INFO") + formatter = logging.Formatter( + '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S' + ) + stream_handler = logging.StreamHandler() + stream_handler.setLevel(getattr(logging, stream_log_level)) + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + if filename is not None: + os.makedirs(os.path.dirname(filename), exist_ok=True) + file_handler = logging.FileHandler(filename) + file_handler.setLevel(getattr(logging, file_log_level)) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + return logger + + +def training_log(name): + def log_this(function): + logger = logging.getLogger(name) + + def wrapper(*args, **kwargs): + output = function(*args, **kwargs) + logger.info(','.join(map(str, output.values()))) + return output + + return wrapper + + return log_this |