summaryrefslogtreecommitdiff
path: root/eval.py
diff options
context:
space:
mode:
authorJordan Gong <jordan.gong@protonmail.com>2021-01-12 11:29:02 +0800
committerJordan Gong <jordan.gong@protonmail.com>2021-01-12 11:29:02 +0800
commit966d4431c037b0c4641aa2a5fc22f05be064b331 (patch)
tree0239ba89d31857a7f086acf627fc1bbf167855a9 /eval.py
parent7825f978f198e56958703f0d08f7ccbd8cef49ca (diff)
parent36cf502afe9b93efe31c244030270b0a62e644b8 (diff)
Merge branch 'master' into python3.8
# Conflicts: # models/model.py
Diffstat (limited to 'eval.py')
-rw-r--r--eval.py30
1 files changed, 30 insertions, 0 deletions
diff --git a/eval.py b/eval.py
new file mode 100644
index 0000000..7b68220
--- /dev/null
+++ b/eval.py
@@ -0,0 +1,30 @@
+import numpy as np
+
+from config import config
+from models import Model
+from utils.dataset import ClipConditions
+from utils.misc import set_visible_cuda
+
+set_visible_cuda(config['system'])
+model = Model(config['system'], config['model'], config['hyperparameter'])
+
+dataset_selectors = {
+ 'nm': {'conditions': ClipConditions({r'nm-0\d'})},
+ 'bg': {'conditions': ClipConditions({r'nm-0\d', r'bg-0\d'})},
+ 'cl': {'conditions': ClipConditions({r'nm-0\d', r'cl-0\d'})},
+}
+
+accuracy = model.predict_all(config['model']['total_iter'], config['dataset'],
+ dataset_selectors, config['dataloader'])
+rank = 5
+np.set_printoptions(formatter={'float': '{:5.2f}'.format})
+for n in range(rank):
+ print(f'===Rank-{n + 1} Accuracy===')
+ for (condition, accuracy_c) in accuracy.items():
+ acc_excl_identical_view = accuracy_c[:, :, n].fill_diagonal_(0)
+ num_gallery_views = (acc_excl_identical_view != 0).sum()
+ acc_each_angle = acc_excl_identical_view.sum(0) / num_gallery_views
+ print('{0}: {1} mean: {2:5.2f}'.format(
+ condition, acc_each_angle.cpu().numpy() * 100,
+ acc_each_angle.mean() * 100)
+ )