From f68422d72b3963930247ab5e9b647a5eb22d51cb Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Fri, 22 Jul 2022 18:38:00 -0600 Subject: [PATCH 1/3] Update grad.jl --- src/lib/grad.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/lib/grad.jl b/src/lib/grad.jl index a522d685a..e07c0c342 100644 --- a/src/lib/grad.jl +++ b/src/lib/grad.jl @@ -73,8 +73,7 @@ julia> hessian(sin, pi/2) """ hessian(f, x) = hessian_dual(f, x) -hessian_dual(f, x::AbstractArray) = forward_jacobian(x -> gradient(f, x)[1], x)[2] - +hessian_dual(f, x::AbstractArray) = ForwardDiff.jacobian(x -> gradient(f, x)[1], x) hessian_dual(f, x::Number) = ForwardDiff.derivative(x -> gradient(f, x)[1], x) """ @@ -234,11 +233,11 @@ end diaghessian(f, args...) -> Tuple Diagonal part of the Hessian. Returns a tuple containing, for each argument `x`, -`h` of the same shape with `h[i] = Hᵢᵢ = ∂²y/∂x[i]∂x[i]`. +`h` of the same shape with `h[i] = Hᵢᵢ = ∂²y/∂x[i]∂x[i]`. The original evaluation `y = f(args...)` must give a real number `y`. For one vector argument `x`, this is equivalent to `(diag(hessian(f,x)),)`. -Like [`hessian`](@ref) it uses ForwardDiff over Zygote. +Like [`hessian`](@ref) it uses ForwardDiff over Zygote. !!! warning For arguments of any type except `Number` & `AbstractArray`, the result is `nothing`. From 2ebf9a0b4f21a22cd0feafa0de6f18ae02095138 Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Mon, 25 Jul 2022 20:36:14 -0600 Subject: [PATCH 2/3] add testing --- test/utils.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/utils.jl b/test/utils.jl index 40b2e85b7..317da6f82 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -3,13 +3,19 @@ using ForwardDiff using Zygote: hessian_dual, hessian_reverse @testset "hessian: $hess" for hess in [hessian_dual, hessian_reverse] + function f(x, bias) + hessian = hess(x->sum(x.^3), x) + return hessian * x .+ bias + end if hess == hessian_dual @test hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0] @test hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0] # original docstring version + @test gradient(b->sum(f(rand(3),b)),rand(3))[1] ≈ [1, 1, 1] else @test_broken hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0] # can't differentiate ∇getindex @test_broken hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0] + @test_broken gradient(b->sum(f(rand(3),b)),rand(3))[1] ≈ [1, 1, 1] # jacobian is not differentiable end @test hess(x -> sum(x.^3), [1 2; 3 4]) ≈ Diagonal([6, 18, 12, 24]) @test hess(sin, pi/2) ≈ -1 @@ -133,7 +139,7 @@ using ForwardDiff g3(x) = sum(abs2,ForwardDiff.jacobian(f,x)) out,back = Zygote.pullback(g3,[2.0,3.2]) @test back(1.0)[1] == ForwardDiff.gradient(g3,[2.0,3.2]) - + # From https://github.com/FluxML/Zygote.jl/issues/1218 f1218(x::AbstractVector,y::AbstractVector) = sum(x)*sum(y) gradf1218(x,y) = ForwardDiff.gradient(x->f1218(x,y), x)[1] From f88b34387d4d99aca7fca50e2249f7e75029e0ac Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Mon, 25 Jul 2022 21:37:29 -0600 Subject: [PATCH 3/3] change names --- test/utils.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index 317da6f82..fac4226db 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -3,19 +3,19 @@ using ForwardDiff using Zygote: hessian_dual, hessian_reverse @testset "hessian: $hess" for hess in [hessian_dual, hessian_reverse] - function f(x, bias) - hessian = hess(x->sum(x.^3), x) - return hessian * x .+ bias + function f1(x, bias) + h = hess(x -> sum(x.^3), x) + return h * x .+ bias end if hess == hessian_dual @test hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0] @test hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0] # original docstring version - @test gradient(b->sum(f(rand(3),b)),rand(3))[1] ≈ [1, 1, 1] + @test gradient(b->sum(f1(rand(3),b)),rand(3))[1] ≈ [1, 1, 1] else @test_broken hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0] # can't differentiate ∇getindex @test_broken hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0] - @test_broken gradient(b->sum(f(rand(3),b)),rand(3))[1] ≈ [1, 1, 1] # jacobian is not differentiable + @test_broken gradient(b->sum(f1(rand(3),b)),rand(3))[1] ≈ [1, 1, 1] # jacobian is not differentiable end @test hess(x -> sum(x.^3), [1 2; 3 4]) ≈ Diagonal([6, 18, 12, 24]) @test hess(sin, pi/2) ≈ -1