AWS Machine Learning Blog

New built-in Amazon SageMaker algorithms for tabular data modeling: LightGBM, CatBoost, AutoGluon-Tabular, and TabTransformer

July 2023: You can also use the newly launched JumpStart APIs, an extension of the SageMaker Python SDK. These APIs allow you to programmatically deploy and fine-tune a vast selection of JumpStart-supported pre-trained models on your own datasets. Please refer to Amazon SageMaker JumpStart models and algorithms now available via API for more details on how to use them.

Amazon SageMaker provides a suite of built-in algorithms, pre-trained models, and pre-built solution templates to help data scientists and machine learning (ML) practitioners get started on training and deploying ML models quickly. You can use these algorithms and models for both supervised and unsupervised learning. They can process various types of input data, including tabular, image, and text.

Starting today, SageMaker provides four new built-in tabular data modeling algorithms: LightGBM, CatBoost, AutoGluon-Tabular, and TabTransformer. You can use these popular, state-of-the-art algorithms for both tabular classification and regression tasks. They’re available through the built-in algorithms on the SageMaker console as well as through the Amazon SageMaker JumpStart UI inside Amazon SageMaker Studio.

The following is the list of the four new built-in algorithms, with links to their documentation, example notebooks, and source.

Documentation Example Notebooks Source
LightGBM Algorithm Regression, Classification LightGBM
CatBoost Algorithm Regression, Classification CatBoost
AutoGluon-Tabular Algorithm Regression, Classification AutoGluon-Tabular
TabTransformer Algorithm Regression, Classification TabTransformer

In the following sections, we provide a brief technical description of each algorithm, and examples of how to train a model via the SageMaker SDK or SageMaker Jumpstart.

LightGBM

LightGBM is a popular and efficient open-source implementation of the Gradient Boosting Decision Tree (GBDT) algorithm. GBDT is a supervised learning algorithm that attempts to accurately predict a target variable by combining an ensemble of estimates from a set of simpler and weaker models. LightGBM uses additional techniques to significantly improve the efficiency and scalability of conventional GBDT.

CatBoost

CatBoost is a popular and high-performance open-source implementation of the GBDT algorithm. Two critical algorithmic advances are introduced in CatBoost: the implementation of ordered boosting, a permutation-driven alternative to the classic algorithm, and an innovative algorithm for processing categorical features. Both techniques were created to fight a prediction shift caused by a special kind of target leakage present in all currently existing implementations of gradient boosting algorithms.

AutoGluon-Tabular

AutoGluon-Tabular is an open-source AutoML project developed and maintained by Amazon that performs advanced data processing, deep learning, and multi-layer stack ensembling. It automatically recognizes the data type in each column for robust data preprocessing, including special handling of text fields. AutoGluon fits various models ranging from off-the-shelf boosted trees to customized neural network models. These models are ensembled in a novel way: models are stacked in multiple layers and trained in a layer-wise manner that guarantees raw data can be translated into high-quality predictions within a given time constraint. Over-fitting is mitigated throughout this process by splitting the data in various ways with careful tracking of out-of-fold examples. AutoGluon is optimized for performance, and its out-of-the-box usage has achieved several top-3 and top-10 positions in data science competitions.

TabTransformer

TabTransformer is a novel deep tabular data modelling architecture for supervised learning. The TabTransformer is built upon self-attention based Transformers. The Transformer layers transform the embeddings of categorical features into robust contextual embeddings to achieve higher prediction accuracy. Furthermore, the contextual embeddings learned from TabTransformer are highly robust against both missing and noisy data features, and provide better interpretability. This model is the product of recent Amazon Science research (paper and official blog post here) and has been widely adopted by the ML community, with various third-party implementations (KerasAutoGluon,) and social media features such as tweetstowardsdatascience, medium, and Kaggle.

Benefits of SageMaker built-in algorithms

When selecting an algorithm for your particular type of problem and data, using a SageMaker built-in algorithm is the easiest option, because doing so comes with the following major benefits:

  • The built-in algorithms require no coding to start running experiments. The only inputs you need to provide are the data, hyperparameters, and compute resources. This allows you to run experiments more quickly, with less overhead for tracking results and code changes.
  • The built-in algorithms come with parallelization across multiple compute instances and GPU support right out of the box for all applicable algorithms (some algorithms may not be included due to inherent limitations). If you have a lot of data with which to train your model, most built-in algorithms can easily scale to meet the demand. Even if you already have a pre-trained model, it may still be easier to use its corollary in SageMaker and input the hyperparameters you already know rather than port it over and write a training script yourself.
  • You are the owner of the resulting model artifacts. You can take that model and deploy it on SageMaker for several different inference patterns (check out all the available deployment types) and easy endpoint scaling and management, or you can deploy it wherever else you need it.

Let’s now see how to train one of these built-in algorithms.

Train a built-in algorithm using the SageMaker SDK

To train a selected model, we need to get that model’s URI, as well as that of the training script and the container image used for training. Thankfully, these three inputs depend solely on the model name, version (for a list of the available models, see JumpStart Available Model Table), and the type of instance you want to train on. This is demonstrated in the following code snippet:

