Skip to content

Commit

Permalink
A few optimizations that speedup training
Browse files Browse the repository at this point in the history
A few optimizations that speedup training
  • Loading branch information
hvoss-techfak authored Oct 8, 2024
1 parent 3a6d740 commit 3fa41fb
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 16 deletions.
36 changes: 22 additions & 14 deletions oml/miners/inbatch_hard_tri.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import List

import numpy as np
import torch
from torch import Tensor

from oml.interfaces.miners import ITripletsMinerInBatch, TTripletsIds
from oml.utils.misc import find_value_ids
from oml.utils.misc_torch import pairwise_dist


Expand Down Expand Up @@ -57,24 +56,33 @@ def _sample_from_distmat(distmat: Tensor, labels: List[int]) -> TTripletsIds:
``(anchor, positive, negative)``
"""
ids_all = set(range(len(labels)))

ids_anchor, ids_pos, ids_neg = [], [], []
labels = torch.tensor(labels)

for i_anch, label in enumerate(labels):
ids_label = set(find_value_ids(it=labels, value=label))
# Ensure labels are a torch tensor
labels = labels.to(distmat.device)

ids_pos_cur = np.array(list(ids_label - {i_anch}), int)
ids_neg_cur = np.array(list(ids_all - ids_label), int)
batch_size = labels.size(0)

i_pos = ids_pos_cur[distmat[i_anch, ids_pos_cur].argmax()]
i_neg = ids_neg_cur[distmat[i_anch, ids_neg_cur].argmin()]
label_equal = labels.unsqueeze(0) == labels.unsqueeze(1) # Shape [batch_size, batch_size]
label_not_equal = ~label_equal

ids_anchor.append(i_anch)
ids_pos.append(i_pos)
ids_neg.append(i_neg)
# Get the hardest positives: argmax over the distance matrix where labels match (i.e. hardest positive)
dist_pos = distmat.clone()
dist_pos[label_not_equal] = -float("inf") # Set non-positives to -inf
hardest_pos_idx = torch.argmax(dist_pos, dim=1)

return ids_anchor, ids_pos, ids_neg
# Get the hardest negatives: argmin over the distance matrix where labels don't match (i.e. hardest negative)
dist_neg = distmat.clone()
dist_neg[label_equal] = float("inf") # Set non-negatives to +inf
hardest_neg_idx = torch.argmin(dist_neg, dim=1)

# Return anchor indices, positive indices, and negative indices
ids_anchor = list(range(batch_size))
ids_pos = hardest_pos_idx
ids_neg = hardest_neg_idx

return ids_anchor, ids_pos.cpu().tolist(), ids_neg.cpu().tolist()


__all__ = ["HardTripletsMiner"]
10 changes: 8 additions & 2 deletions oml/samplers/balance.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import Counter
from collections import Counter, defaultdict
from typing import Iterator, List, Union

import numpy as np
Expand Down Expand Up @@ -64,7 +64,13 @@ def __init__(self, labels: Union[List[int], np.ndarray], n_labels: int, n_instan
self._unq_labels = unq_labels

labels = np.array(labels)
self.lbl2idx = {label: np.arange(len(labels))[labels == label].tolist() for label in set(labels)}

lbl2idx = defaultdict(list)

for idx, label in enumerate(labels):
lbl2idx[label].append(idx)

self.lbl2idx = dict(lbl2idx)

self._batches_in_epoch = len(self._unq_labels) // self.n_labels

Expand Down

0 comments on commit 3fa41fb

Please sign in to comment.