The surprising effectiveness of self-supervised learning
Self-supervised learning, or SSL, used to be a pretty niche technique in machine learning. But thanks to LLMs, it’s become one of the most important ways to train complex machine learning models. It’s quite an interesting technique, and — as the title suggests — surprisingly effective. So I thought I’d say something about it.
SSL rests on two principles: learning from unlabelled data, and learning through auxiliary tasks. The first of these just means it doesn’t know what it’s learning. That is, if you give it a picture of a cat, you don’t tell it it’s a picture of a cat. Or if you give it some text, you don’t tell it anything about the meaning of the text. Instead, it has to work these things out for itself. Which is generally quite challenging, at least when you compare it to supervised learning, in which all the data are labelled. But given that most data in the world are unlabelled, it’s often necessary.
The second — learning through auxiliary1 tasks — means that you don’t train the model to do what you want it to do. Instead, you train it to do some other task, and then you tweak the resulting model for your intended use case. A prominent example can be seen in the training of text-based LLMs. These learn from large quantities of unlabelled text, and they — at least initially — do this through the auxiliary task of predicting the next word2. When you use an LLM, you don’t want to predict the next word in a sentence; rather, you want to prompt it with a question, in response to which you expect an answer. But if you want to train the LLM on vast quantities of data, then the auxiliary task of next word prediction provides a route to do this, since most data in the world does not take the form of question-answer pairs.
And this shows the power of SSL. By learning to predict the next word in a sentence, and doing this across most of the text on the internet, it implicitly learns about the underlying structure of language. Which is a pretty big thing. The resulting model can then be tweaked to solve the actual task of answering questions, and this can be done using a quantity of curated and labelled examples that is tiny compared to the masses of unlabelled text initially used to train the model. Or to put it another way, using SSL allows LLMs to be built that can answer questions, something there’s not enough data to directly train them to do.
And I’d say this is quite surprising. If you’d asked me a few years ago: can we train general-purpose AI systems by asking them to predict the next word in everyone’s internet ramblings, I wouldn’t have thought the answer was “yes”. But apparently it is.
Whilst SSL has become famous thanks to text-based LLMs, it’s not restricted to these. It’s also increasingly used in large image-based models such as vision transformers, which face the same problem in terms of access to sufficient labelled data. That is, whilst there are a huge number of images (mostly of cats) on the internet, comparatively few of them are labelled with explicit descriptions that can be used within the training process. So again, it’s desirable to learn from unlabelled data, and again this can be done using auxiliary tasks.
There are many different auxiliary tasks you can use for image-based SSL. Probably the simplest ones to understand are in-painting and colourisation. In-painting involves taking an image, removing a section of it, and then training the model to generate a new image that contains the missing part. Colourisation involves taking an image, making it grayscale by removing all the colour information, and then training the model to turn it back into the original colour version. That is, in both cases, the auxiliary task the model is being trained to do is to replace information that was removed from the image. Similar to text-based LLMs, the resulting model can then be tweaked3 using a relatively small set of labelled data to carry out a different task — such as image classification. Again, it’s a bit surprising this works, but it does.
But SSL isn’t just useful for building large models from large unlabelled data sets. It’s also handy for dealing with small data sets, where it can be used to squeeze more information out of training data by first pretraining on an auxiliary task and then using the resulting model as a starting point for training on the actual task. In this case, the auxiliary task is being used to extract useful knowledge about the underlying characteristics of the data. During pretraining, this knowledge gets implicitly embedded in the model’s initial parameter values, and can then be leveraged when training the model to solve its intended task. Or to put it another way, it embeds a different perspective on the data that is complementary to the target task.
An example of this is the work of Heba El-Shimy, one of my PhD students, who is working on the problem of detecting cancerous polyps in colonoscopy videos using capsule networks4. This is an area in which there is relatively little good-quality labelled data available. It’s also an area that doesn’t seem to benefit from transfer learning, in which a model is first trained on a large generic image dataset such as ImageNet using supervised learning, and is then fine-tuned on the data of interest. This is probably because images of a person’s insides bear little resemblance to the cats, planes and such like that make up generic datasets. And basically this means the only real option is to train the model from scratch using just the available data.
Above is an example of a pretrained capsule network solving the colourisation and in-painting tasks. Heba experimented with three auxiliary tasks. In addition to in-painting and colourisation, she also used something called contrastive learning, which involves training the model to recognise different variants of the same image produced through cropping, rotation and such like. The most effective approach turned out to be a combination of in-painting and contrastive learning, demonstrating that it can be beneficial to carry out multiple auxiliary tasks in order to maximise the information that’s extracted during pretraining.
You can find more details on her work in this preprint5, but essentially Heba found that SSL led to a significant increase in the accuracy of capsule networks, achieving this without requiring any additional data. And this is important because, in many fields, data is both limited and quite specialised — meaning that larger more generic datasets cannot readily be used for pretraining.
So, SSL can be a powerful technique regardless of the nature of the data it’s being applied to, and it’s shown benefit when applied to data of different modalities and datasets of different sizes. I guess one thing to be careful of is information leaks6, especially when applying SSL using the same dataset that is going to be used for supervised training (as in the last example). In this case, it’s important to split the data into train and test sets prior to applying SSL, and only use the training data for SSL; otherwise information about the test set can leak into the pretrained weights.
I’m using the term auxiliary task here since I think it is more self-explanatory. However, a more typical term used in the literature is pretext task. The actual/target task is also known as the downstream task.
This is specifically for GPT-style LLMs, i.e. most of the current bunch. Earlier LLMs like BERT, which are used for tasks such as classification rather than text generation, instead used missing word prediction. That is, you mask out a word in a sentence, and ask it to fill it in — a bit like the in-painting task I mention later in this post.
In case you’re wondering how this is implemented, it typically takes the form of an encoder-decoder architecture, which is something I talked about in Deep Dips #2: Embeddings and latent spaces. In effect, the encoder part is trained to compress the input image into a latent representation, and the decoder part is trained to carry out the auxiliary task. Once the encoder is sufficiently trained, the decoder is thrown away, and the encoder is bolted on to whatever’s needed to carry out the downstream task, e.g. a classification head. This new component can then be trained in a supervised fashion using the labelled data.
This is an unconventional deep learning architecture from Geoffrey Hinton’s group. The idea is to form more stable representations of image elements by explicitly capturing how they’re transformed within an image. It also has some nice characteristics in terms of explainability, which has driven interest in using it for medical applications. Check out this review of capsule networks in medical image analysis.
Currently in press. Presented at the 9th International Conference on Information System Design & Intelligent Applications, where Heba won the best paper award.
See my pitfalls guide for lots more about this sort of thing.