03. Computer Vision - Why do we torch.argmax(dim=1) in logits->prediction probs->prediction labels? #477
theanilsomani
started this conversation in
General
Replies: 1 comment 2 replies
-
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
In the below code, I want to ask why do we use the torch.argmax(dim=1) at "Turn predictions from logits -> prediction probabilities -> predictions labels"
What is the reason behind it?
Make predictions with trained model y_preds = [] model_2.eval() with torch.inference_mode(): for X, y in tqdm(test_dataloader, desc="Making predictions"): # Send data and targets to target device X, y = X.to(device), y.to(device) # Do the forward pass y_logit = model_2(X) # Turn predictions from logits -> prediction probabilities -> predictions labels y_pred = torch.softmax(y_logit, dim=1).argmax(dim=1) # Put predictions on CPU for evaluation y_preds.append(y_pred.cpu())
Beta Was this translation helpful? Give feedback.
All reactions