AWS for Industries

Training Machine Learning Models on Multimodal Health Data with Amazon SageMaker

This post was co-authored by Olivia Choudhury, PhD, Partner Solutions Architect; Michael Hsieh, Sr. AI/ML Specialist Solutions Architect; and Andy Schuetz, PhD, Sr. Startup Solutions Architect at AWS.

This is the second blog post in a two-part series on Multimodal Machine Learning (Multimodal ML). In part one, we deployed pipelines for processing RNA sequence data, clinical data (reflective of EHR data), and medical images with human annotations. In this post, we show how to pool features from each data modality, and train a model to predict survival of patients diagnosed with Non-Small Cell Lung Cancer (NSCLC). Building on the first blog post, we’ll continue using the NSCLC Radiogenomics data set, which consists of RNA sequence data, clinical data (reflective of EHR data), medical images, and human annotations of those images [1].

Multimodal ML models can be applied to other applications, including, but not limited to, personalized treatment, clinical decision support, and drug response prediction. And, while we present an application to genomic, clinical, and medical imaging, the approach and architecture are applicable to a broad set of health data ML use-cases and frameworks. Further, the approach we present can be applied to train models for batch and real-time inferencing use cases.

Walkthrough

In part one of this series, we presented an architecture for multimodal data pipelines, as shown in Figure 1. We concluded that post having processed genomic, clinical, and medical imaging data, and loading the processed features from each modality in Amazon SageMaker Feature Store. In this post, we demonstrate how to pool features from the disparate modalities and train a predictive model that outperforms one trained on one or two data modalities.

Architecture for integrating and analyzing multimodal health data

Figure 1: Architecture for integrating and analyzing multimodal health data

Then, the features engineered from each data modality are written to Amazon SageMaker Feature Store, a purpose-built repository for storing ML features. Feature preparation and model training is performed using Amazon SageMaker Studio.

Prerequisites

If you completed part one of this series, you can skip the prerequisites and SageMaker Studio onboarding steps, and proceed with the Multimodal Feature Store below.

The prerequisites for this walkthrough are as follows:

Running the analyses outlined in this blog post will fall within the Free Tier for SageMaker, or cost less than $5 if your account is not eligible.

Create step section

SageMaker Studio

We will use SageMaker Studio to work with data, author Jupyter notebooks, and access the EC2 instance used in the genomics pipeline. To get started, follow the Standard Setup procedure using Access Management (IAM) to onboard to SageMaker Studio in our account. For simplicity, select Create role for the Execution role, and do not specify any S3 buckets explicitly, as shown in Figure 2. Permit SageMaker access to our S3 objects with the tag “sagemaker” and value “true”. In this way, only input and output data on S3 will be accessible.

Configuration of AIM role for SageMaker

Figure 2: Configuration of IAM role for SageMaker.

In the Network section, choose to onboard SageMaker Studio in our VPC, and specify the private subnet. Set the AppNetworkAccessType to be VpcOnly, to disable direct access from the internet.

Select submit to create a studio, and wait a few moments for the environment to be provisioned. After the SageMaker Studio IDE becomes available, select the Git icon on the left-hand tool bar and clone the repository that accompanies this blog.

By default, SageMaker Studio notebooks, local data, and model artifacts are all encrypted with AWS managed customer master keys (CMKs). In the present example, we are working with deidentified data. However, when working with Protected Health Information (PHI), it is important to encrypt all data at rest and in transit, apply narrowly defined roles, and limit access in accordance with the principles of least privilege. You can find further guidance on best practices in the white paper Architecting for HIPAA Security and Compliance.

Multimodal Feature Store

Use SageMaker Feature Store to store and manage the features, obtained from different sources of data, for training the Multimodal ML model. A FeatureGroup contains the metadata for all the data stored in Feature Store. SageMaker Feature Store supports two types of store: an online store and an offline store. Online store offers GetRecord API to access the latest value for a record in real-time with low latency, whereas offline store keeps all versions of features and are ideal for batch access, when creating training data. Create a Feature Group for each data type, i.e., genomic, clinical, and medical imaging, and add the previously-constructed feature vectors as records in their corresponding Feature Groups, as part of the respective processing pipeline. The following code snippet shows how to retrieve the output data generated by the genomic secondary analysis pipeline from S3 bucket, create a FeatureGroup specific to this data, and then ingest it for storage.

Details of implementation can be found in preprocess-genomic-data.ipynbpreprocess-clinical-data.ipynb, and preprocess-imaging-data.ipynb scripts inside genomic, clinical, and imaging folders, respectively.

data_gen = pd.read_csv('s3://{}/{}'.format(<S3-BUCKET>, <FILE-NAME>))

feature_store_session = Session(
    boto_session=<BOTO3-SESSION>,
    sagemaker_client=<BOTO3-SESSION-CLIENT-FOR-SAGEMAKER>,
    sagemaker_featurestore_runtime_client=<BOTO3-SESSION-CLIENT-FOR-SAGEMAKER-FEATURESTORE-RUNTIME>
)

genomic_feature_group_name = 'genomic-feature-group-' + strftime('%d-%H-%M-%S', gmtime())
genomic_feature_group = FeatureGroup(name=genomic_feature_group_name, sagemaker_session=feature_store_session)

genomic_feature_group.create(
    s3_uri=f"<S3-DEFAULT-BUCKET>",
    record_identifier_name=<ID>,
    event_time_feature_name=<EVENT-TIME>,
    role_arn=<ROLE>,
    enable_online_store=True)

genomic_feature_group.ingest(data_frame=data_gen, max_workers=3, wait=True)

For ML training, run a query against the three feature groups to join the data stored in the offline store. For the given dataset, this integration results in 119 data samples, where each sample is a 215-dimensional vector.

genomic_table = <GENOMIC-TABLE-NAME>
clinical_table = <CLINICAL-TABLE-NAME>
imaging_table = <IMAGING-TABLE-NAME>

query = <FEATURE-GROUP-NAME>.athena_query()

query_string = 'SELECT '+genomic_table+'.*, '+clinical_table+'.*, '+imaging_table+'.* \
FROM '+genomic_table+' LEFT OUTER JOIN '+clinical_table+' ON '+clinical_table+'.case_id = '+genomic_table+'.case_id \
LEFT OUTER JOIN '+imaging_table+' ON '+clinical_table+'.case_id = '+imaging_table+'.subject \
ORDER BY '+clinical_table+'.case_id ASC;'

query.run(query_string=query_string, output_location='s3://'+<BUCKET>+'/'+prefix+'/query_results/')

imaging_query.wait()

multimodal_features = query.as_dataframe()

The implementation details can be found in Get features from SageMaker FeatureStore based on data type section of train-test-model.ipynb inside model-train-test folder of the code repository.

Model Training

For the purpose of demonstration, we will consider a classification task to predict patient survival. To train the ML model, construct an estimator of the gradient boosting library XGBoost through SageMaker XGBoost container. Use the combined dataset, integrated in the above step, for further analysis. Randomly shuffle this data and divide it into 80% for training and 20% for testing the model. Further split the training data into 80% for training and 20% for validating the model. Use feature scaling to normalize the range of independent features. To identify the most discriminative features, perform principal component analysis (PCA) on the integrated features. We identify the top principal components that contribute to 99% variance in the data. Concretely, this results in a dimensionality reduction from 215 features down to 65 principal components, which constitute features for the supervised learner. Since our objective is to train a baseline model with multimodal data, we consider default hyperparameters and do not perform any hyperparameter tuning. Adopting other classification algorithms and exploring hyperparameter tuning may yield improved results.

The implementation details for training Multimodal ML model can be found in between the Split data for training and testing section and Train model section of train-test-model.ipynb inside model-train-test folder of the code repository.

Model Evaluation

We first evaluate performance of the model in predicting the survival outcome when trained on genomic, clinical, and medical imaging data modalities alone. Then, analyze performance of a model trained on features obtained from all data domains combined, i.e., multimodal data. For comparative analysis, we compute accuracy, F1, precision, and recall scores for each setup. As shown in Table 1, multimodal data leads to higher predictive capability than using data from a single domain. This result is consistent across all four evaluation metrics. The corresponding code can be found in Test model section of train-test-model.ipynb inside model-train-test folder of the code repository.

Comparison of evaluation metrics for using different data domains to train the model for predicting survival outcome

Table 1: Comparison of evaluation metrics for using different data domains to train the model for predicting survival outcome. 

To better understand this result, let’s further investigate the importance of each data domain for the given predictive task. To determine the importance or discriminative property of the data belonging to different modalities, explore the composition of the principal components that were used as features. These principal components are linear combinations of the original features taken from each modality. They constitute a set of linearly uncorrelated features that describe 99% of the variance in the data. The first principal component explains the most variance, followed by subsequent principal components each explaining successively less of the variance in the data. Figure 3 demonstrates the variance explained by each original feature, color-coded by modality, for the top 20 principal components. For the first few components, medical imaging features are most heavily weighted and have higher variance (higher intensity in heat map of Figure 3) than clinical and genomic features, indicating the importance of medical imaging features for survival outcome prediction. However, from the third principal component on, we see the weighting of the clinical and genomic modality. This shows how the clinical and genomic modalities explain additional variance beyond what is captured in the medical images. We see the same trend in the variable correlation circle in Figure 4, where the top 10 features contributing to the first and second principal components are from the medical imaging domain. Yet, genomic features (like gdf_15) and clinical features (like the measure of tobacco use pack_years) contribute to the fourth and fifth principal components, as shown in Figure 4.

The code for generating the heat map and correlation circle can be found in Get feature importance section and Plot Correlation Circle section, respectively of train-test-model.ipynb inside model-train-test folder of the code repository.

Heat map demonstrating variance explained by each feature, color-coded by domain, for the top 20 principal components

Figure 3: Heat map demonstrating variance explained by each feature, color-coded by domain, for the top 20 principal components. The X-axis denotes principal components, Y-axis denotes features obtained from different domains (genomic, clinical, and medical imaging), and cell intensity denotes that feature’s level of variance (normalized value) for the corresponding principal component. For the first few principal components, medical imaging features have high variance, followed by clinical and genomic features. This indicates that for the given dataset, features obtained from the medical imaging domain are most discriminative or important for predicting survival outcome. However, all modalities contain information that can improve model performance.

Variable correlation circle depicting correlation between top 25 features or variables of the fourth and fifth principal components

Figure 4: Variable correlation circle depicting correlation between top 25 features or variables of the fourth and fifth principal components. The angle between a pair of vectors indicates the level of correlation between them. A small angle indicates positive correlation, whereas an angle close to 180 degrees indicates negative correlation. The distance between the feature and the origin indicates how well the feature is represented in the plot. For these principal components, the top 25 features are obtained from all three domains: genomic (eg. gdf15), clinical (eg. pack_years, histology_adinocarcinoma), and medical imaging (eg. original_glcm_clustershade).

Cleaning up

To avoid incurring future charges, delete the resources created in the steps above.

  1. From the SageMaker Studio file menu, select Shutdown All, and delete any resources shown in the Recent Activity section of the SageMaker Dashboard.
  2. From the Athena console, delete all tables you created in this analysis.
  3. Delete all data stored on S3.

Conclusion

In this two-part blog series, we’ve shown how to deploy a set of data analysis pipelines, and ML tools to efficiently processes data from diverse, unstructured data modalities. By leveraging fully managed AWS services, we’ve streamlined the setup, processed multi-modal data at scale, and made it easy to focus on the ML modeling, rather than the infrastructure. Through experimental evaluation on real-world data, we demonstrated that integrating data from multiple modalities can enhance the predictive capability of a model.

Leveraging multimodal data promises better ML models for healthcare and life sciences, and subsequently improved care delivery and patient outcomes. While we focused on genomics, clinical data, and medical imaging, the approach we present can be applied to other data modalities. We encourage you to extend this example to the multimodal data applications that interest you most.

To learn more about healthcare & life sciences on AWS, visit aws.amazon.com/health.

References

[1] Bakr, Shaimaa, et al. “A radiogenomic dataset of non-small cell lung cancer.” Scientific data 5.1 (2018): 1-9.

Olivia Choudhury

Olivia Choudhury

Olivia Choudhury, PhD, is a Senior Partner SA at AWS. She helps partners, in the Healthcare and Life Sciences domain, design, develop, and scale state-of-the-art solutions leveraging AWS. She has a background in genomics, healthcare analytics, federated learning, and privacy-preserving machine learning. Outside of work, she plays board games, paints landscapes, and collects manga.

Andy Schuetz

Andy Schuetz

Andy Schuetz, PhD, is a Sr. Startup Solutions Architect at AWS, where he focuses on helping customers deliver Healthcare and Life Sciences solutions. When Andy’s not building things at AWS, he prefers to be riding a mountain bike.

Michael Hsieh

Michael Hsieh

Michael Hsieh is a Senior AI/ML Specialist Solutions Architect. He works with HCLS customers to advance their ML journey with AWS technologies and his expertise in medical imaging. As a Seattle transplant, he loves exploring the great mother nature the city has to offer, such as the hiking trails, scenery kayaking in the SLU, and the sunset at Shilshole Bay.