From 113c0976808d1e790a881f3572f123dd2724ec54 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 30 Apr 2023 11:00:15 -0400 Subject: [PATCH 1/3] allow multiple returns in withgradient --- src/compiler/interface.jl | 30 +++++++++++++++++++++++++++++- test/features.jl | 18 ++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index cce7c4d6d..7e72ce4b5 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -119,7 +119,27 @@ julia> ∇ == gradient(/, 1, 2) # explicit mode true julia> w = [3.0]; +``` + +If `f` returns a Tuple or NamedTuple, then it calculates +`gradient(first∘f, args...)` but returns the whole `f(args...)`: + +```jldoctest; setup=:(using Zygote) +julia> withgradient([1,2,4]) do x + z = 1 ./ x + sum(z), z + end +(val = (1.75, [1.0, 0.5, 0.25]), grad = ([-1.0, -0.25, -0.0625],)) + +julia> withgradient(3.0, 4.0) do x, y + (div = x/y, mul = x*y) + end +(val = (div = 0.75, mul = 12.0), grad = (0.25, -0.1875)) +``` + +Also supports implicit mode: +```jldoctest; setup=:(using Zygote) julia> res = withgradient(() -> sum(abs2, w), Params([w])) # implicit mode (val = 9.0, grad = Grads(...)) @@ -130,7 +150,15 @@ julia> res.grad[w] """ function withgradient(f, args...) y, back = pullback(f, args...) - grad = back(sensitivity(y)) + grad = if y isa Tuple + dy = (sensitivity(first(y)), map(_ -> nothing, Base.tail(y))...) + back(dy) + elseif y isa NamedTuple + dy = (sensitivity(first(y)), map(_ -> nothing, Base.tail(y))...) + back(NamedTuple{propertynames(y), typeof(dy)}(dy)) + else + back(sensitivity(y)) + end results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad) (val=y, grad=results) end diff --git a/test/features.jl b/test/features.jl index 0499987d8..908ae5815 100644 --- a/test/features.jl +++ b/test/features.jl @@ -866,3 +866,21 @@ end end @test gradient(f760, 3)[1] ≈ 123.93054835019153 end + +@testset "withgradient" begin + @test withgradient([1,2,4]) do x + z = 1 ./ x + sum(z), z + end == (val = (1.75, [1.0, 0.5, 0.25]), grad = ([-1.0, -0.25, -0.0625],)) + + @test withgradient(3.0, 4.0) do x, y + (div = x/y, mul = x*y) + end == (val = (div = 0.75, mul = 12.0), grad = (0.25, -0.1875)) + + f3(x) = sum(sin, x), sum(cos, x), sum(tan, x) + g1 = gradient(first∘f3, [1,2,3.0]) + y2, g2 = withgradient(first∘f3, [1,2,3.0]) + y3, g3 = withgradient(f3, [1,2,3.0]) + @test g1[1] ≈ g2[1] ≈ g3[1] +end + From c7ed3fdbb4b870cd622dcfd2fe75c34208ff22b3 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 30 Apr 2023 12:13:50 -0400 Subject: [PATCH 2/3] fix doctest --- src/compiler/interface.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 7e72ce4b5..43e47c6c3 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -115,10 +115,8 @@ as a named tuple. julia> y, ∇ = withgradient(/, 1, 2) (val = 0.5, grad = (0.5, -0.25)) -julia> ∇ == gradient(/, 1, 2) # explicit mode +julia> ∇ == gradient(/, 1, 2) true - -julia> w = [3.0]; ``` If `f` returns a Tuple or NamedTuple, then it calculates @@ -140,7 +138,9 @@ julia> withgradient(3.0, 4.0) do x, y Also supports implicit mode: ```jldoctest; setup=:(using Zygote) -julia> res = withgradient(() -> sum(abs2, w), Params([w])) # implicit mode +julia> w = [3.0]; + +julia> res = withgradient(() -> sum(abs2, w), Params([w])) (val = 9.0, grad = Grads(...)) julia> res.grad[w] From e0d3d8b1a785ec291f0a41da3f12cad51d80eb6b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 19 Jun 2023 18:31:14 -0400 Subject: [PATCH 3/3] better words --- src/compiler/interface.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 43e47c6c3..c09d6db31 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -119,13 +119,15 @@ julia> ∇ == gradient(/, 1, 2) true ``` -If `f` returns a Tuple or NamedTuple, then it calculates -`gradient(first∘f, args...)` but returns the whole `f(args...)`: +Allows you to capture auxillary outputs, in addition to the scalar +used by `gradient`. To do this, `f` must return a Tuple or NamedTuple. +Then it calculates `grad = gradient(first∘f, args...) +but returns the whole `val = f(args...)`: ```jldoctest; setup=:(using Zygote) julia> withgradient([1,2,4]) do x z = 1 ./ x - sum(z), z + sum(z), z # here z is an auxillary output end (val = (1.75, [1.0, 0.5, 0.25]), grad = ([-1.0, -0.25, -0.0625],))