AWS Machine Learning Blog

Pruning machine learning models with Amazon SageMaker Debugger and Amazon SageMaker Experiments

In the past decade, deep learning has advanced many different areas, such as computer vision and natural language processing. State-of-the-art models now achieve near-human performance in tasks such as image classification. Deep neural networks can achieve this because they consist of millions of parameters that you train on large training datasets. For instance, the BERT (Large) model consists of 340 million parameters, and Resnet-152 consists of 60 million parameters. Training such models from scratch is computationally intensive and can take hours, days, or even weeks.

Typically, data scientists perform transfer learning, which is the process of using knowledge gained by solving one problem to a related but different problem. With transfer learning, you fine-tune a pretrained model on a smaller dataset to improve accuracy. In such scenarios, you may not need a large number of parameters in the model—a smaller model may just perform as well.

In the context of machine learning (ML) at the edge, having small models is essential. Due to hardware constraints, factors such as latency, memory footprint, and compute time are as important as model accuracy. For instance, autonomous driving requires models with high accuracy and low latency. In such a scenario, a model that achieves 1% better accuracy but takes twice the time for prediction is not favorable.

Model pruning can significantly reduce model size without sacrificing accuracy. The idea is simple: identify the redundant parameters in the model that contribute little to the training process.

This post demonstrates iterative model pruning with Amazon SageMaker. The post walks through a sample application that uses a pretrained model and iteratively prunes it by more than a factor of three with no significant loss in accuracy.

Model pruning

Model pruning aims to remove weights that don’t contribute much to the training process. Weights are learnable parameters: they are randomly initialized and optimized during the training process. During the forward pass, data passes through the model. The loss function evaluates model output given the labels; during the backward pass, weights are updated to minimize the loss. To do so, the gradients of the loss with respect to the weights are computed, and each weight receives a different update. After a few iterations, certain weights are typically more impactful than others; the goal of pruning is to remove the useless ones without significantly reducing model accuracy. The following diagram illustrates this workflow.

You can use the following heuristics to measure the importance of weights:

  • Magnitudes of weights – Remove weights if their absolute value is smaller than a threshold: smaller weights have less effect on the output.
  • Average activation – If a neuron is mostly not active throughout the training, you can infer that the weights going into the activation function are less relevant.

You can distinguish between non-structured and structured weight pruning:

  • Non-structured pruning removes arbitrary weights (as in the preceding diagram)
  • Structured pruning removes entire convolutional filters and associated channels

Structured pruning is especially relevant for computer vision models that typically consist of many convolutional layers. A filter is a collection of kernels (one kernel for every single input channel). A filter produces one feature map, also referred to as an output channel. The following diagram shows three kernels that produce three output feature maps. The number of parameters that the model has to learn (its weights) is 3 x input_channels x kernel_width x kernel_height, where the number of input channels in this example is 1. For simplification, you can assume in this diagram that there is no bias tensor. You can rank the filters to identify those of least importance (for instance, the yellow filter). You reduce the number of parameters by 1 x input_channels x kernel_width x kernel_height by removing it.

To rank the importance of each filter, you employ the ranking method as described in the paper “Pruning Convolutional Neural Networks for Resource Efficient Inference”: they estimate the impact of a pruned filter on the loss. The goal is to remove the ones that do not impact the loss. A filter receives a low rank if its activation outputs and corresponding gradients are small. You can get an estimate of filter importance by accumulating the product of activation outputs and gradients throughout the training.

You then remove the lowest-ranked filters and fine-tune the model to recover from the pruning and regain accuracy. You can repeat these steps several times.

Iterative model pruning on Amazon SageMaker

Amazon SageMaker Debugger allows you to emit tensors from your training jobs and run built-in rules that automatically detect training issues. You can retrieve gradients and activation outputs to compute filter ranks. For more information, see Amazon SageMaker Debugger – Debug Your Machine Learning Models. Amazon SageMaker Experiments lets you customize, visualize, and track ML experiments at scale. For more information, see Amazon SageMaker Experiments – Organize, Track and Compare Your Machine Learning Trainings. For this post, you use it to track the different pruning iterations. You can quickly identify and deploy the model that yields the best accuracy and size trade-off with the Experiments view in Amazon SageMaker Studio.

To walk through iterative model pruning step-by-step, this post uses a Resnet18 model pretrained on ImageNet and fine-tunes it on the Caltech101 dataset, which consists of only 101 classes. You resize images during retraining. ResNet is a convolutional neural network consisting of multiple residual blocks. A block is made up of convolutional layers followed by a batch normalization layer and ReLu function. Skip connections allow the input to bypass the block. For this post, you use ResNet18, the smallest version of ResNet models with about 11 million parameters. In each pruning iteration, you remove the 200 lowest-ranked filters.

The following diagram illustrates the solution workflow:

The steps are as follows:

  1. Start the training job
  2. Acquire the weights, gradients, biases, and activation outputs
  3. Compute filter ranks
  4. Prune low-ranking filters
  5. Set new weights
  6. Start the training job with the pruned model

Creating the experiment and debugger hook configuration

