-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rework regularized layers #73
Conversation
- Only one struct named `Regularized`, every regularized layer is a particular case of it - Specific constructors for `SparseArgmax`, `SoftArgmax`, and `RegularizedFrankWolfe` - Now we can also use `Regularized` with a custom optimizer (we may need to test this feature)
Codecov ReportPatch coverage:
❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more. Additional details and impacted files@@ Coverage Diff @@
## main #73 +/- ##
==========================================
- Coverage 80.57% 80.33% -0.25%
==========================================
Files 19 20 +1
Lines 345 356 +11
==========================================
+ Hits 278 286 +8
- Misses 67 70 +3
☔ View full report in Codecov by Sentry. |
@@ -32,30 +32,14 @@ Some values you can tune: | |||
|
|||
See the documentation of FrankWolfe.jl for details. | |||
""" | |||
struct RegularizedGeneric{M,RF,RG,FWK} | |||
maximizer::M | |||
struct FrankWolfeOptimizer{M,RF,RG,FWK} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rather call this FrankWolfeConcaveMaximizer
""" | ||
struct Regularized{O,R} | ||
Ω::R | ||
optimizer::O |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rather call this concave_maximizer
to differentiate from (linear_)maximizer
used elsewhere
TODO | ||
""" | ||
function RegularizedFrankWolfe(linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs=NamedTuple()) | ||
# TODO : add a warning if DifferentiableFrankWolfe is not imported ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea
@@ -9,7 +9,7 @@ Relies on the Frank-Wolfe algorithm to minimize a concave objective on a polytop | |||
Since this is a conditional dependency, you need to run `import DifferentiableFrankWolfe` before using `RegularizedGeneric`. | |||
|
|||
# Fields | |||
- `maximizer::M`: linear maximization oracle `θ -> argmax_{x ∈ C} θᵀx`, implicitly defines the polytope `C` | |||
- `linear_maximizer::M`: linear maximization oracle `θ -> argmax_{x ∈ C} θᵀx`, implicitly defines the polytope `C` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should use linear_maximizer
throughout InferOpt?
""" | ||
optimizer: θ ⟼ argmax θᵀy - Ω(y) | ||
""" | ||
struct Regularized{O,R} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we also need the linear maximizer as a field for when the layer is called outside of training?
It would make sense to me to modify the behavior of Perturbed
as well so that the standard forward pass just calls the naked linear maximizer
@@ -10,8 +10,12 @@ function soft_argmax(z::AbstractVector; kwargs...) | |||
return s | |||
end | |||
|
|||
@traitimpl IsRegularized{typeof(soft_argmax)} | |||
# @traitimpl IsRegularized{typeof(soft_argmax)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the trash
@@ -10,11 +10,15 @@ function sparse_argmax(z::AbstractVector; kwargs...) | |||
return p | |||
end | |||
|
|||
@traitimpl IsRegularized{typeof(sparse_argmax)} | |||
# @traitimpl IsRegularized{typeof(sparse_argmax)} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the trash
What do you have in mind? I think we can use a basic QP solver from JuMP or write our own with FISTA |
Regularized
struct, and not under theIsRegularized
anymore (partially adresses Get rid of SimpleTraits? #68). Every regularized layer is now a particular insance ofRegularized
SparseArgmax
,SoftArgmax
, andRegularizedFrankWolfe
Regularized
with a custom optimizer (adresses Other solvers than FW for RegularizedGeneric #62)TODO:
Regularized
with a custom optimizer