DDPG

Deep Deterministic Policy Gradiant for continuous action domains

(Lillicrap et al., 2016) [DDPG] Continuous Control with Deep Reinforcement Learning based on the DPG algorithm in (Silver et al., 2014) [DPG] Deterministic Policy Gradient Algorithms.

DDPG uses an actor-critic architecture and has a similar training / learning paradym to DQNs.

Below is (Lillicrap et al., 2016) Algorithm 1 that summarizes DDPG.

Model


init_xavier_uniform_weights

 init_xavier_uniform_weights (m:fastrl.torch_core.Module, bias=0.01)

Initializes weights for linear layers using torch.nn.init.xavier_uniform_


init_uniform_weights

 init_uniform_weights (m:fastrl.torch_core.Module, bound)

Initializes weights for linear layers using torch.nn.init.uniform_


init_kaiming_normal_weights

 init_kaiming_normal_weights (m:fastrl.torch_core.Module, bias=0.01)

Initializes weights for linear layers using torch.nn.init.kaiming_normal_

Lilicrap et al., 2016 pg 11 notes: “The other layers were initialized from uniform distributions \([ \frac{-1}{\sqrt{f}},\frac{1}{\sqrt{f}}]\) where f is the fan-in of the layer.”

init_kaiming_normal_weights is the most similar to this strategy. Other implimentations of DDPGs have also used init_xavier_uniform_weights

Note: There does not appear to be a major difference between performance of using either.

The same page notes: “final layer weights and biases of both the actor and critic were initialized from a uniform distribution \([−3 * 10^{−3}, 3 * 10^{−3}]\) and \([3 * 10^{−4}, 3 * 10^{−4}]\) for the low dimensional and pixel cases respectively.”, so the default value for final_layer_init_fn uses init_uniform_weights with a bound of 1e-4 for low dim, and if pixels, needs to be changed to 1e-5.

The same page notes: “The low-dimensional networks had 2 hidden layers with 400 and 300 units respectively … When learning from pixels we used 3 convolutional layers (no pooling) with 32 filters at each layer. This was followed by two fully connected layers with 200 units”

We default to expect low-dimensions, and for images we will augment this.


ddpg_conv2d_block

 ddpg_conv2d_block (state_sz:Tuple[int,int,int], filters=32,
                    activation_fn=<class
                    'torch.nn.modules.activation.ReLU'>,
                    ignore_warning:bool=False)

Creates a 3 layer conv block from state_sz along with expected n_feature output shape.

Type Default Details
state_sz typing.Tuple[int, int, int] A tuple of state sizes generally representing an image of format:
[channel,width,height]
filters int 32 Number of filters to use for each conv layer
activation_fn type ReLU Activation function between each layer.
ignore_warning bool False We assume the channels dim should be size 3 max. If it is more
we assume the width/height are in the location of channel and need to
be transposed.
Returns typing.Tuple[torch.nn.modules.container.Sequential, int] (Convolutional block,n_features_out)
ddpg_conv2d_block((3,100,100))
(Sequential(
   (0): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (1): Conv2d(3, 3, kernel_size=(32, 32), stride=(1, 1))
   (2): ReLU()
   (3): Conv2d(3, 3, kernel_size=(32, 32), stride=(1, 1))
   (4): ReLU()
   (5): Conv2d(3, 3, kernel_size=(32, 32), stride=(1, 1))
   (6): Flatten(start_dim=1, end_dim=-1)
 ),
 147)

Critic

 Critic (state_sz:int, action_sz:int, hidden1:int=400, hidden2:int=300,
         head_layer:fastrl.torch_core.Module=<class
         'torch.nn.modules.linear.Linear'>,
         activation_fn:fastrl.torch_core.Module=<class
         'torch.nn.modules.activation.ReLU'>,
         weight_init_fn:Callable=<function init_kaiming_normal_weights>,
         final_layer_init_fn:Callable=functools.partial(<function
         init_uniform_weights at 0x7f75db31ef80>, bound=0.0001), conv_bloc
         k:Union[torch.nn.modules.container.Sequential,NoneType]=None,
         batch_norm:bool=False)

