from fastrl.pipes.core import *
from fastrl.data.block import *
from fastrl.envs.gym import *Experience Replay
ExperienceReplay
ExperienceReplay (*args, **kwds)
Simplest form of memory. Takes steps from source_datapipe to stores them in memory. It outputs bs steps.
lets generate some batches to test with…
def baseline_test(envs,total_steps,seed=0):
pipe = GymTransformBlock(None,n=total_steps,seed=seed)(envs)
pipe = pipe.unbatch()
return list(pipe), pipe
@delegates(ExperienceReplay)
def exp_replay_test(envs,total_steps,seed=0,**kwargs):
pipe = GymTransformBlock(None,n=total_steps,seed=seed)(envs)
pipe = pipe.unbatch()
pipe = ExperienceReplay(pipe,**kwargs)
if total_steps is None: return None,pipe
return list(pipe), pipesteps, experience_replay = exp_replay_test(['CartPole-v1'],0,bs=1)
test_eq(len(experience_replay),0)what if we fill up ER? Lets add the batches, this process will happen inplace…
steps, experience_replay = exp_replay_test(['CartPole-v1'],10,max_sz=20)
test_eq(experience_replay._sz_tracker,10)
test_eq(experience_replay._idx_tracker,10)
test_eq(experience_replay._cycle_tracker,0)
test_len(experience_replay,10)If we run 10 more times, the total size should be 20…
steps = [step for step,_ in zip(*(range(10),experience_replay))]
test_eq(experience_replay._sz_tracker,20)
test_eq(experience_replay._idx_tracker,20)
test_eq(experience_replay._cycle_tracker,0)
test_len(experience_replay,20)experience_replay memory should contain identical steps to if we just run without it…
steps, pipe = baseline_test(['CartPole-v1'],20,seed=0)
_, experience_replay = exp_replay_test(['CartPole-v1'],20,max_sz=20)
for i,(baseline_step,memory_step) in enumerate(zip(steps,experience_replay.memory)):
test_eq(baseline_step.state,memory_step.state)
test_eq(baseline_step.next_state,memory_step.next_state)
print('Step ',i)Step 0
Step 1
Step 2
Step 3
Step 4
Step 5
Step 6
Step 7
Step 8
Step 9
Step 10
Step 11
Step 12
Step 13
Step 14
Step 15
Step 16
Step 17
Step 18
Step 19
Since the max_sz is 20, and so far we have run a total of 20 steps, if we run another 10 steps, the _cycle_tracker should be 1 (since this is a new cycle),_idx_tracker should be 10 since it should have reset and stopped half way in the memory. The _sz_tracker should still be 20.
_, experience_replay = exp_replay_test(['CartPole-v1'],None,max_sz=20)
list(experience_replay.header(19))
steps = [step for step,_ in zip(*(range(10),experience_replay))]
test_eq(experience_replay._sz_tracker,20)
test_eq(experience_replay._idx_tracker,10)
test_eq(experience_replay._cycle_tracker,1)
test_len(experience_replay,20)…and if we run the baseline, the last 10 steps in the baseline, should match the first 10 steps in memory since it is in the middle of re-writing the memory due to being at max size.
steps, pipe = baseline_test(['CartPole-v1'],30)
for baseline_step,memory_step in zip(steps[20:],experience_replay.memory[:10]):
test_eq(baseline_step.state,memory_step.state)
test_eq(baseline_step.next_state,memory_step.next_state)Finally we want to finish writing over the memory in its entirety.
steps = [step for step,_ in zip(*(range(10),experience_replay))]
test_eq(experience_replay._sz_tracker,20)
test_eq(experience_replay._idx_tracker,20)
test_eq(experience_replay._cycle_tracker,1)
test_len(experience_replay,20)steps, pipe = baseline_test(['CartPole-v1'],40)
for baseline_step,memory_step in zip(steps[20:],experience_replay.memory):
test_eq(baseline_step.state,memory_step.state)
test_eq(baseline_step.next_state,memory_step.next_state)Let’s verify that the steps are what we expect…
What if we sample the experience?
steps, experience_replay = exp_replay_test(['CartPole-v1'],1000,bs=300,max_sz=1000)
memory = None
for i,sample in enumerate(experience_replay):
for s in sample:
if memory is not None: test_ne(s,memory)
memory = copy(s)
if i>100:breakWe should be able to sample enough times that we have sampled everything. So we test this by sampling, check if that sample has been seen before, and then record that.
steps, experience_replay = exp_replay_test(['CartPole-v1'],1000,bs=1,max_sz=30,return_idxs=True)
memory_hits = [False]*30
for i in range(150):
res,idxs = experience_replay.sample()
for idx in idxs: memory_hits[idx] = True
test_eq(all(memory_hits),True)