
This article is part of our coverage of the latest in AI research.
One of the longstanding problems of machine learning is the memorization of wrong correlations. Here’s an example: Say you are developing a deep neural network to classify images between land birds and seabirds. You train the model on thousands of labeled images and it performs very well on the train and test sets. However, when you show the model a new picture of a wounded seabird being treated at a veterinarian, it misclassifies it as a land bird.
It turns out that you had trained your model on images of sea birds in the wild, flying over the sea. As a result, instead of learning the features of sea birds, the model learned to detect images that contain vast expanses of water. Since the new image was missing water, the model mistook it for a land bird.
This is an example of a machine learning model learning a spurious correlation between features and labels. ML models are lazy. They will often chase the shortest path to their goal. In this case, the model memorized the wrong features in its training data, which was the presence of water in seabird pictures.
The main drawback of memorizing spurious correlations is the lack of generalization. The model can give a false sense of progress but might fail to work well in real-world situations. The bird classification anecdote is a benign example. But spurious correlations can cause harm when machine learning models are used in critical applications, such as healthcare or autonomous vehicles.
How do you detect whether machine learning models from learning spurious correlations? A new paper by researchers at Universite de Montreal and FAIR at Meta explores the dynamics of memorization in machine learning models and how it leads to learning spurious correlations. They also propose a new paradigm called “memorization-aware training” (MAT), which can help prevent ML models from learning spurious correlations during training.
The problem with ERM
The standard method to train neural networks is empirical risk minimization (ERM), a learning algorithm that seeks to minimize the model’s loss over a training dataset. Stochastic gradient descent (SGD), used in machine learning and deep learning, is an optimization algorithm that solves ERM.
The problem with ERM is that it can drive models to quickly capture spurious correlations instead of learning the true patterns that explain the underlying distribution of the problem. When the spurious correlation is very prominent (e.g., the water pixels in the seabird example), the model will stop its learning before it has the chance to properly learn the real useful patterns (e.g., the bird pixels in the images). This results in poor generalization because the spurious features can be absent in real situations while the useful features are always present (e.g., a seabird away from water).
If a model has enough parameters, it might even memorize example-specific features that are unique to individual data points and do not generalize to other examples. These features are not related to the core attributes that are truly predictive of the target variable.
To verify whether the model has learned spurious correlations, it must be evaluated on a held-out that contains minority examples, instances that do not conform to the simple explanations that neural networks learn from the majority of the training data. For example, consider a model that classifies images of cows and camels. If most cows in a training set appear on grass and most camels on sand, a minority example would be a cow on sand or a camel on grass.
Memorization-aware training
While held-out examples can help spot signs of memorizing spurious correlations, the paper suggests a method that uses minority examples to guide the model toward learning generalizable patterns.
Called memorization-aware training (MAT), the technique uses held-out predictions to modify the model’s logits, the raw predictions that a neural network outputs before they are converted into probabilities.
Specifically, MAT modifies the ERM objective by introducing a per-example logit shift based on “calibrated held-out probabilities.” Calibrated probabilities are designed to adjust the training focus by increasing the loss for examples with incorrect held-out predictions and lowering it for examples with correct held-out predictions. By adding these probabilities to the loss function, the training algorithm prevents the model from memorizing spurious correlations and prioritizes the learning of minority or hard-to-classify examples, which often suffer from poor generalization.
To calculate the held-out probabilities, MAT uses an auxiliary model, trained using cross-risk minimization (XRM). XRM is a training technique designed to discover different environments within a dataset by training two networks on random halves of the training data. The key idea is to encourage each network to learn a biased classifier, and then use the errors made by one model on the other’s data, or cross-mistakes, to annotate training and validation examples.
To track the effectiveness of MAT, you can compare the difference between a trained model’s average accuracy and worst-group accuracy (WGA). ( WGA measures a model’s accuracy on the subgroup where it performs the worst. It is a crucial metric for assessing the model’s robustness, particularly when dealing with spurious correlations and imbalanced datasets.)
In classic training methods, the gap between average accuracy and WGA can be large. In MTA, the gap diminishes (albeit at the cost of a small penalty to average accuracy), thus reflecting the real capabilities of the model.
While the industry is super-focused on areas such as large language models (LLMs), it is refreshing to see continued work on the fundamentals of ML. Techniques such as MAT can be crucial for real-world ML applications, where you expect your model to deal with all kinds of surprises.