Amazon Web Services ブログ

Amazon SageMaker Debugger を使った機械学習の説明可能性



機械学習 (ML) は、金融サービス業界 (FSI) から製造、自律走行車、および宇宙探査にいたるまで、世界中の業界に影響を及ぼします。ML はもはや学術機関および研究機関限定の単なる野心的なテクノロジーではなくなり、あらゆる規模の組織に利益をもたらす主力テクノロジーへと進化しました。しかし、ML プロセスにおける透明性の欠如と、結果として生じるモデルのブラックボックス的な性質が、金融サービスおよびヘルスケアなどの業界における ML の導入を向上させる上での妨げとなっています。

ML モデルを開発するチームにとっては、ビジネス成果に対する予測の影響が増加するにつれて、モデル予測を説明する責任も増加します。たとえば、消費者には、説明がなくても ML モデルから推薦された映画を受け入れる傾向があります。消費者がその推薦に同意するかどうかはわかりませんが、モデル開発者が予測を正当化する必要性は比較的低くなります。これに対して、クレジットローンの申し込みが承認されるかどうか、または患者に投与する薬の量を ML モデルが予測する場合、モデル開発者はその予測を説明する責任を負い、「ローンが拒否されたのはなぜですか」または「この薬を 10 ㎎ 飲まなくてはならないのはなぜですか」といった質問に対応する必要があります。 トレーニングプロセスに対する可視性を得て、人に対する説明が可能な ML モデルの開発が重要なのはこのためです。

Amazon SageMaker は、開発者およびデータサイエンティストが、あらゆる規模の ML モデルを迅速かつ簡単に構築、トレーニング、およびデプロイすることを可能にする完全マネージド型サービスです。Amazon SageMaker Debugger は Amazon SageMaker の機能で、リアルタイムおよびオフラインでの分析のためのモデルトレーニングプロセスに対する可視性を自動的に提供します。Amazon SageMaker Debugger では内部モデルの状態が定期的な間隔で保存され、トレーニング中のリアルタイムでの分析、およびトレーニング完了後のオフラインでの分析の両方を行うことができます。Amazon SageMaker Debugger はモデルのトレーニング中に問題を識別し、トレーニングされたモデルが行う予測への洞察を提供します。Amazon SageMaker Debugger には一連の組み込みルールが備わっており、これらは一般的なトレーニング問題を検知し、トレーニングが成功する上で重要な一般的な状態をモニタリングします。また、トレーニングジョブをモニタリングするカスタムルールを作成することも可能です。

この記事では、ML の説明可能性、人気の説明可能性ツールである SHAP (SHapley Additive exPlanation)、および Amazon SageMaker Debugger との SHAP のネイティブな統合について説明します。この記事の一環として、Amazon SageMaker Debugger を使用して金融サービスユースケースで説明を提供する方法を説明する詳細なノートブックも提供します。このユースケースでは、個人の所得が 50,000 USD を超えるか超えないかをモデルが予測します。この記事では、UCI Adult データセットを使用します。

ML の説明可能性

説明可能性とは、人間の言葉で ML または深層学習システムの内部構造を説明できる範囲のことです。これは、設計者でさえも AI が特定の結論に達した理由を説明できないというブラックボックスの概念とは対照的なものです。

説明可能性には、グローバルローカルの 2 つのタイプがあります。グローバル説明は ML モデル全体を透明かつ包括的にすることを目的とし、ローカル説明はモデルの個々の予測を説明することに焦点を合わせます。

ML モデルとその予測を説明する能力は、信頼を築き、ML の導入を向上させ、モデルは孤立した状態で予測を行うブラックボックスではなくなります。これは、モデル予測の消費者の安心感を高めます。モデルオーナーにとって、ML モデルにつきものの不確実性を理解する能力は、不具合が生じた場合にモデルをデバッグする、およびより良いビジネス成果のためにモデルを改善することに役立ちます。

この記事では、特徴量重要度と SHAP の 2 つの説明可能性手法を検証します。

特徴量重要度

特徴量重要度は、スコア (重要度) を使ってトレーニングデータを構成する特徴を説明する手法です。これは、ある特徴が他の特徴と比べてどれだけ価値があるか、または有益かを示します。XGBoost を使った個人所得予測のユースケースでは、重要度スコアがモデル内にあるブーストされた決定木の構造における各特徴の価値を示します。モデルが決定木で重要な判断を行うために属性を使用すればするほど、その属性の相対的重要度も高くなります。

SHAP

オープンソースツールである SHAP は、協力ゲーム理論に基づくシャープレイ値を使用します。これは、トレーニングデータインスタンスの各特徴量を、予測を報酬としたゲーム内のプレイヤーと見なすことで ML 予測を説明します。シャープレイ値は、これらの特徴間で報酬を公平に配分する方法を示します。この値はインスタンスに関するすべての可能な予測を考慮し、入力の可能な組み合わせをすべて使用します。この網羅的なアプローチのため、SHAP は整合性とローカル精度を保証することができます。詳細については、SHAP ウェブサイトのドキュメントをご覧ください。

Amazon SageMaker Debugger

Amazon SageMaker Debugger は、トレーニングジョブの状態をキャプチャするテンソルデータのモニタリング、記録、および分析を行うことによって、ML モデルのトレーニングに対する完全な可視性を提供します。Amazon SageMaker Debugger が有効化されたトレーニングジョブでは、どのテンソルを保存するか、テンソルをどこに保存するか、および トレーニングデータセットで実行するトライアルを設定します。コレクションに分類されたテンソルは、トレーニングジョブのライフライクルにおけるあらゆる時点でのジョブの状態を定義します。Amazon SageMaker Debugger に組み込まれたテンソルコレクションには、特徴量重要度、フル SHAP、および平均 SHAP が含まれます。モデルトレーニング中のデバッグ、および一般的なエラーの自動検知に関する詳細については、Amazon SageMaker Debugger – Debug Your Machine Learning Models を参照してください。

この記事の残りの部分では、トレーニングジョブに対して Amazon SageMaker Debugger を有効化する、組み込みテンソルコレクションの average_shapfull_shap を使用する、およびキャプチャされたテンソルをモデル説明のために視覚化して分析する方法を取り上げます。

チュートリアルの概要

このチュートリアルには、次のおおまかな手順が含まれます。

  1. 問題を示すトレーニングデータセットを検証する
  2. デバッガをオンにした Amazon SageMaker でモデルをトレーニングする
  3. デバッガの出力を視覚化して分析する

これらの手順は、Amazon SageMaker での説明可能性のための組み込みテンソルの使用に特化したものです。この記事では説明しませんが、この他にもライブラリのインポート、IAM のアクセス許可のセットアップ、およびその他機能に必要な手順があります。GitHub リポジトリにある次のノートブックでウォークスルーを実施し、コードを実行することができます。

このユースケースでは、個人の年齢、交際ステータス、就労時間数、およびキャピタルゲインなどのさまざまな特徴に基づいて、個人の所得が 50,000 USD に等しい、これを下回る、または上回るかを予測します。これには UCI Adult データセットを使用します。

トレーニングデータの検証

トレーニングデータを構成する特徴を理解するため、データセットをダウンロードします。以下のスクリーンショットは、データからの最初の数行を示しています。

以下のスクリーンショットは、特徴のリストです。

データセットは、個人の年齢、学歴、およびその他詳細情報をキャプチャする 12 個の異なる特徴で構成されています。ここでは、年収が 50,000 USD を超える個人の可能性を予測するために、XGBoost を使用します。目標は、各特徴が XGBoost モデルとその予測にどのように影響するかを理解することです。

Amazon SageMaker Debugger を有効にした状態での XGBoost モデルのトレーニング

XGBoost モデルは Amazon SageMaker Estimator API を使ってトレーニングします。トレーニング中に Amazon SageMaker Debugger を有効化するには、DebuggerHookConfig オブジェクトを作成し、この設定を Estimator API に追加します。DebuggerHookConfig は収集したいテンソルコレクションと、収集されたテンソルを保存する Amazon S3 の場所を指定します。このユースケースでは、トレーニングプロセス中に feature_importanceaverage_shap、および full_shap を 10 イテレーションごとに収集し、これはコードの save_interval で指定します。

トレーニングプロセスのモニタリングには、Amazon SageMaker Debugger ルールを作成します。この記事では、Amazon SageMaker に組み込まれた LossNotDecreasing ルールを使ってメトリクスコレクションをモニタリングします。このルールは、メトリクスのテンソルが 10 回を超えるステップにわたって減少しない場合にアラートを発行します。

このルール (debugger hook config) と Estimator は単一の API で設定できます。以下のコードを参照してください。

from sagemaker.debugger import rule_configs, Rule, DebuggerHookConfig, CollectionConfig
from sagemaker.estimator import Estimator

xgboost_estimator = Estimator(
    role=role,
    base_job_name=base_job_name,
    train_instance_count=1,
    train_instance_type='ml.m5.4xlarge',
    image_name=container,
    hyperparameters=hyperparameters,
    train_max_run=1800,

    debugger_hook_config=DebuggerHookConfig(
        s3_output_path=bucket_path,  # Required
        collection_configs=[
            CollectionConfig(
                name="metrics",
                parameters={
                    "save_interval": str(save_interval)
                }
            ),
            CollectionConfig(
                name="feature_importance",
                parameters={
                    "save_interval": str(save_interval)
                }
            ),
            CollectionConfig(
                name="full_shap",
                parameters={
                    "save_interval": str(save_interval)
                }
            ),
            CollectionConfig(
                name="average_shap",
                parameters={
                    "save_interval": str(save_interval)
                }
            ),
        ],
    ),

    rules=[
        Rule.sagemaker(
            rule_configs.loss_not_decreasing(),
            rule_parameters={
                "collection_names": "metrics",
                "num_steps": str(save_interval * 2),
            },
        ),
    ],
)

次に、作成した Estimator オブジェクトを使ってトレーニングジョブを開始します。以下のコードを参照してください。

from sagemaker.session import s3_input

train_input = s3_input("s3://{}/{}/{}".format(bucket, prefix, "data/train.csv"), content_type="csv")

validation_input = s3_input( "s3://{}/{}/{}".format(bucket, prefix, "data/validation.csv"), content_type="csv")

xgboost_estimator.fit(
    {"train": train_input, "validation": validation_input},
    # これは fire and forget イベントです。wait=False と設定することにより、バックグラウンドで実行するようにジョブをサブミットします。
    # Amazon SageMaker はひとつのトレーニングジョブを開始し、ノートブックの次のセルにコントロールをリリースします。
    # このノートブックに従って、トレーニングジョブのステータスを確認してください。
    wait=False
)

以下の出力が表示されます。

Training job status: InProgress, Rule Evaluation Status: InProgress
Training job status: InProgress, Rule Evaluation Status: InProgress
…..

上記のコードの結果として、Amazon SageMaker がひとつのトレーニングジョブとひとつのルールジョブを開始します。ルール評価ジョブのステータスをチェックするには、以下のコードを入力します。

xgboost_estimator.latest_training_job.rule_job_summary()

以下の出力が表示されます。

[{'RuleConfigurationName': 'LossNotDecreasing', 'RuleEvaluationJobArn': 'arn:aws:sagemaker:us-west-2:xxxxxxxxxxxx:processing-job/demo-smdebug-xgboost-adult-lossnotdecreasing-95f1ab04', 'RuleEvaluationStatus': 'InProgress', 'LastModifiedTime': datetime.datetime(2020, 3, 15, 4, 0, 31, 217000, tzinfo=tzlocal())}]

