-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add EnzymeRules #103
base: master
Are you sure you want to change the base?
Add EnzymeRules #103
Conversation
For some reason, I can't seem to get the extension to work. Package precompilation fails with the error: ERROR: The following 1 direct dependency failed to precompile:
AbstractFFTs [621f4979-c628-5d54-868e-fcf4e3e8185c]
Failed to precompile AbstractFFTs [621f4979-c628-5d54-868e-fcf4e3e8185c] to "/home/runner/.julia/compiled/v1.9/AbstractFFTs/jl_mYHZQL".
ERROR: LoadError: ArgumentError: Package AbstractFFTs does not have LinearAlgebra in its dependencies:
- You may have a partially installed environment. Try `Pkg.instantiate()`
to ensure all packages in the environment are installed.
- Or, if you have AbstractFFTs checked out for development and have
added LinearAlgebra as a dependency but haven't updated your primary
environment's manifest file, try `Pkg.resolve()`.
- Otherwise you may need to report an issue with AbstractFFTs although LinearAlgebra is clearly listed as both a dep and a weak dep. Weirder still, if I activate the project, it now says it's empty, whereas if I remove this extension, it shows the dependencies: julia> using Pkg; Pkg.activate(".")
Activating project at `~/projects/AbstractFFTs.jl`
julia> Pkg.status()
Project AbstractFFTs v1.3.1
Status `~/projects/AbstractFFTs.jl/Project.toml` (empty project)
shell> head ./Project.toml
name = "AbstractFFTs"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "1.3.1"
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @KristofferC I've never had this problem with my extensions before. Do you know what could cause this? |
Nevermind, it seems extensions cannot have weak deps that are also deps. In this case, the dep needs to be loaded within the extension from the main package, see e.g. JuliaStats/LogExpFunctions.jl#63 |
Codecov ReportPatch coverage has no change and project coverage change:
Additional details and impacted files@@ Coverage Diff @@
## master #103 +/- ##
==========================================
- Coverage 87.08% 78.60% -8.48%
==========================================
Files 3 4 +1
Lines 209 229 +20
==========================================
- Hits 182 180 -2
- Misses 27 49 +22
☔ View full report in Codecov by Sentry. |
If #67 is merged, we could add rules for |
y::DuplicatedOrBatchDuplicated{<:StridedArray{T}}, | ||
p::Const{<:AbstractFFTs.Plan{T}}, | ||
x::DuplicatedOrBatchDuplicated{<:StridedArray{T}}, | ||
) where {T} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wish the type T
can be restricted to a finite set, e.g. BLAS number types, otherwise, it may produce incorrect gradients for user defined extensions. Generally speaking, I feel "generic" AD is not a good practise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pushforward of a linear operator is always itself. And so far as I know, every definition of an FFT is a linear operator. So I can see no reasons why this rule should be problematic for forward-mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, I may want to extended FFT with tropical numbers, which is not a real number. It is linear, but does not have an inverse. Then your rule would give me incorrect gradients without throwing an error. I have seen too many incorrect gradients in previous AD frameworks such as Zygote when handling complex numbers.
I agree it is good to have a generic backward routine there, but please constraint the interfaces to concrete types when porting it to an AD engine. It should not be so difficult for users to extend the list of supported types in the future. Defining fft rules on BLAS types would be good enough to cover most using cases. For those non-BLAS types, honestly we can not make any assumption for them. Julia community needs an AD engine with provable correctness, I think it is also one of the goals of Enzyme.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may want to extended FFT with tropical numbers
Is this really an FFT per se? I would consider a DFT generalized to some other ring to be a different transform.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may want to extended FFT with tropical numbers
Is this really an FFT per se? I would consider a DFT generalized to some other ring to be a different transform.
Since Julia does not have a good trait system, I think it is in general impossible to restrict users to input what the functions are designed for. This is what I meant there lacks provable correctness.
It has been a big issue that none of the Julia libraries (except Enzyme) can provide reliable gradients. They claim too much on untested using cases, like complex numbers and tropical numbers. There has been a belief that "it is cool if the code works in cases that it is not expected to work". But no, untested rules are not reliable, they can break on any future change even it works now. Rules must be concrete and tested, they are easy to extend, but hard to debug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By that argument, no AD rules should be defined here anyways, since downstream a user could define a custom Plan that doesn't do any kind of FFT at all. Then even with BLAS number types and strides arrays, any rule we write here would be wrong.
The counterargument is that if a user adds a method of a function whose properties are well-documented, other code should be able to assume and depend on those properties when calling the method for arbitrary inputs.
Taken to its logical conclusion, wouldn't your principle require that rules are never defined for abstract types, and further, that the type of every argument is concrete and known to the rule implementer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wouldn't your principle require that rules are never defined for abstract types, and further, that the type of every argument is concrete and known to the rule implementer?
A big YES. I do not think many people need the backward rules for non-BLAS types. You may want to support e.g. double float that defined in DoubleFloat.jl
. I would argue in these using cases, users can port the generic rule to the AD framework with little effort. The rule can be generic, but when porting it to the AD framework, it should be concrete.
We have to decide between support more data types and ensure the correctness. I really wish there can be a trait system that user can tell the compiler "this element type is a field", then users can use the rule with more confidence. Facts obvious to you, like "fft should work on field rather than other rings" may not be obvious to others.
The counterargument is that if a user adds a method of a function whose properties are well-documented, other code should be able to assume and depend on those properties when calling the method for arbitrary inputs.
To differentiate a long code, I will let the code fly and see where it falls. I will add new rules to the AD engine to keep it flying. It is not a problem for me if a rule does not exist. So when using a new element type, like complex number, symbolic type, finite field algebra or the Tropical number type as mentioned above, I will probably not check whether the property of each function is as documented.
Then even with BLAS number types and strides arrays, any rule we write here would be wrong.
A warning will be thrown when overloading an existing function. Also, pirating is not difficult to avoid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In any case, if we have ChainRules I think we should have the corresponding EnzymeRules.
If users make the questionable choice of overriding fft
to compute an unrelated function, then it is up to them to override the EnzymeRules/ChainRules as well.
I've paused work on this until EnzymeTestUtils (EnzymeAD/Enzyme.jl#782) is registered, which will make testing these rules reliably much more straightforward. |
Coming back to this, I think Enzyme rules should only be defined here abstractly for cases where we know they will not be breaking downstream code that otherwise Enzyme would have handled fine. So I agree with the following restrictions:
These rules are considerably stricter than the ChainRules and for good reason. ChainRules are by convention often defined to cover up indexing code and mutating code to help Zygote and Diffractor, but this comes at the cost of doing the wrong thing for lots of types, hence the Rules for |
Any hope of having this PR revived? Enzyme has come a long way lately and FFT support would be another great step forward. |
I'm afraid I don't have the bandwidth now to revive this. I still think #103 (comment) is the right way forward, and this PR is a good starting point for someone who wants to take it on. |
Will fix #99