AWS Machine Learning Blog
ML Explainability with Amazon SageMaker Debugger
Machine Learning (ML) impacts industries around the globe, from financial services industry (FSI) and manufacturing to autonomous vehicles and space exploration. ML is no longer just an aspirational technology exclusive to academic and research institutions; it has evolved into a mainstream technology that has the potential to benefit organizations of all sizes. However, a lack of transparency in the ML process and the black box nature of resulting models is a hindrance for improved ML adoption in industries such as financial services and healthcare.
For a team developing ML models, the responsibility to explain model predictions increases as the impact of predictions on business outcomes increase. For example, consumers are likely to accept a movie recommendation from an ML model without needing an explanation. The consumer may or may not agree with the recommendation, but the need to justify the prediction is relatively low on the model developers. On the contrary, if an ML model predicts whether a credit loan application is approved or a patient’s drug dosage to administer, the model developers are responsible for explaining the prediction. They need to address questions such as “Why was my loan rejected?” or “Why should I take 10 mg of this drug?” For this reason, gaining visibility into the training process and developing human explainable ML models is important.
Amazon SageMaker is a fully managed service that enables developers and data scientists to quickly and easily build, train, and deploy ML models at any scale. Amazon SageMaker Debugger is a capability within Amazon SageMaker that automatically provides visibility into the model training process for real-time and offline analysis. Amazon SageMaker Debugger saves the internal model state at periodic intervals, which you can analyze in real time both during training and offline after the training is complete. Amazon SageMaker Debugger identifies issues during model training and provides insights into the predictions the trained model makes. Amazon SageMaker Debugger comes with a set of built-in rules that detect common training issues and monitor common conditions critical for successful training. You can also author custom rules to monitor your training job.
This post discusses ML explainability, the popular explainability tool SHAP (SHapley Additive exPlanation), and the native integration of SHAP with Amazon SageMaker Debugger. As part of this post, we provide a detailed notebook that shows how to use the Amazon SageMaker Debugger to provide explanations in a financial services use case, in which the model predicts if an individual’s income is above or below $50,000. For this post, we use the UCI Adult dataset.
ML explainability
Explainability is the extent to which you can explain the internal mechanics of an ML or deep learning system in human terms. It is in contrast to the concept of the black box, in which even designers cannot explain why an AI arrives at a specific decision.
There are two types of explainability: global and local. Global explainability aims at making the overall ML model transparent and comprehensive, whereas local explainability focuses on explaining the model’s individual predictions.
The ability to explain an ML model and its predictions builds trust and improves ML adoption—the model is no longer a black box that makes predictions in a vacuum. This increases the comfort level of the consumers of model predictions. For model owners, the ability to understand the uncertainty inherent in ML models helps with debugging the model when things go wrong and improving the model for better business outcomes.
This post considers two techniques of explainability: feature importance and SHAP.
Feature importance
Feature importance is a technique that explains the features that make up the training data using a score (importance). It indicates how useful or valuable the feature is relative to other features. In the use case of individual income prediction using XGBoost, the importance score indicates the value of each feature in the construction of the boosted decision trees within the model. The more a model uses an attribute to make key decisions with decision trees, the higher the attribute’s relative importance.
SHAP
The open-source tool SHAP uses Shapely values based on coalitional game theory. It explains an ML prediction by assuming that each feature value of training data instance is a player in a game in which the prediction is the payout. Shapley values indicate how to distribute the payout fairly among the features. The values consider all possible predictions for an instance and use all possible combinations of inputs. Because of this exhaustive approach, SHAP can guarantee consistency and local accuracy. For more information, see the documentation on the SHAP website.
Amazon SageMaker Debugger
Amazon SageMaker Debugger provides full visibility into ML model training by monitoring, recording, and analyzing the tensor data that captures the state of a training job. When you have a training job with Amazon SageMaker Debugger enabled, you configure what tensors to save, where to save the tensors, and the trials to run on the training dataset. Tensors, grouped into collections, define the state of the training job at any particular instant in its lifecycle. Amazon SageMaker Debugger built-in tensor collections include feature importance, full SHAP, and average SHAP. For more information about debugging and automatically detecting common errors during model training, see Amazon SageMaker Debugger – Debug Your Machine Learning Models.
The remainder of this post highlights how to enable Amazon SageMaker Debugger for a training job, use average_shap
and full_shap
built-in tensor collections, and visualize and analyze the captured tensors for model explanation.
Walkthrough overview
The walkthrough includes the following high-level steps:
- Examine the training dataset that represents the problem
- Train the model on Amazon SageMaker with the debugger turned on
- Visualize and analyze the debugger output
These steps are specific to using built-in tensors for explainability on Amazon SageMaker. There are other steps necessary to import libraries, set up IAM permissions, and other functions, which this post doesn’t discuss. You can walk through and run the code with the following notebook on the GitHub repo.
For this use case, you predict if an individual’s income is less than, greater than, or equal to $50,000 based on various features like the individual’s age, relationship status, number of hours worked, the capital gain. You use the UCI Adult dataset.
Examining the training data
To understand the features that make up the training data, download the dataset. The following screenshot shows the first few lines of data.
The following screenshot shows the list of features.
The dataset consists of 12 different features that capture an individual’s age, education, and other details. You use XGBoost to predict the probability of an individual making over $50,000 a year in annual income. The goal is to understand how each feature impacts the XGBoost model and its predictions.
Training the XGBoost model with Amazon SageMaker Debugger enabled
You train the XGBoost model using the Amazon SageMaker Estimator API. To enable Amazon SageMaker Debugger during training, you create a DebuggerHookConfig
object and add this configuration to the Estimator API. The DebuggerHookConfig
specifies the tensor collections you are interested in collecting and the Amazon S3 location to save the collected tensors. For this use case, you collect feature_importance
, average_shap
, and full_shap
every 10 iterations during the training process, as which save_interval
specifies in the code.
To monitor the training process, you create an Amazon SageMaker Debugger rule. For this post, use the Amazon SageMaker built-in LossNotDecreasing
rule to monitor the metrics collection. The rule alerts you if the tensors in metrics haven’t decreased for more than 10 steps.
You can configure the rule, debugger hook config, and the Estimator with a single API. See the following code:
Next, start a training job by using the Estimator object you created. See the following code:
You receive the following output:
As a result of the preceding code, Amazon SageMaker starts one training job and one rule job for you. To check the status of the rule evaluation job, enter the following code:
You receive the following output:
Visualizing and analyzing the debugger output
In this step, you visualize the feature_importance
, full_shap
, and average_shap
tensors captured during training. Analysis includes arriving at global and local explanations of the model to understand how the individual features contribute to model predictions. You also briefly look at the outlier predictions.
You use the smdebug library and the concept of a trial, which represents a single training run. The trial object includes the path to the tensor location and allows access to query tensors.
Within a trial, a step represents a single batch of the training job. Each trial has multiple steps. A collected tensor has a particular value at each step. The tensor values are stored in the Amazon S3 location you specified earlier. See the following code:
You receive the following output:
To view the tensors collected, enter the following code:
You receive the following output:
The actual names of the features aren’t included in the tensor names; they’re represented as f0
, f1
, and so on. This prevents sensitive feature names from showing up in analysis. Alternatively, the following code uses the actual feature names you saved earlier. You can use either of the approaches depending on your requirements.
To view the average_shap
tensor value for f1
, enter the following code:
You receive the following output:
You can also plot tensors collected for multiple features. For example, to plot the feature_importance
, enter the following code:
The following graph shows the output:
Similarly, you can plot the average_shap
tensor values collected for all features. See the following graph.
The two preceding plots give you an idea about how the plotted metrics change during the training process.
Global explanations
Global explanatory methods allow you to understand the model and its feature contributions in aggregate over multiple data points. The following graph is an aggregate bar plot that plots the mean absolute SHAP value for each feature. Specifically, the following plot indicates that the value of relationship (Wife
=5, Husband
=4, Own-child
=3, Other-relative
=2, Unmarried
=1, Not-in-family
=0) plays the most important role in predicting if the income probability is higher than $50,000.
You can further view the SHAP value distribution for each feature. The following summary plot provides more context than the bar chart. It indicates which features are most important and also their range of effects over the dataset. The color allows you to match how changes in the value of a feature affect the change in prediction (for example, an increase in age leads to higher log odds for prediction, which eventually leads to True
predictions more often). You can also see that the individual’s gender (Sex
) negatively affects the prediction.
The red indicates the higher value of the feature, and blue indicates lower (normalized over the features). This allows conclusions such as an increase in age leads to higher log odds for prediction, which eventually leads to True predictions more often.
Local explanations
Local explanations focus on explaining each individual prediction. A force plot explanation shows how features contribute to pushing the model output from the base value (the average model output over the dataset) to the model output. Features pushing the prediction higher are in red; those pushing the prediction lower are in blue.
The following plot indicates that for this particular data point, the prediction probability (0.48) is higher than the average (~0.2), primarily because this person is in a relationship (Relationship = Wife
), and to a smaller degree because of the higher-than-average age. Similarly, the model reduces the probability due to specific Sex
and Race
values, which indicates a bias in model behavior (possibly due to a bias in the data).
SHAP allows stacking multiple force-plots after rotating 90 degrees to understand the explanations for multiple data points. See the following plot.
Although this post shows the static stacked force plots, you can make the plot interactive by enabling Javascript in the complete notebook. This allows you to understand how the output changes based on each feature independently. This stacking of force plots provides a balance between local and global explainability.
Outliers
Outliers are extreme values that deviate from other observations on data. It’s useful to understand the influence of various features for outlier predictions to determine if it’s a novelty, an experimental error, or a shortcoming in the model.
The following force plot shows prediction outliers that are on either side of the baseline value. The graph indicates that “if workclass=Federal-gov
, Age=38
and Relationship=Wife
(meaning the wife of a federal employee) is probably an outlier to income group”.
The notebook provides a more detailed outlier analysis.
Conclusion
This post discussed the importance of explainability for improved ML adoption and introduced the Amazon SageMaker Debugger capability with built-in tensor collections to enable model explainability. The post also provided a detailed notebook that walks you through training an ML model for a financial services use case of individual income prediction. You further analyzed the global and local explanations of the model by visualizing the captured tensors. With its ability to provide real-time insight into the training process and offline analysis of captured tensor data, Amazon SageMaker Debugger is a very powerful tool for teams developing ML models. Give Amazon SageMaker Debugger a try and leave your feedback in the comments.
About the authors
Mona Mona is an AI/ML specialist solutions architect working with the AWS World Wide Public Sector Team. She works with AWS World Wide Public Sector customers to help them adopt machine learning on a large scale.
Rahul Iyer is a Software Development Manager at AWS AI. He leads the Framework Algorithms team, building and optimizing machine learning frameworks like XGBoost and Scikit-learn.
Sireesha Muppala is an AI/ML Specialist Solutions Architect at AWS, providing guidance to customers on architecting and implementing machine learning solutions at scale. She received her Ph.D. in Computer Science from the University of Colorado, Colorado Springs. In her spare time, Sireesha loves to run and hike Colorado trails.