-
Notifications
You must be signed in to change notification settings - Fork 124
Is there data leakage in the maml-omniglot example? #107
Comments
I think there is. However, in my experience without it the model diverges and explodes after an adaptation step (i.e. e.g. 5 steps of the inner opt):
though these are on mini-imagenet but I am 100% that the issue causing it. when I do |
@SunHaozhe did you ever fix this...? I think the only way to fix it (and idk if it will work) is to either
|
create a pull request: #122 |
@SunHaozhe I don't think this is an issue anymore because of this: this is likely wrong because during training we want to use batch statistics in meta-leanring, since tasks have different distributions. But then how do we retain determinism at inference time? see: https://discuss.pytorch.org/t/how-does-one-use-the-mean-and-std-from-training-in-batch-norm/136029/5 in summary .train() uses batch statistics, yes it updates the running mean with cheating means but those are never actually used if the .train() is set. During training, the network uses batch statistics anyway. As long as you don't save the model as a checkpoint with these cheated means it doesn't matter. But your code does become less determinsitic. |
In the
maml-omniglot.py
example code,net.train()
is used for meta-test phases (link).Does this not cause data leakage of meta-test data via the statistics of
nn.BatchNorm2d
(net
contains severalnn.BatchNorm2d
)?The text was updated successfully, but these errors were encountered: