SlideShare a Scribd company logo
Accelerated Training of
Transformer Models
Kaarthik Sivashanmugam – Principal Engineering Manager
Sherlock Huang – Principal Engineer
Azure AI - Frameworks
Agenda
ONNX Runtime for Training
Introduction
Integration with training frameworks
Acceleration & Native Capabilities
Memory usage and execution optimizations
Mixed precision training, Distributed training parallelism
modes, Gradient checkpointing, AdaSum, DeepSpeed
ZeRO
Training Recipes & Perf Results
Pretraining and finetuning: BERT, GPT-2, Turing
Demo: ONNX Runtime Training in Azure Databricks
Intro: ONNX, ONNX Runtime
ONNX: an open and interoperable format for ML models
ONNX IR (intermediate representation)
ONNX Operator schema
Operation type
Attributes
Inputs/outputs
Shape inference function
https://onnx.ai/
https://github.com/onnx/onnx/blob/master/docs/Operators.md
Y
weight
(128 x 256)
(128 x 256)
(batch x 256)
X
(batch x 128)
bias
(256)
(256)
Inputs
A (batch x 128)
B (128 x 256)
C (256)
Outputs
Y (batch x 256)
Attributes
alpha: 0.7
beta: 0.5
Gemm
ONNX Spec
Graph composed of computational
nodes
Built-in and custom operators
ONNX Model
ONNX Runtime (ORT)
Cross-platform accelerator for training and inferencing
Core part of ML stack at Microsoft for innovations from the company
and industry
ORT Training
Adopted by 1P and 3P workloads for acceleration
Current focus on large transformer models (based on demand and acceleration needs)
Extensible and supports PyTorch, Keras/Tensorflow, …
ONNX Runtime for Training
Training & ORT Acceleration
Define Model
Get Data Batch
Compute Loss
Compute Gradients
& Update Weights
Evaluate
Train
Loop
Acceleration scope
Create ORTTrainer
using the model
ORTTrainer.train_step()
Checkpoint
import torch
from onnxruntime.training import ORTTrainer, optim
# Model definition
class NeuralNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
...
def forward(self, x):
...
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10)
criterion = torch.nn.Functional.cross_entropy
model_description =
{'inputs': [('data', ['in', 'batch_size']),
('target', ['label_x_batch_size'])],
'outputs’: [('loss', [], True),
('output', ['out', 'batch_size’])]
}
optimizer_config = optim.AdamConfig(lr=learning_rate)
trainer = ORTTrainer(model, model_description, optimizer_config,
optimizer configuration, criterion)
# Training Loop
for t in range(1000):
# forward + backward + weight update
loss, y_pred = trainer.train_step(x, y)
ORT in PyTorch
PyTorch PyTorch + ONNX Runtime backend
import torch
# Model definition
class NeuralNet(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
...
def forward(self, x):
...
model = NeuralNet(input_size=784, hidden_size=500, num_classes=10)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
# Training Loop
for t in range(1000):
# forward
y_pred = model(x)
loss = criterion(y_pred, y)
# reset gradient buffer
optimizer.zero_grad()
# backward
loss.backward()
# weight update
optimizer.step()
ONNXRuntime
ORT TrainingSession Python API
PyTorch Script
PyTorch
ORTTrainer
To ONNX
GPU
buffer
TF/Keras Script
TF
ORTTrainer
To ONNX
GPU
buffer
ORT Frontend Adapters
Acceleration & Native Capabilities
Contributors to ORT Acceleration
Optimal
Gradient Graph
CUDA Kernel
Optimizations
Graph
Optimizations
Memory
Efficiency
Other Training
Capabilities
Static graph optimization
techniques like constant
folding, redundant node
elimination
Memory and compute
optimized using global
knowledge of data
dependencies
Static graph used for
preallocation of memory
for weights and gradients
Memory reuse
Op fusion
Reimplemented cuDNN
kernels
Removed redundant
computation
Mixed precision training
Distributed training
parallelism modes
Gradient checkpointing
AdaSum
DeepSpeed ZeRO
Native Capabilities in ORT
Distributed
Training
Modes
Gradient
Checkpoint
Mixed
Precision
Training
Gradient
Accumulation
AdaSum
16-bit and 32-bit FP types to
make training faster and use
less memory
Parallelism modes: Data,
Horizontal and Pipeline
Computed gradients are
accumulated into gradient buffer
using partial execution of graph
repeated for N steps
Averaged gradients are used in
optimizer for weight updates
Stashed activations often
dominate memory consumption
in training
Recompute discarded
activations when needed.
Trade off between memory
usage vs. computation cost.
Combines gradients in a novel
way to improve convergence
Model converges faster
DeepSpeed
ZeRO
Redundancy
Optimizer
Optimizer State Partitioning
Gradient Partitioning
Parameter Partitioning
Code Sample & Training Recipes
BERT Pretraining using ORT
https://github.com/microsoft/onnxruntime-training-examples/
Training Recipes
▪ BERT Pretraining
▪ Nvidia’s implementation of BERT pretraining accelerated using ORT
▪ https://github.com/microsoft/onnxruntime-training-examples/tree/master/nvidia-bert
▪ GPT-2 Finetuning
▪ Finetuning of Hugging Face GPT-2 model
▪ https://github.com/microsoft/onnxruntime-training-examples/tree/master/huggingface-gpt2
▪ Turing Finetuning
▪ Finetuning of Microsoft Turing model for abstractive text summarization, sentiment analysis and suggested reply scenarios
▪ https://github.com/microsoft/Turing-NLR (private preview)
Performance Improvement Results
BERT Pretraining in 4xDGX-2
PyTorch 1.5 with
NGC 20.03-py3
PyTorch 1.5 with
ONNX Runtime
% Gain with
ONNX Runtime
Phase 1 Throughput (ex/sec) 11522.1 12826.2 11.32%
Phase 2 Throughput (ex/sec) 2150.0 2464.1 14.61%
Phase 1 time (hours) 11.12 9.99 10.16%
Phase 2 time (hours) 6.62 5.77 12.84%
Total time (hours) 17.74 15.76 11.16%
PyTorch w/ ORT can train with 2x the local batch size as PyTorch w/o ORT
(global batch size was kept the same for comparison)
Perf Improvement with ORT
Model (Scenario)/# Params Perf improvement w/ ORT
Turing* (pretraining)/340M 1.4x
Turing* (pretraining)/350M 1.2x
RoBERTa XL (pretraining)/500M 3x
RoBERTa XL (finetuning)/500M 1.2x
RoBERTa XXL (pretraining)/1B 7x
GPT-2 M(pretraining)/345M 1.2x
* https://msturing.org/
Demo: ONNX Runtime Training in Azure Databricks
https://github.com/skaarthik/onnxruntime-training-databricks
Summary
▪ Optimize and accelerate model
training using ONNX Runtime (ORT)
▪ ORT is used in training very large
models used in various Microsoft
products/services
▪ https://github.com/microsoft/onnxruntime
▪ https://github.com/microsoft/onnxruntime-
training-examples
Feedback
Your feedback is important to us.
Don’t forget to rate
and review the sessions.

More Related Content

PPTX
PDF
An introduction to the Transformers architecture and BERT
PDF
Introduction to Transformers for NLP - Olga Petrova
PDF
What is MLOps
PDF
NLP using transformers
PDF
BERT - Part 1 Learning Notes of Senthil Kumar
PPTX
Random forest
PDF
Introduction to Machine Learning with SciKit-Learn
An introduction to the Transformers architecture and BERT
Introduction to Transformers for NLP - Olga Petrova
What is MLOps
NLP using transformers
BERT - Part 1 Learning Notes of Senthil Kumar
Random forest
Introduction to Machine Learning with SciKit-Learn

What's hot (20)

PPTX
ONNX and MLflow
PPTX
NLP State of the Art | BERT
PPTX
Kaggle meetup #3 instacart 2nd place solution
PPTX
딥 러닝 자연어 처리를 학습을 위한 파워포인트. (Deep Learning for Natural Language Processing)
PPTX
Attention in Deep Learning
PPTX
Introduction to Transformer Model
PDF
BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
PDF
ViT (Vision Transformer) Review [CDM]
PPTX
Deep Learning Workflows: Training and Inference
PDF
Transformer Introduction (Seminar Material)
PPTX
Transformers AI PPT.pptx
PDF
Convolutional Neural Networks (CNN)
PDF
Recurrent Neural Networks, LSTM and GRU
PPTX
Introduction For seq2seq(sequence to sequence) and RNN
PDF
Introduction to batch normalization
PDF
Yurii Pashchenko: Zero-shot learning capabilities of CLIP model from OpenAI
PPTX
Attention Is All You Need
PDF
Machine Learning Algorithms
PPTX
Deep Learning - RNN and CNN
PDF
Introduction to object detection
ONNX and MLflow
NLP State of the Art | BERT
Kaggle meetup #3 instacart 2nd place solution
딥 러닝 자연어 처리를 학습을 위한 파워포인트. (Deep Learning for Natural Language Processing)
Attention in Deep Learning
Introduction to Transformer Model
BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
ViT (Vision Transformer) Review [CDM]
Deep Learning Workflows: Training and Inference
Transformer Introduction (Seminar Material)
Transformers AI PPT.pptx
Convolutional Neural Networks (CNN)
Recurrent Neural Networks, LSTM and GRU
Introduction For seq2seq(sequence to sequence) and RNN
Introduction to batch normalization
Yurii Pashchenko: Zero-shot learning capabilities of CLIP model from OpenAI
Attention Is All You Need
Machine Learning Algorithms
Deep Learning - RNN and CNN
Introduction to object detection
Ad

Similar to Accelerated Training of Transformer Models (20)

PPTX
Onnx and onnx runtime
PPTX
Onnx at lf oss na 20200629 v5
PDF
Flock: Data Science Platform @ CISL
PDF
BlaBlaConf'22 The art of MLOps in TensorFlow Ecosystem
PPTX
Onnx to Symbol Table Project
PDF
Tensorflow 2.0 and Coral Edge TPU
PDF
OpenPOWER Workshop in Silicon Valley
PPTX
ONNX - The Lingua Franca of Deep Learning
PPTX
AML_service.pptx
PDF
Pytorch meetup
PDF
Anirudh Koul. 30 Golden Rules of Deep Learning Performance
PDF
Icpp power ai-workshop 2018
PDF
Privacy-first in-browser Generative AI web apps: offline-ready, future-proof,...
PPTX
Cognitive Toolkit - Deep Learning framework from Microsoft
PDF
A Tale of Three Deep Learning Frameworks: TensorFlow, Keras, & PyTorch with B...
PDF
The Flow of TensorFlow
PDF
Power ai tensorflowworkloadtutorial-20171117
PDF
“State-of-the-art Model Quantization and Optimization for Efficient Edge AI,”...
PDF
Training at AI Frontiers 2018 - Lukasz Kaiser: Sequence to Sequence Learning ...
PDF
ONNX and Edge Deployments
Onnx and onnx runtime
Onnx at lf oss na 20200629 v5
Flock: Data Science Platform @ CISL
BlaBlaConf'22 The art of MLOps in TensorFlow Ecosystem
Onnx to Symbol Table Project
Tensorflow 2.0 and Coral Edge TPU
OpenPOWER Workshop in Silicon Valley
ONNX - The Lingua Franca of Deep Learning
AML_service.pptx
Pytorch meetup
Anirudh Koul. 30 Golden Rules of Deep Learning Performance
Icpp power ai-workshop 2018
Privacy-first in-browser Generative AI web apps: offline-ready, future-proof,...
Cognitive Toolkit - Deep Learning framework from Microsoft
A Tale of Three Deep Learning Frameworks: TensorFlow, Keras, & PyTorch with B...
The Flow of TensorFlow
Power ai tensorflowworkloadtutorial-20171117
“State-of-the-art Model Quantization and Optimization for Efficient Edge AI,”...
Training at AI Frontiers 2018 - Lukasz Kaiser: Sequence to Sequence Learning ...
ONNX and Edge Deployments
Ad

More from Databricks (20)

PPTX
DW Migration Webinar-March 2022.pptx
PPTX
Data Lakehouse Symposium | Day 1 | Part 1
PPT
Data Lakehouse Symposium | Day 1 | Part 2
PPTX
Data Lakehouse Symposium | Day 2
PPTX
Data Lakehouse Symposium | Day 4
PDF
5 Critical Steps to Clean Your Data Swamp When Migrating Off of Hadoop
PDF
Democratizing Data Quality Through a Centralized Platform
PDF
Learn to Use Databricks for Data Science
PDF
Why APM Is Not the Same As ML Monitoring
PDF
The Function, the Context, and the Data—Enabling ML Ops at Stitch Fix
PDF
Stage Level Scheduling Improving Big Data and AI Integration
PDF
Simplify Data Conversion from Spark to TensorFlow and PyTorch
PDF
Scaling your Data Pipelines with Apache Spark on Kubernetes
PDF
Scaling and Unifying SciKit Learn and Apache Spark Pipelines
PDF
Sawtooth Windows for Feature Aggregations
PDF
Redis + Apache Spark = Swiss Army Knife Meets Kitchen Sink
PDF
Re-imagine Data Monitoring with whylogs and Spark
PDF
Raven: End-to-end Optimization of ML Prediction Queries
PDF
Processing Large Datasets for ADAS Applications using Apache Spark
PDF
Massive Data Processing in Adobe Using Delta Lake
DW Migration Webinar-March 2022.pptx
Data Lakehouse Symposium | Day 1 | Part 1
Data Lakehouse Symposium | Day 1 | Part 2
Data Lakehouse Symposium | Day 2
Data Lakehouse Symposium | Day 4
5 Critical Steps to Clean Your Data Swamp When Migrating Off of Hadoop
Democratizing Data Quality Through a Centralized Platform
Learn to Use Databricks for Data Science
Why APM Is Not the Same As ML Monitoring
The Function, the Context, and the Data—Enabling ML Ops at Stitch Fix
Stage Level Scheduling Improving Big Data and AI Integration
Simplify Data Conversion from Spark to TensorFlow and PyTorch
Scaling your Data Pipelines with Apache Spark on Kubernetes
Scaling and Unifying SciKit Learn and Apache Spark Pipelines
Sawtooth Windows for Feature Aggregations
Redis + Apache Spark = Swiss Army Knife Meets Kitchen Sink
Re-imagine Data Monitoring with whylogs and Spark
Raven: End-to-end Optimization of ML Prediction Queries
Processing Large Datasets for ADAS Applications using Apache Spark
Massive Data Processing in Adobe Using Delta Lake

Recently uploaded (20)

PPTX
Leprosy and NLEP programme community medicine
PDF
Clinical guidelines as a resource for EBP(1).pdf
PPTX
Data_Analytics_and_PowerBI_Presentation.pptx
PDF
22.Patil - Early prediction of Alzheimer’s disease using convolutional neural...
PPTX
oil_refinery_comprehensive_20250804084928 (1).pptx
PPTX
SAP 2 completion done . PRESENTATION.pptx
PDF
Data Engineering Interview Questions & Answers Cloud Data Stacks (AWS, Azure,...
PPTX
iec ppt-1 pptx icmr ppt on rehabilitation.pptx
PDF
Transcultural that can help you someday.
PDF
Capcut Pro Crack For PC Latest Version {Fully Unlocked 2025}
PPTX
Introduction-to-Cloud-ComputingFinal.pptx
PPTX
01_intro xxxxxxxxxxfffffffffffaaaaaaaaaaafg
PPT
Quality review (1)_presentation of this 21
PDF
Introduction to Data Science and Data Analysis
PPTX
modul_python (1).pptx for professional and student
PDF
Galatica Smart Energy Infrastructure Startup Pitch Deck
PPTX
IBA_Chapter_11_Slides_Final_Accessible.pptx
PPT
Reliability_Chapter_ presentation 1221.5784
PDF
Lecture1 pattern recognition............
Leprosy and NLEP programme community medicine
Clinical guidelines as a resource for EBP(1).pdf
Data_Analytics_and_PowerBI_Presentation.pptx
22.Patil - Early prediction of Alzheimer’s disease using convolutional neural...
oil_refinery_comprehensive_20250804084928 (1).pptx
SAP 2 completion done . PRESENTATION.pptx
Data Engineering Interview Questions & Answers Cloud Data Stacks (AWS, Azure,...
iec ppt-1 pptx icmr ppt on rehabilitation.pptx
Transcultural that can help you someday.
Capcut Pro Crack For PC Latest Version {Fully Unlocked 2025}
Introduction-to-Cloud-ComputingFinal.pptx
01_intro xxxxxxxxxxfffffffffffaaaaaaaaaaafg
Quality review (1)_presentation of this 21
Introduction to Data Science and Data Analysis
modul_python (1).pptx for professional and student
Galatica Smart Energy Infrastructure Startup Pitch Deck
IBA_Chapter_11_Slides_Final_Accessible.pptx
Reliability_Chapter_ presentation 1221.5784
Lecture1 pattern recognition............

Accelerated Training of Transformer Models

  • 1. Accelerated Training of Transformer Models Kaarthik Sivashanmugam – Principal Engineering Manager Sherlock Huang – Principal Engineer Azure AI - Frameworks
  • 2. Agenda ONNX Runtime for Training Introduction Integration with training frameworks Acceleration & Native Capabilities Memory usage and execution optimizations Mixed precision training, Distributed training parallelism modes, Gradient checkpointing, AdaSum, DeepSpeed ZeRO Training Recipes & Perf Results Pretraining and finetuning: BERT, GPT-2, Turing Demo: ONNX Runtime Training in Azure Databricks
  • 4. ONNX: an open and interoperable format for ML models
  • 5. ONNX IR (intermediate representation) ONNX Operator schema Operation type Attributes Inputs/outputs Shape inference function https://onnx.ai/ https://github.com/onnx/onnx/blob/master/docs/Operators.md Y weight (128 x 256) (128 x 256) (batch x 256) X (batch x 128) bias (256) (256) Inputs A (batch x 128) B (128 x 256) C (256) Outputs Y (batch x 256) Attributes alpha: 0.7 beta: 0.5 Gemm ONNX Spec
  • 6. Graph composed of computational nodes Built-in and custom operators ONNX Model
  • 7. ONNX Runtime (ORT) Cross-platform accelerator for training and inferencing Core part of ML stack at Microsoft for innovations from the company and industry ORT Training Adopted by 1P and 3P workloads for acceleration Current focus on large transformer models (based on demand and acceleration needs) Extensible and supports PyTorch, Keras/Tensorflow, …
  • 8. ONNX Runtime for Training
  • 9. Training & ORT Acceleration Define Model Get Data Batch Compute Loss Compute Gradients & Update Weights Evaluate Train Loop Acceleration scope Create ORTTrainer using the model ORTTrainer.train_step() Checkpoint
  • 10. import torch from onnxruntime.training import ORTTrainer, optim # Model definition class NeuralNet(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): ... def forward(self, x): ... model = NeuralNet(input_size=784, hidden_size=500, num_classes=10) criterion = torch.nn.Functional.cross_entropy model_description = {'inputs': [('data', ['in', 'batch_size']), ('target', ['label_x_batch_size'])], 'outputs’: [('loss', [], True), ('output', ['out', 'batch_size’])] } optimizer_config = optim.AdamConfig(lr=learning_rate) trainer = ORTTrainer(model, model_description, optimizer_config, optimizer configuration, criterion) # Training Loop for t in range(1000): # forward + backward + weight update loss, y_pred = trainer.train_step(x, y) ORT in PyTorch PyTorch PyTorch + ONNX Runtime backend import torch # Model definition class NeuralNet(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): ... def forward(self, x): ... model = NeuralNet(input_size=784, hidden_size=500, num_classes=10) criterion = torch.nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) # Training Loop for t in range(1000): # forward y_pred = model(x) loss = criterion(y_pred, y) # reset gradient buffer optimizer.zero_grad() # backward loss.backward() # weight update optimizer.step()
  • 11. ONNXRuntime ORT TrainingSession Python API PyTorch Script PyTorch ORTTrainer To ONNX GPU buffer TF/Keras Script TF ORTTrainer To ONNX GPU buffer ORT Frontend Adapters
  • 12. Acceleration & Native Capabilities
  • 13. Contributors to ORT Acceleration Optimal Gradient Graph CUDA Kernel Optimizations Graph Optimizations Memory Efficiency Other Training Capabilities Static graph optimization techniques like constant folding, redundant node elimination Memory and compute optimized using global knowledge of data dependencies Static graph used for preallocation of memory for weights and gradients Memory reuse Op fusion Reimplemented cuDNN kernels Removed redundant computation Mixed precision training Distributed training parallelism modes Gradient checkpointing AdaSum DeepSpeed ZeRO
  • 14. Native Capabilities in ORT Distributed Training Modes Gradient Checkpoint Mixed Precision Training Gradient Accumulation AdaSum 16-bit and 32-bit FP types to make training faster and use less memory Parallelism modes: Data, Horizontal and Pipeline Computed gradients are accumulated into gradient buffer using partial execution of graph repeated for N steps Averaged gradients are used in optimizer for weight updates Stashed activations often dominate memory consumption in training Recompute discarded activations when needed. Trade off between memory usage vs. computation cost. Combines gradients in a novel way to improve convergence Model converges faster DeepSpeed ZeRO Redundancy Optimizer Optimizer State Partitioning Gradient Partitioning Parameter Partitioning
  • 15. Code Sample & Training Recipes
  • 16. BERT Pretraining using ORT https://github.com/microsoft/onnxruntime-training-examples/
  • 17. Training Recipes ▪ BERT Pretraining ▪ Nvidia’s implementation of BERT pretraining accelerated using ORT ▪ https://github.com/microsoft/onnxruntime-training-examples/tree/master/nvidia-bert ▪ GPT-2 Finetuning ▪ Finetuning of Hugging Face GPT-2 model ▪ https://github.com/microsoft/onnxruntime-training-examples/tree/master/huggingface-gpt2 ▪ Turing Finetuning ▪ Finetuning of Microsoft Turing model for abstractive text summarization, sentiment analysis and suggested reply scenarios ▪ https://github.com/microsoft/Turing-NLR (private preview)
  • 19. BERT Pretraining in 4xDGX-2 PyTorch 1.5 with NGC 20.03-py3 PyTorch 1.5 with ONNX Runtime % Gain with ONNX Runtime Phase 1 Throughput (ex/sec) 11522.1 12826.2 11.32% Phase 2 Throughput (ex/sec) 2150.0 2464.1 14.61% Phase 1 time (hours) 11.12 9.99 10.16% Phase 2 time (hours) 6.62 5.77 12.84% Total time (hours) 17.74 15.76 11.16% PyTorch w/ ORT can train with 2x the local batch size as PyTorch w/o ORT (global batch size was kept the same for comparison)
  • 20. Perf Improvement with ORT Model (Scenario)/# Params Perf improvement w/ ORT Turing* (pretraining)/340M 1.4x Turing* (pretraining)/350M 1.2x RoBERTa XL (pretraining)/500M 3x RoBERTa XL (finetuning)/500M 1.2x RoBERTa XXL (pretraining)/1B 7x GPT-2 M(pretraining)/345M 1.2x * https://msturing.org/
  • 21. Demo: ONNX Runtime Training in Azure Databricks https://github.com/skaarthik/onnxruntime-training-databricks
  • 22. Summary ▪ Optimize and accelerate model training using ONNX Runtime (ORT) ▪ ORT is used in training very large models used in various Microsoft products/services ▪ https://github.com/microsoft/onnxruntime ▪ https://github.com/microsoft/onnxruntime- training-examples
  • 23. Feedback Your feedback is important to us. Don’t forget to rate and review the sessions.