aboutsummaryrefslogtreecommitdiff
path: root/libs/datautils.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2022-08-13 23:38:43 +0800
committerJordan Gong <jordan.gong@protonmail.com>2022-08-13 23:38:43 +0800
commit957a2a46e7725184776c3c72860e8215164cc4ef (patch)
tree43e098595db4ee332bca5f6caecfbd02369debbe /libs/datautils.py
parent1b8f01ce9706905c36c6f11ed9deac8548ad7341 (diff)
Implement distributed data parallel via torch elastic launcher
Diffstat (limited to 'libs/datautils.py')
-rw-r--r--libs/datautils.py11
1 files changed, 11 insertions, 0 deletions
diff --git a/libs/datautils.py b/libs/datautils.py
index 6a7c506..53222a8 100644
--- a/libs/datautils.py
+++ b/libs/datautils.py
@@ -125,3 +125,14 @@ class TwinTransform:
v1 = self.transform(x)
v2 = self.transform(x)
return v1, v2
+
+
+class ContinuousSampler(torch.utils.data.sampler.Sampler):
+ def __init__(self, sampler):
+ super(ContinuousSampler, self).__init__(sampler)
+ self.base_sampler = sampler
+
+ def __iter__(self):
+ while True:
+ for batch in self.base_sampler:
+ yield batch