Amazon Web Services ブログ
Amazon SageMaker 線形学習者でマルチクラス分類器を構築する
Amazon SageMaker は、機械学習モデルのスケーラブルな訓練とホスティングのための完全マネージド型サービスです。Amazon SageMaker の線形学習者アルゴリズムにマルチクラス分類のサポートが追加されます。線形学習者は、広告のクリック予測、不正検出、またはその他の分類問題のロジスティック回帰や売上予測、配達時間の予測、または数値の予測を目的とした線形回帰などの線形モデルに利用できる API を既に提供しています。線形学習者を利用したことがない場合は、本アルゴリズムに関するドキュメントまたはこれまでのブログ投稿をご参考にして使い始めて下さい。Amazon SageMaker が初めての場合は、ここから始めて下さい。
このブログ記事では、マルチクラス分類を線形学習者で訓練する 3 つの側面について説明します。
- マルチクラス分類器の訓練
- マルチクラス分類メトリクス
- バランスの取れたクラス毎の重み付けを使った訓練
マルチクラス分類器の訓練
マルチクラス分類は、機械学習タスクの一つで、出力がラベルの有限集合に入ることで知られています。たとえば、電子メールを分類するには、それぞれに受信トレイ、仕事、ショッピング、スパムの中のいずれかのラベルを割り当てます。あるいは、顧客が shirt、mug、bumper_sticker、no_purchase の中から何を購入するかを予測しようとするかもしれません。それぞれの例が数値的な特徴や既に知っているカテゴリのラベルがある場合、マルチクラス分類器を訓練することができます。
関連する問題:バイナリ、マルチクラス、マルチラベル
マルチクラス分類は、バイナリ分類およびマルチラベル問題の 2 つの機械学習タスクに関連します。線形学習者はすでにバイナリ分類をサポートしてましたが、マルチクラス分類も利用できるようになりました。ただし、マルチラベルサポートはまだサポートされてません。
データセットに可能性のあるラベルが 2 つしかない場合は、バイナリ分類問題になります。例としては、取引や顧客のデータに基づいて取引が不正であるかどうかを予測することや、写真から抽出された特徴に基づいて人が笑顔であるかどうかを検出することなどがあります。データセットの各例では、可能性のあるラベルの 1 つが正しく、もう 1 つが間違っています。その人物は笑顔なのか、笑顔でないのか。
あなたのデータセットに 3 つ以上の可能性のあるラベルがある場合、マルチクラス分類問題になります。たとえば、トランザクションが詐欺、キャンセル、返品、または通常どおりに完了するかどうかを予測します。また、写真の人物が笑っている、悩んでいるのか、驚いているのか、あるいは恐れているのかを検出することもできます。可能性のあるラベルは複数ありますが、一度に付けられる正しいラベルは 1 つだけです。
複数のラベルがあり、1 つの訓練サンプルに複数の正しいラベルがある場合は、マルチラベル問題になります。たとえば、既知のセットから画像にタグを付けるなどです。公園でフリスビーを追っている犬の画像は、屋外、犬、および公園でラベル付けするかもしれません。どんな画像でも、これらの 3 つのラベルがすべて真、すべてが偽、あるいは何らかの組み合わせになるはずです。マルチラベル問題のサポートはまだ追加されていませんが、現状の線形学習でマルチラベル問題を解決する方法がいくつかあります。ラベルごとに別々のバイナリ分類器を訓練することができます。または、マルチクラス分類器を訓練して、最上位クラスだけでなく、最上位の k クラス、または確率スコアがあるしきい値を超えるすべてのクラスを予測できます。
線形学習者は、softmax 損失関数を使用してマルチクラス分類器を訓練します。アルゴリズムは、各クラスの重みの集合を学習し、各クラスの確率を予測します。これらの確率を直接使用することができます。たとえば、電子メールを受信トレイ、仕事、ショッピング、スパムに分類して、クラスの確率が 99.99% を超える場合にのみスパムとしてフラグを立てるポリシーを検討します。しかし、多くのマルチクラス分類のユースケースでは、予測ラベルとして最も高い確率を持つクラスを取り上げます。
実例:森林被覆の種類を予測する
マルチクラス予測の例として、Covertype データセット (著作権: Jock A. Blackard とコロラド州立大学) を見てみましょう。このデータセットには、米国地質調査所および米国森林局がコロラド州北部の荒野について収集した情報が含まれています。特徴を土壌タイプ、標高、水との距離などの測定値とし、ラベルを基に各場所の樹木の種類 (森林被覆の種類) をエンコードします。機械学習のタスクは、特徴を使用して所定の場所での被覆の種類を予測することです。データセットをダウンロードして探索し、Python SDK を使用して線形学習者でマルチクラス分類器を訓練します。この例を自分で実行するには、このブログ記事のメモを見てみましょう。
ラベルを 1 から始まるインデックスではなく、0 から始まるインデックスに変換したことに注意してください。この処理は重要です。何故なら線形学習者は、クラスラベルが [0, k-1] の範囲内に入っていることを期待してるからです。ここで、k はラベルの数を表します。Amazon SageMaker のアルゴリズムは、dtype
(すべての特徴とラベル値) が float32
であることを期待します。また、訓練セットでは、例の順序を入れ替えていることにも注意してください。numpy の train_test_split
メソッドを使います。numpy
配列に対応しているので、デフォルトで縦の列をシャッフルします。これは、確率的勾配降下法を用いて訓練されたアルゴリズムにとって重要です。線形学習者や最も深く学習するアルゴリズムでは、確率的勾配降下法を用いて最適化してます。訓練サンプルをシャッフルします。ただし、訓練サンプルにテストサンプルよりも前のタイムスタンプがあることを予測する問題など、データを自然な順序に保持する必要がある場合は除きます。
データを 80/10/10 の比率で訓練、検証、テストセットに分割しました。検証セットを使用すると、訓練が改善されます。これは、過学習が検出されると、線形学習者が訓練を中止するからです。つまり、訓練時間が短縮され、より正確な予測が可能になります。また、線形学習者にテストセットを与えることもできます。テストセットは最終モデルには影響しませんが、アルゴリズムのログにはテストセット上の最終モデルのパフォーマンスのメトリックが含まれます。この記事の後半では、テストセットをローカルで使用して、モデルのパフォーマンスについて少し掘り下げてみましょう。
データの調査
訓練データに含まれるクラスラベルの組み合わせを見てみましょう。データセットのドキュメントに記載したマッピングを使って、意味のあるカテゴリ名を追加します。
いくつかの森林被覆の種類は他よりもはるかに一般的であることがわかります。Lodgepole Pine と Spruce/Fir の両方が頻繁に現れます。Cottonwood/Willow のようなラベルは非常にまれです。この記事の後半では、これらのまれなカテゴリがユースケースにとって重要であるかどうかに応じて、アルゴリズムを微調整する方法を見ていきます。しかし、まず最高のオールアラウンドモデルのデフォルトを訓練します。
Amazon SageMaker Python SDK を使用した分類器の訓練を行う
高レベルのエスティメーターのクラスである LinearLearner
を使って訓練ジョブと推論のエンドポイントをインスタンス化します。Python SDK の一般的な Estimator
クラスをこの前の投稿で見てみましょう。一般的な Python SDK エスティメーターにはいくつかの制御オプションがありますが、高レベルのエスティメーターはより簡潔で、利点がいくつかあります。1 つは、訓練に使用するアルゴリズムコンテナの場所を指定する必要がないということです。これは、線形学習者アルゴリズムの最新バージョンを指定するからです。もう 1 つの利点は、訓練クラスタを起動する前に、コードエラーを発見することです。たとえば、誤りである n_classes=7
を正しい num_classes=7
の代わりに入力すると、高レベルのエスティメーターすぐに失敗しますが、一般的な Python SDK エスティメーターは失敗する前にクラスタを起動させます。
線形学習者は、protobuf または csv のコンテンツタイプの訓練データを受け取り、protobuf、csv、または json のコンテンツタイプの推論リクエストを受入れます。訓練データにはフィーチャーラベルとグラウンドトルースラベルが含まれますが、推論リクエストのデータにはフィーチャラベルのみになります。プロダクションパイプラインでは、データを Amazon SageMaker の protobuf 形式に変換して S3 に格納することをお勧めします。ただし、素早く起動して実行するために、便利なメソッド record_set
を用意しています。データセットがローカルメモリに収まるように十分小さい場合に、変換およびアップロードを行うメソッドです。本メソッドは、numpy
配列に対応しているので、ここで使用します。この RecordSet
はデータの一時的な S3 の保存場所を追跡します。
マルチクラス分類メトリクス
これで、訓練されたモデルが得られたので、テストセットで予測を行い、モデルのパフォーマンスを評価したいと思います。そのためには、エスティメーター API を使用して推論リクエストを受け入れるためのモデルホスティングエンドポイントをデプロイする必要があります。
平均のとれたクラス毎の重み付けの場合、線形学習者は訓練セットのラベルの頻度をカウントします。これは、訓練セットのサンプルを使用して効率的に行われます。重みは頻度の逆数になります。サンプリングされた訓練例の 1/3 に存在するあるラベルは、3 の重みを持ち、訓練例のわずか 0.001% に存在する希少なラベルは 100,000 の重みを与えられます。サンプリングされた訓練例に存在しないラベルは、デフォルトで 1,000,000 の重みを与えられます。クラスの重みを有効にするには、balance_multiclass_weights
hyperparameter: