summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-12 23:16:55 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-12 23:16:55 +0800
commit333bacd26376250d8606d737f01361853527f451 (patch)
tree7ae35d386b9a3a18caa361361f6a77226868e7f9
parent544cd7d3408191a9cabb5e0f2e6e83e2a2a7782e (diff)
Bug fixesxla
1. Replace inplace Leaky ReLU in auto-encoder classifier with non-inplace one 2. Replace rank number with get_ordinal method in xmp
-rw-r--r--models/auto_encoder.py2
-rw-r--r--models/model.py8
2 files changed, 4 insertions, 6 deletions
diff --git a/models/auto_encoder.py b/models/auto_encoder.py
index 64c52e3..0c247f1 100644
--- a/models/auto_encoder.py
+++ b/models/auto_encoder.py
@@ -126,7 +126,7 @@ class AutoEncoder(nn.Module):
f_c_dim = embedding_dims[1]
self.classifier = nn.Sequential(
- nn.LeakyReLU(0.2, inplace=True),
+ nn.LeakyReLU(0.2),
BasicLinear(f_c_dim, num_class)
)
diff --git a/models/model.py b/models/model.py
index b86a050..689fc70 100644
--- a/models/model.py
+++ b/models/model.py
@@ -129,14 +129,12 @@ class Model:
para_loader = pl.ParallelLoader(dataloader, [device])
self._train_loop(
- rank,
para_loader.per_device_loader(device),
rgb_pn, optimizer, scheduler, writer
)
def _train_loop(
self,
- rank: int,
dataloader: pl.PerDeviceLoader,
rgb_pn: RGBPartNet,
optimizer: optim.Adam,
@@ -170,10 +168,10 @@ class Model:
# Write losses to TensorBoard
writer.add_scalar(
- f'[xla:{rank}]Loss/all', loss.item(), iter_i + 1
+ f'[xla:{xm.get_ordinal()}]Loss/all', loss.item(), iter_i + 1
)
writer.add_scalars(
- f'[xla:{rank}]Loss/details', dict(zip([
+ f'[xla:{xm.get_ordinal()}]Loss/details', dict(zip([
'Cross reconstruction loss', 'Pose similarity loss',
'Canonical consistency loss', 'Batch All triplet loss'
], metrics)),
@@ -181,7 +179,7 @@ class Model:
)
if iter_i % 100 == 99:
- print('[xla:{0}]({1:5d})'.format(rank, iter_i + 1),
+ print('[xla:{0}]({1:5d})'.format(xm.get_ordinal(), iter_i + 1),
'loss: {:6.3f}'.format(loss),
'(xrecon = {:f}, pose_sim = {:f},'
' cano_cons = {:f}, ba_trip = {:f})'.format(*metrics),