from fastrl.pipes.core import *
from fastrl.data.block import *
from fastrl.envs.gym import *
Experience Replay
ExperienceReplay
ExperienceReplay (*args, **kwds)
Simplest form of memory. Takes steps from source_datapipe
to stores them in memory
. It outputs bs
steps.
lets generate some batches to test with…
def baseline_test(envs,total_steps,seed=0):
= GymTransformBlock(None,n=total_steps,seed=seed)(envs)
pipe = pipe.unbatch()
pipe return list(pipe), pipe
@delegates(ExperienceReplay)
def exp_replay_test(envs,total_steps,seed=0,**kwargs):
= GymTransformBlock(None,n=total_steps,seed=seed)(envs)
pipe = pipe.unbatch()
pipe = ExperienceReplay(pipe,**kwargs)
pipe if total_steps is None: return None,pipe
return list(pipe), pipe
= exp_replay_test(['CartPole-v1'],0,bs=1)
steps, experience_replay len(experience_replay),0) test_eq(
what if we fill up ER? Lets add the batches, this process will happen inplace…
= exp_replay_test(['CartPole-v1'],10,max_sz=20)
steps, experience_replay 10)
test_eq(experience_replay._sz_tracker,10)
test_eq(experience_replay._idx_tracker,0)
test_eq(experience_replay._cycle_tracker,10) test_len(experience_replay,
If we run 10 more times, the total size should be 20…
= [step for step,_ in zip(*(range(10),experience_replay))]
steps 20)
test_eq(experience_replay._sz_tracker,20)
test_eq(experience_replay._idx_tracker,0)
test_eq(experience_replay._cycle_tracker,20) test_len(experience_replay,
experience_replay
memory should contain identical steps to if we just run without it…
= baseline_test(['CartPole-v1'],20,seed=0)
steps, pipe = exp_replay_test(['CartPole-v1'],20,max_sz=20)
_, experience_replay
for i,(baseline_step,memory_step) in enumerate(zip(steps,experience_replay.memory)):
test_eq(baseline_step.state,memory_step.state)
test_eq(baseline_step.next_state,memory_step.next_state)print('Step ',i)
Step 0
Step 1
Step 2
Step 3
Step 4
Step 5
Step 6
Step 7
Step 8
Step 9
Step 10
Step 11
Step 12
Step 13
Step 14
Step 15
Step 16
Step 17
Step 18
Step 19
Since the max_sz
is 20, and so far we have run a total of 20 steps, if we run another 10 steps, the _cycle_tracker
should be 1 (since this is a new cycle),_idx_tracker
should be 10 since it should have reset and stopped half way in the memory. The _sz_tracker
should still be 20.
= exp_replay_test(['CartPole-v1'],None,max_sz=20)
_, experience_replay list(experience_replay.header(19))
= [step for step,_ in zip(*(range(10),experience_replay))]
steps 20)
test_eq(experience_replay._sz_tracker,10)
test_eq(experience_replay._idx_tracker,1)
test_eq(experience_replay._cycle_tracker,20) test_len(experience_replay,
…and if we run the baseline, the last 10 steps in the baseline, should match the first 10 steps in memory since it is in the middle of re-writing the memory due to being at max size.
= baseline_test(['CartPole-v1'],30)
steps, pipe
for baseline_step,memory_step in zip(steps[20:],experience_replay.memory[:10]):
test_eq(baseline_step.state,memory_step.state) test_eq(baseline_step.next_state,memory_step.next_state)
Finally we want to finish writing over the memory in its entirety.
= [step for step,_ in zip(*(range(10),experience_replay))]
steps 20)
test_eq(experience_replay._sz_tracker,20)
test_eq(experience_replay._idx_tracker,1)
test_eq(experience_replay._cycle_tracker,20) test_len(experience_replay,
= baseline_test(['CartPole-v1'],40)
steps, pipe
for baseline_step,memory_step in zip(steps[20:],experience_replay.memory):
test_eq(baseline_step.state,memory_step.state) test_eq(baseline_step.next_state,memory_step.next_state)
Let’s verify that the steps are what we expect…
What if we sample the experience?
= exp_replay_test(['CartPole-v1'],1000,bs=300,max_sz=1000)
steps, experience_replay = None
memory for i,sample in enumerate(experience_replay):
for s in sample:
if memory is not None: test_ne(s,memory)
= copy(s)
memory if i>100:break
We should be able to sample enough times that we have sampled everything. So we test this by sampling, check if that sample has been seen before, and then record that.
= exp_replay_test(['CartPole-v1'],1000,bs=1,max_sz=30,return_idxs=True)
steps, experience_replay = [False]*30
memory_hits for i in range(150):
= experience_replay.sample()
res,idxs for idx in idxs: memory_hits[idx] = True
all(memory_hits),True) test_eq(