diff --git a/oml/miners/inbatch_hard_tri.py b/oml/miners/inbatch_hard_tri.py index 6d99747e3..8919f2db6 100644 --- a/oml/miners/inbatch_hard_tri.py +++ b/oml/miners/inbatch_hard_tri.py @@ -1,10 +1,9 @@ from typing import List -import numpy as np +import torch from torch import Tensor from oml.interfaces.miners import ITripletsMinerInBatch, TTripletsIds -from oml.utils.misc import find_value_ids from oml.utils.misc_torch import pairwise_dist @@ -57,24 +56,33 @@ def _sample_from_distmat(distmat: Tensor, labels: List[int]) -> TTripletsIds: ``(anchor, positive, negative)`` """ - ids_all = set(range(len(labels))) - ids_anchor, ids_pos, ids_neg = [], [], [] + labels = torch.tensor(labels) - for i_anch, label in enumerate(labels): - ids_label = set(find_value_ids(it=labels, value=label)) + # Ensure labels are a torch tensor + labels = labels.to(distmat.device) - ids_pos_cur = np.array(list(ids_label - {i_anch}), int) - ids_neg_cur = np.array(list(ids_all - ids_label), int) + batch_size = labels.size(0) - i_pos = ids_pos_cur[distmat[i_anch, ids_pos_cur].argmax()] - i_neg = ids_neg_cur[distmat[i_anch, ids_neg_cur].argmin()] + label_equal = labels.unsqueeze(0) == labels.unsqueeze(1) # Shape [batch_size, batch_size] + label_not_equal = ~label_equal - ids_anchor.append(i_anch) - ids_pos.append(i_pos) - ids_neg.append(i_neg) + # Get the hardest positives: argmax over the distance matrix where labels match (i.e. hardest positive) + dist_pos = distmat.clone() + dist_pos[label_not_equal] = -float("inf") # Set non-positives to -inf + hardest_pos_idx = torch.argmax(dist_pos, dim=1) - return ids_anchor, ids_pos, ids_neg + # Get the hardest negatives: argmin over the distance matrix where labels don't match (i.e. hardest negative) + dist_neg = distmat.clone() + dist_neg[label_equal] = float("inf") # Set non-negatives to +inf + hardest_neg_idx = torch.argmin(dist_neg, dim=1) + + # Return anchor indices, positive indices, and negative indices + ids_anchor = list(range(batch_size)) + ids_pos = hardest_pos_idx + ids_neg = hardest_neg_idx + + return ids_anchor, ids_pos.cpu().tolist(), ids_neg.cpu().tolist() __all__ = ["HardTripletsMiner"] diff --git a/oml/samplers/balance.py b/oml/samplers/balance.py index 63d5924f0..a610e1c36 100644 --- a/oml/samplers/balance.py +++ b/oml/samplers/balance.py @@ -1,4 +1,4 @@ -from collections import Counter +from collections import Counter, defaultdict from typing import Iterator, List, Union import numpy as np @@ -64,7 +64,13 @@ def __init__(self, labels: Union[List[int], np.ndarray], n_labels: int, n_instan self._unq_labels = unq_labels labels = np.array(labels) - self.lbl2idx = {label: np.arange(len(labels))[labels == label].tolist() for label in set(labels)} + + lbl2idx = defaultdict(list) + + for idx, label in enumerate(labels): + lbl2idx[label].append(idx) + + self.lbl2idx = dict(lbl2idx) self._batches_in_epoch = len(self._unq_labels) // self.n_labels