This paper is available on arxiv under CC 4.0 license.
Authors:
(1) Andrey Zhmoginov, Google Research & {azhmogin,sandler,mxv}@google.com;
(2) Mark Sandler, Google Research & {azhmogin,sandler,mxv}@google.com;
(3) Max Vladymyrov, Google Research & {azhmogin,sandler,mxv}@google.com.
In this work we propose a HyperTransformer, a transformer-based model for fewshot learning that generates weights of a convolutional neural network (CNN) directly from support samples. Since the dependence of a small generated CNN model on a specific task is encoded by a high-capacity transformer model, we effectively decouple the complexity of the large task space from the complexity of individual tasks. Our method is particularly effective for small target CNN architectures where learning a fixed universal task-independent embedding is not optimal and better performance is attained when the information about the task can modulate all model parameters. For larger models we discover that generating the last layer alone allows us to produce competitive or better results than those obtained with state-of-the-art methods while being end-to-end differentiable. Finally, we extend our approach to a semi-supervised regime utilizing unlabeled samples in the support set and further improving few-shot performance.
In few-shot learning, a conventional machine learning paradigm of fitting a parametric model to training data is taken to a limit of extreme data scarcity where entire categories are introduced with just one or few examples. A generic approach to solving this problem uses training data to identify parameters φ of a learner aφ that given a small batch of examples for a particular task (called a support set) can solve this task on unseen data (called a query set).
One broad family of few-shot image classification methods frequently referred to as metric-based learning, relies on pretraining an embedding eφ(·) and then using some distance in the embedding space to label query samples based on their closeness to known labeled support samples. These methods proved effective on numerous benchmarks (see Tian et al. (2020) for review and references), however the capabilities of the learner are limited by the capacity of the architecture itself, as these methods try to build a universal embedding function.
On the other hand, optimization-based methods such as seminal MAML algorithm (Finn et al., 2017) can fine-tune the embedding eφ by performing additional SGD updates on all parameters φ of the model producing it. This partially addresses the constraints of metric-based methods by learning a new embedding for each new task. However, in many of these methods, all the knowledge extracted during training on different tasks and describing the learner aφ still has to “fit” into the same number of parameters as the model itself. Such limitation becomes more severe as the target models get smaller, while the richness of the task set increases.
In this paper we propose a new few-shot learning approach that allows us to decouple the complexity of the task space from the complexity of individual tasks. The main idea is to use the transformer model (Vaswani et al., 2017) that given a few-shot task episode generates an entire inference model by producing all model weights in a single pass. This allows us to encode the intricacies of the available training data inside the transformer model, while still producing specialized tiny models that can solve individual tasks. Reducing the size of the generated model and moving the computational overhead to the transformer-based weight generator, we can lower the cost of the inference on new images. This can reduce the overall computation cost in cases where the tasks change infrequently and hence the weight generator is only used sporadically.
We start by observing that the self-attention mechanism is well suited to be an underlying mechanism for a few-shot CNN weight generator. In contrast with earlier CNN- (Zhao et al., 2020) or BiLSTM-based approaches (Ravi & Larochelle, 2017), the vanilla[1] transformer model is invariant to sample permutations and can handle unbalanced datasets with a varying number of samples per category. Furthermore, we demonstrate that a single-layer self-attention model can replicate a simplified gradient-descent-based learning algorithm. Using a transformer model to generate the logits layer on top of a conventionally learned embedding, we achieve competitive results on several common few-shot learning benchmarks. Varying transformer parameters we demonstrate that this high performance can be attributed to additional capacity of the transformer model that decouples its complexity from that of the generated CNN.
We then extend our method to support unlabeled samples by using a special input token that we concatenate to all unlabeled examples encoding the fact that their classes are unknown. In our experiments outlined in Section 4.3, we observe that adding unlabeled samples can significantly improve model performance. Interestingly, the full benefit of using additional data is only realized if the transformers use two or more layers. This result is consistent with the basic mechanism described in Section 3.2, where we show that a transformer model with at least two layers can encode the nearest-neighbor style algorithm that associates unlabeled samples with similar labeled examples. In essence, by training the weight generator to produce CNN models with best possible performance on a query set, we teach the transformer to utilize unlabeled samples without having to manually introduce additional optimization objectives. Our approach could be further generalized to treat partially known sample labels (when the true label is known to belong to some set of classes), but this will be the subject of the future work.
Finally, we explore the capability of our approach to generate all weights of the CNN model, adjusting both the logits layer and all intermediate layers producing the sample embedding. We show that by generating all layers we can improve both the training and test accuracies[2] of CNN models below a certain size. But, interestingly, generation of the logits layer alone appears to be sufficient above a certain model size threshold (see Figure 3). This threshold is expected to depend on the variability and the complexity of the training tasks.
In addition to being able to decouple the complexity of the task distribution from the complexity of individual tasks, another important advantage of our method is that it allows to do learning end to end without relying on complex nested gradients optimization and other meta-learning approaches, where the number of unrolls steps is large. In contrast with these methods, our optimization is done in a single loop of updates to the transformer (and feature extractor) parameters.
The paper is structured as follows. In Section 2, we discuss the few-shot learning problem setup and highlight related work. Section 3 introduces our approach, discusses the motivation for choosing an attention-based model and shows how our approach can be used to meta-learn semi-supervised learning algorithms. In Section 4, we discuss our experimental results. Finally, in Section 5, we provide concluding remarks.
[1] without attention masking or positional encodings
[2] As discussed in Section 4.2, HT with a high training accuracy can be a practical approach to model personalization under the assumption that real tasks come from the distribution seen at the training time.