Takes a 2 tensors of size [B,state_sz], [B,action_sz] -> [B,1] outputs a 1d tensor representing the Q value

Type Default Details
state_sz int The input dim of the state / flattened conv output
action_sz int The input dim of the actions
hidden1 int 400 Number of neurons connected between the 2 input/output layers
hidden2 int 300 Number of neurons connected between the 2 input/output layers
head_layer Module Linear Output layer
activation_fn Module ReLU The activation function
weight_init_fn typing.Callable init_kaiming_normal_weights The weight initialization strategy
final_layer_init_fn typing.Callable functools.partial(<function init_uniform_weights at 0x7f75db31ef80>, bound=0.0001) Final layer initialization strategy
conv_block typing.Union[torch.nn.modules.container.Sequential, NoneType] None For pixel inputs, we can plug in a nn.Sequential block from ddpg_conv2d_block.
This means that actions will be feed into the second linear layer instead of the
first.
batch_norm bool False Whether to do batch norm.

The Critic is used by DDPG to estimate the Q value of state-action pairs and is updated using the the Bellman-Equation similarly to DQN/Q-Learning and is represeted by \(Q(s,a)\)

Check that low dim input works…

torch.manual_seed(0)
critic = Critic(4,2)

state = torch.randn(1,4)
action = torch.randn(1,2)

with torch.no_grad(),evaluating(critic):
    test_eq(
        str(critic(state,action)),
        str(tensor([[0.0083]]))
    )

Check that image input works…

torch.manual_seed(0)

image_shape = (3,100,100)

conv_block,feature_out = ddpg_conv2d_block(image_shape)
critic = Critic(feature_out,2,conv_block=conv_block)

state = torch.randn(1,*image_shape)
action = torch.randn(1,2)

with torch.no_grad(),evaluating(critic):
    test_eq(
        str(critic(state,action)),
        str(tensor([[0.0102]]))
    )

Actor

 Actor (state_sz:int, action_sz:int, hidden1:int=400, hidden2:int=300,
        head_layer:fastrl.torch_core.Module=<class
        'torch.nn.modules.linear.Linear'>,
        activation_fn:fastrl.torch_core.Module=<class
        'torch.nn.modules.activation.ReLU'>,
        weight_init_fn:Callable=<function init_kaiming_normal_weights>,
        final_layer_init_fn:Callable=functools.partial(<function
        init_uniform_weights at 0x7f75db31ef80>, bound=0.0001), conv_block
        :Union[torch.nn.modules.container.Sequential,NoneType]=None,
        batch_norm:bool=False)

Takes a single tensor of size [B,state_sz] -> [B,action_sz] and outputs a tensor of actions.

Type Default Details
state_sz int The input dim of the state
action_sz int The output dim of the actions
hidden1 int 400 Number of neurons connected between the 2 input/output layers
hidden2 int 300 Number of neurons connected between the 2 input/output layers
head_layer Module Linear Output layer
activation_fn Module ReLU The activiation function
weight_init_fn typing.Callable init_kaiming_normal_weights The weight initialization strategy
final_layer_init_fn typing.Callable functools.partial(<function init_uniform_weights at 0x7f75db31ef80>, bound=0.0001) Final layer initialization strategy
conv_block typing.Union[torch.nn.modules.container.Sequential, NoneType] None For pixel inputs, we can plug in a nn.Sequential block from ddpg_conv2d_block.
batch_norm bool False Whether to do batch norm.

The Actor is used by DDPG to predict actions based on state inputs and is represeted by \(\mu(s|\theta^\mu)\)

Check that low dim input works…

torch.manual_seed(0)
actor = Actor(4,2)

state = torch.randn(1,4)

with torch.no_grad(),evaluating(actor):
    test_eq(
        str(actor(state)),
        str(tensor([[0.0101, 0.0083]]))
    )

Check that image input works…

torch.manual_seed(0)