Before you implement the solution, create a model pruning experiment. An experiment is a collection of trials, whereby a trial is a collection of training steps. See the following code:

from smexperiments.experiment import Experiment

                    description="Iterative model pruning of ResNet18 trained on Caltech101",

In each pruning iteration, you start a new Amazon SageMaker training job, which is a new trial within the experiment. See the following code:

from smexperiments.trial import Trial

trial = Trial.create(experiment_name="model_pruning_experiment",

You next define the experiment_config, which is a dictionary that is passed to the Amazon SageMaker training job. This allows Amazon SageMaker to associate the training job with the experiment and trial. See the following code:

experiment_config = { "ExperimentName": "model_pruning_experiment",
                        "TrialName":  trial.trial_name,
                        "TrialComponentDisplayName": "Training"}

Before you can start the training job, you must define a debugger hook configuration. Amazon SageMaker Debugger provides default collections for weights, biases, gradients, and losses . For this post, you must also store  activation outputs. To retrieve those, create a custom collection in which a regular expression indicates the tensor names to include. Because ResNet consists of batch norm layers, you also store the batch norm statistics running mean and variance. Tensors are saved every 100 steps, wherein a step presents one forward and backward pass. See the following code:

from sagemaker.debugger import DebuggerHookConfig, CollectionConfig

debugger_hook_config = DebuggerHookConfig(
                parameters={ "include_regex": ".*relu|.*weight|.*bias|.*running_mean|.*running_var",
                             "save_interval": "100" })])

Starting your training job

You are now ready to train the ResNet18 model with Amazon SageMaker. The training loop is defined in the entry_point file For more information, see the GitHub repo. To emit tensors, pass the debugger hook configuration into the PyTorch Estimator. See the following code:

import sagemaker
from sagemaker.pytorch import PyTorch

