diff --git a/GPU-Re-Ranking/gnn_reranking.py b/GPU-Re-Ranking/gnn_reranking.py index e3f9180..d57567d 100644 --- a/GPU-Re-Ranking/gnn_reranking.py +++ b/GPU-Re-Ranking/gnn_reranking.py @@ -23,8 +23,9 @@ -def gnn_reranking(X_q, X_g, k1, k2): +def gnn_reranking(X_q, X_g, k1, k2, lamb=0.3): query_num, gallery_num = X_q.shape[0], X_g.shape[0] + original_cos = torch.mm(X_q, X_g.t()) X_u = torch.cat((X_q, X_g), axis = 0) original_score = torch.mm(X_u, X_u.t()) @@ -48,6 +49,7 @@ def gnn_reranking(X_q, X_g, k1, k2): cosine_similarity = torch.mm(A[:query_num,], A[query_num:, ].t()) del A, S + cosine_similarity = (1-lamb)*cosine_similarity + lamb*original_cos L = torch.sort(-cosine_similarity, dim = 1)[1] L = L.data.cpu().numpy()