Demystifying AI Scaling Laws

Demystifying AI Scaling Laws

AI Scaling Laws are front and center again, thanks to recent groundbreaking announcements by AI startups that are claiming parity with their much longer established and more extensive counterparts.

Some of the key concepts behind AI scaling were well articulated in 2020 by a team of OpenAI researchers who moved on to co-found Anthropic (see https://arxiv.org/pdf/2001.08361).

The researchers identified “precise power-law scalings for performance as a function of training time, context length, dataset size, model size, and compute budget.” In essence, AI model performance improves consistently with increases in model size, dataset size, and compute power.

Although the commercial landscape of AI has evolved significantly since 2020, these scaling laws remain consistent, with profound implications for the AI infrastructure that supports the training and inference processes users increasingly rely on.

Pre-training Scaling, post-training Scaling, and Test-time Scaling refer to different stages and enhancements in the lifecycle of training machine learning models, particularly large-scale models like those used in deep learning.

As Pre-training Scaling and Post-training Scaling see diminishing returns, another technique that is gradually emerging is known as Test-time Scaling, which sees the AI dynamically allocate resources during inference in a way that is no longer limited to parameter optimization.

Article content
NVIDIA CEO Jensen Huang discusses AI scaling laws during CES 2025.

Pre-Training Scaling

This is the initial phase where a model is trained on a large, general dataset (often unsupervised or semi-supervised) to learn foundational patterns or representations.

During pre-training, the model gains broad knowledge, such as understanding language, image patterns, or generic tasks, which can be transferred to more specific tasks.

This phase typically involves massive datasets and computational resources, and utilizes self-supervised learning, such as predicting masked tokens in language models or reconstructing data.

Each of the scaling phases have their unique continuous improvements steps, which in the case of Pre-training Scaling include:

1. Larger datasets and more diverse sources

Example 1: Common Crawl for Language Models

Many large language models, such as OpenAI’s GPT-3 and Google’s PaLM, are pre-trained on datasets sourced from Common Crawl, which contains vast amounts of web pages, blogs, forums, and other web content.

These datasets are further supplemented with specialized data, such as Wikipedia, GitHub projects, books, academic papers, and multilingual corpora, ensuring the model learns a broad range of language structures and domain-specific knowledge.

Example 2: LAION Dataset for Vision-Language Models

Models like CLIP (OpenAI) and Stable Diffusion are trained on the LAION dataset, which pairs over a billion image-text pairs scraped from the internet, significantly enhancing the diversity and quality of visual and textual information.

2. Innovations in architectures (e.g., transformers, sparse attention)

Example 1: Transformers

Transformers, introduced by the seminal “Attention is All You Need” paper from Google in 2017, revolutionized pre-training by enabling parallel processing of sequence data and learning long-range dependencies in text, replacing older RNN-based architectures (see https://arxiv.org/abs/1706.03762).

Models like BERT, GPT-3, and T5 are based on transformer architectures.

Example 2: Sparse Attention

Sparse transformers, such as Longformer and BigBird, reduce the computational cost of attention mechanisms by limiting attention to local or important subsets of data rather than the entire input sequence.

This enables pre-training on much longer sequences, such as entire documents or videos, which was computationally infeasible with standard transformers.

3. Efficiency improvements, such as low-rank adaptation or parallel processing

Example 1: Low-Rank Adaptation (LoRA)

LoRA modifies pre-trained models by adding low-rank matrices during fine-tuning, allowing for efficient adaptation to downstream tasks with fewer parameters.

For instance, LoRA has been applied in fine-tuning large language models like GPT-3 to perform task-specific operations with minimal additional computational costs.

Example 2: Parallel Processing with TPU and GPU Clusters

Google’s TPU Pods and NVIDIA DGX Systems enable massive parallelism for pre-training large models by distributing computations across hundreds or thousands of accelerators.

For example, PaLM 2 by Google was trained on thousands of TPUs using model parallelism to split the workload across devices, drastically reducing training time.

Example 3: Mixed-Precision Training

Tools like NVIDIA Apex enable training in half-precision (FP16) instead of full-precision (FP32), reducing memory usage and speeding up training while maintaining model accuracy.

Mixed-precision training has been widely used in models like Megatron-LM and Turing-NLG.

Post-Training Scaling

After pre-training, the model is fine-tuned on a smaller, task-specific dataset to adapt it to a particular application (e.g., sentiment analysis, medical diagnosis).

The purpose of this phase is to customize the pre-trained model to perform well on specific downstream tasks.

Fine-tuning typically requires less data and computational power than pre-training, and it involves supervised learning or reinforcement learning with a well-defined goal.

Continuous improvements in post-training include:

1. Prompt tuning or adapter layers to reduce the computational overhead

Example 1: Prompt Tuning

Prompt tuning involves freezing the pre-trained model and learning only a small set of task-specific parameters called “soft prompts.” For example, in T5 (Text-to-Text Transfer Transformer), prompts can be designed to guide the model to perform tasks like summarization or translation without retraining the entire model.

This significantly reduces computational costs and storage needs because the base model remains unchanged, and only the prompts are fine-tuned.

OpenAI’s GPT-3 achieves task-specific performance through few-shot or zero-shot prompting by appending examples or instructions directly to the input text. This eliminates the need for fine-tuning and allows users to leverage the same pre-trained model for diverse tasks.

Example 2: Adapter Layers

Adapters are lightweight, task-specific modules inserted into a pre-trained model (e.g., BERT) without modifying the original weights. For instance, the AdapterHub framework enables researchers to fine-tune adapters for various NLP tasks, reducing the computational overhead compared to full fine-tuning.

This method has been used to adapt large models for specific tasks like question answering and sentiment analysis while keeping the main model reusable for other applications.

Adapters have also been applied to multimodal models like CLIP to specialize them for domain-specific image or text classification tasks while retaining the general-purpose pre-trained model.

2. Incorporating task-specific feedback, such as reinforcement learning from human feedback (RLHF)

Example 1: OpenAI’s ChatGPT

Reinforcement Learning from Human Feedback (RLHF) is central to ChatGPT’s development. After pre-training on a broad dataset, the model was fine-tuned with human feedback to align its responses with user preferences.

The process involved:

1. Training a reward model using labeled examples of preferred responses.

2. Optimizing the model with Proximal Policy Optimization (PPO), a reinforcement learning algorithm, based on the reward model’s feedback.

RLHF significantly improves response quality and aligns the model’s behavior with human values and expectations.

Example 2: Alignment of GPT-4 for Safety and Usability

GPT-4 employs RLHF to refine the model’s ability to handle sensitive or complex queries responsibly. Feedback from human evaluators helps the model avoid harmful content and provide safer, more accurate responses.

Example 3: DeepMind’s Sparrow

DeepMind’s chatbot, Sparrow, incorporates RLHF to refine its conversational capabilities. It uses feedback from humans to ensure that its responses are informative, aligned with ethical guidelines, and less likely to spread misinformation.

Test-Time Scaling

Refers to techniques applied during the inference phase (when the trained model is used for prediction) to enhance its performance without altering its parameters.

The purpose of this phase is to improve the model’s adaptability and accuracy during real-world usage or deployment.

Continuous improvements include:

1. Dynamic Test-Time Augmentation: Augmenting input data dynamically to improve predictions (e.g., applying image transformations or paraphrasing text inputs).

Example 1: Image Transformations in Vision Models (EfficientNet)

At inference, test-time augmentations like cropping, flipping, rotation, or color jittering are applied dynamically to an input image. The model predicts the class probabilities for each augmented version of the image, and the final prediction is derived by averaging these probabilities. This improves robustness and accuracy without retraining the model.

Example 2: Text Paraphrasing in NLP Models

During inference for tasks like sentiment analysis or translation, multiple paraphrased versions of the input are generated and passed through the model. Predictions from these augmented inputs are combined (e.g., averaged or majority-voted) to improve reliability and reduce sensitivity to specific wording.

2. Ensemble Techniques: Combining outputs of multiple models or using different layers of the same model.

Example 1: Model Ensembles for Classification

In applications like autonomous driving, deep ensembles in vision models such as ResNet, EfficientNet, and Vision Transformers can be trained independently on the same task. During test time, their predictions are combined (e.g., by averaging or weighted voting) to increase accuracy and reliability, reducing the risk of errors from a single model.

Example 2: Layer Ensembles in Transformer Models

During inference, outputs from multiple transformer layers (e.g., middle and top layers) are combined to generate predictions. This technique has been used in BERT Layer Ensembles (BERT-based Question Answering systems) to leverage the complementary strengths of different layers, improving overall prediction quality.

3. Test-Time Optimization: Adjusting specific aspects of the model (like normalization layers or attention weights) in real-time to better fit the given input.

Example 1: Batch Normalization Recalibration

During test time, in techniques such as DeepLabV3 for Image Segmentation, the statistics of batch normalization layers (mean and variance) can be recalibrated using the actual test data rather than relying on precomputed training statistics. This ensures better alignment between the model and the distribution of test-time inputs.

Example 2: Attention Weight Adjustment

Models like Transformer-XL dynamically adjust attention in NLP models by changing the weights during test time based on the content and length of the input sequence. This reduces the over-reliance on specific tokens and improves long-sequence understanding.

4. Scaling Mechanisms: Adjusting the computational budget or inference depth dynamically based on the complexity of the input, optimizing resource use and accuracy.

Example 1: Dynamic Depth Scaling in Neural Networks

EfficientNet with compound scaling dynamically adjusts its inference depth based on the complexity of the input image. For simple images, shallow layers suffice, reducing computation, while more complex images utilize deeper layers to enhance accuracy.

Example 2: Progressive Inference in Language Models

For models like GPT-3, using techniques such as BigBench Adaptive Inference, inference depth (e.g., number of transformer layers processed) can be dynamically adjusted based on the complexity of the query. Simple queries use fewer layers, optimizing inference speed and resource usage.

Example 3: Neural Architecture Search (NAS) with Dynamic Scaling

At test time, MnasNet adjusts computation by scaling the input resolution or selecting smaller subnetworks to handle inputs dynamically, ensuring a trade-off between latency and accuracy.

To view or add a comment, sign in

Others also viewed

Explore topics