estimator = PyTorch(role=sagemaker.get_execution_role(),

After you define the estimator object, you can call fit. This launches a managed ml.p3.2xlarge instance to run your training script. As discussed previously, you pass the experiment_config to the training job. See the following code:

Getting gradients, weights, and biases

After the training job is complete, you retrieve its tensors, such as gradients, weights, and biases. You can use the smdebug library, which provides functions to read and filter tensors. First, create a trial that allows access to the tensors that the debugger saved. For more information, see the GitHub repo. In the context of Amazon SageMaker Debugger, a trial is an object that lets you query tensors for a given training job. In the context of Amazon SageMaker Experiments, a trial is part of an experiment and presents a collection of training steps involved in a single training job. See the following code:

from smdebug.trials import create_trial

path = estimator.latest_job_debugger_artifacts_path()
smdebug_trial = create_trial(path)

To access tensor values, call smdebug_trial.tensor(). For instance, to get the activation outputs of the first convolution layer, enter the following code:

smdebug_trial.tensor('layer1.0.relu_0_output_0').value(0, mode=modes.TRAIN)

Computing filter ranks

Now that you have access to the tensors, you can compute their filter ranks. Iterate the available training steps and retrieve the activation outputs and their gradients. For instance, the following code segment computes the filter ranks of the first feature layer of the model, and you then compute a single value (rank) for each filter:

rank = 0
for step in smdebug_trial.steps(mode=modes.TRAIN):
    activation_output = smdebug_trial.tensor( 'layer1.0.relu_0_output_0').value(step, mode=modes.TRAIN)
    gradient = smdebug_trial.tensor( 'gradient/layer1.0.relu_ReLU_output').value(step, mode=modes.TRAIN)
    product = activation_output * gradient
    rank += np.mean(product, axis=(0,2,3))

Afterwards, you normalize filter ranks and sort them by size.

Pruning low-ranking filters

You can now retrieve the smallest filters. The following code shows that you can prune filter 1, 36, and 127 in layer1.0 because they have a rank of 0.0:

[('layer1.0.relu_ReLU_output', 1, 0.0), 
('layer1.0.relu_ReLU_output', 127, 0.0), 
('layer1.0.relu_ReLU_output', 36, 0.0)]

To prune those filters, use SageMaker Debugger to get the weight tensor of the first convolutional layer in layer1.0 and delete the preceding entries in the second dimension (axis=1). See the following code:

weight = trial.tensor('ResNet_layer1.0.conv1.weight').value(step, mode=modes.TRAIN)
weight = np.delete(weight, [1,36,127], axis=1)

Setting new weights

You also need to adjust the convolution parameters. The convolutional layer in layer1.0 in the Resnet18 model has 64 output channels. After removing three filters, it only has 61 output channels. You also need to adjust the weights of the subsequent batch norm layer. Therefore, remove the entries in the first dimension (axis=0). See the following code:

weight = trial.tensor('ResNet_layer1.0.bn1.weight').value(step, mode=modes.TRAIN)
bias = trial.tensor('ResNet_layer1.0.bn1.bias').value(step, mode=modes.TRAIN)
weight = np.delete(weight, [1,36,127], axis=0)
bias =  np.delete(bias, [1,36,127], axis=0)

Starting the next pruning iteration

After you prune the 200 smallest filters, save the new model definition with the latest weights and start the next pruning iteration. You can complete these steps multiple times, which leads to a smaller model in each iteration.


You can track and visualize the iterative model pruning experiment in Amazon SageMaker Studio. The training script uses SageMaker Debugger’s save_scalar method to store the number of model parameters and model accuracy. For more information, see the GitHub repo. The values that save_scalar writes go into a data store that Amazon SageMaker Studio uses to create visualizations. The following graph shows a scatterplot in which the x-axis shows the number of model parameters and the y-axis shows the validation accuracy.

Initially, the model consisted of 11 million parameters. After 11 iterations, the number of parameters reduced to 706,000, and accuracy increased to 90% and started dropping after eight pruning iterations. The following screenshot shows the experiment view where the different trials and details are listed.

Running iterative model pruning with a custom rule

In the previous example, accuracy drops when the model has fewer than 4 million parameters. You want to stop the experiment when you reach this point. SageMaker Debugger provides built-in rules that trigger when the model training runs into issues, such as for instance vanishing gradients or the loss is not decreasing. If the model does not have enough capacity (too few parameters), it doesn’t learn well. One of the implications is that the loss may not decrease. For more information, see Built-in Rules Provided by Amazon SageMaker Debugger.

For this post, you can define a custom rule that compares the accuracy of the current model with the accuracy of the previous training job. For example, if it drops by more than 10%, the rule triggers and returns True. You can then set up an Amazon CloudWatch alarm and AWS Lambda function that stops the training job and prevents your model pruning experiment from wasting resources on jobs that produce low-quality models. For more information, see the GitHub repo.

The following code shows a high-level overview of the custom rule:

from smdebug.rules.rule import Rule

class check_accuracy(Rule):
    def __init__(self, base_trial, 
          self.previous_accuracy = float(previous_accuracy) 
    def invoke_at_step(self, step):  
        predictions = np.argmax(self.base_trial.tensor('CrossEntropyLoss_0_input_0').value(step, mode=modes.EVAL), axis=1)
        labels = self.base_trial.tensor('CrossEntropyLoss_0_input_1').value(step, mode=modes.EVAL)
        current_accuracy = compute_accurcay(predictions, labels)
        if self.previous_accuracy - current_accuracy > 0.10:
            return True
         return False

The rule implements a Python class that inherits from the smdebug rule class. It takes as argument the accuracy of the previous training job. The class implements the function invoke_at_step that is called every time tensors of a new step are available. With smdebug, you get the inputs into the loss function, which are the model predictions and labels. You can then compute the accuracy and compare it to the previous training job. For more information about the full implementation of the rule, see the GitHub repo.

To run the rule with your training job, you need to pass it to the PyTorch estimator object. You first need to create a custom rule configuration. See the following code:

from sagemaker.debugger import Rule, CollectionConfig, rule_configs

check_accuracy_rule = Rule.custom(
    rule_parameters={"previous_accuracy": "0.0"},

The rule configuration specifies the location of the rule definition, its input parameters, and the instance type where the rule container is going to run. You need to specify the image for the rule container. For more information, see Amazon SageMaker Custom Rule Evaluator Registry Ids.

After you define the rule, pass the argument rules = [check_accuracy_rule] to the Pytorch estimator.

In each pruning iteration, you need to pass the accuracy of the previous training job to the rule, which is retrievable via SageMaker Experiments and the ExperimentAnalytics module. See the following code:

from import ExperimentAnalytics

trial_component_analytics = ExperimentAnalytics(experiment_name="model_pruning_experiment")
accuracy = trial_component_analytics.dataframe()['scalar/accuracy_EVAL - Max'][0]

Overwrite the value in the rule configuration with the following code:

check_accuracy_rule.rule_parameters["previous_accuracy"] = str(accuracy)

In each iteration, check the job status. If the previous job stops, exit the loop. See the following code:

job_name =
client = estimator.sagemaker_session.sagemaker_client
description = client.describe_training_job(TrainingJobName=job_name)
if description['TrainingJobStatus'] == 'Stopped':

The following graph shows the number of parameters versus accuracy. In contrast to the previous experiment, the training job stops when it produces a low-quality model, and the experiment ends. As a consequence, it only runs eight iterations.

The following screenshot shows the Debugger view in SageMaker Studio. It shows that the custom rule found an issue.

For more information about defining and running custom rules with SageMaker Debugger, see How to Use Custom Rules.


This post discussed iterative model pruning with Amazon SageMaker and how you can significantly reduce the size of your models and retain accuracy by identifying redundant parameters that contribute little to the training process. You walked through a sample application that uses a pretrained model and saw that the model did not lose accuracy when iteratively pruned.

About the Authors

Nathalie Rauschmayr is an Applied Scientist at AWS, where she helps customers develop deep learning applications.




Julien Simon is an Artificial Intelligence & Machine Learning Evangelist for EMEA, Julien focuses on helping developers and enterprises bring their ideas to life.




Satadal Bhattacharjee is Principal Product Manager at AWS AI. He leads the machine learning engine PM team on projects such as SageMaker and optimizes machine learning frameworks such as TensorFlow, PyTorch, and MXNet.