aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--environment.yml9
-rw-r--r--posrecon/main.py31
-rw-r--r--simclr/config-ddp.example.yaml37
-rw-r--r--supervised/config-resnet-ddp.yaml29
4 files changed, 104 insertions, 2 deletions
diff --git a/environment.yml b/environment.yml
index 7796ffe..a46316e 100644
--- a/environment.yml
+++ b/environment.yml
@@ -1,6 +1,7 @@
name: pytorch-stable
channels:
- pytorch
+ - conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
@@ -17,8 +18,8 @@ dependencies:
- brotlipy=0.7.0=py310h7f8727e_1002
- bzip2=1.0.8=h7b6447c_0
- c-ares=1.18.1=h7f8727e_0
- - ca-certificates=2022.07.19=h06a4308_0
- - certifi=2022.6.15=py310h06a4308_0
+ - ca-certificates=2022.6.15=ha878542_0
+ - certifi=2022.6.15=py310hff52083_0
- cffi=1.15.1=py310h74dc2b5_0
- charset-normalizer=2.0.4=pyhd3eb1b0_0
- cryptography=37.0.1=py310h9ce1e76_0
@@ -31,6 +32,7 @@ dependencies:
- executing=0.8.3=pyhd3eb1b0_0
- ffmpeg=4.3=hf484d3e_0
- freetype=2.11.0=h70c0345_0
+ - future=0.18.2=py310hff52083_5
- giflib=5.2.1=h7b6447c_0
- gmp=6.2.1=h295c915_3
- gnutls=3.6.15=he1e5248_0
@@ -112,7 +114,9 @@ dependencies:
- python=3.10.4=h12debd9_0
- python-dateutil=2.8.2=pyhd3eb1b0_0
- python-fastjsonschema=2.15.1=pyhd3eb1b0_0
+ - python_abi=3.10=2_cp310
- pytorch=1.12.1=py3.10_cuda11.3_cudnn8.3.2_0
+ - pytorch-lightning=0.8.5=py_0
- pytorch-mutex=1.0=cuda
- pyzmq=23.2.0=py310h6a678d5_0
- readline=8.1.2=h7f8727e_1
@@ -152,3 +156,4 @@ dependencies:
- zstd=1.5.2=ha4553b6_0
- pip:
- pyyaml==6.0
+ - timm==0.6.7
diff --git a/posrecon/main.py b/posrecon/main.py
new file mode 100644
index 0000000..0660ea0
--- /dev/null
+++ b/posrecon/main.py
@@ -0,0 +1,31 @@
+from typing import Callable, Iterable
+
+import torch
+from torch.utils.data import Dataset
+
+from libs.logging import Loggers
+from libs.utils import Trainer, BaseConfig
+
+
+class PosReconTrainer(Trainer):
+ def __init__(self, *args, **kwargs):
+ super(PosReconTrainer, self).__init__(*args, **kwargs)
+
+ @staticmethod
+ def _prepare_dataset(dataset_config: BaseConfig.DatasetConfig) -> tuple[Dataset, Dataset]:
+ pass
+
+ @staticmethod
+ def _init_models(dataset: str) -> Iterable[tuple[str, torch.nn.Module]]:
+ pass
+
+ @staticmethod
+ def _configure_optimizers(models: Iterable[tuple[str, torch.nn.Module]], optim_config: BaseConfig.OptimConfig) -> \
+ Iterable[tuple[str, torch.optim.Optimizer]]:
+ pass
+
+ def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device):
+ pass
+
+ def eval(self, loss_fn: Callable, device: torch.device):
+ pass
diff --git a/simclr/config-ddp.example.yaml b/simclr/config-ddp.example.yaml
new file mode 100644
index 0000000..67ee54f
--- /dev/null
+++ b/simclr/config-ddp.example.yaml
@@ -0,0 +1,37 @@
+codename: cifar10-simclr-96-lars-warmup-ddp-example
+seed: -1
+num_iters: 26042
+log_dir: logs
+checkpoint_dir: checkpoints
+
+hid_dim: 2048
+out_dim: 128
+temp: 0.5
+
+dataset: cifar10
+dataset_dir: dataset
+crop_size: 32
+crop_scale_range:
+ - 0.8
+ - 1
+hflip_prob: 0.5
+distort_strength: 0.5
+#gauss_ker_scale: 10
+#gauss_sigma_range:
+# - 0.1
+# - 2
+#gauss_prob: 0.5
+
+batch_size: 96
+num_workers: 2
+
+optim: lars
+lr: 0.5
+momentum: 0.9
+#betas:
+# - 0.9
+# - 0.999
+weight_decay: 1.0e-06
+
+sched: warmup-anneal
+warmup_iters: 2604 \ No newline at end of file
diff --git a/supervised/config-resnet-ddp.yaml b/supervised/config-resnet-ddp.yaml
new file mode 100644
index 0000000..6fc8007
--- /dev/null
+++ b/supervised/config-resnet-ddp.yaml
@@ -0,0 +1,29 @@
+codename: cifar10-resnet-128-adam-warmup-anneal-ddp
+seed: -1
+num_iters: 200
+log_dir: logs
+checkpoint_dir: checkpoints
+
+backbone: resnet
+label_smooth: 0
+
+dataset: cifar10
+dataset_dir: dataset
+crop_size: 32
+crop_scale_range:
+ - 0.8
+ - 1
+hflip_prob: 0.5
+
+batch_size: 128
+num_workers: 2
+
+optim: adam
+lr: 0.001
+betas:
+ - 0.9
+ - 0.999
+weight_decay: 1.0e-06
+
+sched: warmup-anneal
+warmup_iters: 5 \ No newline at end of file