image_shape = (3,100,100)

conv_block,feature_out = ddpg_conv2d_block(image_shape)
actor = Actor(feature_out,2,conv_block=conv_block)

state = torch.randn(1,*image_shape)
action = torch.randn(1,2)

with torch.no_grad(),evaluating(actor):
    test_eq(
        str(actor(state)),
        str(tensor([[0.0100, 0.0100]]))
    )

pipe_to_device

 pipe_to_device (pipe, device, debug=False)

Attempt to move an entire pipe and its pipeline to device

Ornstein-Uhlenbeck Exploration


OrnsteinUhlenbeck

 OrnsteinUhlenbeck (*args, **kwds)

Used for exploration in continuous action domains via temporaly correlated noise.

[1] From https://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab

[2] Cumulatively based on Uhlenbeck et al., 1930

The OrnsteinUhlenbeck for DDPG has natation:

\(\mu'(s_t)=\mu(s_t|\theta_{t}^{\mu}) + N\)

Note: (Lilicrap et al., 2016) pg 4 says “generate temporally correlated exploration for exploration efficiency in physical control problems with inertia”. This might be important to consider when training on environments that don’t require inertia.


ExplorationComparisonLogger

 ExplorationComparisonLogger (*args, **kwds)

Allows for quickly doing a “what if” on exploration methods by comparing the actions selected via exploration with the ones chosen by the model.

Below we demonstrate that the exploration works. As the number of steps increase, epsilon will decrease to zero, and so the actions slowly become more deterministic.

torch.manual_seed(0)

actions = dp.iter.IterableWrapper(
    # Batch of 4 actions with dimensions 2
    torch.randn(4,2).to(device=default_device())
)

actions = OrnsteinUhlenbeck(
    actions,
    min_epsilon=0,
    max_steps=200,
    action_sz=2,
    decrement_on_val=True,
    explore_on_val=True,
    ret_original=True
)
actions.to(device=default_device())
actions = actions.cycle(count=50)
actions = ExplorationComparisonLogger(actions)
list(actions)
actions.show()

Agent


ActionUnbatcher

 ActionUnbatcher (*args, **kwds)

Removes the batch dim from an action.


ActionClip

 ActionClip (*args, **kwds)

Restricts actions from source_datapipe between clip_min and clip_max

Interally calls torch.clip


DDPGAgent

 DDPGAgent (model:__main__.Actor,
            logger_bases:Union[fastrl.loggers.core.LoggerBase,NoneType]=No
            ne, min_epsilon:float=0.2, max_epsilon:float=1,
            max_steps:int=100, dp_augmentation_fns:Union[List[fastrl.pipes
            .core.DataPipeAugmentationFn],NoneType]=None)

Produces continuous action outputs.

Type Default Details
model Actor The actor to use for mapping states to actions
logger_bases typing.Union[fastrl.loggers.core.LoggerBase, NoneType] None LoggerBases push logs to. If None, logs will be collected and output
by the dataloader.
min_epsilon float 0.2 The minimum epsilon to drop to
max_epsilon float 1 The max/starting epsilon if epsilon is None and used for calculating epislon decrease speed.
max_steps int 100 Determines how fast the episilon should drop to min_epsilon. This should be the number
of steps that the agent was run through.
dp_augmentation_fns typing.Union[typing.List[fastrl.pipes.core.DataPipeAugmentationFn], NoneType] None Any augmentations to the DDPG agent.
Returns AgentHead

Check that given a step, we can get actions from the DDPGAgent

torch.manual_seed(0)

actor = Actor(4,2)

agent = DDPGAgent(actor)

input_tensor = tensor([1,2,3,4]).float()
step = SimpleStep(state=input_tensor)

for _ in range(10):
    for action in agent([step]):
        print(action)
[0.91868794 0.7928086 ]
[0.89038295 1.        ]
[0.59973884 0.5640401 ]
[0.8244315 0.6234891]
[0.6276505  0.41781008]
[0.59242374 0.68170094]
[0.339647   0.42176256]
[0.5469213  0.11551299]
[ 0.34022635 -0.23093009]
[ 0.33037645 -0.13671152]
from fastrl.envs.gym import GymTransformBlock
from fastrl.loggers.vscode_visualizers import VSCodeTransformBlock

Learner


BasicOptStepper

 BasicOptStepper (*args, **kwds)

Optimizes model using opt. source_datapipe must produce a dictionary of format: {"loss":...}, otherwise all non-dicts will be passed through.


LossCollector

 LossCollector (*args, **kwds)

Itercepts dictionary results generated from source_datapipe that are in the format: {'loss':tensor(...)}. All other elements will be ignored and passed through.

If filter=true, then intercepted dictionaries will filtered out by this pipe, and will not be propagated to the rest of the pipeline.


SoftTargetUpdater

 SoftTargetUpdater (*args, **kwds)

Soft-Copies model to a target_model (internal) every target_sync batches.

We use SoftTargetUpdater to update the target Critic and Actor. This is characterized by the notation:

\[ \theta^{Q'} \leftarrow \tau \theta^Q + (1 - \tau)\theta^{Q'} \] \[ \theta^{\mu'} \leftarrow \tau \theta^\mu + (1 - \tau)\theta^{\mu'} \]

For both the Critic(Q) and Actor(\(\mu\)) are slowly copied to their targets based on the value \(\tau\)


get_target_model

 get_target_model (model:Union[torch.nn.modules.module.Module,NoneType], p
                   ipe:Union[torch.utils.data.datapipes.datapipe.IterDataP
                   ipe,torch.utils.data.datapipes.datapipe.MapDataPipe],
                   model_cls:torch.nn.modules.module.Module, target_update
                   r_cls:Tuple[Union[torch.utils.data.datapipes.datapipe.I
                   terDataPipe,torch.utils.data.datapipes.datapipe.MapData
                   Pipe]]=(<class '__main__.SoftTargetUpdater'>,),
                   debug:bool=False)

Basic utility for getting the ‘target’ version of model_cls in pipe

Type Default Details
model typing.Union[torch.nn.modules.module.Module, NoneType] If model is not none, then we assume it to be the target model
and simply return it, otherwise we search for a target_model
pipe typing.Union[torch.utils.data.datapipes.datapipe.IterDataPipe, torch.utils.data.datapipes.datapipe.MapDataPipe] The pipe to start search along
model_cls Module The class of the model we are looking for
target_updater_cls typing.Tuple[typing.Union[torch.utils.data.datapipes.datapipe.IterDataPipe, torch.utils.data.datapipes.datapipe.MapDataPipe]] (<class ‘main.SoftTargetUpdater’>,) A tuple of datapipes that have a field called target_model.
get_target_model will look for these in pipe
debug bool False Verbose output

CriticLossProcessor

 CriticLossProcessor (*args, **kwds)

Produces a critic loss based on critic,t_actor,t_critic and batch StepTypes from source_datapipe where the targets and predictions are fed into loss.

This datapipe produces either Dict[Literal[‘loss’],torch.Tensor] or SimpleStep.

From (Lilicrap et al., 2016), we expect to get N transitions from \(R\) where \(R\) is source_datapipe.

\(N\) transitions \((s_i, a_i, r_i, s_{i+1})\) from \(R\) where \((s_i, a_i, r_i, s_{i+1})\) are StepType

The targets are similar to DQN since we are estimating the \(Q\) value:

\(y_i = r_i + \gamma Q' (s_{i+1}, \mu'(s_{i+1} | \theta^{\mu'})|\theta^{Q'})\)

Where \(y_i\) is the targets, \(\gamma\) is the discount**nsteps, \(Q'\) is the t_critic, \(\mu'\) is the t_actor.

\(\mu'(s_{i+1} | \theta^{\mu'})\) is the t_actors predicted actions of s_{i+1}

Update critic by minimizing the loss: \(L = \frac{1}{N}\sum_i{y_i - Q(s_i,a_i|\theta^Q))^2}\)

Where \(Q(s_i,a_i|\theta^Q)\) is critic(batch.state,batch.action) and anything with \(\frac{1}{N}\sum_i{(...)}^2\) is just nn.MSELoss

torch.manual_seed(0)
pipe = GymTransformBlock(agent=None,n=1000,bs=64,seed=0)(['Pendulum-v1'])
pipe = StepBatcher(pipe)

actor = Actor(3,1)
critic = Critic(3,1)

pipe = SoftTargetUpdater(pipe,critic)
pipe = CriticLossProcessor(pipe,critic,actor)

pipe_loss = LossCollector(pipe,main_buffers=[[]])
pipe = BasicOptStepper(pipe_loss,critic,1e-3)
list(pipe)
pipe_loss.show(title='Critic Loss over N-Steps')

ActorLossProcessor

 ActorLossProcessor (*args, **kwds)

Produces a critic loss based on critic,actor and batch StepTypes from source_datapipe where the targets and predictions are fed into loss.

(Lilicrap et al., 2016) notes: “The actor is updated by following the applying the chain rule to the expected return from the start distribution J with respect to the actor parameters”

The loss is defined as the “policy gradient” below:

\[ \nabla_{\theta^{\mu}} J \approx \frac{1}{N} \sum_i{\nabla_aQ(s,a|\theta^Q)|_{s={s_i},a={\mu(s_i)}}\nabla_{\theta^{\mu}\mu(s|\theta^Q)|_{s_i}}} \]

Where:

\(\frac{1}{N} \sum_i\) is the mean.

\(\nabla_{\theta^{\mu}\mu(s|\theta^Q)|_{s_i}}\) is the actor output.

\(\nabla_aQ(s,a|\theta^Q)|_{s={s_i},a={\mu(s_i)}}\) is the critic output, using actions from the actor.

Important: A little confusing point, \(\nabla\) is the gradient/derivative of both. The point of the loss is that we want to select actions that have critic output higher values. We can do this by first calling CriticLossProcessor to load critic with gradients, then run it again but with the actor inputs. We want the actor to have the critic produce more positive gradients, than negative i.e: Have actions that maximize the critic outputs. The confusing thing is since pytorch has autograd, the actual code is not going to match the math above, for good and bad.

TODO: It would be helpful if this documentation can be better explained.

Note: We actually multiply J by -1 since the optimizer is trying to make the value as “small” as possible, but the actual value we want to be as big as possible. So if we have a J of 100 (high reward), it becomes -100, letting the optimizer know that it is moving is the correct direction (the more negative, the better).

actor = Actor(3,1)
critic = Critic(3,1)

agent = DDPGAgent(actor,max_steps=10000)

pipe = GymTransformBlock(agent=agent,n=1000,bs=10)(['Pendulum-v1'])
pipe = StepBatcher(pipe)

pipe = ActorLossProcessor(pipe,critic,actor)

pipe_loss = LossCollector(pipe,main_buffers=[[]])
pipe = BasicOptStepper(pipe_loss,actor,1e-3)
list(pipe)
pipe_loss.show()

DDPGLearner

 DDPGLearner (actor:__main__.Actor, critic:__main__.Critic, dls:List[Union
              [torch.utils.data.datapipes.datapipe.IterDataPipe,torch.util
              s.data.datapipes.datapipe.MapDataPipe,torchdata.dataloader2.
              dataloader2.DataLoader2]], logger_bases:Union[List[fastrl.lo
              ggers.core.LoggerBase],NoneType]=None, actor_lr:float=0.001,
              actor_opt:torch.optim.optimizer.Optimizer=<class
              'torch.optim.adam.Adam'>, critic_lr:float=0.01,
              critic_opt:torch.optim.optimizer.Optimizer=<class
              'torch.optim.adam.Adam'>, critic_target_copy_freq:int=1,
              actor_target_copy_freq:int=1, tau:float=0.001, bs:int=128,
              max_sz:int=10000, nsteps:int=1, device:torch.device=None,
              batches:int=None, dp_augmentation_fns:Union[List[fastrl.pipe
              s.core.DataPipeAugmentationFn],NoneType]=None,
              debug:bool=False)

DDPG is a continuous action, actor-critic model, first created in (Lilicrap et al., 2016). The critic estimates a Q value estimate, and the actor attempts to maximize that Q value.

Type Default Details
actor Actor The actor model to use
critic Critic The critic model to use
dls typing.List[typing.Union[torch.utils.data.datapipes.datapipe.IterDataPipe, torch.utils.data.datapipes.datapipe.MapDataPipe, torchdata.dataloader2.dataloader2.DataLoader2]] A list of dls, where index=0 is the training dl.
logger_bases typing.Union[typing.List[fastrl.loggers.core.LoggerBase], NoneType] None Optional logger bases to log training/validation data to.
actor_lr float 0.001 The learning rate for the actor. Expected to learn slower than the critic
actor_opt Optimizer Adam The optimizer for the actor
critic_lr float 0.01 The learning rate for the critic. Expected to learn faster than the actor
critic_opt Optimizer Adam The optimizer for the critic
Note that weight decay doesnt seem to be great for
Pendulum, so we use regular Adam, which has the decay rate
set to 0. (Lilicrap et al., 2016) would instead use AdamW
critic_target_copy_freq int 1 Reference: SoftTargetUpdater docs
actor_target_copy_freq int 1 Reference: SoftTargetUpdater docs
tau float 0.001 Reference: SoftTargetUpdater docs
bs int 128 Reference: ExperienceReplay docs
max_sz int 10000 Reference: ExperienceReplay docs
nsteps int 1 Reference: GymStepper docs
device device None The device for the entire pipeline to use. Will move the agent, dls,
and learner to that device.
batches int None Number of batches per epoch
dp_augmentation_fns typing.Union[typing.List[fastrl.pipes.core.DataPipeAugmentationFn], NoneType] None Any augmentations to the learner
debug bool False Debug mode will output device moves
Returns LearnerHead
# Setup Loggers
logger_base = ProgressBarLogger(epoch_on_pipe=EpocherCollector,
                 batch_on_pipe=BatchCollector)

# Setup up the core NN
torch.manual_seed(0)
actor = Actor(3,1)
critic = Critic(3,1)

# Setup the Agent
agent = DDPGAgent(actor,[logger_base],max_steps=5000,min_epsilon=0.1)

# Setup the DataBlock
block = DataBlock(
    GymTransformBlock(agent=agent,nsteps=2,nskips=2,firstlast=True), 
    (GymTransformBlock(agent=agent,n=400,nsteps=2,nskips=2,firstlast=True,include_images=True),VSCodeTransformBlock())
)
dls = L(block.dataloaders(['Pendulum-v1']*1))
# Setup the Learner
learner = DDPGLearner(actor,critic,dls,logger_bases=[logger_base],
                      bs=128,max_sz=20_000,nsteps=2,
                      batches=1000)
# learner.fit(1)
learner.fit(15)
actor-loss critic-loss episode rolling_reward epoch batch
23.484377 2.7217586 10 -1606.006924 1 1001
31.69706 12.016734 20 -1542.288166 2 1001
40.06336 25.084415 30 -1472.712552 3 1001
50.46943 7.714135 40 -1450.873017 4 1001
59.180523 17.604204 50 -1406.556492 5 1001
67.14108 46.2947 60 -1338.766411 6 1001
68.84964 16.201414 70 -1135.972938 7 1001
69.09436 143.93195 80 -993.549349 8 1001
69.35192 25.970486 90 -817.256042 9 1001
73.09758 171.4342 100 -673.698936 10 1001
74.067116 261.42377 110 -546.182211 11 1001
71.98171 119.34804 119 -492.149361 12 1001
72.43381 239.37431 129 -445.763457 13 1001
74.57726 23.358807 139 -408.875556 14 1001
61.790493 230.8901 149 -369.512350 14 1001