diff options
author | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-19 14:02:04 +0800 |
---|---|---|
committer | Jordan Gong <jordan.gong@protonmail.com> | 2022-08-19 14:02:04 +0800 |
commit | 26420733f98292639b9addb02e73fd8f12ee82e7 (patch) | |
tree | 971d3e5223be2261562e6ca404378aa34aaee453 | |
parent | bff36e9337bd9493e95588b8e342431eb31184f6 (diff) |
Add dependencies
- Pytorch Lighting (pytorch-lighting)
- Pytorch image models (timm)
-rw-r--r-- | environment.yml | 9 | ||||
-rw-r--r-- | posrecon/main.py | 31 | ||||
-rw-r--r-- | simclr/config-ddp.example.yaml | 37 | ||||
-rw-r--r-- | supervised/config-resnet-ddp.yaml | 29 |
4 files changed, 104 insertions, 2 deletions
diff --git a/environment.yml b/environment.yml index 7796ffe..a46316e 100644 --- a/environment.yml +++ b/environment.yml @@ -1,6 +1,7 @@ name: pytorch-stable channels: - pytorch + - conda-forge - defaults dependencies: - _libgcc_mutex=0.1=main @@ -17,8 +18,8 @@ dependencies: - brotlipy=0.7.0=py310h7f8727e_1002 - bzip2=1.0.8=h7b6447c_0 - c-ares=1.18.1=h7f8727e_0 - - ca-certificates=2022.07.19=h06a4308_0 - - certifi=2022.6.15=py310h06a4308_0 + - ca-certificates=2022.6.15=ha878542_0 + - certifi=2022.6.15=py310hff52083_0 - cffi=1.15.1=py310h74dc2b5_0 - charset-normalizer=2.0.4=pyhd3eb1b0_0 - cryptography=37.0.1=py310h9ce1e76_0 @@ -31,6 +32,7 @@ dependencies: - executing=0.8.3=pyhd3eb1b0_0 - ffmpeg=4.3=hf484d3e_0 - freetype=2.11.0=h70c0345_0 + - future=0.18.2=py310hff52083_5 - giflib=5.2.1=h7b6447c_0 - gmp=6.2.1=h295c915_3 - gnutls=3.6.15=he1e5248_0 @@ -112,7 +114,9 @@ dependencies: - python=3.10.4=h12debd9_0 - python-dateutil=2.8.2=pyhd3eb1b0_0 - python-fastjsonschema=2.15.1=pyhd3eb1b0_0 + - python_abi=3.10=2_cp310 - pytorch=1.12.1=py3.10_cuda11.3_cudnn8.3.2_0 + - pytorch-lightning=0.8.5=py_0 - pytorch-mutex=1.0=cuda - pyzmq=23.2.0=py310h6a678d5_0 - readline=8.1.2=h7f8727e_1 @@ -152,3 +156,4 @@ dependencies: - zstd=1.5.2=ha4553b6_0 - pip: - pyyaml==6.0 + - timm==0.6.7 diff --git a/posrecon/main.py b/posrecon/main.py new file mode 100644 index 0000000..0660ea0 --- /dev/null +++ b/posrecon/main.py @@ -0,0 +1,31 @@ +from typing import Callable, Iterable + +import torch +from torch.utils.data import Dataset + +from libs.logging import Loggers +from libs.utils import Trainer, BaseConfig + + +class PosReconTrainer(Trainer): + def __init__(self, *args, **kwargs): + super(PosReconTrainer, self).__init__(*args, **kwargs) + + @staticmethod + def _prepare_dataset(dataset_config: BaseConfig.DatasetConfig) -> tuple[Dataset, Dataset]: + pass + + @staticmethod + def _init_models(dataset: str) -> Iterable[tuple[str, torch.nn.Module]]: + pass + + @staticmethod + def _configure_optimizers(models: Iterable[tuple[str, torch.nn.Module]], optim_config: BaseConfig.OptimConfig) -> \ + Iterable[tuple[str, torch.optim.Optimizer]]: + pass + + def train(self, num_iters: int, loss_fn: Callable, logger: Loggers, device: torch.device): + pass + + def eval(self, loss_fn: Callable, device: torch.device): + pass diff --git a/simclr/config-ddp.example.yaml b/simclr/config-ddp.example.yaml new file mode 100644 index 0000000..67ee54f --- /dev/null +++ b/simclr/config-ddp.example.yaml @@ -0,0 +1,37 @@ +codename: cifar10-simclr-96-lars-warmup-ddp-example +seed: -1 +num_iters: 26042 +log_dir: logs +checkpoint_dir: checkpoints + +hid_dim: 2048 +out_dim: 128 +temp: 0.5 + +dataset: cifar10 +dataset_dir: dataset +crop_size: 32 +crop_scale_range: + - 0.8 + - 1 +hflip_prob: 0.5 +distort_strength: 0.5 +#gauss_ker_scale: 10 +#gauss_sigma_range: +# - 0.1 +# - 2 +#gauss_prob: 0.5 + +batch_size: 96 +num_workers: 2 + +optim: lars +lr: 0.5 +momentum: 0.9 +#betas: +# - 0.9 +# - 0.999 +weight_decay: 1.0e-06 + +sched: warmup-anneal +warmup_iters: 2604
\ No newline at end of file diff --git a/supervised/config-resnet-ddp.yaml b/supervised/config-resnet-ddp.yaml new file mode 100644 index 0000000..6fc8007 --- /dev/null +++ b/supervised/config-resnet-ddp.yaml @@ -0,0 +1,29 @@ +codename: cifar10-resnet-128-adam-warmup-anneal-ddp +seed: -1 +num_iters: 200 +log_dir: logs +checkpoint_dir: checkpoints + +backbone: resnet +label_smooth: 0 + +dataset: cifar10 +dataset_dir: dataset +crop_size: 32 +crop_scale_range: + - 0.8 + - 1 +hflip_prob: 0.5 + +batch_size: 128 +num_workers: 2 + +optim: adam +lr: 0.001 +betas: + - 0.9 + - 0.999 +weight_decay: 1.0e-06 + +sched: warmup-anneal +warmup_iters: 5
\ No newline at end of file |