亚马逊AWS官方博客

Amazon SageMaker 模型监控器 – 完全托管的机器学习模型自动化监控

今天,我们非常高兴地宣布推出 Amazon SageMaker 模型监控器。这是 Amazon SageMaker 的一项新功能,可以自动监控生产中的机器学习 (ML) 模型,并在出现数据质量问题时向您发出警报。

在我从事数据处理之初时,我学会了一样东西,那便是再关注数据质量都不为过。不知道您是否有过这样的经历:您花费数小时排查问题,最后知道是意外的 NULL 值或不知怎么就到了您的一个数据库的外来字符编码导致的。

由于模型实际上是根据大量数据构建的,因此不难理解,为什么 ML 从业人员会花费大量时间来维护数据集。特别是,他们会确保训练集(用于训练模型)和验证集(用于测量模型的准确性)中的数据样本具有相同的统计属性。

现在还不是松懈的时候! 尽管您可以完全控制实验数据集,但对于模型将要接收的真实数据就不是那回事了。当然,这些数据将是未经清理的,但是更令人担忧的问题是“数据漂移”,即您所接收数据的统计性质发生渐变。最小值和最大值、平均值、中位数、方差等等:所有这些都是决定模型训练期间做出的假设和决策的关键属性。我们的直觉告诉我们,这些值的任何重大变化都会影响预测的准确性:设想一下,要是由于输入特征出现漂移甚至缺失,导致一个贷款应用程序预测的金额升高,那多可怕!

检测这些条件非常困难:您将需要捕获模型接收的数据,运行各种统计分析以将这些数据与训练集进行比较,定义规则以检测漂移,并在发生漂移时发出警报……并在每次更新模型时从头再来一遍。专家级 ML 从业人员当然知道如何构建这些复杂的工具,但是却要花费大量的时间和耗费大量的资源。这不就是眉毛胡子一把抓么……

为了帮助所有客户专注于创造价值,我们构建了 Amazon SageMaker 模型监控器。下面我来进行更多介绍。

Amazon SageMaker 模型监控器简介

典型的监控会话如下。首先,我们要从 SageMaker 终端节点开始,可以使用现有的终端节点,也可以专门为了监控目的而创建新的终端节点。您可以在任何终端节点上使用 SageMaker 模型监控器,无论模型是从内置算法内置框架,还是从您自己的容器训练而来。

使用 SageMaker 开发工具包,您可以捕获发送到终端节点的部分数据(可配置),您也可以根据需要捕获预测,并将这些数据存储在您的 Amazon Simple Storage Service (S3) 存储桶中。捕获的数据会附加上元数据(内容类型、时间戳等),您可以像使用任何 S3 对象一样保护和访问它。

然后,从用于训练在终端节点上部署的模型的数据集建立基线。当然,您也可以选择使用已有的基线。这将启动 Amazon SageMaker 处理作业,其中 SageMaker 模型监控器将执行以下操作:

  • 推断输入数据的架构,即有关每个特征的类型和完整性的信息。您应该对其进行检查,并在需要时进行更新。
  • (仅对于预构建的容器)使用 Deequ(基于由 Amazon 开发并在 Amazon 使用Apache Spark 的开放源代码工具)来计算特征统计信息(博客文章研究论文)。这些统计信息包括 KLL 草图,这是一种用于在数据流上计算准确分位数的高级技术,这也是我们最近对 Deequ 做出的一项贡献

使用这些构件,下一步是启动监控计划,以使 SageMaker 模型监控器检查收集的数据和预测质量。无论使用的是内置容器还是自定义容器,都需要应用许多内置规则,并且报告会定期推送到 S3。这些报告包含在上一个时间段内接收到的数据的统计和架构信息以及检测到的任何违规情况。

最后但并非最不重要的一点是, SageMaker 模型监控器会向 Amazon CloudWatch 发出与特征相对应的指标,可用于设置控制面板和警报。CloudWatch 的摘要指标也可以在 Amazon SageMaker Studio 中看到,当然所有统计数据、监控结果和收集的数据都可以在笔记本中查看和进一步分析。

有关更多信息以及有关如何通过 AWS CloudFormation 使用 SageMaker 模型监控器的示例,请参阅开发人员指南

现在,让我们使用经过内置 XGBoost 算法训练的用户流失预测模型进行演示。

启用数据捕获

第一步是创建终端节点配置以启用数据捕获。在这里,我决定捕获 100% 的传入数据以及模型输出(即预测)。我还传递了 CSV 和 JSON 数据的内容类型。

data_capture_configuration = {
    "EnableCapture": True,
    "InitialSamplingPercentage": 100,
    "DestinationS3Uri": s3_capture_upload_path,
    "CaptureOptions": [
        { "CaptureMode": "Output" },
        { "CaptureMode": "Input" }
    ],
    "CaptureContentTypeHeader": {
       "CsvContentTypes": ["text/csv"],
       "JsonContentTypes": ["application/json"]
}

接下来,我使用常规的 CreateEndpoint API 创建终端节点。

create_endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName = endpoint_config_name,
    ProductionVariants=[{
        'InstanceType':'ml.m5.xlarge',
        'InitialInstanceCount':1,
        'InitialVariantWeight':1,
        'ModelName':model_name,
        'VariantName':'AllTrafficVariant'
    }],
    DataCaptureConfig = data_capture_configuration)

对于已有的终端节点,我可以使用 UpdateEndpoint API 来无缝更新终端节点配置。

反复调用终端节点后,我可以在 S3 中看到一些捕获的数据(为清晰起见,对输出进行了编辑)。

$ aws s3 ls --recursive s3://sagemaker-us-west-2-123456789012/sagemaker/DEMO-ModelMonitor/datacapture/DEMO-xgb-churn-pred-model-monitor-2019-11-22-07-59-33/
AllTrafficVariant/2019/11/22/08/24-40-519-9a9273ca-09c2-45d3-96ab-fc7be2402d43.jsonl
AllTrafficVariant/2019/11/22/08/25-42-243-3e1c653b-8809-4a6b-9d51-69ada40bc809.jsonl

这是其中一个文件中的一行。

    "endpointInput":{
        "observedContentType":"text/csv",
        "mode":"INPUT",
        "data":"132,25,113.2,96,269.9,107,229.1,87,7.1,7,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,1",
        "encoding":"CSV"
     },
     "endpointOutput":{
        "observedContentType":"text/csv; charset=utf-8",
        "mode":"OUTPUT",
        "data":"0.01076381653547287",
        "encoding":"CSV"}
     },
    "eventMetadata":{
        "eventId":"6ece5c74-7497-43f1-a263-4833557ffd63",
        "inferenceTime":"2019-11-22T08:24:40Z"},
        "eventVersion":"0"}

看上去没什么问题了。现在,让我们为此模型创建一个基线。

创建监控基线
这是一个非常简单的步骤:传递基线数据集的位置以及存储结果的位置。

from processingjob_wrapper import ProcessingJob

processing_job = ProcessingJob(sm_client, role).
   create(job_name, baseline_data_uri, baseline_results_uri)

完成这项工作后,我可以在 S3 中看到两个新对象:一个用于统计信息,一个用于约束。

aws s3 ls s3://sagemaker-us-west-2-123456789012/sagemaker/DEMO-ModelMonitor/baselining/results/
constraints.json
statistics.json

constraints.json 文件告诉了我推断出的训练数据集的架构(不要忘记检查它的准确性)。每个特征都分了类,我还获得有关某个特征是否始终存在的信息(此处 1.0 表示 100%)。以下是前几行。

{
  "version" : 0.0,
  "features" : [ {
    "name" : "Churn",
    "inferred_type" : "Integral",
    "completeness" : 1.0
  }, {
    "name" : "Account Length",
    "inferred_type" : "Integral",
    "completeness" : 1.0
  }, {
    "name" : "VMail Message",
    "inferred_type" : "Integral",
    "completeness" : 1.0
  }, {
    "name" : "Day Mins",
    "inferred_type" : "Fractional",
    "completeness" : 1.0
  }, {
    "name" : "Day Calls",
    "inferred_type" : "Integral",
    "completeness" : 1.0

在该文件的末尾,我可以看到 CloudWatch 监控的配置信息:将其打开或关闭、设置漂移阈值等。

"monitoring_config" : {
    "evaluate_constraints" : "Enabled",
    "emit_metrics" : "Enabled",
    "distribution_constraints" : {
      "enable_comparisons" : true,
      "min_domain_mass" : 1.0,
      "comparison_threshold" : 1.0
    }
  }

statistics.json 文件显示每个特征(平均值、中位数、分位数等)的不同统计信息,以及终端节点接收的唯一值。示例如下。

"name" : "Day Mins",
    "inferred_type" : "Fractional",
    "numerical_statistics" : {
      "common" : {
        "num_present" : 2333,
        "num_missing" : 0
      },
      "mean" : 180.22648949849963,
      "sum" : 420468.3999999996,
      "std_dev" : 53.987178959901556,
      "min" : 0.0,
      "max" : 350.8,
      "distribution" : {
        "kll" : {
          "buckets" : [ {
            "lower_bound" : 0.0,
            "upper_bound" : 35.08,
            "count" : 14.0
          }, {
            "lower_bound" : 35.08,
            "upper_bound" : 70.16,
            "count" : 48.0
          }, {
            "lower_bound" : 70.16,
            "upper_bound" : 105.24000000000001,
            "count" : 130.0
          }, {
            "lower_bound" : 105.24000000000001,
            "upper_bound" : 140.32,
            "count" : 318.0
          }, {
            "lower_bound" : 140.32,
            "upper_bound" : 175.4,
            "count" : 565.0
          }, {
            "lower_bound" : 175.4,
            "upper_bound" : 210.48000000000002,
            "count" : 587.0
          }, {
            "lower_bound" : 210.48000000000002,
            "upper_bound" : 245.56,
            "count" : 423.0
          }, {
            "lower_bound" : 245.56,
            "upper_bound" : 280.64,
            "count" : 180.0
          }, {
            "lower_bound" : 280.64,
            "upper_bound" : 315.72,
            "count" : 58.0
          }, {
            "lower_bound" : 315.72,
            "upper_bound" : 350.8,
            "count" : 10.0
          } ],
          "sketch" : {
            "parameters" : {
              "c" : 0.64,
              "k" : 2048.0
            },
            "data" : [ [ 178.1, 160.3, 197.1, 105.2, 283.1, 113.6, 232.1, 212.7, 73.3, 176.9, 161.9, 128.6, 190.5, 223.2, 157.9, 173.1, 273.5, 275.8, 119.2, 174.6, 133.3, 145.0, 150.6, 220.2, 109.7, 155.4, 172.0, 235.6, 218.5, 92.7, 90.7, 162.3, 146.5, 210.1, 214.4, 194.4, 237.3, 255.9, 197.9, 200.2, 120, ...

现在,让我们开始监控终端节点。

监控终端节点
同样的,我们只需要调用一个 API:我只需为终端节点创建一个监控计划,并传递基线数据集的约束和统计文件。如果需要调整数据和预测,我还可以传递处理前和处理后函数。

ms = MonitoringSchedule(sm_client, role)
schedule = ms.create(
   mon_schedule_name,
   endpoint_name,
   s3_report_path,
   # record_preprocessor_source_uri=s3_code_preprocessor_uri,
   # post_analytics_source_uri=s3_code_postprocessor_uri,
   baseline_statistics_uri=baseline_results_uri + '/statistics.json',
   baseline_constraints_uri=baseline_results_uri+ '/constraints.json'
)

然后,我开始将假的数据发送到终端节点,即根据随机值构造的样本,然后等待 SageMaker 模型监控器开始生成报告。我等不及想要知道结果了!

检查报告

很快,我看到报告出现在了 S3 中。

mon_executions = sm_client.list_monitoring_executions(MonitoringScheduleName=mon_schedule_name, MaxResults=3)
for execution_summary in mon_executions['MonitoringExecutionSummaries']:
    print("ProcessingJob: {}".format(execution_summary['ProcessingJobArn'].split('/')[1]))
    print('MonitoringExecutionStatus: {} \n'.format(execution_summary['MonitoringExecutionStatus']))

ProcessingJob: model-monitoring-201911221050-df2c7fc4
MonitoringExecutionStatus: Completed 

ProcessingJob: model-monitoring-201911221040-3a738dd7
MonitoringExecutionStatus: Completed 

ProcessingJob: model-monitoring-201911221030-83f15fb9
MonitoringExecutionStatus: Completed 

让我们找到其中一个监控作业的报告。

desc_analytics_job_result=sm_client.describe_processing_job(ProcessingJobName=job_name)
report_uri=desc_analytics_job_result['ProcessingOutputConfig']['Outputs'][0]['S3Output']['S3Uri']
print('Report Uri: {}'.format(report_uri))

Report Uri: s3://sagemaker-us-west-2-123456789012/sagemaker/DEMO-ModelMonitor/reports/2019112208-2019112209

好了,我们来看看这里面到底有什么。

aws s3 ls s3://sagemaker-us-west-2-123456789012/sagemaker/DEMO-ModelMonitor/reports/2019112208-2019112209/

constraint_violations.json
constraints.json
statistics.json

如您所料,constraints.jsonstatistics.json 包含监控作业处理的数据样本的架构和统计信息。让我们直接打开第三个文件 constraints_violations.json

violations" : [ {
    "feature_name" : "State_AL",
    "constraint_check_type" : "data_type_check",
    "description" : "Value: 0.8 does not meet the constraint requirement! "
  }, {
    "feature_name" : "Eve Mins",
    "constraint_check_type" : "baseline_drift_check",
    "description" : "Numerical distance: 0.2711598746081505 exceeds numerical threshold: 0"
  }, {
    "feature_name" : "CustServ Calls",
    "constraint_check_type" : "baseline_drift_check",
    "description" : "Numerical distance: 0.6470588235294117 exceeds numerical threshold: 0"
  }

糟糕! 我好像给整数型特征分配了浮点值,结果可想而知!

一些特征也表现出了漂移现象,这也不是一个让人乐见的情况。也许我的数据提取过程出了点问题,也许数据的分布实际上已经改变,我需要重新训练模型。由于所有这些信息都可以作为 CloudWatch 指标提供,因此我可以定义阈值、设置警报甚至自动触发新的训练作业。

现已推出!

如您所见,Amazon SageMaker Model Monitor 易于设置,可帮助您快速了解 ML 模型中存在的质量问题。

现在轮到您自己实际使用了。Amazon SageMaker 模型监控器现已在提供 Amazon SageMaker 的所有商业区域推出。此功能还集成在了 Amazon SageMaker Studio (我们的 ML 项目工作台)中。最后但并非最不重要的一点是,所有信息都可以在笔记本中查看和进一步分析。

请试一试,并通过 Amazon SageMaker 的 AWS 论坛或您常用的 AWS Support 联系方式向我们发送反馈。

– Julien