You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a custom embedding network, and everything is on cuda on my script. However, I haven't been able to use 'mdn' in multi-round trading (I'm still on SBI 0.23.2). After adding a lot of debugging statement, I realized that in calculations of mog_log_prob, constant is still on cpu. It might be still an issue from my side, but I just wanted to share this. here's the printing statements of other components of that function:
theta.device = cuda:0 //from my script
experiment_data.device = cuda:0 //from my script
Using SNPE-C with non-atomic loss
theta.device in mog_log_prob = cuda:0
logits_pp.device in mog_log_prob = cuda:0
means_pp.device in mog_log_prob = cuda:0
precisions_pp.device in mog_log_prob = cuda:0
weights.device in mog_log_prob = cuda:0
constant.device in mog_log_prob = cpu //this one
log_det.device in mog_log_prob = cuda:0
theta_minus_mean.device in mog_log_prob = cuda:0
exponent.device in mog_log_prob = cuda:0
The text was updated successfully, but these errors were encountered:
Hi @ali-akhavan89 thanks for digging into this and reporting!
Yes, this looks like a bug, constant should be created on the current device. I am just wondering why our GPU tests dont catch this because they seem to cover multi-round NPE_C with MDN. I will have a look and propose a fix soon.
I have a custom embedding network, and everything is on cuda on my script. However, I haven't been able to use 'mdn' in multi-round trading (I'm still on SBI 0.23.2). After adding a lot of debugging statement, I realized that in calculations of mog_log_prob, constant is still on cpu. It might be still an issue from my side, but I just wanted to share this. here's the printing statements of other components of that function:
theta.device = cuda:0 //from my script
experiment_data.device = cuda:0 //from my script
Using SNPE-C with non-atomic loss
theta.device in mog_log_prob = cuda:0
logits_pp.device in mog_log_prob = cuda:0
means_pp.device in mog_log_prob = cuda:0
precisions_pp.device in mog_log_prob = cuda:0
weights.device in mog_log_prob = cuda:0
constant.device in mog_log_prob = cpu //this one
log_det.device in mog_log_prob = cuda:0
theta_minus_mean.device in mog_log_prob = cuda:0
exponent.device in mog_log_prob = cuda:0
The text was updated successfully, but these errors were encountered: