-
-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1273 from mzgubic/mz/number_rrules
number adjoints to rrules
- Loading branch information
Showing
3 changed files
with
99 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,69 @@ | ||
@adjoint Base.literal_pow(::typeof(^), x::Number, ::Val{p}) where {p} = | ||
Base.literal_pow(^,x,Val(p)), | ||
Δ -> (nothing, Δ * conj(p * Base.literal_pow(^,x,Val(p-1))), nothing) | ||
function ChainRulesCore.rrule( | ||
::ZygoteRuleConfig, ::typeof(convert), T::Type{<:Real}, x::Real | ||
) | ||
convert_pullback(Δ) = (NoTangent(), NoTangent(), Δ) | ||
return convert(T, x), convert_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule( | ||
::ZygoteRuleConfig, ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{p} | ||
) where {p} | ||
function literal_pow_pullback(Δ) | ||
dx = Δ * conj(p * Base.literal_pow(^,x,Val(p-1))) | ||
return (NoTangent(), NoTangent(), dx, NoTangent()) | ||
end | ||
return Base.literal_pow(^,x,Val(p)), literal_pow_pullback | ||
end | ||
|
||
@adjoint Base.convert(T::Type{<:Real}, x::Real) = convert(T, x), ȳ -> (nothing, ȳ) | ||
@adjoint (T::Type{<:Real})(x::Real) = T(x), ȳ -> (nothing, ȳ) | ||
function ChainRulesCore.rrule(::ZygoteRuleConfig, T::Type{<:Real}, x::Real) | ||
Real_pullback(Δ) = (NoTangent(), Δ) | ||
return T(x), Real_pullback | ||
end | ||
|
||
for T in Base.uniontypes(Core.BuiltinInts) | ||
@adjoint (::Type{T})(x::Core.BuiltinInts) = T(x), Δ -> (Δ,) | ||
@eval function ChainRulesCore.rrule(::ZygoteRuleConfig, ::Type{$T}, x::Core.BuiltinInts) | ||
IntX_pullback(Δ) = (NoTangent(), Δ) | ||
return $T(x), IntX_pullback | ||
end | ||
end | ||
|
||
@adjoint Base.:+(xs::Number...) = +(xs...), Δ -> map(_ -> Δ, xs) | ||
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(+), xs::Number...) | ||
plus_pullback(Δ) = (NoTangent(), map(_ -> Δ, xs)...) | ||
return +(xs...), plus_pullback | ||
end | ||
|
||
@adjoint a // b = (a // b, c̄ -> (c̄ * 1//b, - c̄ * a // b // b)) | ||
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(//), a, b) | ||
divide_pullback(r̄) = (NoTangent(), r̄ * 1//b, - r̄ * a // b // b) | ||
return a // b, divide_pullback | ||
end | ||
|
||
# Complex Numbers | ||
|
||
@adjoint (T::Type{<:Complex})(re, im) = T(re, im), c̄ -> (nothing, real(c̄), imag(c̄)) | ||
function ChainRulesCore.rrule(::ZygoteRuleConfig, T::Type{<:Complex}, r, i) | ||
Complex_pullback(c̄) = (NoTangent(), real(c̄), imag(c̄)) | ||
return T(r, i), Complex_pullback | ||
end | ||
|
||
# we define these here because ChainRules.jl only defines them for x::Union{Real,Complex} | ||
|
||
@adjoint abs2(x::Number) = abs2(x), Δ -> (real(Δ)*(x + x),) | ||
@adjoint real(x::Number) = real(x), r̄ -> (real(r̄),) | ||
@adjoint conj(x::Number) = conj(x), r̄ -> (conj(r̄),) | ||
@adjoint imag(x::Number) = imag(x), ī -> (real(ī)*im,) | ||
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(abs2), x::Number) | ||
abs2_pullback(Δ) = (NoTangent(), real(Δ)*(x + x)) | ||
return abs2(x), abs2_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(real), x::Number) | ||
real_pullback(r̄) = (NoTangent(), real(r̄)) | ||
return real(x), real_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(conj), x::Number) | ||
conj_pullback(c̄) = (NoTangent(), conj(c̄)) | ||
return conj(x), conj_pullback | ||
end | ||
|
||
# for real x, ChainRules pulls back a zero real adjoint, whereas we treat x | ||
# as embedded in the complex numbers and pull back a pure imaginary adjoint | ||
@adjoint imag(x::Real) = zero(x), ī -> (real(ī)*im,) | ||
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(imag), x::Number) | ||
imag_pullback(ī) = (NoTangent(), real(ī)*im) | ||
return imag(x), imag_pullback | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,44 @@ | ||
@testset "nograds" begin | ||
@test gradient(floor, 1) === (0.0,) | ||
@test gradient(ceil, 1) === (0.0,) | ||
@test gradient(round, 1) === (0.0,) | ||
@test gradient(hash, 1) === nothing | ||
@test gradient(div, 1, 2) === nothing | ||
end #testset | ||
@testset "number.jl" begin | ||
@testset "nograds" begin | ||
@test gradient(floor, 1) === (0.0,) | ||
@test gradient(ceil, 1) === (0.0,) | ||
@test gradient(round, 1) === (0.0,) | ||
@test gradient(hash, 1) === nothing | ||
@test gradient(div, 1, 2) === nothing | ||
end | ||
|
||
@testset "basics" begin | ||
@test gradient(Base.literal_pow, ^, 3//2, Val(-5))[2] isa Rational | ||
|
||
@test gradient(convert, Rational, 3.14) == (nothing, 1.0) | ||
@test gradient(convert, Rational, 2.3) == (nothing, 1.0) | ||
@test gradient(convert, UInt64, 2) == (nothing, 1.0) | ||
@test gradient(convert, BigFloat, π) == (nothing, 1.0) | ||
|
||
@test gradient(Rational, 2) == (1//1,) | ||
|
||
@test gradient(Bool, 1) == (1.0,) | ||
@test gradient(Int32, 2) == (1.0,) | ||
@test gradient(UInt16, 2) == (1.0,) | ||
|
||
@test gradient(+, 2.0, 3, 4.0, 5.0) == (1.0, 1.0, 1.0, 1.0) | ||
|
||
@test gradient(//, 3, 2) == (1//2, -3//4) | ||
end | ||
|
||
@testset "Complex numbers" begin | ||
@test gradient(imag, 3.0) == (0.0,) | ||
@test gradient(imag, 3.0 + 3.0im) == (0.0 + 1.0im,) | ||
|
||
@test gradient(conj, 3.0) == (1.0,) | ||
@test gradient(real ∘ conj, 3.0 + 1im) == (1.0 + 0im,) | ||
|
||
@test gradient(real, 3.0) == (1.0,) | ||
@test gradient(real, 3.0 + 1im) == (1.0 + 0im,) | ||
|
||
@test gradient(abs2, 3.0) == (2*3.0,) | ||
@test gradient(abs2, 3.0+2im) == (2*3.0 + 2*2.0im,) | ||
|
||
@test gradient(real ∘ Complex, 3.0, 2.0) == (1.0, 0.0) | ||
end | ||
end |
b9530c7
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
b9530c7
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/65427
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: