Experience Replay

Experience Replay is likely the simplest form of memory used by RL agents.

ExperienceReplay

 ExperienceReplay (*args, **kwds)

Simplest form of memory. Takes steps from source_datapipe to stores them in memory. It outputs bs steps.

lets generate some batches to test with…

from fastrl.pipes.core import *
from fastrl.data.block import *
from fastrl.envs.gym import *
def baseline_test(envs,total_steps,seed=0):
    pipe = GymTransformBlock(None,n=total_steps,seed=seed)(envs)
    pipe = pipe.unbatch()
    return list(pipe), pipe

@delegates(ExperienceReplay)
def exp_replay_test(envs,total_steps,seed=0,**kwargs):
    pipe = GymTransformBlock(None,n=total_steps,seed=seed)(envs)
    pipe = pipe.unbatch()
    pipe = ExperienceReplay(pipe,**kwargs)
    if total_steps is None: return None,pipe
    return list(pipe), pipe
steps, experience_replay = exp_replay_test(['CartPole-v1'],0,bs=1)
test_eq(len(experience_replay),0)

what if we fill up ER? Lets add the batches, this process will happen inplace…

steps, experience_replay = exp_replay_test(['CartPole-v1'],10,max_sz=20)
test_eq(experience_replay._sz_tracker,10)
test_eq(experience_replay._idx_tracker,10)
test_eq(experience_replay._cycle_tracker,0)
test_len(experience_replay,10)

If we run 10 more times, the total size should be 20…

steps = [step for step,_ in zip(*(range(10),experience_replay))]
test_eq(experience_replay._sz_tracker,20)
test_eq(experience_replay._idx_tracker,20)
test_eq(experience_replay._cycle_tracker,0)
test_len(experience_replay,20)

experience_replay memory should contain identical steps to if we just run without it…

steps, pipe = baseline_test(['CartPole-v1'],20,seed=0)
_, experience_replay = exp_replay_test(['CartPole-v1'],20,max_sz=20)

for i,(baseline_step,memory_step) in enumerate(zip(steps,experience_replay.memory)):
    test_eq(baseline_step.state,memory_step.state)
    test_eq(baseline_step.next_state,memory_step.next_state)
    print('Step ',i)
Step  0
Step  1
Step  2
Step  3
Step  4
Step  5
Step  6
Step  7
Step  8
Step  9
Step  10
Step  11
Step  12
Step  13
Step  14
Step  15
Step  16
Step  17
Step  18
Step  19

Since the max_sz is 20, and so far we have run a total of 20 steps, if we run another 10 steps, the _cycle_tracker should be 1 (since this is a new cycle),_idx_tracker should be 10 since it should have reset and stopped half way in the memory. The _sz_tracker should still be 20.

_, experience_replay = exp_replay_test(['CartPole-v1'],None,max_sz=20)
list(experience_replay.header(19))

steps = [step for step,_ in zip(*(range(10),experience_replay))]
test_eq(experience_replay._sz_tracker,20)
test_eq(experience_replay._idx_tracker,10)
test_eq(experience_replay._cycle_tracker,1)
test_len(experience_replay,20)

…and if we run the baseline, the last 10 steps in the baseline, should match the first 10 steps in memory since it is in the middle of re-writing the memory due to being at max size.

steps, pipe = baseline_test(['CartPole-v1'],30)

for baseline_step,memory_step in zip(steps[20:],experience_replay.memory[:10]):
    test_eq(baseline_step.state,memory_step.state)
    test_eq(baseline_step.next_state,memory_step.next_state)

Finally we want to finish writing over the memory in its entirety.

steps = [step for step,_ in zip(*(range(10),experience_replay))]
test_eq(experience_replay._sz_tracker,20)
test_eq(experience_replay._idx_tracker,20)
test_eq(experience_replay._cycle_tracker,1)
test_len(experience_replay,20)
steps, pipe = baseline_test(['CartPole-v1'],40)

for baseline_step,memory_step in zip(steps[20:],experience_replay.memory):
    test_eq(baseline_step.state,memory_step.state)
    test_eq(baseline_step.next_state,memory_step.next_state)

Let’s verify that the steps are what we expect…

What if we sample the experience?

steps, experience_replay = exp_replay_test(['CartPole-v1'],1000,bs=300,max_sz=1000)
memory = None
for i,sample in enumerate(experience_replay):
    for s in sample:
        if memory is not None: test_ne(s,memory)
        memory = copy(s)
    if i>100:break

We should be able to sample enough times that we have sampled everything. So we test this by sampling, check if that sample has been seen before, and then record that.

steps, experience_replay = exp_replay_test(['CartPole-v1'],1000,bs=1,max_sz=30,return_idxs=True)
memory_hits = [False]*30
for i in range(150):
    res,idxs = experience_replay.sample()
    for idx in idxs: memory_hits[idx] = True
test_eq(all(memory_hits),True)