Categorical DQN

An implimentation of a DQN that uses distributions to represent Q from the paper A Distributional Perspective on Reinforcement Learning

The Categorical DQN can be summarized as:

Instead of action outputs being single Q values, they are instead distributions of `N` size.

We start off with the idea of atoms and supports. A support acts as a mask over the output action distributions. This is illistrated by the equations and the corresponding functions.

We start with the equation…

\[ {\large Z_{\theta}(z,a) = z_i \quad w.p. \: p_i(x,a):= \frac{ e^{\theta_i(x,a)}} {\sum_j{e^{\theta_j(x,a)}}} } \]

… which shows that the end of our neural net model needs to be squished to be a proper probability. It also defines \(z_i\) which is a support we will define very soon. Below is the implimentation of the right side equation for \(p_i(x,a)\)

An important note is that $ {_j{e^{_j(x,a)}}} $ is just:


We pretend that the output of the neural net is of shape (batch_sz,n_actions,n_atoms). In this instance, there is only one action. This implies that \(Z_{\theta}\) is just \(z_0\).

out=Softmax(dim=1)(torch.randn(1,51,1))[0] # Action 0

The next function describes how propabilities are calculated from the neural net output. The equation describes a \(z_i\) which is explained by: \[ \{z_i = V_{min} + i\Delta z : 0 \leq i < N \}, \: \Delta z := \frac{V_{max} - V_{min}}{N - 1} \]

Where \(V_{max}\), \(V_{min}\), and \(N\) are constants that we define. Note that \(N\) is the number of atoms. So what does a \(z_i\) look like? We will define this in code below…


 create_support (v_min=-10, v_max=10, n_atoms=51)

Creates the support and returns the z_delta that was used.

print('z_delta: ',z_delta)
z_delta:  0.4

This is a single \(z_i\) in \(Z_{\theta}\). The number of \(z_i\)s is equal to the number of actions that the DQN is operating with.

Note: Josiah: Is this always the case? Could there be only \(z_0\) and multiple actions?

Ok! Hopefully this wasn’t too bad to go through. We basically normalized the neural net output to be nicer to deal with, and created/initialized a (bunch) of increasing arrays that we are calling discrete distributions i.e. output from create_support.

Now for the fun part! We have this giant ass update equation:

\[ {\large (\Phi\hat{\mathcal{T}}Z_{\theta}(x,a))_i = \sum_{j=0}^{N-1} \left[ 1 - \frac{ | \mathcal{T}z_j |_{V_{min}}^{V_{max}} - z_i }{ \Delta z } \right]_0^1 p_j(x^{\prime},\pi(x^{\prime})) } \] Good god… and we also have

\[ \hat{\mathcal{T}}z_j := r + \gamma z_j \]

where, to quote the paper:

“for each atom \(z_j\), [and] then distribute its probability $ p_j(x{},(x{})) $ to the immediate neighbors of $ z_j $”

I highly recommend reading pg6 in the paper for a fuller explaination. I was originally wondering what the difference was between \(\pi\) and simple \(\theta\), which the main difference is that \(\pi\) is a greedy action selection i.e. we run argmax to get the action.

This was a lot! Luckily they have a re-formalation in algorithmic form:

def categorical_update(v_min,v_max,n_atoms,support,delta_z,model,reward,gamma,action,next_state):
    m=torch.zeros((N,)) # m_i = 0 where i in 1,...,N-1
    for j in range(n_atoms):
        # Compute the projection of $ \hat{\mathcal{T}}z_j $ onto support $ z_j $
        b_j=(target_z-v_min)/delta_z # b_j in [0,N-1]
        # Distribute probability of $ \hat{\mathcal{T}}z_j $
    return # Some cross entropy loss

There is a small problem with the above equation. This was a (fairly) literal convertion from Algorithm 1 in the paper to Python. There are some problems here:
- The current setup doesnt handle batches - Some of the variables are a little vague - Does not handle terminal states

Lets rename these! We will instead have:
\[ m\_i \rightarrow projection\\ a\_star \rightarrow next\_action\\ b\_j \rightarrow support\_value\\ l \rightarrow support\_left\\ u \rightarrow support\_right\\ \]

So lets revise the problem and pretend that we have a 2 action model, batch size of 8, where the last element has a reward of 0, and where left actions are -1, while right actions are 1.

from torch.distributions.normal import Normal

So for a single action we would have a distribution like this…


So since our model has 2 actions that it can pick, we create some distributions for them…

(torch.Size([1, 2, 51]), torch.Size([1, 2, 51]))

…where the \([1, 2, 51]\) is \([batch, action, n\_atoms]\)

model_out=torch.vstack([copy([dist_left,dist_right][i%2==0]) for i in range(1,9)]).to(device=default_device())
torch.Size([8, 2, 51])
(torch.Size([8, 2]),
 tensor([[0.1954, 0.8046],
         [0.0060, 0.9940],
         [0.1954, 0.8046],
         [0.0060, 0.9940],
         [0.1954, 0.8046],
         [0.0060, 0.9940],
         [0.1954, 0.8046],
         [0.0060, 0.9940]], device='cuda:0'))

So when we sum/normalize the distrubtions per batch, per action, we get an output that looks like your typical dqn output…

We can also treat this like a regular DQN and do an argmax to get actions like usual…

        [1]], device='cuda:0')
        [1]], device='cuda:0')
        [ True]], device='cuda:0')

So lets decompose the categorical_update above into something easier to read. First we will note the author’s original algorithm:

We can break this into 3 different functions:
- getting the Q
- calculating the update
- calculating the loss

We will start with the \(Q(x_{t+1},a):=\sum_iz_ip_i(x_{t_1},a))\)


 CategoricalDQN (state_sz:int, action_sz:int, n_atoms:int=51, hidden=512,
                 v_min=-10, v_max=10, head_layer=<class
                 'torch.nn.modules.linear.Linear'>, activation_fn=<class

Same as nn.Module, but no need for subclasses to call super().__init__

The CategoricalDQN.q function gets us 90% of the way to the equation above. However, you will notice that that equation is for a specific action. We will handle this in the actual update function.

torch.Size([8, 2, 51])
tensor([[ 0.3418, -0.2006],
        [ 0.1096, -0.0358],
        [-0.2790,  0.0382],
        [ 0.1743,  0.0024],
        [-0.5164,  0.0867],
        [-0.0825, -0.0634],
        [-0.5792,  0.2759],
        [-0.0598, -0.0087]], device='cuda:0', grad_fn=<SumBackward1>)
tensor([[-0.0020,  0.0022],
        [-0.0043, -0.0014],
        [-0.0063,  0.0037],
        [-0.0079,  0.0028],
        [-0.0001,  0.0039],
        [-0.0001, -0.0017],
        [-0.0008,  0.0004],
        [-0.0108,  0.0052]], device='cuda:0', grad_fn=<MeanBackward1>)


 final_distribute (projection, left, right, support_value, p_a, atom,

Does: m_l <- m_l + p_j(x_{t+1},a*)(u - b_j) operation for final states.


 distribute (projection, left, right, support_value, p_a, atom, done)

Does: m_l <- m_l + p_j(x_{t+1},a*)(u - b_j) operation for non-final states.


 categorical_update (support, delta_z, q, p, actions, rewards, dones,
                     v_min=-10, v_max=10, n_atoms=51, gamma=0.99,
                     passes=None, nsteps=1, debug=False)


 show_q_distribution (cat_dist, title='Update Distributions')

cat_dist being shape: (bs,n_atoms)


show_q_distribution(output,title='Real Model Update Distributions')


 PartialCrossEntropy (p, q)


 CategoricalTargetQCalc (*args, **kwds)

 MultiModelRunner (*args, **kwds)

If a model contains multiple models, then we support selecting a sub model.

from torchdata.datapipes.utils import to_graph
from fastrl.envs.gym import *
from import *
# Setup Loggers
logger_base = ProgressBarLogger(epoch_on_pipe=EpocherCollector,

# Setup up the core NN
model = CategoricalDQN(4,2).to(device='cuda')
# Setup the Agent
agent = DQNAgent(model,[logger_base],max_steps=4000,device='cuda',
# Setup the DataBlock
block = DataBlock(
    GymTransformBlock(agent=agent,nsteps=2,nskips=2,firstlast=True), # We basically merge 2 steps into 1 and skip. 
# pipes = L(block.datapipes(['CartPole-v1']*1,n=10))
dls = L(block.dataloaders(['CartPole-v1']*1))
# Setup the Learner
learner = DQNLearner(model,dls,logger_bases=[logger_base],bs=128,
                     loss_func = PartialCrossEntropy,
loss episode rolling_reward epoch batch epsilon
2.8483639 65 30.870000 0 1001 0.523000
from IPython.display import HTML
import as px
from torchdata.dataloader2.graph import find_dps,traverse


 show_q (cat_dist, title='Update Distributions')

cat_dist being shape: (bs,n_atoms)