diff --git a/src/Registries/Registries.jl b/src/Registries/Registries.jl index 49d38ff7c4..ac20e93e5d 100644 --- a/src/Registries/Registries.jl +++ b/src/Registries/Registries.jl @@ -1,6 +1,6 @@ module Registries -using ..FastAI +using ..FastAI: FastAI, BlockLike, Label, LabelMulti, issubblock using ..FastAI.Datasets using ..FastAI.Datasets: DatasetLoader, DataDepLoader, isavailable, loaddata, typify @@ -48,10 +48,12 @@ end include("datasets.jl") include("tasks.jl") include("recipes.jl") +include("models.jl") export datasets, learningtasks, datarecipes, + models, find, info, load diff --git a/src/Registries/models.jl b/src/Registries/models.jl index 8b13789179..5f284dd030 100644 --- a/src/Registries/models.jl +++ b/src/Registries/models.jl @@ -1 +1,272 @@ +# # Model registry +# +# This file defines [`models`](#), a feature registry for models. +# ## Registry definition + +const _MODELS_DESCRIPTION = """ +A `FeatureRegistry` for models. Allows you to find and load models for various learning +tasks using a unified interface. Call `models()` to see a table view of available models: + +```julia +using FastAI +models() +``` + +Which models are available depends on the loaded packages. For example, FastVision.jl adds +vision models from Metalhead to the registry. Index the registry with a model ID to get more +information about that model: + +```julia +using FastAI: models +using FastVision # loading the package extends the list of available models + +models()["metalhead/resnet18"] +``` + +If you've selected a model, call `load` to then instantiate a model: + +```julia +model = load("metalhead/resnet18") +``` + +By default, `load` loads a default version of the model without any pretrained weights. + +`load(model)` also accepts keyword arguments that allow you to specify variants of the model and +weight checkpoints that should be loaded. + +Loading a checkpoint of pretrained weights: + +- `load(entry; pretrained = true)`: Use any pretrained weights, if they are available. +- `load(entry; checkpoint = "checkpoint-name")`: Use the weights with given name. See + `entry.checkpoints` for available checkpoints (if any). +- `load(entry; pretrained = false)`: Don't use pretrained weights + +Loading a model variant for a specific task: + +- `load(entry; input = ImageTensor, output = OneHotLabel)`: Load a model variant matching + an input and output block. +- `load(entry; variant = "backbone"): Load a model variant by name. See `entry.variants` for + available variants. +""" + + +# ## `ModelVariant` interface +""" + abstract type ModelVariant + +A `ModelVariant` handles loading a model, optionally with pretrained weights and +transforming it so that it can be used for specific learning tasks. + + +are subblocks (see [`issubblock`](#)) of `blocks = (xblock, yblock)`. + +## Interface + +- [`compatibleblocks`](#)`(variant)` returns a tuple `(xblock, yblock)` of [`BlockLike`](#) that + are compatible with the model. This means that a variant can be used for a task with + input and output blocks `blocks`, if [`issubblock`](#)`(blocks, compatibleblocks(variant))`. +- [`loadvariant`](#)`(::ModelVariant, xblock, yblock, checkpoint; kwargs...)` loads a model + compatible with block instances `xblock` and `yblock`, with (optionally) weights + from `checkpoint`. +""" +abstract type ModelVariant end + +""" + compatibleblocks(::ModelVariant) + +Indicate compatible input and output block for a model variant. +""" +function compatibleblocks end + +function loadvariant end + + +# ## Model registry creation + +function _modelregistry(; name = "Models", description = _MODELS_DESCRIPTION) + fields = (; + id = Field(String; name = "ID", formatfn = FeatureRegistries.string_format), + description = Field(String; + name = "Description", + optional = true, + description = "More information about the model", + formatfn = FeatureRegistries.md_format), + backend = Field(Symbol, + name = "Backend", + default = :flux, + description = "The backend deep learning framework that the model uses. The default is `:flux`."), + loadfn = Field(Any, + name = "Load function", + optional = false, + description = "A function `loadfn(checkpoint)` that loads a default version of the model, possibly with `checkpoint` weights.", + ), + variants = Field(Vector{Pair{String, ModelVariant}}, + name = "Variants", + default = Pair{String, ModelVariant}[], + description = "Model variants suitable for different learning tasks. See `?ModelVariant` for more details.", + formatfn = d -> join(first.(d), ", ")), + checkpoints = Field(Vector{String}; + name = "Checkpoints", + description = """ + Pretrained weight checkpoints that can be loaded for the model. Checkpoints are listed as a + `Vector{String}` and `loadfn` should take care of loading the selected checkpoint""", + formatfn = cs -> join(cs, ", "), + defaultfn = (row, key) -> String[]), + ) + return Registry(fields; name, loadfn = _loadmodel, description = description) +end + +""" + _loadmodel(row) + +Load a model specified by `row` from a model registry. +""" +function _loadmodel(row; input = Any, output = Any, variant = nothing, checkpoint = nothing, + pretrained = !isnothing(checkpoint), kwargs...) + checkpoints, variants = row.checkpoints, row.variants # 1.6 support + + # Finding matching configuration + checkpoint = _findcheckpoint(checkpoints; pretrained, name = checkpoint) + if (pretrained && isnothing(checkpoint)) + throw(NoCheckpointFoundError(checkpoints, checkpoint)) + end + + # If no variant is asked for, use the base model loading function that only takes + # care of the checkpoint. + if isnothing(variant) && input === Any && output === Any + return row.loadfn(checkpoint) + # If a variant is specified, either by name (through `variant`) or through block + # constraints `input` or `output`, try to find a matching variant. + # care of the checkpoint. + else + variant = _findvariant(variants, variant, input, output) + isnothing(variant) && throw(NoModelVariantFoundError(variants, input, output, variant)) + return loadvariant(variant, input, output, checkpoint; kwargs...) + end +end + +# ### Errors + +# TODO: Implement Base.showerror +struct NoModelVariantFoundError <: Exception + variants::Vector + input::BlockLike + output::BlockLike + variant::Union{String, Nothing} +end + +# TODO: Implement Base.showerror +struct NoCheckpointFoundError <: Exception + checkpoints::Vector{String} + checkpoint::Union{String, Nothing} +end + +# ## Create the default registry instance + +const MODELS = _modelregistry() + +""" + models() + +$_MODELS_DESCRIPTION +""" +models(; kwargs...) = isempty(kwargs) ? MODELS : filter(MODELS; kwargs...) + +# ## Helpers + +function _findcheckpoint(checkpoints::AbstractVector; pretrained = false, name = nothing) + if isempty(checkpoints) + nothing + elseif !isnothing(name) + i = findfirst(==(name), checkpoints) + isnothing(i) ? nothing : checkpoints[i] + elseif pretrained + first(values(checkpoints)) + else + nothing + end +end + +function _findvariant(variants::Vector, + variantname::Union{String, Nothing}, xblock, yblock) + if !isnothing(variantname) + variants = filter(variants) do (name, _) + name == variantname + end + end + i = findfirst(variants) do (_, variant) + v_xblock, v_yblock = compatibleblocks(variant) + issubblock(v_xblock, xblock) && issubblock(v_yblock, yblock) + end + isnothing(i) ? nothing : variants[i][2] +end + +# ## Tests + +struct MockVariant <: ModelVariant + model + blocks +end + +compatibleblocks(variant::MockVariant) = variant.blocks +loadvariant(variant::MockVariant, _, _, ch) = (ch, variant.model) + +@testset "Model registry" begin + @testset "Basic" begin + @test_nowarn _modelregistry() + reg = _modelregistry() + push!(reg, (; id = "test", loadfn = (checkpoint,) -> checkpoint)) + + @test load(reg["test"]) === nothing + @test_throws NoCheckpointFoundError load(reg["test"], pretrained = true) + end + + @testset "_loadmodel" begin + reg = _modelregistry() + @test_nowarn push!(reg, + (; + id = "test", + loadfn = (checkpoint; kwarg = 1) -> (checkpoint, kwarg), + checkpoints = ["checkpoint", "checkpoint2"], + variants = [ + "base" => MockVariant(1, (Any, Any)), + "ext" => MockVariant(2, (Any, Label)), + ])) + entry = reg["test"] + @test _loadmodel(entry) == (nothing, 1) + @test _loadmodel(entry; pretrained = true) == ("checkpoint", 1) + @test _loadmodel(entry; checkpoint = "checkpoint2") == ("checkpoint2", 1) + @test_throws NoCheckpointFoundError _loadmodel(entry; checkpoint = "checkpoint3") + + @test _loadmodel(entry; output = Label) == (nothing, 2) + @test _loadmodel(entry; variant = "ext") == (nothing, 2) + @test _loadmodel(entry; pretrained = true, output = Label) == ("checkpoint", 2) + @test_throws NoModelVariantFoundError _loadmodel(entry; input = Label) + end + + @testset "_findvariant" begin + vars = [ + "1" => MockVariant(1, (Any, Any)), + "2" => MockVariant(1, (Any, Label)), + ] + # no restrictions => select first variant + @test _findvariant(vars, nothing, Any, Any) == vars[1][2] + # name => select named variant + @test _findvariant(vars, "2", Any, Any) == vars[2][2] + # name not found => nothing + @test _findvariant(vars, "3", Any, Any) === nothing + # restrict block => select matching + @test _findvariant(vars, nothing, Any, Label) == vars[2][2] + # restrict block not found => nothing + @test _findvariant(vars, nothing, Any, LabelMulti) === nothing + end + + @testset "_findcheckpoint" begin + chs = ["check1", "check2"] + @test _findcheckpoint(chs) === nothing + @test _findcheckpoint(chs, pretrained = true) === "check1" + @test _findcheckpoint(chs, pretrained = true, name = "check2") === "check2" + @test _findcheckpoint(chs, pretrained = true, name = "check3") === nothing + end +end diff --git a/src/datablock/block.jl b/src/datablock/block.jl index 462d3ece10..c2423df49f 100644 --- a/src/datablock/block.jl +++ b/src/datablock/block.jl @@ -131,3 +131,39 @@ and other diagrams. """ blockname(block::Block) = string(nameof(typeof(block))) blockname(blocks::Tuple) = "(" * join(map(blockname, blocks), ", ") * ")" + +const BlockLike = Union{<:AbstractBlock, Type{<:AbstractBlock}, <:Tuple, Type{Any}} + +""" + function issubblock(subblock, superblock) + +Predicate whether `subblock` is a subblock of `superblock`. This means that `subblock` is + +- a subtype of a type `superblock <: Type{AbstractBlock}` +- an instance of a subtype of `superblock <: Type{AbstractBlock}` +- equal to `superblock` + +Both arguments can also be tuples. In that case, each element of the tuple `subblock` is +compared recursively against the elements of the tuple `superblock`. +""" +function issubblock end + +issubblock(_, _) = false +issubblock(sub::BlockLike, super::Type{Any}) = true +issubblock(sub::Tuple, super::Tuple) = + (length(sub) == length(super)) && all(map(issubblock, sub, super)) +issubblock(sub::Type{<:AbstractBlock}, super::Type{<:AbstractBlock}) = sub <: super +issubblock(sub::AbstractBlock, super::Type{<:AbstractBlock}) = issubblock(typeof(sub), super) +issubblock(sub::AbstractBlock, super::AbstractBlock) = sub == super + +@testset "issubblock" begin + @test issubblock(Label, Any) + @test issubblock((Label,), (Any,)) + @test issubblock((Label,), Any) + @test !issubblock(Label, (Any,)) + @test issubblock(Label{String}, Label) + @test !issubblock(Label, Label{String}) + @test issubblock(Label{Int}(1:10), Label{Int}) + @test issubblock(Label{Int}(1:10), Label{Int}(1:10)) + @test !issubblock(Label{Int}, Label{Int}(1:10)) +end