Unsupervised Data Pruning: Enhancing Model Performance with Less
Written on
Unsupervised data pruning: less data to learn better
Not every increase in data volume guarantees a more accurate model. So, how do we select the right data?
Scaling laws have been noted across various domains such as images, text, and speech. The question arises: is merely increasing the number of parameters the key to building superior models? If not, what alternative strategies can be employed?
What are scaling laws, and why are they concerning?
Recently, the surge in the number of model parameters has been remarkable, with major companies striving to develop increasingly sophisticated models. This has resulted in lower errors on benchmark datasets and the appearance of unexpected behaviors. But what exactly is a scaling law?
In essence, the scaling law posits that “test error frequently diminishes according to a power law in relation to the quantity of training data, model size, or compute resources.” In simpler terms, enhancing a model's performance typically necessitates an increase in one of these three elements: the number of training examples, the count of parameters, or the duration of training sessions.
Early technical research indicated that test loss could decline as a power function of the training dataset. This concept was formally introduced in 2017 when Hestness examined it across various machine-learning fields, including machine translation, language modeling, image processing, and speech recognition.
The scaling law was later articulated in an OpenAI paper demonstrating that augmenting the model size, dataset, or computational resources leads to enhanced performance.
In their findings, they stated:
> Model performance is primarily influenced by scale, which encompasses three factors: the number of model parameters N (excluding embeddings), the dataset size D, and the compute resources C allocated for training. Performance is minimally affected by other architectural hyperparameters like depth versus width. — source
In summary, they contend that performance exhibits a power-law correlation with each of these three factors. Moreover, if the parameter count N increases, the dataset size D must also rise concurrently; otherwise, overfitting may occur. There exists a relationship between N and D, necessitating a fivefold increase in data for every eightfold increase in parameters.
This principle was exemplified in the development of GPT-3 and subsequent models, with Google’s LaMDA boasting over 500 billion parameters. While these models have demonstrated remarkable abilities, speculation persists regarding whether simply escalating parameter counts will yield general intelligence. Will it?
No. In brief, neural networks function as pattern-matching systems (or universal approximators). They identify patterns encountered during training. A larger neural network trained on more data can indeed retain more patterns and recognize more of them. But is data infinite? The answer is no.
Such power-law scaling has prompted considerable investments in data collection, computational resources, and energy consumption. However, this scaling is weak and unsustainable. — source
As noted in the OpenAI article, it requires a significant increase in data, parameters, or compute to achieve merely a 2-3% reduction in error. For instance, scaling vision transformers necessitated two billion data points to improve accuracy marginally in ImageNet.
To summarize, we have observed a trend suggesting that more data is better. But is this the sole method?
“Can we achieve exponential scaling instead, with an effective strategy for choosing training examples?”
It is worth noting that much of the data may be redundant. Models frequently encounter numerous similar examples. Datasets are often compiled by indiscriminately downloading thousands of examples from the internet. Previous research has indicated that it’s feasible to rank training examples based on their difficulty, from easy and redundant to challenging ones. By reducing the dataset size while preserving performance, one can eliminate easy and redundant examples that merely consume training cycles without contributing to learning.
Unsupervised data pruning: can you remove useless training data points without knowing the labels?
Previous studies have left several questions unanswered. For instance, can a power law of error concerning data be defined that allows for an exponential reduction in examples without compromising performance? Furthermore, the strategies mentioned necessitate labeled examples, which can be both time-consuming and costly. Therefore, an optimal strategy should be unsupervised.
Recently, a paper was published addressing these questions, resulting from collaboration among META, Stanford, and the University of Tübingen:
Beyond neural scaling laws: beating power law scaling via data pruning
The authors began with the observation that the scaling law is inefficient: the exponents of the power law are nearly zero (indicating suboptimal resource use). Additionally, increasing parameters or data results in a minimal error reduction. What we desire is the ability to prune a dataset without affecting model performance, even when the dataset lacks labels (as labeling data is one of the most labor-intensive and costly processes). How?
The authors investigated this possibility within a teacher-student framework. This training method involves a pre-trained model (the teacher) with a large dataset (not limited to CNNs, although they are commonly used). A smaller model (the student) is trained using the teacher’s output (probabilities rather than original labels, also known as soft labels).
In essence, the authors used CIFAR-10 as their dataset, derived probabilities from a teacher model, and trained a student model for several epochs using the teacher’s output probabilities as labels. They subsequently calculated the margin between student and teacher outputs, which offers a measure of the learning challenge posed by each example.
The first notable finding is that when the training set contains a large number of examples, retaining the difficult examples while pruning the easy ones is preferable. Conversely, when the training set is small, it is more advantageous to remove the challenging examples. This may seem counterintuitive. The authors explain that easy examples provide basic information about the target function, essentially capturing the dataset's general patterns. In contrast, difficult examples deliver more detailed insights that could be overlooked in larger datasets.
> Intuitively, in the limited data scenario, modeling outliers is challenging since the fundamentals are not adequately captured; thus, retaining easy examples is crucial for achieving moderate error. However, with a larger dataset, learning the easy examples is straightforward, making the challenge of modeling outliers more significant. — source
In other words, with scarce data, the model benefits from grasping general patterns, while abundant data allows difficult examples to aid in delineating decision boundaries between classes.
From an information-theoretic standpoint, the authors suggest that data pruning enhances the information gained from each individual example by filtering out uninformative examples.
Data pruning enhances transfer learning
One of the motivations behind exploring scaling laws is the pursuit of foundation models. A foundation model is a broad model (transformer, vision transformer, etc.) trained on a vast amount of unlabeled data, which can subsequently be adapted for various downstream tasks.
In simpler terms, a wide model is trained on large datasets, which is then fine-tuned for different tasks. Examples include BERT and GPT-3 for text-related tasks, and ResNet for computer vision tasks (DALL-E and stable diffusion also incorporate a pre-trained language model).
Training a foundation model is exceedingly costly, and thus far, efforts have focused on augmenting the number of parameters and the amount of training data. However, studies, including DeepMind’s Chinchilla, indicate that refining the training data might yield more benefits. The authors of this study pondered: Could data pruning enhance transfer learning?
The authors employed a pre-trained vision transformer (ViT) and fine-tuned it on a pruned subset of 10% of CIFAR-10. This method outperformed fine-tuning the ViT on the entire CIFAR-10 dataset. Additionally, they pre-trained ResNet50 on various pruned subsets of ImageNet (a reduced version) and then fine-tuned it on CIFAR-10. The results indicated that training on a pruned dataset yielded superior performance compared to using the entire ImageNet.
Thus, intriguingly, pruning pre-training data for an upstream task can sustain high performance on a different downstream task. Collectively, these findings highlight the potential of data pruning in enhancing transfer learning during both pre-training and fine-tuning stages.
Scaling the approach on a large dataset
Prior pruning studies have focused on small datasets, yet understanding how these principles generalize to larger datasets is crucial. Consequently, the authors benchmarked various previous approaches on ImageNet to assess their impact on model performance, selecting eight distinct methods.
The results showed that these metrics retained only a fraction of challenging examples and outperformed random pruning. However, while these methods performed well on smaller datasets, few matched the performance achieved through training on the complete dataset. The authors further noted:
> We discovered that all pruning metrics exacerbate class imbalance, leading to diminished performance. Many pruning metrics do not scale effectively to ImageNet, and those that do require substantial computational resources. Moreover, all these metrics necessitate labels, limiting their applicability for pruning data in large foundation models trained on extensive unlabeled datasets. Hence, there is a clear demand for simple, scalable, self-supervised pruning metrics. — source
The authors proposed the following solution:
- First, utilize a pre-trained model named SWaV to extract a low-dimensional representation for each example in the dataset.
- Then, employ k-means clustering to group the representations.
- Next, calculate the distance to the cluster center using cosine distance. Examples closer to the center are deemed easy, while those farther away are classified as difficult.
- Finally, one can decide to prune a certain percentage of easy or difficult examples as needed.
> Our self-supervised prototype metric matches or surpasses the performance of the best supervised metric, memorization, until only 70-80% of the data is retained, despite the fact that our metric does not utilize labels and is considerably simpler and cheaper to compute than many previously proposed supervised metrics. — source
The results equaled the state-of-the-art technique of memorization, which requires labels and is slower to compute.
Conclusions
The authors demonstrate how data pruning can influence errors comparably to the scaling law. Additionally, they reveal that unsupervised learning can yield a coreset (a subset of a dataset allowing model training with equivalent performance to the full dataset). This approach is cost-effective, scalable, and label-free.
Looking ahead, the authors suggest that this method can be further refined to enable even more aggressive pruning, which would be immensely beneficial for training large foundation models. They also propose:
> If highly pruned versions of these datasets can be employed to train a multitude of different models, one might envision these carefully selected data subsets as foundational datasets, where the initial computational investment in data pruning can be amortized across efficiency improvements in training numerous downstream models, similar to how the initial computational costs of training foundation models are offset by efficiency gains in fine-tuning for various downstream tasks. — source
In conclusion, reducing the dataset size prior to training conserves time and resources (less labeling work). Furthermore, minimizing overrepresented populations may assist in addressing or identifying biases during training.
What are your thoughts? Have you experimented with dataset pruning?
If you found this topic intriguing:
Explore my other articles, and feel free to subscribe for notifications on new publications. You can also connect with me on LinkedIn.
Here’s the link to my GitHub repository, where I plan to compile resources and code related to machine learning, artificial intelligence, and more.
GitHub - SalvatoreRa/tutorial: Tutorials on machine learning, artificial intelligence, data science…
Tutorials on machine learning, artificial intelligence, data science with mathematical explanations and reusable code (in Python).
or you may be interested in one of my recent articles:
Microsoft BioGPT: Towards the ChatGPT of life science?
BioGPT achieves the SOTA in various biomedical NLP tasks.
Everything but everything you need to know about ChatGPT
What is known, the latest updates, its impacts, and changes—all in one article.