Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

In DifferentiableAdam, sqrt() is non-differentiable at zero #125

Open
rickyloynd-microsoft opened this issue Dec 10, 2021 · 6 comments
Open

Comments

@rickyloynd-microsoft
Copy link

rickyloynd-microsoft commented Dec 10, 2021

When using higher with Adam as the inner optimizer, calling the outer loss.backward() sometimes raises the following torch error:

RuntimeError: Function 'SqrtBackward' returned nan values in its 0th output.

The problem occurs when the exp_avg_sq tensor in DifferentiableAdam contains a zero, in which case exp_avg_sq.sqrt() is non-differentiable.

The problem disappears when I add a tiny value to exp_avg_sq before applying the sqrt():

exp_avg_sq = exp_avg_sq + 1e-16

But I don’t know whether this would cause other problems.

Is the _maybe_mask function designed to deal with zeros in exp_avg_sq?

I’m using the latest pip-installed version (higher=0.2.1).

@rickyloynd-microsoft
Copy link
Author

Forgot to mention, anomaly detection needs to be enabled to see the runtime error:

torch.autograd.set_detect_anomaly(True)

@zhiqihuang
Copy link

@rickyloynd-microsoft I'm facing the same problem. Did you solve it without adding a small value?

@rickyloynd-microsoft
Copy link
Author

Yes, the solution above has been working for me.

@rickyloynd-microsoft
Copy link
Author

rickyloynd-microsoft commented Feb 2, 2022

Although I'm having another issue now with Adam's internal state not getting persisted (and detached) between rollouts (#114).

@yuanhaitao
Copy link

Yes, the solution above has been working for me.

Can you tell me how to solve the problem?

@yuanhaitao
Copy link

@rickyloynd-microsoft I'm facing the same problem. Did you solve it without adding a small value?
@zhiqihuang 老哥,你解决这个问题了么?我也遇到了这个问题

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants