Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Selfie: Self-supervised Pretraining for Image Embedding #63

Open
howardyclo opened this issue Aug 17, 2019 · 1 comment
Open

Selfie: Self-supervised Pretraining for Image Embedding #63

howardyclo opened this issue Aug 17, 2019 · 1 comment

Comments

@howardyclo
Copy link
Owner

Metadata

@howardyclo
Copy link
Owner Author

TL;DR

  1. This paper aims to translate the success of language model pre-training from texts to images by proposing the BERT-like self-supervised learning method called Selfie.
  2. Selfie combines BERT and CPC (contrastive predictive coding) loss which is novel.
  3. The related work is well covered.

Method

  • 2 stage: pre-training (focus of this work) and fine-tuning.
  • Pre-train the first 3 blocks of ResNet-50, and fine-tune the whole ResNet-50.
  • High-level idea:
    • (1) Given an image, split it to 3x3 (for example) patches.
    • (2) Randomly masked out 3 patches, for example, 3th, 4th, 8th patch.
    • (3) Task: Given context patches (1th, 2th, 5th, 6th, 7th, 9th), predict what patched is being masked out (predict one at a time).
    • (4) The prediction is formed as a classification task instead of regression (i.e., generate a patch) since regression is sensitive to small changes in the image.
  • Model detail:
    • Patch processing network (Pnet): The first 3 blocks of ResNet + average pooling, aiming to encode 1-9th patches independently. Now we have 9 feature vectors.
    • Encoder: Encode (1th, 2th, 5th, 6th, 7th, 9th) feature vectors to a "context vector (u)" with an "Attention Pooling Layer".
    • Choose which location of masked out patch to predict: Add the "position embedding" of a masked out patch (e.g., 4th) with the context vector (u), we get a "query vector (v)".
    • (Pointer-based) Decoder: Point to the masked out patch (e.g., 4th) based on the query vector (v) i.e., compute the dot product (similarity) between each pair <i-th feature vector, v> (i=3, 4, 8), and maximize the similarity of <4th feature vector, v> by cross-entropy loss.
    • Attention Pooling Layer:
      • Can be think of a generalized average/max pooling layer.
      • Here they use Transformer layers for pooling.
      • The attention blocks follow the self-attention in BERT.
    • Positional embedding: Decomposed to row and col embeddings and sum together.

Experiment Setting

  • Dataset: CIFAR-10, ImageNet 32x32 and 224x224.
  • Patch size: 8x8 for 32x32 and 32x32 for 224x224.
  • Full data (50K for CIFAR-10 & 1.2M for ImageNet) for pre-training stage.
  • Split data to 5%, 10%, 20%, 100% labeled data for fine-tuning stage.
  • Replace CIFAR-10 10% data with 8% (4000 examples) data following AutoAugment and Realistic Eval of SSL papers.
  • They also describe model training and hyper-parameters in detail.
  • Train model for 120K steps.
  • Baseline ResNet achieves strong accuracy 95.5% on CIFAR-10 and 76.9% on ImageNet 224x224.

Key Experiment Results and Findings

  • Selfie pre-training helps regularization when labeled data is 10%.
  • Selfie pre-training improves all performance when labeled data is 5%~100%.
  • Pre-training benefits more when there is less labeled data.
  • Self-attention as the last layer helps fine-tuning performance.

Fine-tuning Sensitivity and Mismatch to Pre-training.

  • There are difficulties in transferring pre-trained model across tasks, e.g., from ImageNet to CIFAR-10.
  • For the 100% subset of ImageNet 224x224, additional tuning of the pre-training phase using a development set is needed to achieve the reported result.
  • Input mismatch exists between pre-training (the Pnet only sees the patches) and fine-tuning (the Pnet sees the whole image).

Personal Thoughts

  • I am wondering that whether the patch size is too small for ResNet-based Pnet to encode, especially 8x8 patch.
  • The attention pooling layer and positional embedding is not clear here, may refer to the BERT paper.
  • Where is CPC loss? The paper only describes cross-entropy loss. May also need to refer to CPC paper.
  • The hyper-parameter selection in Reporting results section is not clear to me. Is it training on pretraining stage or finetuning stage?
  • The performance gain from SSL becomes very marginal when labeled data becomes more, especially if task is easy (CIFAR-10's gain is smaller than ImageNet).
  • The challenge of making SSL more useful (i.e., acquire more performance gain in full labeled data) still exists.
  • Although they didn't compare to other SSL approach (e.g., for CPC, they said that CPC uses ResNet-171), they emphasize that their baseline ResNet is stronger than previous ones, which hope to make the experiment results more convincible to the reader.
  • They did not show whether the effect of regularization still stands if we have full labeled data.
  • There's a semi-supervised learning approach called UDA achieves 94.7% accuracy on the 8% CIFAR-10, while this paper only achieves 80.3% accuracy, suggesting that maybe semi-supervised learning is more promising than self-supervised learning (?). Or maybe they can be combined with each other, as shown in the S^4L paper.
  • Other worth-noting author: Avital Oliver (author of MixMatch and S^4L,..) and Xiaohua Zhai (author of Revisiting SSL and S^4L). Lets read their papers!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant