AWS Machine Learning Blog
Build multiclass classifiers with Amazon SageMaker linear learner
Amazon SageMaker is a fully managed service for scalable training and hosting of machine learning models. We’re adding multiclass classification support to the linear learner algorithm in Amazon SageMaker. Linear learner already provides convenient APIs for linear models such as logistic regression for ad click prediction, fraud detection, or other classification problems, and linear regression for forecasting sales, predicting delivery times, or other problems where you want to predict a numerical value. If you haven’t worked with linear learner before, you might want to start with the documentation or our previous blog post on this algorithm. If it’s your first time working with Amazon SageMaker, you can get started here.
In this blog post we’ll cover three aspects of training a multiclass classifier with linear learner:
- Training a multiclass classifier
- Multiclass classification metrics
- Training with balanced class weights
Training a multiclass classifier
Multiclass classification is a machine learning task where the outputs are known to be in a finite set of labels. For example, we might classify emails by assigning each one a label from the set inbox, work, shopping, spam. Or we might try to predict what a customer will buy from the set shirt, mug, bumper_sticker, no_purchase. If we have a dataset where each example has numerical features and a known categorical label, we can train a multiclass classifier.
Related problems: binary, multiclass, and multilabel
Multiclass classification is related to two other machine learning tasks, binary classification and the multilabel problem. Binary classification is already supported by linear learner, and multiclass classification is now available with linear learner, but multilabel support is not yet available from linear learner.
If there are only two possible labels in your dataset, then you have a binary classification problem. Examples include predicting whether a transaction will be fraudulent or not based on transaction and customer data, or detecting whether a person is smiling or not based on features extracted from a photo. For each example in your dataset, one of the possible labels is correct and the other is incorrect. The person is smiling or not smiling.
If there are more than two possible labels in your dataset, then you have a multiclass classification problem. For example, predicting whether a transaction will be fraudulent, cancelled, returned, or completed as usual. Or detecting whether a person in a photo is smiling, frowning, surprised, or frightened. There are multiple possible labels, but only one is correct at a time.
If there are multiple labels, and a single training example can have more than one correct label, then you have a multilabel problem. For example, tagging an image with tags from a known set. An image of a dog catching a Frisbee at the park might be labeled as outdoors, dog, and park. For any given image, those three labels could all be true, or all be false, or any combination. Although we haven’t added support for multilabel problems yet, there are a couple of ways you can solve a multilabel problem with linear learner today. You can train a separate binary classifier for each label. Or you can train a multiclass classifier and predict not only the top class, but the top k classes, or all classes with probability scores above some threshold.
Linear learner uses a softmax loss function to train multiclass classifiers. The algorithm learns a set of weights for each class, and predicts a probability for each class. We might want to use these probabilities directly, for example if we’re classifying emails as inbox, work, shopping, spam and we have a policy to flag as spam only if the class probability is over 99.99%. But in many multiclass classification use cases, we’ll simply take the class with highest probability as the predicted label.
Hands-on example: predicting forest cover type
As an example of multiclass prediction, let’s take a look at the Covertype dataset (copyright Jock A. Blackard and Colorado State University). The dataset contains information collected by the US Geological Survey and the US Forest Service about wilderness areas in northern Colorado. The features are measurements like soil type, elevation, and distance to water, and the labels encode the type of trees – the forest cover type – for each location. The machine learning task is to predict the cover type in a given location using the features. We’ll download and explore the dataset, then train a multiclass classifier with linear learner using the Python SDK. To run this example yourself, take a look at the notebook version of this blog post.
Note that we transformed the labels to a zero index rather than an index starting from one. That step is important, since linear learner requires the class labels to be in the range [0, k-1], where k is the number of labels. Amazon SageMaker algorithms expect the dtype
of all feature and label values to be float32
. Also note that we shuffled the order of examples in the training set. We used the train_test_split
method from numpy
, which shuffles the rows by default. That’s important for algorithms trained using stochastic gradient descent. Linear learner, as well as most deep learning algorithms, use stochastic gradient descent for optimization. Shuffle your training examples, unless your data have some natural ordering which needs to be preserved, such as a forecasting problem where the training examples should all have time stamps earlier than the test examples.
We split the data into training, validation, and test sets with an 80/10/10 ratio. Using a validation set will improve training, since linear learner uses the validation data to stop training once overfitting is detected. That means shorter training times and more accurate predictions. We can also provide a test set to linear learner. The test set will not affect the final model, but algorithm logs will contain metrics from the final model’s performance on the test set. Later on in this post, we’ll also use the test set locally to dive a little bit deeper on model performance.
Exploring the data
Let’s take a look at the mix of class labels present in training data. We’ll add meaningful category names using the mapping provided in the dataset documentation.
We can see that some forest cover types are much more common than others. Lodgepole Pine and Spruce/Fir are both well represented. Some labels, such as Cottonwood/Willow, are extremely rare. Later in this post, we’ll see how to fine-tune the algorithm depending on how important these rare categories are for our use case. But first we’ll train with the defaults for the best all-around model.
Training a classifier using the Amazon SageMaker Python SDK
We’ll use the high-level estimator class LinearLearner
to instantiate our training job and inference endpoint. For an example using the Python SDK’s generic Estimator
class, take a look at this previous post. The generic Python SDK estimator offers some more control options, but the high-level estimator is more succinct and has some advantages. One is that we don’t need to specify the location of the algorithm container we want to use for training. It will pick up the latest version of the linear learner algorithm. Another advantage is that some code errors will be surfaced before a training cluster is spun up, rather than after. For example, if we try to pass n_classes=7
instead of the correct num_classes=7
, then the high-level estimator will fail immediately, but the generic Python SDK estimator will spin up a cluster before failing.
Linear learner accepts training data in protobuf or csv content types, and accepts inference requests in protobuf, csv, or json content types. Training data have features and ground-truth labels, while the data in an inference request has only features. In a production pipeline, we recommend converting the data to the Amazon SageMaker protobuf format and storing it in S3. However, to get up and running quickly, we provide a convenience method record_set
for converting and uploading when the dataset is small enough to fit in local memory. It accepts numpy
arrays like the ones we already have, so we’ll use it here. The RecordSet
object will keep track of the temporary S3 location of our data.
Multiclass classification metrics
Now that we have a trained model, we want to make predictions and evaluate model performance on our test set. For that we’ll need to deploy a model hosting endpoint to accept inference requests using the estimator API:
With balanced class weights turned on, linear learner will count label frequencies in your training set. This is done efficiently using a sample of the training set. The weights will be the inverses of the frequencies. A label that’s present in 1/3 of the sampled training examples will get a weight of 3, and a rare label that’s present in only 0.001% of the examples will get a weight of 100,000. A label that’s not present at all in the sampled training examples will get a weight of 1,000,000 by default. To turn on class weights, use the balance_multiclass_weights
hyperparameter: