Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a feature registry for models #267

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/Registries/Registries.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Registries

using ..FastAI
using ..FastAI: FastAI, BlockLike, Label, LabelMulti, issubblock
using ..FastAI.Datasets
using ..FastAI.Datasets: DatasetLoader, DataDepLoader, isavailable, loaddata, typify

Expand Down Expand Up @@ -48,10 +48,12 @@ end
include("datasets.jl")
include("tasks.jl")
include("recipes.jl")
include("models.jl")

export datasets,
learningtasks,
datarecipes,
models,
find,
info,
load
Expand Down
271 changes: 271 additions & 0 deletions src/Registries/models.jl
Original file line number Diff line number Diff line change
@@ -1 +1,272 @@
# # Model registry
#
# This file defines [`models`](#), a feature registry for models.

# ## Registry definition

const _MODELS_DESCRIPTION = """
A `FeatureRegistry` for models. Allows you to find and load models for various learning
tasks using a unified interface. Call `models()` to see a table view of available models:

```julia
using FastAI
models()
```

Which models are available depends on the loaded packages. For example, FastVision.jl adds
vision models from Metalhead to the registry. Index the registry with a model ID to get more
information about that model:

```julia
using FastAI: models
using FastVision # loading the package extends the list of available models

models()["metalhead/resnet18"]
```

If you've selected a model, call `load` to then instantiate a model:

```julia
model = load("metalhead/resnet18")
```

By default, `load` loads a default version of the model without any pretrained weights.

`load(model)` also accepts keyword arguments that allow you to specify variants of the model and
weight checkpoints that should be loaded.

Loading a checkpoint of pretrained weights:

- `load(entry; pretrained = true)`: Use any pretrained weights, if they are available.
- `load(entry; checkpoint = "checkpoint-name")`: Use the weights with given name. See
`entry.checkpoints` for available checkpoints (if any).
- `load(entry; pretrained = false)`: Don't use pretrained weights

Loading a model variant for a specific task:

- `load(entry; input = ImageTensor, output = OneHotLabel)`: Load a model variant matching
an input and output block.
- `load(entry; variant = "backbone"): Load a model variant by name. See `entry.variants` for
available variants.
"""


# ## `ModelVariant` interface
"""
abstract type ModelVariant

A `ModelVariant` handles loading a model, optionally with pretrained weights and
transforming it so that it can be used for specific learning tasks.


are subblocks (see [`issubblock`](#)) of `blocks = (xblock, yblock)`.

## Interface

- [`compatibleblocks`](#)`(variant)` returns a tuple `(xblock, yblock)` of [`BlockLike`](#) that
are compatible with the model. This means that a variant can be used for a task with
input and output blocks `blocks`, if [`issubblock`](#)`(blocks, compatibleblocks(variant))`.
- [`loadvariant`](#)`(::ModelVariant, xblock, yblock, checkpoint; kwargs...)` loads a model
compatible with block instances `xblock` and `yblock`, with (optionally) weights
from `checkpoint`.
"""
abstract type ModelVariant end

"""
compatibleblocks(::ModelVariant)

Indicate compatible input and output block for a model variant.
"""
function compatibleblocks end

function loadvariant end


# ## Model registry creation

function _modelregistry(; name = "Models", description = _MODELS_DESCRIPTION)
fields = (;
id = Field(String; name = "ID", formatfn = FeatureRegistries.string_format),
description = Field(String;
name = "Description",
optional = true,
description = "More information about the model",
formatfn = FeatureRegistries.md_format),
backend = Field(Symbol,
name = "Backend",
default = :flux,
description = "The backend deep learning framework that the model uses. The default is `:flux`."),
loadfn = Field(Any,
name = "Load function",
optional = false,
description = "A function `loadfn(checkpoint)` that loads a default version of the model, possibly with `checkpoint` weights.",
),
variants = Field(Vector{Pair{String, ModelVariant}},
name = "Variants",
default = Pair{String, ModelVariant}[],
description = "Model variants suitable for different learning tasks. See `?ModelVariant` for more details.",
formatfn = d -> join(first.(d), ", ")),
checkpoints = Field(Vector{String};
name = "Checkpoints",
description = """
Pretrained weight checkpoints that can be loaded for the model. Checkpoints are listed as a
`Vector{String}` and `loadfn` should take care of loading the selected checkpoint""",
formatfn = cs -> join(cs, ", "),
defaultfn = (row, key) -> String[]),
)
return Registry(fields; name, loadfn = _loadmodel, description = description)
end

"""
_loadmodel(row)

Load a model specified by `row` from a model registry.
"""
function _loadmodel(row; input = Any, output = Any, variant = nothing, checkpoint = nothing,
pretrained = !isnothing(checkpoint), kwargs...)
checkpoints, variants = row.checkpoints, row.variants # 1.6 support

# Finding matching configuration
checkpoint = _findcheckpoint(checkpoints; pretrained, name = checkpoint)
if (pretrained && isnothing(checkpoint))
throw(NoCheckpointFoundError(checkpoints, checkpoint))
end

# If no variant is asked for, use the base model loading function that only takes
# care of the checkpoint.
if isnothing(variant) && input === Any && output === Any
return row.loadfn(checkpoint)
# If a variant is specified, either by name (through `variant`) or through block
# constraints `input` or `output`, try to find a matching variant.
# care of the checkpoint.
else
variant = _findvariant(variants, variant, input, output)
isnothing(variant) && throw(NoModelVariantFoundError(variants, input, output, variant))
return loadvariant(variant, input, output, checkpoint; kwargs...)
end
end

# ### Errors

# TODO: Implement Base.showerror
struct NoModelVariantFoundError <: Exception
variants::Vector
input::BlockLike
output::BlockLike
variant::Union{String, Nothing}
end

# TODO: Implement Base.showerror
struct NoCheckpointFoundError <: Exception
checkpoints::Vector{String}
checkpoint::Union{String, Nothing}
end

# ## Create the default registry instance

const MODELS = _modelregistry()

"""
models()

$_MODELS_DESCRIPTION
"""
models(; kwargs...) = isempty(kwargs) ? MODELS : filter(MODELS; kwargs...)

# ## Helpers

function _findcheckpoint(checkpoints::AbstractVector; pretrained = false, name = nothing)
if isempty(checkpoints)
nothing
elseif !isnothing(name)
i = findfirst(==(name), checkpoints)
isnothing(i) ? nothing : checkpoints[i]
elseif pretrained
first(values(checkpoints))
else
nothing
end
end

function _findvariant(variants::Vector,
variantname::Union{String, Nothing}, xblock, yblock)
if !isnothing(variantname)
variants = filter(variants) do (name, _)
name == variantname
end
end
i = findfirst(variants) do (_, variant)
v_xblock, v_yblock = compatibleblocks(variant)
issubblock(v_xblock, xblock) && issubblock(v_yblock, yblock)
end
isnothing(i) ? nothing : variants[i][2]
end

# ## Tests

struct MockVariant <: ModelVariant
model
blocks
end

compatibleblocks(variant::MockVariant) = variant.blocks
loadvariant(variant::MockVariant, _, _, ch) = (ch, variant.model)

@testset "Model registry" begin
@testset "Basic" begin
@test_nowarn _modelregistry()
reg = _modelregistry()
push!(reg, (; id = "test", loadfn = (checkpoint,) -> checkpoint))

@test load(reg["test"]) === nothing
@test_throws NoCheckpointFoundError load(reg["test"], pretrained = true)
end

@testset "_loadmodel" begin
reg = _modelregistry()
@test_nowarn push!(reg,
(;
id = "test",
loadfn = (checkpoint; kwarg = 1) -> (checkpoint, kwarg),
checkpoints = ["checkpoint", "checkpoint2"],
variants = [
"base" => MockVariant(1, (Any, Any)),
"ext" => MockVariant(2, (Any, Label)),
]))
entry = reg["test"]
@test _loadmodel(entry) == (nothing, 1)
@test _loadmodel(entry; pretrained = true) == ("checkpoint", 1)
@test _loadmodel(entry; checkpoint = "checkpoint2") == ("checkpoint2", 1)
@test_throws NoCheckpointFoundError _loadmodel(entry; checkpoint = "checkpoint3")

@test _loadmodel(entry; output = Label) == (nothing, 2)
@test _loadmodel(entry; variant = "ext") == (nothing, 2)
@test _loadmodel(entry; pretrained = true, output = Label) == ("checkpoint", 2)
@test_throws NoModelVariantFoundError _loadmodel(entry; input = Label)
end

@testset "_findvariant" begin
vars = [
"1" => MockVariant(1, (Any, Any)),
"2" => MockVariant(1, (Any, Label)),
]
# no restrictions => select first variant
@test _findvariant(vars, nothing, Any, Any) == vars[1][2]
# name => select named variant
@test _findvariant(vars, "2", Any, Any) == vars[2][2]
# name not found => nothing
@test _findvariant(vars, "3", Any, Any) === nothing
# restrict block => select matching
@test _findvariant(vars, nothing, Any, Label) == vars[2][2]
# restrict block not found => nothing
@test _findvariant(vars, nothing, Any, LabelMulti) === nothing
end

@testset "_findcheckpoint" begin
chs = ["check1", "check2"]
@test _findcheckpoint(chs) === nothing
@test _findcheckpoint(chs, pretrained = true) === "check1"
@test _findcheckpoint(chs, pretrained = true, name = "check2") === "check2"
@test _findcheckpoint(chs, pretrained = true, name = "check3") === nothing
end
end
36 changes: 36 additions & 0 deletions src/datablock/block.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,39 @@ and other diagrams.
"""
blockname(block::Block) = string(nameof(typeof(block)))
blockname(blocks::Tuple) = "(" * join(map(blockname, blocks), ", ") * ")"

const BlockLike = Union{<:AbstractBlock, Type{<:AbstractBlock}, <:Tuple, Type{Any}}

"""
function issubblock(subblock, superblock)

Predicate whether `subblock` is a subblock of `superblock`. This means that `subblock` is

- a subtype of a type `superblock <: Type{AbstractBlock}`
- an instance of a subtype of `superblock <: Type{AbstractBlock}`
- equal to `superblock`

Both arguments can also be tuples. In that case, each element of the tuple `subblock` is
compared recursively against the elements of the tuple `superblock`.
"""
function issubblock end

issubblock(_, _) = false
issubblock(sub::BlockLike, super::Type{Any}) = true
issubblock(sub::Tuple, super::Tuple) =
(length(sub) == length(super)) && all(map(issubblock, sub, super))
issubblock(sub::Type{<:AbstractBlock}, super::Type{<:AbstractBlock}) = sub <: super
issubblock(sub::AbstractBlock, super::Type{<:AbstractBlock}) = issubblock(typeof(sub), super)
issubblock(sub::AbstractBlock, super::AbstractBlock) = sub == super

@testset "issubblock" begin
@test issubblock(Label, Any)
@test issubblock((Label,), (Any,))
@test issubblock((Label,), Any)
@test !issubblock(Label, (Any,))
@test issubblock(Label{String}, Label)
@test !issubblock(Label, Label{String})
@test issubblock(Label{Int}(1:10), Label{Int})
@test issubblock(Label{Int}(1:10), Label{Int}(1:10))
@test !issubblock(Label{Int}, Label{Int}(1:10))
end