aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--.idea/csv-plugin.xml23
-rw-r--r--supervised/baseline.py28
-rw-r--r--supervised/utils.py46
4 files changed, 84 insertions, 14 deletions
diff --git a/.gitignore b/.gitignore
index 08f506b..c545d59 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,4 @@
/checkpoints/
/dataset/
/runs/
+/logs/
diff --git a/.idea/csv-plugin.xml b/.idea/csv-plugin.xml
new file mode 100644
index 0000000..3263370
--- /dev/null
+++ b/.idea/csv-plugin.xml
@@ -0,0 +1,23 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+ <component name="CsvFileAttributes">
+ <option name="attributeMap">
+ <map>
+ <entry key="/logs/batch-cifar10-resnet50-256-lars-warmup.csv">
+ <value>
+ <Attribute>
+ <option name="separator" value="," />
+ </Attribute>
+ </value>
+ </entry>
+ <entry key="/logs/epoch-cifar10-resnet50-256-lars-warmup.csv">
+ <value>
+ <Attribute>
+ <option name="separator" value="," />
+ </Attribute>
+ </value>
+ </entry>
+ </map>
+ </option>
+ </component>
+</project> \ No newline at end of file
diff --git a/supervised/baseline.py b/supervised/baseline.py
index e21aeee..dc92408 100644
--- a/supervised/baseline.py
+++ b/supervised/baseline.py
@@ -14,7 +14,7 @@ from datautils import color_distortion, Clip, RandomGaussianBlur
from models import CIFARResNet50, ImageNetResNet50
from optimizers import LARS
from schedulers import LinearWarmupAndCosineAnneal, LinearLR
-from utils import training_log
+from utils import training_log, setup_logging, EPOCH_LOGGER, BATCH_LOGGER
def build_parser():
@@ -78,10 +78,21 @@ def build_parser():
train_group.add_argument('--weight_decay', default=1e-6, type=float,
help='Weight decay (l2 regularization) (default: 1e-6)')
+ logging_group = parser.add_argument_group('Logging config')
+ logging_group.add_argument('--log_dir', default='logs', type=str,
+ help="Path to log directory (default: 'logs')")
+ logging_group.add_argument('--tensorboard_dir', default='runs', type=str,
+ help="Path to tensorboard directory (default: 'runs')")
+ logging_group.add_argument('--checkpoint_dir', default='checkpoints', type=str,
+ help='Path to checkpoints directory '
+ "(default: 'checkpoints')")
+
args = parser.parse_args()
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- args.checkpoint_root = os.path.join('checkpoints', args.codename)
- args.tensorboard_root = os.path.join('runs', args.codename)
+ args.batch_log_filename = os.path.join(args.log_dir, f'batch-{args.codename}.csv')
+ args.epoch_log_filename = os.path.join(args.log_dir, f'epoch-{args.codename}.csv')
+ args.tensorboard_root = os.path.join(args.tensorboard_dir, args.codename)
+ args.checkpoint_root = os.path.join(args.checkpoint_dir, args.codename)
return args
@@ -208,7 +219,7 @@ def configure_optimizer(args, model):
return optimizer
-@training_log
+@training_log(EPOCH_LOGGER)
def load_checkpoint(args, model, optimizer):
checkpoint_path = os.path.join(args.checkpoint_root, f'{args.restore_epoch:04d}.pt')
checkpoint = torch.load(checkpoint_path)
@@ -275,7 +286,7 @@ def eval(args, test_loader, model, loss_fn):
yield batch, loss.item(), accuracy.item()
-@training_log
+@training_log(BATCH_LOGGER)
def batch_logger(args, writer, batch, epoch, loss, lr):
global_batch = epoch * args.num_train_batches + batch
writer.add_scalar('Batch loss/train', loss, global_batch + 1)
@@ -292,7 +303,7 @@ def batch_logger(args, writer, batch, epoch, loss, lr):
}
-@training_log
+@training_log(EPOCH_LOGGER)
def epoch_logger(args, writer, epoch, train_loss, test_loss, test_accuracy):
train_loss_mean = train_loss.mean().item()
test_loss_mean = test_loss.mean().item()
@@ -311,8 +322,7 @@ def epoch_logger(args, writer, epoch, train_loss, test_loss, test_accuracy):
def save_checkpoint(args, epoch_log, model, optimizer):
- if not os.path.exists(args.checkpoint_root):
- os.makedirs(args.checkpoint_root)
+ os.makedirs(args.checkpoint_root, exist_ok=True)
torch.save(epoch_log | {
'model_state_dict': model.state_dict(),
@@ -323,6 +333,8 @@ def save_checkpoint(args, epoch_log, model, optimizer):
if __name__ == '__main__':
args = build_parser()
set_seed(args)
+ setup_logging(BATCH_LOGGER, args.batch_log_filename)
+ setup_logging(EPOCH_LOGGER, args.epoch_log_filename)
train_set, test_set = prepare_dataset(args)
train_loader, test_loader = create_dataloader(args, train_set, test_set)
diff --git a/supervised/utils.py b/supervised/utils.py
index c477544..95d865a 100644
--- a/supervised/utils.py
+++ b/supervised/utils.py
@@ -1,7 +1,41 @@
-def training_log(func):
- def wrapper(*args, **kwargs):
- result = func(*args, **kwargs)
- print(result)
- return result
+import logging
+import os
- return wrapper
+EPOCH_LOGGER = 'epoch_logger'
+BATCH_LOGGER = 'batch_logger'
+
+
+def setup_logging(name="log",
+ filename=None,
+ stream_log_level="INFO",
+ file_log_level="INFO"):
+ logger = logging.getLogger(name)
+ logger.setLevel("INFO")
+ formatter = logging.Formatter(
+ '%(asctime)s.%(msecs)03d,%(name)s,%(message)s', '%Y-%m-%d %H:%M:%S'
+ )
+ stream_handler = logging.StreamHandler()
+ stream_handler.setLevel(getattr(logging, stream_log_level))
+ stream_handler.setFormatter(formatter)
+ logger.addHandler(stream_handler)
+ if filename is not None:
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ file_handler = logging.FileHandler(filename)
+ file_handler.setLevel(getattr(logging, file_log_level))
+ file_handler.setFormatter(formatter)
+ logger.addHandler(file_handler)
+ return logger
+
+
+def training_log(name):
+ def log_this(function):
+ logger = logging.getLogger(name)
+
+ def wrapper(*args, **kwargs):
+ output = function(*args, **kwargs)
+ logger.info(','.join(map(str, output.values())))
+ return output
+
+ return wrapper
+
+ return log_this