Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

constant.device in mog_log_prob remains on cpu #1343

Open
ali-akhavan89 opened this issue Dec 31, 2024 · 1 comment
Open

constant.device in mog_log_prob remains on cpu #1343

ali-akhavan89 opened this issue Dec 31, 2024 · 1 comment
Labels
bug Something isn't working question Further information is requested

Comments

@ali-akhavan89
Copy link

ali-akhavan89 commented Dec 31, 2024

I have a custom embedding network, and everything is on cuda on my script. However, I haven't been able to use 'mdn' in multi-round trading (I'm still on SBI 0.23.2). After adding a lot of debugging statement, I realized that in calculations of mog_log_prob, constant is still on cpu. It might be still an issue from my side, but I just wanted to share this. here's the printing statements of other components of that function:

theta.device = cuda:0 //from my script
experiment_data.device = cuda:0 //from my script

Using SNPE-C with non-atomic loss

theta.device in mog_log_prob = cuda:0
logits_pp.device in mog_log_prob = cuda:0
means_pp.device in mog_log_prob = cuda:0
precisions_pp.device in mog_log_prob = cuda:0
weights.device in mog_log_prob = cuda:0
constant.device in mog_log_prob = cpu //this one
log_det.device in mog_log_prob = cuda:0
theta_minus_mean.device in mog_log_prob = cuda:0
exponent.device in mog_log_prob = cuda:0

@ali-akhavan89 ali-akhavan89 added the question Further information is requested label Dec 31, 2024
@janfb janfb added the bug Something isn't working label Jan 2, 2025
@janfb
Copy link
Contributor

janfb commented Jan 2, 2025

Hi @ali-akhavan89 thanks for digging into this and reporting!

Yes, this looks like a bug, constant should be created on the current device. I am just wondering why our GPU tests dont catch this because they seem to cover multi-round NPE_C with MDN. I will have a look and propose a fix soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants