AWS Machine Learning Blog
Training batch reinforcement learning policies with Amazon SageMaker RL
Amazon SageMaker is a fully managed service that enables developers and data scientists to quickly and easily build, train, and deploy machine learning (ML) models at any scale. In addition to building ML models using more commonly used supervised and unsupervised learning techniques, you can also build reinforcement learning (RL) models using Amazon SageMaker RL.
Amazon SageMaker RL includes pre-built RL libraries and algorithms that make it easy to get started with reinforcement learning. For more information, see Amazon SageMaker RL – Managed Reinforcement Learning with Amazon Sagemaker. Amazon SageMaker RL makes it easy to integrate with various simulation environments such as AWS RoboMaker, Open AI Gym, open-source environments, and custom-built environments for training RL models. You can also use Amazon RL containers (MXNet and TensorFlow) which include Open AI Gym, Intel Coach, and Berkeley Ray RLLib.
This post shows how to use Amazon SageMaker RL to implement batch reinforcement learning (batch RL), in which the complete amount of learning experience—usually a set of transitions sampled from the system—is given beforehand. This technique requires you to collect a set of state and action transitions from previous policies and use them to train a new RL policy without interacting with environments.
This post also shows you how to use Amazon SageMaker RL to collect offline data from an initial random policy, train an RL policy with the offline data, and get action predictions from the trained policy, which you can use to collect offline data for the next RL policy training.
Reinforcement learning has shown promise in solving problems across multiple domains, such as portfolio management, energy optimization, and robotics. RL is a category of ML that does not depend on any training data to be present. Instead, in RL, a learning agent interacts with an environment (real or simulated) and learns a policy that provides an optimal sequence of actions to take. The policy the agent learns is based on the reward or penalty it receives for each action it takes.
However, for many real-world problems, the RL agent needs to learn from historical data that a deployed policy generated. For example, you may have historical data of experts playing games, users interacting with a website, or sensor data from a control system. You can use this data as input to train a new and improved RL policy by treating the historical data as the outcome of an existing policy deployed.
This approach to RL is called batch RL, in which the learning agent derives an improved policy from a batch of fixed, offline dataset samples. For more information, see the “Batch Reinforcement Learning” chapter from the book “Reinforcement Learning: State-of-the-Art”
This post includes an accompanying notebook with an example of how to use batch RL to train a new policy from an offline dataset created with predictions from a previously deployed policy. For more information, see the GitHub repo.
To create the offline dataset from a previously deployed model, this post uses Amazon SageMaker batch transform, which is a high-performance and high-throughput feature in Amazon SageMaker for generating inferences for large datasets. We can collect the inferences from Batch Transform and the rewards from the environment to train a better policy with batch RL. For more information, see Get Inferences for an Entire Dataset with Batch Transform.
Batch RL on Amazon SageMaker RL
For this post, you apply batch RL to the CartPole balancing problem, in which an unactuated joint attaches a pole to a cart that is moving along a frictionless track.
First, you need to formulate the cart-pole balance problem in the following RL terms:
- Objective – Prevent the pole from falling over
- Environment – The environment this post uses is part of OpenAI Gym
- State – Cart position, cart velocity, pole angle, and pole velocity at tip
- Action – Push the cart to the left and the right
- Reward – 1 for every step taken, including the termination step
For more information, see Use Reinforcement Learning with Amazon SageMaker.
At a high level, batch RL implementation includes the following steps:
- Simulate a initial policy and collect data from this policy
- Train an RL policy with the offline data from the initial policy without interacting with the simulator.
- Visualize and evaluate the trained RL policy’s performance
- Use Amazon SageMaker batch transform to make batch inferences from the trained RL policy
These steps are specific to batch RL implementation on Amazon SageMaker. There are other steps necessary to import libraries, set up permissions, and other functions that this post does not discuss. For more information, see the GitHub repo.
Simulating a random policy and collecting data
For batch RL training, you need to simulate the batches of data generated by a previously deployed policy. In a real use case, you can collect the off-policy data by interacting with the live environment using existing policies. For this post, you use OpenAI Gym Cartpole-v0 as the environment to mimic a live environment and use a random policy with uniform action distribution to mimic a deployed agent.
Complete the following steps:
- Create 100 environments of
Cartpole-v0and collect five episodes of data from each. See the following code:
This gives you 500 episodes in total for training. You can gather more trajectories from the environments by interacting with multiple environments simultaneously.
- Start from a random policy with uniform action probabilities for all state features. See the following code:
The average cumulative reward over 500 episodes is 22.22
- Save the dataframe as a CSV file to use later. See the following code:
Training an RL policy with the offline data
You now have the offline data and can train an RL policy with it.
In this post, you use a deep RL with double Q-learning (DDQN) algorithm to update the policy in an off-policy manner. For more information, see Deep Reinforcement Learning with Double Q-learning on ArXiv. You combine it with a batch-constrained deep Q-learning BCQ algorithm to address the error induced by inaccurately estimated values for unseen state-action pairs. The training is completely offline. While DDQN addresses the potential overestimation issue of the typical Q-learning in RL, BCQ is designed to learn an improved policy for a given dataset with restrictions on the actions to mitigate the extrapolation error. The dataset must have exploratory interactions for the algorithm to learn anything useful. For more information, see the research paper Off-Policy Deep Reinforcement Learning without Exploration.
RL parameters are captured in
preset-cartpole-ddqnbcq.py. You can define agent parameters to select the specific agent algorithm by using the preset file. You can also define schedule parameters, offline dataset parameters, and visualization parameters. In this preset file, you use
BatchRLGraphManager without setting up parameters for the environment. See the following code:
Use the Amazon SageMaker
RLEstimator object to point to a script
train-coach.py, which provides the training code. For more information, see the GitHub repo. This post uses Amazon SageMaker script mode. This allows you to customize the training script you might develop in your local environment (such as a laptop or Amazon SageMaker notebook). For more information, see Using TensorFlow eager execution with Amazon SageMaker script mode.
The instance type changes based on the
local_mode setup. On an Amazon SageMaker notebook instance, you can train an RL model using local mode, in which you train the policy on the container launched on the notebook instance itself, which speeds up iterative testing and debugging. For more information, see Use the Amazon SageMaker local mode to train on your notebook instance. Amazon SageMaker allows you to switch seamlessly between training locally and distributed, managed training by simply changing one line of code. See the following code:
Store intermediate training output and model checkpoints with the following code:
You can now visualize metrics for the training job. Pull the off-policy evaluation (OPE) metric of the training by using the intermediate results and plot it to see the performance of the model over time.
You can use a set of methods to investigate the performance of the current trained policy without interacting with the simulator or live environment. You can use them to estimate the effectiveness of the policy based on the dataset collected from other policies. The following code uses two OPE metrics: weighted importance sampling (WIS) and sequential doubly robust (SDR):
The following graph shows that these metrics improved as the learning agent iterates multiple epochs over the given dataset.
Evaluating the RL policy performance
To evaluate the model trained with offline data, you need to see the accumulative rewards of the agent by interacting with the environment. Use the last checkpointed model to run an evaluation for the RL agent. The checkpointed data from the previously trained models is passed on for evaluation and inference in the checkpoint channel. See the following code:
The following screenshot shows that the total reward of the agent is 195.33.
When you compare this to the cumulative reward of 22.22 achieved with the random policy, you can see the improvement in the RL policy that you trained using the offline data. The exact reward values might be different when you execute the notebook due to the randomness in the collected dataset.
Using Amazon SageMaker batch transform to make batch inferences
After you train the RL policy, you can use batch transform, in which you provide a set of input state features and get inferences with high throughput. You can use the inferences and the resulting environment rewards to prepare offline data for the next batch RL policy training.
The following code uses the states of the environments as input for the batch transform:
You can see how to use batch transform inferences in production to train the next RL policy. For this post, you use simulated environments to collect rollout data of a random policy. Assuming that the updated policy is now good enough to deploy, you can use batch transform to collect rollout data from this policy. The process includes the following steps:
- You use batch transform to get action predictions, provided observation features from the live environment at timestep t.
- The deployed agent takes the suggested actions against the environment (simulator or real) at timestep t.
- The environment returns new observation features and rewards for timestep t+1.
- You return to using batch transform to get action predictions at timestep t+1.
This iterative procedure enables you to collect a set of data that can cover the whole episode. When data is sufficient, you can use the data to kick off batch RL training again.
This post showed you the detailed procedure for implementing batch RL with Amazon SageMaker RL. While this post uses a simple game of CartPole to detail the various steps involved, you can make the appropriate changes to apply these steps to other problems. Additionally, the final solution on the GitHub repo shows how to train the batch RL policy. The post also demonstrated how you can use batch transform functionality to collect off-policy data and use it to train future RL policies. You can collect rollout data from millions of user contexts efficiently and concurrently by using batch transform, and use the collected rollout data to train a better policy.
Now that you have seen how to use batch RL for the CartPole problem, you can apply this technique for other RL problems in which you need to train a new policy with data collected from previously deployed policies. For example, in an email campaign, each email user is an independent episode interacting with the deployed policy. You can apply the techniques used in this post to a new problem, iterate quickly, and train and deploy at scale on a cluster managed by Amazon SageMaker.
Give batch RL a try, and share your feedback and questions in the comments.
About the Authors
Sireesha Muppala is a AI/ML Specialist Solutions Architect at AWS, providing guidance to customers on architecting and implementing machine learning solutions at scale. She received her Ph.D. in Computer Science from University of Colorado, Colorado Springs. In her spare time, Sireesha loves to run and hike Colorado trails.
Yijie Zhuang is a Software Engineer with AWS SageMaker. He did his MS in Computer Engineering from Duke. His interests lie in building scalable algorithms and reinforcement learning systems. He contributed to Amazon SageMaker Built-in Algorithms and Amazon SageMaker RL.
Anna Luo is an Applied Scientist in the AWS. She works on utilizing RL techniques for different domains including supply chain and recommender system. She received her Ph.D. in Statistics from University of California, Santa Barbara.
Bharathan Balaji is a Research Scientist in AWS and his research interests lie in reinforcement learning systems and applications. He contributed to the launch of Amazon SageMaker RL and AWS DeepRacer. He received his Ph.D. in Computer Science and Engineering from University of California, San Diego.