blob: 445badda2cf3cea23cec726d95c0da3af8a8ca5c (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
|
from typing import TypedDict, Optional, Union
import torch
from utils.dataset import ClipLabels, ClipConditions, ClipViews
class SystemConfiguration(TypedDict):
device: torch.device
CUDA_VISIBLE_DEVICES: str
save_path: str
class DatasetConfiguration(TypedDict):
name: str
path: str
train_size: int
num_sampled_frames: int
discard_threshold: int
selector: Optional[dict[str, Union[ClipLabels, ClipConditions, ClipViews]]]
num_input_channels: int
frame_size: tuple[int, int]
cache_on: bool
class DataloaderConfiguration(TypedDict):
batch_size: tuple[int, int]
num_workers: int
pin_memory: bool
class HyperparameterConfiguration(TypedDict):
hidden_dim: int
lr: int
betas: tuple[float, float]
hard_or_all: str
margin: float
class ModelConfiguration(TypedDict):
name: str
restore_iter: int
total_iter: int
class Configuration(TypedDict):
system: SystemConfiguration
dataset: DatasetConfiguration
dataloader: DataloaderConfiguration
hyperparameter: HyperparameterConfiguration
model: ModelConfiguration
|