diff --git a/src/lib/base.jl b/src/lib/base.jl index 1a85cc56c..8401dc5a7 100644 --- a/src/lib/base.jl +++ b/src/lib/base.jl @@ -47,6 +47,45 @@ end end end +# This rule behaves much like the getindex adjoint, +# just with an (internal) ordinal index instead of a key. +function _pullback(cx::AContext, ::typeof(iterate), d::Dict, i) + iter = iterate(d, i) + function dict_iterate_pullback(Δ) + (iter === nothing || Δ === nothing) && return + k, v = iter[1] + _, dv = Δ[1] + accum_param(cx, v, dv) === nothing && return + grad = grad_mut(cx, d) + grad[k] = accum(get(grad, k, nothing), dv) + return (nothing, grad, nothing) + end + return iter, dict_iterate_pullback +end + +# ...while this one is to avoid duplicating code or differentiating skip_deleted. +# The alternative would be to write a rule for the private _iterate(::Dict, i). +function _pullback(cx::AContext, ::typeof(iterate), d::Dict) + # Calculation of i is the same used in iterate(::Dict) + return _pullback(cx, iterate, d, Base.skip_deleted(d, d.idxfloor)) +end + +function _pullback(cx::AContext, ::typeof(iterate), vi::Base.ValueIterator{<:Dict}, i::Int) + iter = iterate(vi, i) + function values_iterate_pullback(Δ) + (iter === nothing || Δ === nothing) && return + v, dv = iter[1], Δ[1] + accum_param(cx, v, dv) === nothing && return + # Same as vi.dict.keys[i], but without reaching into Dict internals. + # Iterating the dict instead of keys() is to hit the rules above in nested AD. + k = iterate(vi.dict, i)[1][1] + grad = grad_mut(cx, vi.dict) + grad[k] = accum(get(grad, k, nothing), dv) + return (nothing, (; dict = grad), nothing) + end + return iter, values_iterate_pullback +end + # Channels grad_mut(ch::Channel) = Channel(ch.sz_max) diff --git a/test/lib/base.jl b/test/lib/base.jl index 5186483da..74f129f6d 100644 --- a/test/lib/base.jl +++ b/test/lib/base.jl @@ -10,4 +10,36 @@ @test result1 == result2 end + + @testset "Dict iteration" begin + # https://github.com/FluxML/Zygote.jl/issues/1065 + function sumkv(d) + s = zero(d["c"]) + for (k, v) in d + s += v + k == :b && (s += v) + end + return sum(s) + end + + function sumvals(d) + s = zero(d["c"]) + for v in values(d) + s += v + end + return sum(s) + end + + d_num = Dict(:a => 3, :b => 4, "c" => 5) + d_arr = Dict(:a => [3], :b => [4], "c" => [5]) + ps = d_arr |> values |> collect |> Params + + @test gradient(sumkv, d_num)[1] == Dict(:a => 1, :b => 2, "c" => 1) + grads = gradient(() -> sumkv(d_arr), ps) + @test (grads[d_arr[:a]], grads[d_arr[:b]], grads[d_arr["c"]]) == ([1], [2], [1]) + + @test gradient(sumvals, d_num)[1] == Dict(:a => 1, :b => 1, "c" => 1) + grads = gradient(() -> sumvals(d_arr), ps) + @test (grads[d_arr[:a]], grads[d_arr[:b]], grads[d_arr["c"]]) == ([1], [1], [1]) + end end