You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I accidentally stumbled upon a problem within the tutorial notebook when playing around with the acme and reverb API that causes a weird synchronization behaviour between sampling from the reverb table and updating priorities. Another artifact of this that I encountered is that the very first transition would be consistently repeated until some hidden tensorflow buffer would be flushed.
What I found is that when I would mutate the priorities in a reverb table using client.mutate_priorities(table_name, my_dict) and then create an iterator from the tf.data.Dataset object, then the priorities would update only after flushing a large number of samples. In contrast, if I didn't convert the tf.data.Dataset to an iterator and used the dataset.batch(n); dataset.take(n) interface, it would immediately sync with the new priorities.
It seems to me that the problem lies with the implementation of __iter__ in tf.data.Dataset, but I posted this issue here since the Colab makes a call to as_numpy_iterator() on the dataset object, and this is also the implementation of the D4PG jax agent. Since this is a silent and obscure bug, this effectively eliminates the possibility of changing the baseline D4PG agent to utilize Prioritized Experience Replay.
Minimal Reproducible example:
importwarningswarnings.filterwarnings('ignore')
importacmefromacmeimportwrappersfromacme.datasetsimportreverbasdatasetsfromacme.adders.reverbimportsequencefromacme.jaximportutilsimporttreeimportreverbimportjaximportnumpyasnpfromdm_controlimportsuite# Create dummy environment with short episodes to easily dichotomize samplesenv=suite.load('cartpole', 'balance')
env=wrappers.step_limit.StepLimitWrapper(env, step_limit=5)
spec=acme.make_environment_spec(env)
# Danger: reverb.Table crashes kernel if run > oncetable=reverb.Table(
name='priority_table',
sampler=reverb.selectors.Prioritized(priority_exponent=0.8),
remover=reverb.selectors.Fifo(),
max_size=10_000,
rate_limiter=reverb.rate_limiters.MinSize(1),
signature=sequence.SequenceAdder.signature(spec)
)
server=reverb.Server([table], port=None)
client=reverb.Client(f'localhost:{server.port}')
# Construct adder such that only 1 sample is added to table after an episode.adder=sequence.SequenceAdder(client, sequence_length=6, period=5)
defnew_dataset():
# Clear old dataclient.reset(table.name)
returndatasets.make_reverb_dataset(
table=table.name, server_address=client.server_address, batch_size=3
)
deffill_dataset():
step=env.reset()
adder.add_first(step)
action=env.action_spec().generate_value()
i=0while (notstep.last()) andi<10:
step=env.step(action)
adder.add(action, step)
i+=1env.close()
adder.reset()
### Example of expected behaviourdataset=new_dataset()
fill_dataset()
print('before mutation')
forsindataset.take(1):
k, p=s.info.key.numpy().ravel(), s.info.priority.numpy().ravel()
print(s.data.action.numpy().reshape(3, -1)) # (B, T, 1) -> (B, T)print('sample priority:', p)
# Iteratively halve the prioritiesnew_priorities=dict(zip(k, p*0.5))
client.mutate_priorities(table.name, new_priorities)
print()
print('after mutation')
forsindataset.take(1):
# Priorities have been updated --> all probabilities should now be adjusted.print(s.data.action.numpy().reshape(3, -1)) # (B, T, 1) -> (B, T)print('sample priority:', s.info.priority.numpy())
### Test-casesprint('\nUsing dataset.take')
dataset=new_dataset()
fill_dataset()
# This runs fineforrepeatinrange(5):
foriinrange(30): # Flush count guessforsindataset.take(1):
k, p=s.info.key.numpy().ravel(), s.info.priority.numpy().ravel()
# Exponentially decay the prioritiesnew_priorities=dict(zip(k, p*0.999))
client.mutate_priorities(table.name, new_priorities)
forsindataset.take(1):
new_p=s.info.priority.numpy().ravel()
assertnotnp.isclose(new_p, p).any(), "priorities did not update!"else:
# No break in for loopprint('No errors!')
print('\nUsing next on iter(dataset) - Problems start here.')
dataset=new_dataset()
fill_dataset()
it=iter(dataset)
# Repeat the test-loop as behaviour strangely changes periodicallyforrepeatinrange(5):
foriinrange(30): # Flush count guesss=next(it)
k, p=s.info.key.numpy().ravel(), s.info.priority.numpy().ravel()
# Iteratively halve the prioritiesnew_priorities=dict(zip(k, p*0.999))
client.mutate_priorities(table.name, new_priorities)
s=next(it)
new_p=s.info.priority.numpy().ravel()
# Priority mutations now sync extremely slowlyifnotnp.isclose(p, new_p).all():
print(f'Priorities updated at flush-step {i}')
breakelse:
# No break in for loop : not reachedprint('No errors!')
Output:
before mutation
[[-1. -1. -1. -1. -1. 0.]
[-1. -1. -1. -1. -1. 0.]
[-1. -1. -1. -1. -1. 0.]]
sample priority: [1. 1. 1.]
after mutation
[[-1. -1. -1. -1. -1. 0.]
[-1. -1. -1. -1. -1. 0.]
[-1. -1. -1. -1. -1. 0.]]
sample priority: [0.5 0.5 0.5]
Using dataset.take
No errors!
No errors!
No errors!
No errors!
No errors!
Using next on iter(dataset) - Problems start here.
Priorities updated at flush-step 24
Priorities updated at flush-step 5
Priorities updated at flush-step 18
Priorities updated at flush-step 5
Priorities updated at flush-step 18
Proposed Solution
The problem is immediately solved if iter(dataset) is called at each call to next. Because of this, I wasn't sure whether to post this issue here or in the tensorflow github, since the problem is with tf.data.Dataset. Personally I would suggest creating a wrapper around tf.data.Dataset for that either makes use of the take and batch API, or reinitialize the iter at every call. Because of how reverb implements sampling, reinitializing the dataset iterator should have no side-effects.
Example solution:
print('\nReinitializing iter on every next call - Problem Solved.')
dataset=new_dataset()
fill_dataset()
it=iter(dataset) # Ignore this iterator# Repeat the test-loop as behaviour strangely changes periodicallyforrepeatinrange(5):
foriinrange(30): # Flush count guesss=next(iter(dataset)) # CHANGE: call iter(dataset) every time `next` is calledk, p=s.info.key.numpy().ravel(), s.info.priority.numpy().ravel()
# Iteratively halve the prioritiesnew_priorities=dict(zip(k, p*0.999))
client.mutate_priorities(table.name, new_priorities)
s=next(iter(dataset)) # CHANGE: call iter(dataset) every time `next` is callednew_p=s.info.priority.numpy().ravel()
# Priority mutations now sync extremely slowlyifnotnp.isclose(p, new_p).all():
print(f'Priorities updated at flush-step {i}')
breakelse:
# No break in for loop : not reachedprint('No errors!')
Output: ( priorities are updated after every call, which is what we expected).
Reinitializing iter on every next call - Problem Solved.
Priorities updated at flush-step 0
Priorities updated at flush-step 0
Priorities updated at flush-step 0
Priorities updated at flush-step 0
Priorities updated at flush-step 0
The text was updated successfully, but these errors were encountered:
Hi, I accidentally stumbled upon a problem within the tutorial notebook when playing around with the acme and reverb API that causes a weird synchronization behaviour between sampling from the
reverb
table and updating priorities. Another artifact of this that I encountered is that the very first transition would be consistently repeated until some hidden tensorflow buffer would be flushed.What I found is that when I would mutate the priorities in a
reverb
table usingclient.mutate_priorities(table_name, my_dict)
and then create an iterator from thetf.data.Dataset
object, then the priorities would update only after flushing a large number of samples. In contrast, if I didn't convert thetf.data.Dataset
to an iterator and used thedataset.batch(n); dataset.take(n)
interface, it would immediately sync with the new priorities.It seems to me that the problem lies with the implementation of
__iter__
in tf.data.Dataset, but I posted this issue here since the Colab makes a call toas_numpy_iterator()
on the dataset object, and this is also the implementation of theD4PG
jax agent. Since this is a silent and obscure bug, this effectively eliminates the possibility of changing the baselineD4PG
agent to utilize Prioritized Experience Replay.Minimal Reproducible example:
Output:
Proposed Solution
The problem is immediately solved if
iter(dataset)
is called at each call tonext
. Because of this, I wasn't sure whether to post this issue here or in the tensorflow github, since the problem is with tf.data.Dataset. Personally I would suggest creating a wrapper around tf.data.Dataset for that either makes use of thetake
andbatch
API, or reinitialize theiter
at every call. Because of howreverb
implements sampling, reinitializing the dataset iterator should have no side-effects.Example solution:
Output: ( priorities are updated after every call, which is what we expected).
The text was updated successfully, but these errors were encountered: