This post will explore meta reinforcement learning with only minimal math. We attempt to dive into the core concepts without making it too complicated.
Reinforcement Learning
Start by recalling that the RL problem can be defined by the attributes of a Markov Decision Process (MDP) with components of the set of states, action space, transition matrix, reward function and initial starting state: \(\{ \mathcal{S}, \mathcal{A}, \mathcal{T}, \mathcal{R}, s_0 \}\). We solve an RL problem by interacting with the environment through some trajectory \(\tau\) to maximize rewards. More specifically, an agent starts at some initial state \(s_0\), takes actions \(a \in \mathcal{A}\) which will take it to some new state \(s\) and recieve some reward \(r_t\). By maximizing the sum of discounted rewards, we have optimized the Bellman equation to arrive at a good agent.
Meta-learning
Traditional supervised learning aims to find an optimal model that is great at performing some target task during test time. In contrast, meta-learning tries to develop a model \(\theta^{\star}\) that is not necessarily good to start with, but can become great at the target task using only a few update steps. In other words, we develop models that have good performance in the few-shot learning setup. To do so, a model ought to learn how to learn.
To understand how this works, we take a brief detour into multi-task learning. Multi-task learning wants a model that can perform well on many tasks at once. A key difficulty is to learn each incremental task without catastrophic forgetting of previously learned tasks. The meta-RL setup also contains many tasks, which we refer to as support tasks, which can be used to aid the learning of the target task. While multi-task learning aims to do well at all the task, classic meta-learning only cares about doing well on the one target task. However, meta-learning can be viewed as being more difficult since it tries to learn the target task with substantially fewer datapoints during test time.
One way of tackling meta-learning is to use metric based methods, many of which boil down to smarter versions of kNNs. Suppose we have ten classes of images from CIFAR10. During training, we have to embed all the images into a shared embedding space, but during training we simply match our test image to the nearest one we’ve seen before. Note that two areas to tweak include (1) how we embed these examples and (2) how we calculate the distance between two examples. A Matching Network (Vinyals et al., 2016) makes this a bit smarter by matching to a weighted combination of the nearest neighbors rather than to just the single nearest image. Their embedding is a BiLSTM or CNN and the distance is cosine similarity. A Relation Network (Sung et al., 2018) learn the a distance metric instead using a CNN that outputs a similarity score. A Prototypical Network (Snell, Swersky & Zemel, 2017) change the embeddings during training to be the average of the support sets. Given nine support classes, a “prototype” vector is calculated for each class which is the weighted average of the embedded examples in that class. The distance used is a squared Euclidean distance. We can imagine many other versions based on other embedding functions and learned distance metrics.
The other way of tackling meta-learning is through gradient based methods, popularized by Model-Agnostic Meta-Learning (MAML) from Finn et al. in 2017. What we ultimately want is a model parameterized by \(\theta^{\star}\) that can learn with just a few datapoints (let’s say 20). So we simply mimic this setup during training. Start with some initial model \(\theta_t\) and pass in a bunch of training examples from the support set. Take some gradient update steps to achieve an updated model \(\hat{\theta_t}\), which we refer to as the inner-loop. Now, pass in 20 training examples for the target set (to match what we expect during test time) through \(\hat{\theta_t}\) to obtain new gradients. Use these gradients to update \(\theta_t\), which results in \(\theta_{t+1}\). Next, repeat the process with new training data, reaching \(\theta_{t+2}\). We call this training the outer loop. After some n rounds of outer loop training we have model \(\theta_{t+n}\) which we use as \(\theta^{\star}\) during test time. Model \(\theta^{\star}\) should work reasonably well since it should have learned basic concepts like edges and shapes that help it learn quickly in new settings.
Meta Reinforcement Learning
One way to view meta-learning is that rather than training on a bunch of datapoints, we instead train on a bunch of support tasks. This gives another distinction with multi-task learning, namely that meta-learning improves as the number of tasks increase, whereas multi-task performance will typically decrease as the number of tasks grow. To transfer into meta-RL, we just recognize that reinforcement learning tasks as defined by their MDPs. In other words, rather than training on a bunch of support tasks, we train on a bunch of MDPs and test on the target MDP during test time.
Meta-learning in a supervised setting tries to learn something useful of the problem space shared across different support tasks. As an example, if our support tasks include question answering about celebrity gossip, news articles and tech blogs, we could hope that the model meta-learns about reading comprehension. Then during test time, when our target task is question answering about financial updates, the model only needs to learn finance jargon since it already knows how to read. For reinforcement learning, if our support MDPs include unique reward functions for a humanoid learning to run left, forward and backward, we could hope that the model has learned to balance and move. Then, during test time, when our target MDP offers rewards for running to the right, the humanoid can learn this quickly.
Optimization Methods
At a high-level, there are at least three methods for optimizing a meta-RL model (Rakelly 2019). These forms are all semantically equivalent, but can be seen as different ways of solving the same problem.
1. Optimization
We can train our reinforcement learning method using the MAML style gradients as we have already encountered. If we optimize our model using policy gradients algorithm, then we are pushing up the likelihood of states that lead to higher rewards and pushing down states that lead to lower reward, which is similar to how we react to losses in supervised learning. So the transition from regular meta-learning to RL-based meta-learning is to learn from reward signals rather than loss signals. Just like regular meta-learning needed to develop first-order MAML (Finn et al., 2017) to get around the double-gradient issue, RL-based meta-learning also needs to make some simplifications along the way. The major benefit of viewing meta-RL as optimization is that it works just like regular deep learning and will converge to the optimal solution given enough data. We say that this method is “consistent”. The downside is that if we face sparse rewards, the model will suffer even worse than in regular RL, which is already not so great.
2. Recurrence
We would like a model to meta-learn something about the environment even if we didn’t get a reward for that episode, such as how to move around in the simulated world and the fact that the current trajectory was bad (thus leading to no reward). Once way to keep track of such data is to feed each trajectory through an RNN instead. Then the hidden states h of the RNN can keep track of such meta-data. A single RNN will run multiple episodes to produce multiple inner-gradient updates, which are then used to produce a single outer-gradient updates. The hidden state of the RNN is reset to its initial \(h_0\) after each outer-loop gradient update, and the process can repeat again. While recurrence is powerful at capturing more detail in the hidden state, this information is latent and may not end up capturing the details we care about (aka. “not expressive”).
3. Contextual
As an improvement on the recurrence method, we can bias the details we want to capture. Suppose the model knows the MDP is it operating based on the context (ie. the rewards it has recieved so far). This will help the model produce temporally coherent trajectories when accomplishing the task, as opposed to taking random actions as it starts to explore the state space. Consequently, the contextual method aims to predict the latent MDP z before predicting the state, which produces more structured exploration during the inner loop (Gupta et al., 2018) . Concretely, we would aim to learn a model that first learns \(q(z|c)\), where c is the context and then updates the \(\pi_{\theta}(a|s,z)\). We learn z with variational inference where we optimize the ELBO with \(D_{KL} [q_{\phi}(z|c_i) || p(z)]\) where \(p(z)\) is a normal Gaussian. During inference, the model will choose actions conditioned on its best estimate of the environment parameters based on the context.
Conclusion
Meta-RL is a powerful tool for training reinforcement learning algorithms with limited data. When applied to dialogue policy management, we can imagine learning a policy that learns to react to its environment with limited experience. Given that marketing teams will often segment customers into different user types (eg. power users, core users, casual users), we can perhaps train agents that can adapt to different user segments in short time frames.