Skip to content

Commit

Permalink
Update gnn_reranking.py
Browse files Browse the repository at this point in the history
  • Loading branch information
layumi authored Aug 17, 2023
1 parent ad62bfe commit 4dae9cd
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion GPU-Re-Ranking/gnn_reranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@



def gnn_reranking(X_q, X_g, k1, k2):
def gnn_reranking(X_q, X_g, k1, k2, lamb=0.3):
query_num, gallery_num = X_q.shape[0], X_g.shape[0]
original_cos = torch.mm(X_q, X_g.t())

X_u = torch.cat((X_q, X_g), axis = 0)
original_score = torch.mm(X_u, X_u.t())
Expand All @@ -48,6 +49,7 @@ def gnn_reranking(X_q, X_g, k1, k2):

cosine_similarity = torch.mm(A[:query_num,], A[query_num:, ].t())
del A, S
cosine_similarity = (1-lamb)*cosine_similarity + lamb*original_cos

L = torch.sort(-cosine_similarity, dim = 1)[1]
L = L.data.cpu().numpy()
Expand Down

0 comments on commit 4dae9cd

Please sign in to comment.