Skip to content

Commit

Permalink
Merge pull request #1273 from mzgubic/mz/number_rrules
Browse files Browse the repository at this point in the history
number adjoints to rrules
  • Loading branch information
ToucheSir authored Aug 1, 2022
2 parents c822e9e + cdadaff commit b9530c7
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 22 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.42"
version = "0.6.43"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
68 changes: 54 additions & 14 deletions src/lib/number.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,69 @@
@adjoint Base.literal_pow(::typeof(^), x::Number, ::Val{p}) where {p} =
Base.literal_pow(^,x,Val(p)),
Δ -> (nothing, Δ * conj(p * Base.literal_pow(^,x,Val(p-1))), nothing)
function ChainRulesCore.rrule(
::ZygoteRuleConfig, ::typeof(convert), T::Type{<:Real}, x::Real
)
convert_pullback(Δ) = (NoTangent(), NoTangent(), Δ)
return convert(T, x), convert_pullback
end

function ChainRulesCore.rrule(
::ZygoteRuleConfig, ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{p}
) where {p}
function literal_pow_pullback(Δ)
dx = Δ * conj(p * Base.literal_pow(^,x,Val(p-1)))
return (NoTangent(), NoTangent(), dx, NoTangent())
end
return Base.literal_pow(^,x,Val(p)), literal_pow_pullback
end

@adjoint Base.convert(T::Type{<:Real}, x::Real) = convert(T, x), ȳ -> (nothing, ȳ)
@adjoint (T::Type{<:Real})(x::Real) = T(x), ȳ -> (nothing, ȳ)
function ChainRulesCore.rrule(::ZygoteRuleConfig, T::Type{<:Real}, x::Real)
Real_pullback(Δ) = (NoTangent(), Δ)
return T(x), Real_pullback
end

for T in Base.uniontypes(Core.BuiltinInts)
@adjoint (::Type{T})(x::Core.BuiltinInts) = T(x), Δ -> (Δ,)
@eval function ChainRulesCore.rrule(::ZygoteRuleConfig, ::Type{$T}, x::Core.BuiltinInts)
IntX_pullback(Δ) = (NoTangent(), Δ)
return $T(x), IntX_pullback
end
end

@adjoint Base.:+(xs::Number...) = +(xs...), Δ -> map(_ -> Δ, xs)
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(+), xs::Number...)
plus_pullback(Δ) = (NoTangent(), map(_ -> Δ, xs)...)
return +(xs...), plus_pullback
end

@adjoint a // b = (a // b, c̄ -> (c̄ * 1//b, -* a // b // b))
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(//), a, b)
divide_pullback(r̄) = (NoTangent(), r̄ * 1//b, -* a // b // b)
return a // b, divide_pullback
end

# Complex Numbers

@adjoint (T::Type{<:Complex})(re, im) = T(re, im), c̄ -> (nothing, real(c̄), imag(c̄))
function ChainRulesCore.rrule(::ZygoteRuleConfig, T::Type{<:Complex}, r, i)
Complex_pullback(c̄) = (NoTangent(), real(c̄), imag(c̄))
return T(r, i), Complex_pullback
end

# we define these here because ChainRules.jl only defines them for x::Union{Real,Complex}

@adjoint abs2(x::Number) = abs2(x), Δ -> (real(Δ)*(x + x),)
@adjoint real(x::Number) = real(x), r̄ -> (real(r̄),)
@adjoint conj(x::Number) = conj(x), r̄ -> (conj(r̄),)
@adjoint imag(x::Number) = imag(x), ī -> (real(ī)*im,)
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(abs2), x::Number)
abs2_pullback(Δ) = (NoTangent(), real(Δ)*(x + x))
return abs2(x), abs2_pullback
end

function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(real), x::Number)
real_pullback(r̄) = (NoTangent(), real(r̄))
return real(x), real_pullback
end

function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(conj), x::Number)
conj_pullback(c̄) = (NoTangent(), conj(c̄))
return conj(x), conj_pullback
end

# for real x, ChainRules pulls back a zero real adjoint, whereas we treat x
# as embedded in the complex numbers and pull back a pure imaginary adjoint
@adjoint imag(x::Real) = zero(x), ī -> (real(ī)*im,)
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(imag), x::Number)
imag_pullback(ī) = (NoTangent(), real(ī)*im)
return imag(x), imag_pullback
end
51 changes: 44 additions & 7 deletions test/lib/number.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,44 @@
@testset "nograds" begin
@test gradient(floor, 1) === (0.0,)
@test gradient(ceil, 1) === (0.0,)
@test gradient(round, 1) === (0.0,)
@test gradient(hash, 1) === nothing
@test gradient(div, 1, 2) === nothing
end #testset
@testset "number.jl" begin
@testset "nograds" begin
@test gradient(floor, 1) === (0.0,)
@test gradient(ceil, 1) === (0.0,)
@test gradient(round, 1) === (0.0,)
@test gradient(hash, 1) === nothing
@test gradient(div, 1, 2) === nothing
end

@testset "basics" begin
@test gradient(Base.literal_pow, ^, 3//2, Val(-5))[2] isa Rational

@test gradient(convert, Rational, 3.14) == (nothing, 1.0)
@test gradient(convert, Rational, 2.3) == (nothing, 1.0)
@test gradient(convert, UInt64, 2) == (nothing, 1.0)
@test gradient(convert, BigFloat, π) == (nothing, 1.0)

@test gradient(Rational, 2) == (1//1,)

@test gradient(Bool, 1) == (1.0,)
@test gradient(Int32, 2) == (1.0,)
@test gradient(UInt16, 2) == (1.0,)

@test gradient(+, 2.0, 3, 4.0, 5.0) == (1.0, 1.0, 1.0, 1.0)

@test gradient(//, 3, 2) == (1//2, -3//4)
end

@testset "Complex numbers" begin
@test gradient(imag, 3.0) == (0.0,)
@test gradient(imag, 3.0 + 3.0im) == (0.0 + 1.0im,)

@test gradient(conj, 3.0) == (1.0,)
@test gradient(real conj, 3.0 + 1im) == (1.0 + 0im,)

@test gradient(real, 3.0) == (1.0,)
@test gradient(real, 3.0 + 1im) == (1.0 + 0im,)

@test gradient(abs2, 3.0) == (2*3.0,)
@test gradient(abs2, 3.0+2im) == (2*3.0 + 2*2.0im,)

@test gradient(real Complex, 3.0, 2.0) == (1.0, 0.0)
end
end

2 comments on commit b9530c7

@ToucheSir
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/65427

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.43 -m "<description of version>" b9530c73ac4765e57a6df5f38c3b81a4a63fae95
git push origin v0.6.43

Please sign in to comment.