デバッガ出力の視覚化と分析

この手順では、トレーニング中にキャプチャされた feature_importancefull_shap、および average_shap の各テンソルを視覚化します。分析には、個々の特徴がモデル予測にどのように寄与するかを理解するための、モデルのグローバル説明とローカル説明を導くことが含まれます。また、外れ値予測も簡単に検証します。

ここでは、smdebug ライブラリとトライアル (単一のトレーニング実行を表す) の概念を使用します。トライアルオブジェクトにはテンソルの場所へのパスが含まれ、テンソルをクエリするためのアクセスを許可します。

トライアル内のステップは、トレーニングジョブの単一のバッチを表します。各トライアルには複数のステップがあります。収集されたテンソルには、各ステップに特定の値があります。テンソル値は、先ほど指定した Amazon S3 の場所に保存されます。以下のコードを参照してください。

from smdebug.trials import create_trial
s3_output_path = xgboost_estimator.latest_job_debugger_artifacts_path()
trial = create_trial(s3_output_path)

以下の出力が表示されます。

[2020-03-15 04:01:46.743 ip-172-16-11-140:23602 INFO s3_trial.py:42] Loading trial debug-output at path s3://sagemaker-us-west-2-xxxxxxxxxxxx/demo-smdebug-xgboost-adult-income-predi-2020-03-15-03-58-05-031/debug-output

収集されたテンソルを表示するには、以下のコードを入力します。

trial.tensor_names()

以下の出力が表示されます。

['average_shap/f0', 'average_shap/f1', ….. , 'average_shap/f11'
'feature_importance/cover/f0',….., 'feature_importance/cover/f1,
'feature_importance/gain/f0',….,  'feature_importance/gain/f11', 'feature_importance/weight/f0', …,'feature_importance/weight/f11',
'full_shap/f0', …, 'full_shap/f9',
'train-error', 'validation-error']

テンソル名に特徴の実際の名前は含まれていません。これらは f0f1 などで表されます。これは、機密性の高い特徴名が分析に表示されることを防ぎます。逆に、以下のコードは先ほど保存した実際の特徴名を使用します。要件に応じて、これらのアプローチのどちらかを使用できます。

f1average_shap テンソル値を表示するには、以下のコードを入力します。

trial.tensor("average_shap/f1").values()

以下の出力が表示されます。

 {0: array([0.], dtype=float32),
 5: array([0.], dtype=float32),
 …
 50: array([0.00796671], dtype=float32)}

収集されたテンソルを複数の特徴にプロットすることもできます。たとえば、feature_importance をプロットするには、以下のコードを入力します。

plot_feature_importance(trial, importance_type="cover")

以下のグラフはその出力です。

同様に、収集された average_shap テンソル値をすべての特徴にプロットすることもできます。以下のグラフを参照してください。

上記の 2 つのプロットにより、プロットされたメトリクスがトレーニングプロセス中にどのように変化するかの概要をつかむことができます。

グローバル説明

グローバル説明手法では、モデルとその特徴の寄与を、複数のデータポイントにまたがる全体で理解することができます。以下のグラフは、各特徴に平均絶対 SHAP 値をプロットする Aggregate Bar Plot のグラフです。具体的に説明すると、以下のプロットは所得可能性が 50,000 USD を超えるかどうかの予測において、交際関係の値 (Wife=5、Husband=4、Own-child=3、Other-relative=2、Unmarried=1、Not-in-family=0) が最も重要な役割を果たしていることを示しています。

各特徴の SHAP 値分布をさらに詳しく表示することができます。以下の Summary Plot は、俸グラフよりも多くのコンテキストを提供します。これには、どの特徴が最も重要か、およびデータセットに対するそれらの影響範囲も示されています。色分けは、特徴の価値における変化が、予測における変化にどのように影響するかを照合することも可能にします (たとえば、年齢が高くなると、予測の対数オッズが高くなり、最終的に True 予測がより頻繁に生じる結果となります)。また、個人の性別 (Sex) が予測にマイナスの影響を与えることもわかります。

