Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

How does one return an adapted model without using the context manager? #119

Open
brando90 opened this issue Nov 2, 2021 · 8 comments
Open

Comments

@brando90
Copy link

brando90 commented Nov 2, 2021

I want to adapt my model and pass the adapted model around in my code. How do I do this?

my guess the best way is to not use the context manager for inner loops but somehow still use the adapted model.

@brando90
Copy link
Author

brando90 commented Nov 3, 2021

in particular, I'd also like to do this during testing. So the k memory footprint really is bad for me. So I just want to get the weights after the adaptation and throw everything away. With no memory footprint at all.

@brando90
Copy link
Author

brando90 commented Nov 3, 2021

relevant: https://github.com/facebookresearch/higher/blob/main/higher/__init__.py

@_contextmanager
def innerloop_ctx(
    model: _torch.nn.Module,
    opt: _torch.optim.Optimizer,
    device: _typing.Optional[_torch.device] = None,
    copy_initial_weights: bool = True,
    override: optim._OverrideType = None,
    track_higher_grads: bool = True
):
    r"""
...
    """
    fmodel = monkeypatch(
        model, 
        device, 
        copy_initial_weights=copy_initial_weights,
        track_higher_grads=track_higher_grads
    )
    diffopt = optim.get_diff_optim(
        opt,
        model.parameters(),
        fmodel=fmodel,
        device=device,
        override=override,
        track_higher_grads=track_higher_grads
    )
    yield fmodel, diffopt


__all__: list = ["innerloop_ctx"]

@brando90
Copy link
Author

brando90 commented Nov 3, 2021

and

        T, L = 1, len(self.layer_names)
        spt_x_t, spt_y_t, qry_x_t, qry_y_t = task
        self.base_model.eval()
        with higher.innerloop_ctx(self.base_model, self.inner_opt, copy_initial_weights=self.args.copy_initial_weights,
                                  track_higher_grads=self.args.track_higher_grads) as (fmodel, diffopt):
            diffopt.fo = self.fo
            for i_inner in range(self.args.nb_inner_train_steps):
                # base/child model forward pass
                spt_logits_t = fmodel(spt_x_t)
                inner_loss = self.args.criterion(spt_logits_t, spt_y_t)
                # inner-opt update
                diffopt.step(inner_loss)

            x = qry_x_t
            if torch.cuda.is_available():
                x = x.cuda()

@brando90
Copy link
Author

brando90 commented Nov 3, 2021

something like this should work:

def get_diff_optimizer_and_functional_model(model: nn.Module,
                                            opt: optim.Optimizer,
                                            copy_initial_weights: bool,
                                            track_higher_grads: bool,
                                            override: Optional = None) \
        -> tuple[FuncModel, DifferentiableOptimizer]:
    """
    Creates a functional model (for higher) and differentiable optimizer (for higher).
    Replaces higher's context manager to return a differentiable optimizer and functional model:
            with higher.innerloop_ctx(base_model, inner_opt, copy_initial_weights=args.copy_initial_weights,
                                       track_higher_grads=args.track_higher_grads) as (fmodel, diffopt):

    ref:
        - https://github.com/facebookresearch/higher/blob/main/higher/__init__.py
        - https://stackoverflow.com/questions/60311183/what-does-the-copy-initial-weights-documentation-mean-in-the-higher-library-for
        - https://github.com/facebookresearch/higher/issues/119

    :param model:
    :param opt:
    :param copy_initial_weights: DONT PUT TRUE. details: set to True only if you do NOT want to train base model's
        initialization https://stackoverflow.com/questions/60311183/what-does-the-copy-initial-weights-documentation-mean-in-the-higher-library-for
    :param track_higher_grads: set to false during meta-testing but code sets it automatically only for meta-test
    :param override:
    :return:
    """
    from higher import monkeypatch
    from higher.patch import _MonkeyPatchBase
    from higher import optim
    from higher.optim import DifferentiableOptimizer
    # - Create a monkey-patched stateless version of a module.
    fmodel: _MonkeyPatchBase = monkeypatch(
        model,
        device,
        copy_initial_weights=copy_initial_weights,
        track_higher_grads=track_higher_grads
    )
    # - Construct/initialize a differentiable version of an existing optimizer.
    diffopt: DifferentiableOptimizer = optim.get_diff_optim(
        opt,
        model.parameters(),
        fmodel=fmodel,
        device=device,
        override=override,
        track_higher_grads=track_higher_grads
    )
    return fmodel, diffopt

def get_maml_adapted_model_with_higher(args: Namespace,
                                       base_model: nn.Module,
                                       inner_opt: optim.Optimizer,
                                       task: list[Tensor],
                                       training: bool) -> FuncModel:
    """
    Return an adaptated model using MAML using pytorch's higher lib.

    Decision of .eval() and .train():
        - when training we are doing base_model.trian() because that is what the official omniglot maml higher code is
        doing. Probably that is fine since during training even if the moel collects BN stats from different tasks, it's
        not a big deal (since it can only improve or worsen the performance but at least it does not "cheat" when reporting
        meta-test accuracy results).
        - whe meta-testing we always do .eval() to avoid task info jumping illegally from one place to another. When it
        solves a task (via spt, qry set) it only uses the BN stats from training (if it has them) or the current batch
        statistics (due to mdl.eval()).

    ref:
        - official higher maml omniglot: https://github.com/facebookresearch/higher/blob/main/examples/maml-omniglot.py
    """
    spt_x_t, spt_y_t, qry_x_t, qry_y_t = task
    # - get fmodel and diffopt ready for inner adaptation
    base_model.train() if training else base_model.eval()
    # self.base_model.train() if self.args.split == 'train' else self.base_model.eval()
    fmodel, diffopt = get_diff_optimizer_and_functional_model(base_model,
                                                              inner_opt,
                                                              copy_initial_weights=args.copy_initial_weights,
                                                              track_higher_grads=args.track_higher_grads)
    # - do inner addptation using task/support set
    diffopt.fo = args.fo
    for i_inner in range(args.nb_inner_train_steps):
        # base model forward pass
        spt_logits_t = fmodel(spt_x_t)
        inner_loss = args.criterion(spt_logits_t, spt_y_t)
        # inner-opt update
        diffopt.step(inner_loss)
    return fmodel

just returning the functional model and not using the context manager.

@brando90
Copy link
Author

brando90 commented Nov 3, 2021

if something subtle has to be done that the context manager is doing let me know pls!

@brando90
Copy link
Author

brando90 commented Nov 4, 2021

doesn't seem to work for some reason even though the code runs. model diverges:

>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5941, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5939, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5941, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5939, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)
>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)
eval_loss=0.9438660621643067, eval_acc=0.6169230967760087
args.meta_learner.lr_inner=0.01
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(50.2733, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5754, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(1.5779e+14, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(67053832., grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(1.4722e+13, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(1.8879e+13, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(109863.9609, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(352.0122, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.7310, grad_fn=<NormBackward1>)
>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)
>>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(926.6426, grad_fn=<NormBackward1>)
eval_loss_sanity=nan, eval_acc_santiy=0.20000000298023224

@brando90
Copy link
Author

brando90 commented Nov 4, 2021

memory issues perhaps? #75

@brando90
Copy link
Author

brando90 commented Nov 4, 2021

my divergance is caused by the data leakage: #107

removing the .train() flag is what makes the errors.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant