From 4dae9cdf42f71c72a44a64fb23bfc470c501085f Mon Sep 17 00:00:00 2001 From: Zhedong Zheng Date: Thu, 17 Aug 2023 12:12:26 +0800 Subject: [PATCH] Update gnn_reranking.py --- GPU-Re-Ranking/gnn_reranking.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/GPU-Re-Ranking/gnn_reranking.py b/GPU-Re-Ranking/gnn_reranking.py index e3f9180..d57567d 100644 --- a/GPU-Re-Ranking/gnn_reranking.py +++ b/GPU-Re-Ranking/gnn_reranking.py @@ -23,8 +23,9 @@ -def gnn_reranking(X_q, X_g, k1, k2): +def gnn_reranking(X_q, X_g, k1, k2, lamb=0.3): query_num, gallery_num = X_q.shape[0], X_g.shape[0] + original_cos = torch.mm(X_q, X_g.t()) X_u = torch.cat((X_q, X_g), axis = 0) original_score = torch.mm(X_u, X_u.t()) @@ -48,6 +49,7 @@ def gnn_reranking(X_q, X_g, k1, k2): cosine_similarity = torch.mm(A[:query_num,], A[query_num:, ].t()) del A, S + cosine_similarity = (1-lamb)*cosine_similarity + lamb*original_cos L = torch.sort(-cosine_similarity, dim = 1)[1] L = L.data.cpu().numpy()