From cf8251e56332bfe89d2078b06c92d8d489352266 Mon Sep 17 00:00:00 2001 From: Brandon Amos Date: Tue, 4 Feb 2020 13:21:50 -0800 Subject: [PATCH] fix grads --- lml.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lml.py b/lml.py index 1e1382c..0e1d020 100755 --- a/lml.py +++ b/lml.py @@ -164,7 +164,8 @@ def backward(ctx, grad_output): dx = torch.zeros_like(x) if single: dx = dx.squeeze() - return dx + grads = tuple([dx] + [None]*5) + return grads Hinv = 1./(1./y + 1./(1.-y)) dnu = bdot(Hinv, grad_output)/Hinv.sum(dim=1)