Meta Learning — Meta-Gradient Reinforcement Learning — An Implementation

Hassaan Naeem
5 min readMay 9, 2022
Photo by Dylan Leagh on Unsplash

… continuing from where I left off. Before getting to the crux of the issue and implementing the paper I wanted to implement, I must still do one more (seemingly necessary) explanation and implementation. This is turning out to be quite the yak shave, which is all too befitting to the area of “meta” learning.

Meta-Gradient Reinforcement Learning

The meta-gradient RL algorithm, as the authors mention in the paper, is based on the principle of online cross-validation.

What is online cross-validation? Here is a comprehensive breakdown if you are so obliged. Simply put, for a time ordered data set, you take the first k examples as a training set, and using the remaining k+1… examples, as a test set and use this to judge your model.

In a similar manner, the meta-gradient algorithm works by, “rolling out” an underlying RL algorithm (in the first version of this implementation, it will be A2C), on k samples, and then measuring performance on a subsequent sample.

Xu et al., 2018 (Supplemental)

As seen above, (the respective policy and value functions) will be initialized according to parameters 𝜃. We then sample trajectories 𝜏, this is nothing more than a sequence of experiences, each having a corresponding set of S (states), A (actions), and R (rewards). I refer to this as “rolling out”. We also sample a second set of trajectories 𝜏′. This second set can just be a continuation of where the previous sample left off, which will be the case in an online RL setting. It is important not to get confused over the sequential structuring of the pseudocode above. When sampling a second time, we have not yet updated the parameters of the agent, so it not as if the second sample will be done on old parameters, think of the samples more or so as a retrieval of sequential events. In an RL setting, the second sampling would be done after computing the objective and updating parameters.

Lets take a look at f(𝜏, 𝜃, 𝜂) (update function) and 𝐽(𝜏, 𝜃, 𝜂) (prediction objective of TD(λ)). Taking the grad w.r.t 𝜃, some fancy use of the chain rule, and multiplicative constants we arrive at:

Xu et al., 2018

where g_𝜂 is the λ-return, and v_ 𝜃 is the approximated value function. The update function above can be differentiated w.r.t. 𝜂 (the meta parameters), and is shown below:

Xu et al., 2018

The key take away is that now the meta-objective 𝐽¯(𝜏′, 𝜃′, 𝜂¯) can be optimized using the second set of trajectories 𝜏′:

Xu et al., 2018

Take a look at the paper and Sutton and Barto, 2018 for a precise and unwatered explanation of the derivations above.

Meta-Gradients and A2C

In the context of A2C, the model architecture will be as follows: Actor Network, Critic Network, Memory, and an Agent.

Actor Network

This will represent the policy. Three linear layers, with input dimensionality of the environment observation space, any sensible choice for hidden dimensionality, and output dimensionality of the environment action space. The actor has a choose action method, which samples the policy for an action. It also has saving and loading methods.

Critic Network

This will represent the value function. Three linear layers, with input dimensionality of the environment observation space, any sensible choice for hidden dimensionality, and output dimensionality 1, since it is just returning to us a value for the corresponding state.

Memory

This is just a helper class. Stores states, actions, and rewards during the “roll out” phase, and will be used to train both the actor and critic.

Agent

This is the meat of the model. This has a few helper methods to calculate discounted reward (with and without grad functionality). The rollout method samples a trajectory, an important caveat to note is that for efficiency this sample is only conducted for n-steps of the sample, meaning the trajectory will only be played out for those n-steps. This ‘myopia’ is a problem and is what the BMG paper aims to fix. There are two separate methods to calculate gradients of the two networks and also the meta-parameter. Another important note is that this implementation is only meta-learning 𝛾 (the discount factor). The run method is where all the learning takes place.

Learning To Learn

The learning mechanism, per episode, is as follows:

  • An n-step trajectory is “rolled-out”.
  • The usual A2C algorithm is played out, computing an actor and critic loss.
  • Then the pre-update gradients are computed, the two important semi-grads here are ∂𝜋/∂𝜃 (policy) and ∂g/∂𝛾 (the λ-return).
  • The optimizers are then stepped, meaning the agent/policy parameters 𝜃′ are now updated.
  • Following this, another n-step trajectory is “rolled-out” from its previous state. This trajectory is the cross-validation sample.
  • After the roll-out is complete, the new policy and value functions are computed, and once again gradients are computed w.r.t these two new functions.
  • Finally, the meta-parameter (𝛾) can be updated. This update is as follows , 𝛾 = 𝛾 − 𝛽 ∗ ∂g/∂𝛾 ∗ (∂𝜋/∂𝜃 ∗ ∂𝜋′/∂𝜃′). This is line 9 of the pseudocode, under the assumption that the A term in Equations 3,4 of the paper is 0. 𝛽 is a hyper-parameter, the meta-learning rate. Refer to the paper for more clarity.

Some clarity for the meta-parameter update, under the previously stated assumption:

𝛾 = 𝛾 − 𝛽 ∗ (∂𝐽¯(𝜏′, 𝜃′, 𝜂¯)/∂𝜃′) ∗ (∂f(𝜏, 𝜃, 𝜂)/∂𝜂)

∂𝐽¯(𝜏′, 𝜃′, 𝜂¯)/∂𝜃′ = (g_n¯(𝜏′) − v_𝜃′(S′)) ∗ (∂log_𝜋_𝜃′(A′|S′)/∂𝜃′) = gradient of policy (actor network) w.r.t original parameters (𝜃) on trajectory 𝜏, multiplied by gradient of policy w.r.t updated parameters (𝜃′) on trajectory 𝜏′. I have called these ∂𝜋/∂𝜃 and ∂𝜋′/∂𝜃′ respectively (in code this is just a matrix multiplication of the two).

f(𝜏, 𝜃, 𝜂)/∂𝜂 = 𝛼 ∗(∂g_n(𝜏)/∂𝜂) ∗ (∂v_𝜃(S)/∂𝜃) = gradient of the of the returns w.r.t 𝜂 (in this case only 𝛾), on the original parameters (𝜃), which I have called ∂g/∂𝛾.

There a code-snippet at the end of the run method to conduct a test checking the performance of the policy.

This summary may be a bit convoluted, I will clean it up when I get the chance. Reach out if you spot any errors, issues, or improvements.

The code repo is located here: https://github.com/l4fl4m3/meta-a2c

Credit to this repo where my implementation is inspired from.

To be continued …

--

--