Skip to content

Commit

Permalink
follow @functor -> @layer changes in Flux (#62)
Browse files Browse the repository at this point in the history
* adapt_structure

* adapt compat

* fixes

* doc fixes

* cleanup
  • Loading branch information
CarloLucibello authored Mar 17, 2024
1 parent 232e84c commit e62cb63
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Adapt = "3"
Adapt = "3, 4"
BSON = "0.3.6"
ChainRulesCore = "1"
Crayons = "4"
Expand Down
1 change: 1 addition & 0 deletions docs/src/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ Tsunami.fit!
Tsunami.FitState
Tsunami.test
Tsunami.validate
Tsunami.Foil
```
4 changes: 2 additions & 2 deletions src/Tsunami.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ include("utils.jl")
include("stats.jl")
# export Stats

include("show.jl")

include("fluxmodule.jl")
export FluxModule
# train_step,
Expand All @@ -40,6 +38,8 @@ export FluxModule
# predict_step,
# configure_optimizers

include("show.jl")

include("hooks.jl")
# export on_before_update,
# on_before_backprop,
Expand Down
12 changes: 9 additions & 3 deletions src/fluxmodule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
An abstract type for Flux models.
A `FluxModule` helps orgainising you code and provides a standard interface for training.
A `FluxModule` comes with `functor` already implemented.
A `FluxModule` comes with the functionality provided by `Flux.@layer`
(cpu/gpu movement, parameter management, etc.) and the ability to interact with
[`Trainer`](@ref) and `Optimisers.jl`.
You can change the trainables by implementing `Optimisers.trainables`.
Types inheriting from `FluxModule` have to be mutable. They also
Expand Down Expand Up @@ -62,13 +65,16 @@ model, fit_state = Tsunami.fit(model, trainer, train_dataloader)
"""
abstract type FluxModule end

function Functors.functor(::Type{<:FluxModule}, m::T) where T

function Functors.functor(::Type{T}, m) where {T<:FluxModule}
childr = (; (f => getfield(m, f) for f in fieldnames(T))...)
Tstripped = Base.typename(T).wrapper # remove all parameters. From https://discourse.julialang.org/t/stripping-parameter-from-parametric-types/8293/16
re = x -> Tstripped(x...)
re = Base.splat(Tstripped)
return childr, re
end

Adapt.adapt_structure(to, m::FluxModule) = Functors.fmap(Adapt.adapt(to), m)

Base.show(io::IO, mime::MIME"text/plain", m::FluxModule) = fluxshow(io, mime, m)
Base.show(io::IO, m::FluxModule) = shortshow(io, m)

Expand Down
7 changes: 6 additions & 1 deletion src/show.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
function fluxshow(io::IO, m::MIME"text/plain", x::T) where T
# Show methods that Flux defines through `@layer`
# https://github.com/FluxML/Flux.jl/blob/master/src/layers/show.jl#L4
function fluxshow(io::IO, m::MIME"text/plain", x)
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL
Flux._big_show(io, x)
elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix
Expand All @@ -7,6 +9,9 @@ function fluxshow(io::IO, m::MIME"text/plain", x::T) where T
show(io, x)
end
end
# Don't show Chain(Tuple(...)), always splat that. And ignore Recur's non-trainable state:
Flux._show_children(x::FluxModule) = Flux._flat_children(trainable(x))


function shortshow(io::IO, x::T) where T
str = string(T.name.name)
Expand Down

0 comments on commit e62cb63

Please sign in to comment.