赤は特徴の価値が高いことを示し、青は低くなります (特徴全体で正規化)。これは、年齢の増加が予測の対数オッズを高くし、最終的に True 予測がより頻繁に生じる結果となるといった結果を導き出します。

ローカル説明

ローカル説明は、個々の予測のそれぞれを説明することに焦点を当てます。Force Plot の説明は、モデル出力がベース値 (データセット全体の平均モデル出力) よりも高く/低くなることに特徴がどのように寄与するかを示しています。予測を押し上げる特徴は、押し下げるものはです。

以下のプロットは、この特定のデータポイントで、予想確率 (0.48) が平均よりも高い (~0.2) ことを示しています。これは、この人物が交際関係 (Relationship = Wife) にあることが主な理由で、影響度は小さくなりますが、年齢が平均以上であることも理由となります。同様に、このモデルは特定の Sex および Race 値が原因で確率が減少しており、モデルの動作におけるバイアスを示しています (データのバイアスのためだと思われます)。

SHAP では、複数のデータポイントに関する説明を理解するために、90 度回転させて複数の Force Plot をスタックすることができます。以下のプロットを参照してください。

 

この記事では静的にスタックされた Force Plot を掲載しましたが、完全なノートブックでは Javascript を有効化してプロットをインタラクティブにすることができます。そうすることで、出力が各特徴に基づいてどのように変化するかを単独で理解することができます。Force Plot のスタックは、ローカル説明とグローバル説明間のバランスを提供します。

外れ値

外れ値とは、データでの他の観察値から逸脱する極端な値のことです。外れ値予測に対するさまざまな特徴による影響を理解することは、それが目新しいものなのか、実験誤差なのか、それともモデルの欠陥なのかを判断するために役立ちます。

以下の Force Plot は、ベースライン値の両側にある予測外れ値を示すものです。このグラフは、「workclass=Federal-govAge=38、かつ Relationship=Wife (つまり、連邦職員の妻) であるならば、所得グループの外れ値になるであろう」ことを示しています。

ノートブックには、より詳しい外れ値分析が提供されています。

まとめ

この記事では、ML の導入を向上させる上での説明可能性の重要性、およびモデルの説明可能性を可能にするテンソルコレクションが組み込まれた Amazon SageMaker Debugger 機能を紹介しました。また、個人の所得予測に関する金融サービスユースケースの ML モデルのトレーニングを一通り説明する詳細なノートブックも提供しました。その後、キャプチャされたテンソルを視覚化することによって、モデルのグローバル説明とローカル説明を分析しました。トレーニングプロセスに対するリアルタイムの洞察と、キャプチャされたテンソルデータのオフライン分析を提供する能力を備えた Amazon SageMaker Debugger は、ML モデルを開発するチームにとって極めて強力なツールです。ぜひ一度 Amazon SageMaker Debugger を試して、コメント欄からフィードバックをお寄せください。


著者について

Mona Mona は、AWS World Wide Public Sector Team で働く AI/ML スペシャリストソリューションアーキテクトです。Mona は、AWS World Wide Public Sector のお客様と連携して、大規模な機械学習の導入を行うお手伝いをしています。

 

 

 

Rahul Iyer は、AWS AI のソフトウェア開発マネージャーです。Framework Algorithms チームを率いる Rahul は、XGBoost および Scikit-learn といった機械学習フレームワークの構築と最適化を行っています。

 

 

 

Sireesha Muppala は AWS の AI/ML スペシャリストソリューションアーキテクトで、大規模な機械学習ソリューションの設計と実装に関するガイダンスをお客様に提供しています。Sireesha はコロラド大学コロラドスプリングス校でコンピュータサイエンスの博士号を取得しました。余暇の楽しみは、コロラドのトレイルを走ったり、ハイキングしたりすることです。