From 60b4bea083094edcf8d6b8b1aec3d2494c80f678 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Sat, 26 Nov 2022 12:48:42 +0100 Subject: [PATCH 1/6] Add a feature registry for models --- src/Registries/Registries.jl | 4 +- src/Registries/models.jl | 246 +++++++++++++++++++++++++++++++++++ src/datablock/block.jl | 36 +++++ 3 files changed, 285 insertions(+), 1 deletion(-) diff --git a/src/Registries/Registries.jl b/src/Registries/Registries.jl index 49d38ff7c4..ac20e93e5d 100644 --- a/src/Registries/Registries.jl +++ b/src/Registries/Registries.jl @@ -1,6 +1,6 @@ module Registries -using ..FastAI +using ..FastAI: FastAI, BlockLike, Label, LabelMulti, issubblock using ..FastAI.Datasets using ..FastAI.Datasets: DatasetLoader, DataDepLoader, isavailable, loaddata, typify @@ -48,10 +48,12 @@ end include("datasets.jl") include("tasks.jl") include("recipes.jl") +include("models.jl") export datasets, learningtasks, datarecipes, + models, find, info, load diff --git a/src/Registries/models.jl b/src/Registries/models.jl index 8b13789179..0b0f549a9f 100644 --- a/src/Registries/models.jl +++ b/src/Registries/models.jl @@ -1 +1,247 @@ + +const _MODELS_DESCRIPTION = """ +A `FeatureRegistry` for models. Allows you to find and load models for various learning +tasks using a unified interface. Call `models()` to see a table view of available models: + +```julia +using FastAI +models() +``` + +Which models are available depends on the loaded packages. For example, FastVision.jl adds +vision models from Metalhead to the registry. Index the registry with a model ID to get more +information about that model: + +```julia +using FastAI: models +using FastVision # loading the package extends the list of available models + +models()["metalhead/resnet18"] +``` + +If you've selected a model, call `load` to then instantiate a model: + +```julia +model = load("metalhead/resnet18") +``` + +By default, `load` loads a default version of the model without any pretrained weights. + +`load(model)` also accepts keyword arguments that allow you to specify variants of the model and +weight checkpoints that should be loaded. + +Loading a checkpoint of pretrained weights: + +- `load(entry; pretrained = true)`: Use any pretrained weights, if they are available. +- `load(entry; checkpoint = "checkpoint-name")`: Use the weights with given name. See + `entry.checkpoints` for available checkpoints (if any). +- `load(entry; pretrained = false)`: Don't use pretrained weights + +Loading a model variant for a specific task: + +- `load(entry; input = ImageTensor, output = OneHotLabel)`: Load a model variant matching + an input and output block. +- `load(entry; variant = "backbone"): Load a model variant by name. See `entry.variants` for + available variants. +""" + + +""" + struct ModelVariant(; transform, input, output) + +A `ModelVariant` is a model transformation that changes a model so that its input and output +are subblocks (see [`issubblock`](#)) of `blocks = (inblock, outblock)`. + +""" +struct ModelVariant + transformfn::Any # callable + xblock::BlockLike + yblock::BlockLike +end +_default_transform(model, xblock, yblock; kwargs...) = model +ModelVariant(; transform = _default_transform, input = Any, output = Any) = + ModelVariant(transform, input, output) + + +# Registry definition + +function _modelregistry(; name = "Models", description = _MODELS_DESCRIPTION) + fields = (; + id = Field(String; name = "ID", formatfn = FeatureRegistries.string_format), + description = Field( + String; + name = "Description", + optional = true, + description = "More information about the model", + formatfn = FeatureRegistries.md_format, + ), + backend = Field( + Symbol, + name = "Backend", + default = :flux, + description = "The backend deep learning framework that the model uses. The default is `:flux`.", + ), + variants = Field( + Vector{Pair{String,ModelVariant}}, + name = "Variants", + description = "Model variants suitable for different learning tasks", + defaultfn = (row, key) -> Pair{String, ModelVariant}[], + formatfn = d -> join(collect(keys(d)), ", "), + ), + checkpoints = Field( + Vector{String}; + name = "Checkpoints", + description = "Pretrained weight checkpoints that can be loaded for the model", + formatfn = cs -> join(cs, ", "), + defaultfn = (row, key) -> String[], + ), + loadfn = Field( + Any; + name = "Load function", + description = """ + Function that loads the base version of the model, optionally with weights. + It is called with the name of the selected checkpoint fro `checkpoints`, + i.e. `loadfn(checkpoint)`. If no checkpoint is selected, it is called with + `nothing`, i.e. loadfn(`nothing`). + + Any unknown keyword arguments passed to `load`, i.e. + `load(registry[id]; kwargs...)` will be passed along to `loadfn`. + """, + optional = false, + ) + ) + return Registry(fields; name, loadfn = identity, description = description) +end + +""" + _loadmodel(row) + +Load a model specified by `row` from a model registry. + + +""" +function _loadmodel(row; input=Any, output=Any, variant = nothing, checkpoint = nothing, pretrained = !isnothing(checkpoint), kwargs...) + (; loadfn, checkpoints, variants) = row + + # Finding matching configuration + checkpoint = _findcheckpoint(checkpoints; pretrained, name = checkpoint) + + pretrained && isnothing(checkpoint) && throw(NoCheckpointFoundError(checkpoints, checkpoint)) + variant = _findvariant(variants, variant, input, output) + isnothing(variant) && throw(NoModelVariantFoundError(variants, input, output, variant)) + + # Loading + basemodel = loadfn(checkpoint, kwargs...) + model = variant.transformfn(basemodel, input, output) + + return model +end + +struct NoModelVariantFoundError <: Exception + variants::Vector{Pair{String, ModelVariant}} + input::BlockLike + output::BlockLike + variant::Union{String, Nothing} +end + +struct NoCheckpointFoundError <: Exception + checkpoints::Vector{String} + checkpoint::Union{String, Nothing} +end + + + +const MODELS = _modelregistry() + + +""" + models() + +$_MODELS_DESCRIPTION +""" +models(; kwargs...) = isempty(kwargs) ? MODELS : filter(MODELS; kwargs...) + + + +function _findcheckpoint(checkpoints::AbstractVector; pretrained = false, name = nothing) + if isempty(checkpoints) + nothing + elseif !isnothing(name) + i = findfirst(==(name), checkpoints) + isnothing(i) ? nothing : checkpoints[i] + elseif pretrained + first(values(checkpoints)) + else + nothing + end +end + +function _findvariant(variants::Vector{Pair{String,ModelVariant}}, variantname::Union{String, Nothing}, xblock, yblock) + if !isnothing(variantname) + variants = filter(variants) do (name, _) + name == variantname + end + end + i = findfirst(variants) do (_, variant) + issubblock(variant.xblock, xblock) && issubblock(variant.yblock, yblock) + end + isnothing(i) ? nothing : variants[i][2] +end + + +@testset "Model registry" begin + @testset "Basic" begin + @test_nowarn _modelregistry() + reg = _modelregistry() + push!(reg, (; + id = "test", + loadfn = _ -> 1, + )) + end + + @testset "_loadmodel" begin + reg = _modelregistry() + @test_nowarn push!(reg, (; + id = "test", + loadfn = (checkpoint; kwarg = 1) -> (checkpoint, kwarg), + checkpoints = ["checkpoint", "checkpoint2"], + variants = [ + "base" => ModelVariant(), + "ext" => ModelVariant(((ch, k), i, o; kwargs...) -> (ch, k + 1), Any, Label), + ] + )) + entry = reg["test"] + @test _loadmodel(entry) == (nothing, 1) + @test _loadmodel(entry; pretrained = true) == ("checkpoint", 1) + @test _loadmodel(entry; checkpoint = "checkpoint2") == ("checkpoint2", 1) + @test_throws NoCheckpointFoundError _loadmodel(entry; checkpoint = "checkpoint3") + + @test _loadmodel(entry; output = Label) == (nothing, 2) + @test _loadmodel(entry; variant = "ext") == (nothing, 2) + @test _loadmodel(entry; pretrained = true, output = Label) == ("checkpoint", 2) + @test_throws NoModelVariantFoundError _loadmodel(entry; input = Label) + end + + @testset "_findvariant" begin + vars = ["1" => ModelVariant(identity, Any, Any), "2" => ModelVariant(identity, Any, Label)] + # no restrictions => select first variant + @test _findvariant(vars, nothing, Any, Any) == vars[1][2] + # name => select named variant + @test _findvariant(vars, "2", Any, Any) == vars[2][2] + # name not found => nothing + @test _findvariant(vars, "3", Any, Any) === nothing + # restrict block => select matching + @test _findvariant(vars, nothing, Any, Label) == vars[2][2] + # restrict block not found => nothing + @test _findvariant(vars, nothing, Any, LabelMulti) === nothing + end + + @testset "_findcheckpoint" begin + chs = ["check1", "check2"] + @test _findcheckpoint(chs) === nothing + @test _findcheckpoint(chs, pretrained = true) === "check1" + @test _findcheckpoint(chs, pretrained = true, name = "check2") === "check2" + @test _findcheckpoint(chs, pretrained = true, name = "check3") === nothing + end +end diff --git a/src/datablock/block.jl b/src/datablock/block.jl index 462d3ece10..c2423df49f 100644 --- a/src/datablock/block.jl +++ b/src/datablock/block.jl @@ -131,3 +131,39 @@ and other diagrams. """ blockname(block::Block) = string(nameof(typeof(block))) blockname(blocks::Tuple) = "(" * join(map(blockname, blocks), ", ") * ")" + +const BlockLike = Union{<:AbstractBlock, Type{<:AbstractBlock}, <:Tuple, Type{Any}} + +""" + function issubblock(subblock, superblock) + +Predicate whether `subblock` is a subblock of `superblock`. This means that `subblock` is + +- a subtype of a type `superblock <: Type{AbstractBlock}` +- an instance of a subtype of `superblock <: Type{AbstractBlock}` +- equal to `superblock` + +Both arguments can also be tuples. In that case, each element of the tuple `subblock` is +compared recursively against the elements of the tuple `superblock`. +""" +function issubblock end + +issubblock(_, _) = false +issubblock(sub::BlockLike, super::Type{Any}) = true +issubblock(sub::Tuple, super::Tuple) = + (length(sub) == length(super)) && all(map(issubblock, sub, super)) +issubblock(sub::Type{<:AbstractBlock}, super::Type{<:AbstractBlock}) = sub <: super +issubblock(sub::AbstractBlock, super::Type{<:AbstractBlock}) = issubblock(typeof(sub), super) +issubblock(sub::AbstractBlock, super::AbstractBlock) = sub == super + +@testset "issubblock" begin + @test issubblock(Label, Any) + @test issubblock((Label,), (Any,)) + @test issubblock((Label,), Any) + @test !issubblock(Label, (Any,)) + @test issubblock(Label{String}, Label) + @test !issubblock(Label, Label{String}) + @test issubblock(Label{Int}(1:10), Label{Int}) + @test issubblock(Label{Int}(1:10), Label{Int}(1:10)) + @test !issubblock(Label{Int}, Label{Int}(1:10)) +end From 936cd1c479ebad7881fb4557ba85b30aa1b18c11 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Sat, 26 Nov 2022 12:58:01 +0100 Subject: [PATCH 2/6] Use 1.6 supported syntax --- src/Registries/models.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Registries/models.jl b/src/Registries/models.jl index 0b0f549a9f..49735e38b5 100644 --- a/src/Registries/models.jl +++ b/src/Registries/models.jl @@ -122,7 +122,7 @@ Load a model specified by `row` from a model registry. """ function _loadmodel(row; input=Any, output=Any, variant = nothing, checkpoint = nothing, pretrained = !isnothing(checkpoint), kwargs...) - (; loadfn, checkpoints, variants) = row + loadfn, checkpoints, variants = row.loadfn, row.checkpoints, row.variants # 1.6 support # Finding matching configuration checkpoint = _findcheckpoint(checkpoints; pretrained, name = checkpoint) From 89c8a6133a78423097f07dae928deee96224052f Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Sat, 26 Nov 2022 19:15:07 +0100 Subject: [PATCH 3/6] Fix model variant printing --- src/Registries/models.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Registries/models.jl b/src/Registries/models.jl index 49735e38b5..4f4b896f4d 100644 --- a/src/Registries/models.jl +++ b/src/Registries/models.jl @@ -87,7 +87,7 @@ function _modelregistry(; name = "Models", description = _MODELS_DESCRIPTION) name = "Variants", description = "Model variants suitable for different learning tasks", defaultfn = (row, key) -> Pair{String, ModelVariant}[], - formatfn = d -> join(collect(keys(d)), ", "), + formatfn = d -> join(first.(d), ", "), ), checkpoints = Field( Vector{String}; From 4a0e6e01a631fc014619bdf66304ecc3754bdcf4 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Sun, 27 Nov 2022 11:05:44 +0100 Subject: [PATCH 4/6] Use correct `load` function in model registry. Formats and adds more docs --- src/Registries/models.jl | 161 +++++++++++++++++++++------------------ 1 file changed, 85 insertions(+), 76 deletions(-) diff --git a/src/Registries/models.jl b/src/Registries/models.jl index 4f4b896f4d..71abe69e6f 100644 --- a/src/Registries/models.jl +++ b/src/Registries/models.jl @@ -1,4 +1,8 @@ +# # Model registry +# +# This file defines [`models`](#), a feature registry for models. +# ## Registry definition const _MODELS_DESCRIPTION = """ A `FeatureRegistry` for models. Allows you to find and load models for various learning @@ -46,13 +50,23 @@ Loading a model variant for a specific task: available variants. """ - """ - struct ModelVariant(; transform, input, output) + struct ModelVariant(; transform, xblock, yblock) A `ModelVariant` is a model transformation that changes a model so that its input and output -are subblocks (see [`issubblock`](#)) of `blocks = (inblock, outblock)`. +are subblocks (see [`issubblock`](#)) of `blocks = (xblock, yblock)`. + +The model transformation function `transform` takes a model and two concrete _instances_ +of the variant's compatible blocks, returning a transformed model. + + `transform(model, xblock, yblock)` + +- `model` is the original model that is transformed +- `xblock` is the [`Block`](#) of the data that is input to the model. +- `yblock` is the [`Block`](#) of the data that the model outputs. +If you're working with a [`SupervisedTask`](#) `task`, these blocks correspond to +`inputblock = getblocks(task).x` and `outputblock = getblocks(task).y` """ struct ModelVariant transformfn::Any # callable @@ -60,74 +74,63 @@ struct ModelVariant yblock::BlockLike end _default_transform(model, xblock, yblock; kwargs...) = model -ModelVariant(; transform = _default_transform, input = Any, output = Any) = - ModelVariant(transform, input, output) - - -# Registry definition +function ModelVariant(; transform = _default_transform, xblock = Any, yblock = Any) + ModelVariant(transform, xblock, yblock) +end function _modelregistry(; name = "Models", description = _MODELS_DESCRIPTION) fields = (; - id = Field(String; name = "ID", formatfn = FeatureRegistries.string_format), - description = Field( - String; - name = "Description", - optional = true, - description = "More information about the model", - formatfn = FeatureRegistries.md_format, - ), - backend = Field( - Symbol, - name = "Backend", - default = :flux, - description = "The backend deep learning framework that the model uses. The default is `:flux`.", - ), - variants = Field( - Vector{Pair{String,ModelVariant}}, - name = "Variants", - description = "Model variants suitable for different learning tasks", - defaultfn = (row, key) -> Pair{String, ModelVariant}[], - formatfn = d -> join(first.(d), ", "), - ), - checkpoints = Field( - Vector{String}; - name = "Checkpoints", - description = "Pretrained weight checkpoints that can be loaded for the model", - formatfn = cs -> join(cs, ", "), - defaultfn = (row, key) -> String[], - ), - loadfn = Field( - Any; - name = "Load function", - description = """ - Function that loads the base version of the model, optionally with weights. - It is called with the name of the selected checkpoint fro `checkpoints`, - i.e. `loadfn(checkpoint)`. If no checkpoint is selected, it is called with - `nothing`, i.e. loadfn(`nothing`). - - Any unknown keyword arguments passed to `load`, i.e. - `load(registry[id]; kwargs...)` will be passed along to `loadfn`. - """, - optional = false, - ) - ) - return Registry(fields; name, loadfn = identity, description = description) + id = Field(String; name = "ID", formatfn = FeatureRegistries.string_format), + description = Field(String; + name = "Description", + optional = true, + description = "More information about the model", + formatfn = FeatureRegistries.md_format), + backend = Field(Symbol, + name = "Backend", + default = :flux, + description = "The backend deep learning framework that the model uses. The default is `:flux`."), + variants = Field(Vector{Pair{String, ModelVariant}}, + name = "Variants", + optional = false, + description = "Model variants suitable for different learning tasks. See `?ModelVariant` for more details.", + formatfn = d -> join(first.(d), ", ")), + checkpoints = Field(Vector{String}; + name = "Checkpoints", + description = """ + Pretrained weight checkpoints that can be loaded for the model. Checkpoints are listed as a + `Vector{String}` and `loadfn` should take care of loading the selected checkpoint""", + formatfn = cs -> join(cs, ", "), + defaultfn = (row, key) -> String[]), + loadfn = Field(Any; + name = "Load function", + description = """ + Function that loads the base version of the model, optionally with weights. + It is called with the name of the selected checkpoint fro `checkpoints`, + i.e. `loadfn(checkpoint)`. If no checkpoint is selected, it is called with + `nothing`, i.e. loadfn(`nothing`). + + Any unknown keyword arguments passed to `load`, i.e. + `load(registry[id]; kwargs...)` will be passed along to `loadfn`. + """, + optional = false)) + return Registry(fields; name, loadfn = _loadmodel, description = description) end """ _loadmodel(row) Load a model specified by `row` from a model registry. - - """ -function _loadmodel(row; input=Any, output=Any, variant = nothing, checkpoint = nothing, pretrained = !isnothing(checkpoint), kwargs...) +function _loadmodel(row; input = Any, output = Any, variant = nothing, checkpoint = nothing, + pretrained = !isnothing(checkpoint), kwargs...) loadfn, checkpoints, variants = row.loadfn, row.checkpoints, row.variants # 1.6 support # Finding matching configuration checkpoint = _findcheckpoint(checkpoints; pretrained, name = checkpoint) - pretrained && isnothing(checkpoint) && throw(NoCheckpointFoundError(checkpoints, checkpoint)) + pretrained && isnothing(checkpoint) && + throw(NoCheckpointFoundError(checkpoints, checkpoint)) variant = _findvariant(variants, variant, input, output) isnothing(variant) && throw(NoModelVariantFoundError(variants, input, output, variant)) @@ -138,6 +141,7 @@ function _loadmodel(row; input=Any, output=Any, variant = nothing, checkpoint = return model end +# ### Errors struct NoModelVariantFoundError <: Exception variants::Vector{Pair{String, ModelVariant}} input::BlockLike @@ -150,11 +154,8 @@ struct NoCheckpointFoundError <: Exception checkpoint::Union{String, Nothing} end - - const MODELS = _modelregistry() - """ models() @@ -162,8 +163,6 @@ $_MODELS_DESCRIPTION """ models(; kwargs...) = isempty(kwargs) ? MODELS : filter(MODELS; kwargs...) - - function _findcheckpoint(checkpoints::AbstractVector; pretrained = false, name = nothing) if isempty(checkpoints) nothing @@ -177,7 +176,8 @@ function _findcheckpoint(checkpoints::AbstractVector; pretrained = false, name = end end -function _findvariant(variants::Vector{Pair{String,ModelVariant}}, variantname::Union{String, Nothing}, xblock, yblock) +function _findvariant(variants::Vector{Pair{String, ModelVariant}}, + variantname::Union{String, Nothing}, xblock, yblock) if !isnothing(variantname) variants = filter(variants) do (name, _) name == variantname @@ -189,28 +189,34 @@ function _findvariant(variants::Vector{Pair{String,ModelVariant}}, variantname:: isnothing(i) ? nothing : variants[i][2] end +# ## Tests @testset "Model registry" begin @testset "Basic" begin @test_nowarn _modelregistry() reg = _modelregistry() push!(reg, (; - id = "test", - loadfn = _ -> 1, - )) + id = "test", + loadfn = _ -> 1, + variants = ["base" => ModelVariant()])) + + @test load(reg["test"]) == 1 + @test_throws NoCheckpointFoundError load(reg["test"], pretrained = true) end @testset "_loadmodel" begin reg = _modelregistry() - @test_nowarn push!(reg, (; - id = "test", - loadfn = (checkpoint; kwarg = 1) -> (checkpoint, kwarg), - checkpoints = ["checkpoint", "checkpoint2"], - variants = [ - "base" => ModelVariant(), - "ext" => ModelVariant(((ch, k), i, o; kwargs...) -> (ch, k + 1), Any, Label), - ] - )) + @test_nowarn push!(reg, + (; + id = "test", + loadfn = (checkpoint; kwarg = 1) -> (checkpoint, kwarg), + checkpoints = ["checkpoint", "checkpoint2"], + variants = [ + "base" => ModelVariant(), + "ext" => ModelVariant(((ch, k), i, o; kwargs...) -> (ch, + k + 1), + Any, Label), + ])) entry = reg["test"] @test _loadmodel(entry) == (nothing, 1) @test _loadmodel(entry; pretrained = true) == ("checkpoint", 1) @@ -224,7 +230,10 @@ end end @testset "_findvariant" begin - vars = ["1" => ModelVariant(identity, Any, Any), "2" => ModelVariant(identity, Any, Label)] + vars = [ + "1" => ModelVariant(identity, Any, Any), + "2" => ModelVariant(identity, Any, Label), + ] # no restrictions => select first variant @test _findvariant(vars, nothing, Any, Any) == vars[1][2] # name => select named variant @@ -233,7 +242,7 @@ end @test _findvariant(vars, "3", Any, Any) === nothing # restrict block => select matching @test _findvariant(vars, nothing, Any, Label) == vars[2][2] - # restrict block not found => nothing + # restrict block not found => nothing @test _findvariant(vars, nothing, Any, LabelMulti) === nothing end From 9aef6d268e877d5f32d355fa35af4d4ff4777427 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Sun, 4 Dec 2022 12:12:17 +0100 Subject: [PATCH 5/6] Change `ModelVariant` API Now handles both loading checkpoints and possible transformations. This makes it easier to ntegrate with third-party model libraries that likewise handle both with a single function. --- src/Registries/models.jl | 101 ++++++++++++++++++++------------------- 1 file changed, 53 insertions(+), 48 deletions(-) diff --git a/src/Registries/models.jl b/src/Registries/models.jl index 71abe69e6f..f236f8a06b 100644 --- a/src/Registries/models.jl +++ b/src/Registries/models.jl @@ -50,33 +50,39 @@ Loading a model variant for a specific task: available variants. """ + +# ## `ModelVariant` interface """ - struct ModelVariant(; transform, xblock, yblock) + abstract type ModelVariant + +A `ModelVariant` handles loading a model, optionally with pretrained weights and +transforming it so that it can be used for specific learning tasks. + -A `ModelVariant` is a model transformation that changes a model so that its input and output are subblocks (see [`issubblock`](#)) of `blocks = (xblock, yblock)`. -The model transformation function `transform` takes a model and two concrete _instances_ -of the variant's compatible blocks, returning a transformed model. +## Interface - `transform(model, xblock, yblock)` +- [`compatibleblocks`](#)`(variant)` returns a tuple `(xblock, yblock)` of [`BlockLike`](#) that + are compatible with the model. This means that a variant can be used for a task with + input and output blocks `blocks`, if [`issubblock`](#)`(blocks, compatibleblocks(variant))`. +- [`loadvariant`](#)`(::ModelVariant, xblock, yblock, checkpoint; kwargs...)` loads a model + compatible with block instances `xblock` and `yblock`, with (optionally) weights + from `checkpoint`. +""" +abstract type ModelVariant end -- `model` is the original model that is transformed -- `xblock` is the [`Block`](#) of the data that is input to the model. -- `yblock` is the [`Block`](#) of the data that the model outputs. +""" + compatibleblocks(::ModelVariant) -If you're working with a [`SupervisedTask`](#) `task`, these blocks correspond to -`inputblock = getblocks(task).x` and `outputblock = getblocks(task).y` +Indicate compatible input and output block for a model variant. """ -struct ModelVariant - transformfn::Any # callable - xblock::BlockLike - yblock::BlockLike -end -_default_transform(model, xblock, yblock; kwargs...) = model -function ModelVariant(; transform = _default_transform, xblock = Any, yblock = Any) - ModelVariant(transform, xblock, yblock) -end +function compatibleblocks end + +function loadvariant end + + +# ## Model registry creation function _modelregistry(; name = "Models", description = _MODELS_DESCRIPTION) fields = (; @@ -102,18 +108,7 @@ function _modelregistry(; name = "Models", description = _MODELS_DESCRIPTION) `Vector{String}` and `loadfn` should take care of loading the selected checkpoint""", formatfn = cs -> join(cs, ", "), defaultfn = (row, key) -> String[]), - loadfn = Field(Any; - name = "Load function", - description = """ - Function that loads the base version of the model, optionally with weights. - It is called with the name of the selected checkpoint fro `checkpoints`, - i.e. `loadfn(checkpoint)`. If no checkpoint is selected, it is called with - `nothing`, i.e. loadfn(`nothing`). - - Any unknown keyword arguments passed to `load`, i.e. - `load(registry[id]; kwargs...)` will be passed along to `loadfn`. - """, - optional = false)) + ) return Registry(fields; name, loadfn = _loadmodel, description = description) end @@ -124,7 +119,7 @@ Load a model specified by `row` from a model registry. """ function _loadmodel(row; input = Any, output = Any, variant = nothing, checkpoint = nothing, pretrained = !isnothing(checkpoint), kwargs...) - loadfn, checkpoints, variants = row.loadfn, row.checkpoints, row.variants # 1.6 support + checkpoints, variants = row.checkpoints, row.variants # 1.6 support # Finding matching configuration checkpoint = _findcheckpoint(checkpoints; pretrained, name = checkpoint) @@ -135,25 +130,27 @@ function _loadmodel(row; input = Any, output = Any, variant = nothing, checkpoin isnothing(variant) && throw(NoModelVariantFoundError(variants, input, output, variant)) # Loading - basemodel = loadfn(checkpoint, kwargs...) - model = variant.transformfn(basemodel, input, output) - - return model + return loadvariant(variant, input, output, checkpoint; kwargs...) end # ### Errors + +# TODO: Implement Base.showerror struct NoModelVariantFoundError <: Exception - variants::Vector{Pair{String, ModelVariant}} + variants::Vector input::BlockLike output::BlockLike variant::Union{String, Nothing} end +# TODO: Implement Base.showerror struct NoCheckpointFoundError <: Exception checkpoints::Vector{String} checkpoint::Union{String, Nothing} end +# ## Create the default registry instance + const MODELS = _modelregistry() """ @@ -163,6 +160,8 @@ $_MODELS_DESCRIPTION """ models(; kwargs...) = isempty(kwargs) ? MODELS : filter(MODELS; kwargs...) +# ## Helpers + function _findcheckpoint(checkpoints::AbstractVector; pretrained = false, name = nothing) if isempty(checkpoints) nothing @@ -176,7 +175,7 @@ function _findcheckpoint(checkpoints::AbstractVector; pretrained = false, name = end end -function _findvariant(variants::Vector{Pair{String, ModelVariant}}, +function _findvariant(variants::Vector, variantname::Union{String, Nothing}, xblock, yblock) if !isnothing(variantname) variants = filter(variants) do (name, _) @@ -184,23 +183,31 @@ function _findvariant(variants::Vector{Pair{String, ModelVariant}}, end end i = findfirst(variants) do (_, variant) - issubblock(variant.xblock, xblock) && issubblock(variant.yblock, yblock) + v_xblock, v_yblock = compatibleblocks(variant) + issubblock(v_xblock, xblock) && issubblock(v_yblock, yblock) end isnothing(i) ? nothing : variants[i][2] end # ## Tests +struct MockVariant <: ModelVariant + model + blocks +end + +compatibleblocks(variant::MockVariant) = variant.blocks +loadvariant(variant::MockVariant, x, y, ch) = (ch, variant.model) + @testset "Model registry" begin @testset "Basic" begin @test_nowarn _modelregistry() reg = _modelregistry() push!(reg, (; id = "test", - loadfn = _ -> 1, - variants = ["base" => ModelVariant()])) + variants = ["base" => MockVariant(1, (Any, Any))])) - @test load(reg["test"]) == 1 + @test load(reg["test"]) == (nothing, 1) @test_throws NoCheckpointFoundError load(reg["test"], pretrained = true) end @@ -212,10 +219,8 @@ end loadfn = (checkpoint; kwarg = 1) -> (checkpoint, kwarg), checkpoints = ["checkpoint", "checkpoint2"], variants = [ - "base" => ModelVariant(), - "ext" => ModelVariant(((ch, k), i, o; kwargs...) -> (ch, - k + 1), - Any, Label), + "base" => MockVariant(1, (Any, Any)), + "ext" => MockVariant(2, (Any, Label)), ])) entry = reg["test"] @test _loadmodel(entry) == (nothing, 1) @@ -231,8 +236,8 @@ end @testset "_findvariant" begin vars = [ - "1" => ModelVariant(identity, Any, Any), - "2" => ModelVariant(identity, Any, Label), + "1" => MockVariant(1, (Any, Any)), + "2" => MockVariant(1, (Any, Label)), ] # no restrictions => select first variant @test _findvariant(vars, nothing, Any, Any) == vars[1][2] From 85c88c36012eaaa59dab86e46a28ade0a71a42f2 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Fri, 3 Feb 2023 16:12:59 +0100 Subject: [PATCH 6/6] Model registry now has a field :loadfn A `loadfn([checkpoint])` holds the default loading function for a model. As a result, the :variants field no longer has to be populated. --- src/Registries/models.jl | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/src/Registries/models.jl b/src/Registries/models.jl index f236f8a06b..5f284dd030 100644 --- a/src/Registries/models.jl +++ b/src/Registries/models.jl @@ -96,9 +96,14 @@ function _modelregistry(; name = "Models", description = _MODELS_DESCRIPTION) name = "Backend", default = :flux, description = "The backend deep learning framework that the model uses. The default is `:flux`."), + loadfn = Field(Any, + name = "Load function", + optional = false, + description = "A function `loadfn(checkpoint)` that loads a default version of the model, possibly with `checkpoint` weights.", + ), variants = Field(Vector{Pair{String, ModelVariant}}, name = "Variants", - optional = false, + default = Pair{String, ModelVariant}[], description = "Model variants suitable for different learning tasks. See `?ModelVariant` for more details.", formatfn = d -> join(first.(d), ", ")), checkpoints = Field(Vector{String}; @@ -123,14 +128,22 @@ function _loadmodel(row; input = Any, output = Any, variant = nothing, checkpoin # Finding matching configuration checkpoint = _findcheckpoint(checkpoints; pretrained, name = checkpoint) - - pretrained && isnothing(checkpoint) && + if (pretrained && isnothing(checkpoint)) throw(NoCheckpointFoundError(checkpoints, checkpoint)) - variant = _findvariant(variants, variant, input, output) - isnothing(variant) && throw(NoModelVariantFoundError(variants, input, output, variant)) + end - # Loading - return loadvariant(variant, input, output, checkpoint; kwargs...) + # If no variant is asked for, use the base model loading function that only takes + # care of the checkpoint. + if isnothing(variant) && input === Any && output === Any + return row.loadfn(checkpoint) + # If a variant is specified, either by name (through `variant`) or through block + # constraints `input` or `output`, try to find a matching variant. + # care of the checkpoint. + else + variant = _findvariant(variants, variant, input, output) + isnothing(variant) && throw(NoModelVariantFoundError(variants, input, output, variant)) + return loadvariant(variant, input, output, checkpoint; kwargs...) + end end # ### Errors @@ -197,17 +210,15 @@ struct MockVariant <: ModelVariant end compatibleblocks(variant::MockVariant) = variant.blocks -loadvariant(variant::MockVariant, x, y, ch) = (ch, variant.model) +loadvariant(variant::MockVariant, _, _, ch) = (ch, variant.model) @testset "Model registry" begin @testset "Basic" begin @test_nowarn _modelregistry() reg = _modelregistry() - push!(reg, (; - id = "test", - variants = ["base" => MockVariant(1, (Any, Any))])) + push!(reg, (; id = "test", loadfn = (checkpoint,) -> checkpoint)) - @test load(reg["test"]) == (nothing, 1) + @test load(reg["test"]) === nothing @test_throws NoCheckpointFoundError load(reg["test"], pretrained = true) end