from sagemaker import image_uris, model_uris, script_uris

train_model_id, train_model_version, train_scope = "lightgbm-classification-model", "*", "training"
training_instance_type = "ml.m5.xlarge"

# Retrieve the docker image
train_image_uri = image_uris.retrieve(
    region=None,
    framework=None,
    model_id=train_model_id,
    model_version=train_model_version,
    image_scope=train_scope,
    instance_type=training_instance_type
)
# Retrieve the training script
train_source_uri = script_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, script_scope=train_scope
)
# Retrieve the model artifact; in the tabular case, the model is not pre-trained 
train_model_uri = model_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, model_scope=train_scope
)

The train_model_id changes to lightgbm-regression-model if we’re dealing with a regression problem. The IDs for all the other models introduced in this post are listed in the following table.

Model Problem Type Model ID
LightGBM Classification lightgbm-classification-model
. Regression lightgbm-regression-model
CatBoost Classification catboost-classification-model
. Regression catboost-regression-model
AutoGluon-Tabular Classification autogluon-classification-ensemble
. Regression autogluon-regression-ensemble
TabTransformer Classification pytorch-tabtransformerclassification-model
. Regression pytorch-tabtransformerregression-model

We then define where our input is on Amazon Simple Storage Service (Amazon S3). We’re using a public sample dataset for this example. We also define where we want our output to go, and retrieve the default list of hyperparameters needed to train the selected model. You can change their value to your liking.

import sagemaker
from sagemaker import hyperparameters

sess = sagemaker.Session()
region = sess.boto_session.region_name

# URI of sample training dataset
training_dataset_s3_path = f"s3:///jumpstart-cache-prod-{region}/training-datasets/tabular_multiclass/"

# URI for output artifacts 
output_bucket = sess.default_bucket()
s3_output_location = f"s3://{output_bucket}/jumpstart-example-tabular-training/output"

# Retrieve the default hyper-parameters for training
hyperparameters = hyperparameters.retrieve_default(
    model_id=train_model_id, model_version=train_model_version
)

# [Optional] Override default hyperparameters with custom values
hyperparameters[
    "num_boost_round"
] = "500"  # The same hyperparameter is named as "iterations" for CatBoost

Finally, we instantiate a SageMaker Estimator with all the retrieved inputs and launch the training job with .fit, passing it our training dataset URI. The entry_point script provided is named transfer_learning.py (the same for other tasks and algorithms), and the input data channel passed to .fit must be named training.

from sagemaker.estimator import Estimator
from sagemaker.utils import name_from_base

# Unique training job name
training_job_name = name_from_base(f"built-in-example-{model_id}")

# Create SageMaker Estimator instance
tc_estimator = Estimator(
    role=aws_role,
    image_uri=train_image_uri,
    source_dir=train_source_uri,
    model_uri=train_model_uri,
    entry_point="transfer_learning.py",
    instance_count=1,
    instance_type=training_instance_type,
    max_run=360000,
    hyperparameters=hyperparameters,
    output_path=s3_output_location,
)

# Launch a SageMaker Training job by passing s3 path of the training data
tc_estimator.fit({"training": training_dataset_s3_path}, logs=True)

Note that you can train built-in algorithms with SageMaker automatic model tuning to select the optimal hyperparameters and further improve model performance.

Train a built-in algorithm using SageMaker JumpStart

You can also train any these built-in algorithms with a few clicks via the SageMaker JumpStart UI. JumpStart is a SageMaker feature that allows you to train and deploy built-in algorithms and pre-trained models from various ML frameworks and model hubs through a graphical interface. It also allows you to deploy fully fledged ML solutions that string together ML models and various other AWS services to solve a targeted use case.

For more information, refer to Run text classification with Amazon SageMaker JumpStart using TensorFlow Hub and Hugging Face models.

Conclusion

In this post, we announced the launch of four powerful new built-in algorithms for ML on tabular datasets now available on SageMaker. We provided a technical description of what these algorithms are, as well as an example training job for LightGBM using the SageMaker SDK.

Bring your own dataset and try these new algorithms on SageMaker, and check out the sample notebooks to use built-in algorithms available on GitHub.


About the Authors

Dr. Xin Huang is an Applied Scientist for Amazon SageMaker JumpStart and Amazon SageMaker built-in algorithms. He focuses on developing scalable machine learning algorithms. His research interests are in the area of natural language processing, explainable deep learning on tabular data, and robust analysis of non-parametric space-time clustering. He has published many papers in ACL, ICDM, KDD conferences, and Royal Statistical Society: Series A journal.

Dr. Ashish Khetan is a Senior Applied Scientist with Amazon SageMaker JumpStart and Amazon SageMaker built-in algorithms and helps develop machine learning algorithms. He is an active researcher in machine learning and statistical inference and has published many papers in NeurIPS, ICML, ICLR, JMLR, ACL, and EMNLP conferences.

João Moura is an AI/ML Specialist Solutions Architect at Amazon Web Services. He is mostly focused on NLP use-cases and helping customers optimize Deep Learning model training and deployment. He is also an active proponent of low-code ML solutions and ML-specialized hardware.