Skip to content

Commit

Permalink
Updated ray dependency and specified package versions
Browse files Browse the repository at this point in the history
  • Loading branch information
louiskirsch committed Mar 25, 2020
1 parent 9763b0d commit c24ada2
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ This is the official research code for the paper Kirsch et al. 2019:

Install the following dependencies (in a virtualenv preferably)
```bash
pip3 install ray gym[all] tensorflow-gpu scipy numpy
pip3 install ray[tune]==0.7.7 gym[all] mujoco_py>=2 tensorflow-gpu==1.13.2 scipy numpy
```

This code base uses [ray](https://github.com/ray-project/ray), if you would like to use multiple machines,
Expand Down
17 changes: 11 additions & 6 deletions ray_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

np.warnings.filterwarnings('ignore')
logger = logging.getLogger(__name__)

tf.logging.set_verbosity(tf.logging.ERROR)

class LLFSExperiment(ExtendedTrainable):

Expand Down Expand Up @@ -178,9 +178,9 @@ def _restore(self, checkpoint_data):
# noinspection PyProtectedMember
def count_required_gpus(config):
if config['agent_count'] > 1:
return math.ceil(config['agent_count'] * ray_workers.AgentWorker._num_gpus + ray_workers.ObjectiveServer._num_gpus)
return math.ceil(config['agent_count'] * ray_workers.AgentWorker.__ray_metadata__.num_gpus + ray_workers.ObjectiveServer.__ray_metadata__.num_gpus)
else:
return ray_workers.AgentWorker._num_gpus
return ray_workers.AgentWorker.__ray_metadata__.num_gpus


def init_ray(redis_address=None):
Expand All @@ -191,7 +191,7 @@ def init_ray(redis_address=None):
ray.init(object_store_memory=mem, redis_max_memory=mem, temp_dir='/tmp/metagenrl/ray')


def run(config, run_name='metagenrl', timesteps=300 * 1000, samples=1):
def run(config, run_name='metagenrl', timesteps=700 * 1000, samples=1):
tune.register_trainable(run_name, LLFSExperiment)
trial_gpus = count_required_gpus(config)
print(f'Requiring {trial_gpus} extra gpus.')
Expand Down Expand Up @@ -226,14 +226,17 @@ def test(args):
Performs meta-test training
"""
assert isinstance(args.objective, str)
config = configs.test(args.objective)
config = configs.test(args.objective, chkp=args.chkp)
config.update({
'name': args.name,
'env_name': tune.grid_search([
'Hopper-v2',
'HalfCheetah-v2',
'LunarLanderContinuous-v2',
]),
})

run(config, run_name='test-public-CheetahLunar')
run(config, run_name=f'test-public-{args.name}-chkp{args.chkp}', samples=1)


if __name__ == '__main__':
Expand All @@ -244,6 +247,8 @@ def test(args):
parser.add_argument('command', choices=FUNCTION_MAP.keys())
parser.add_argument('--redis', dest='redis_address', action='store', type=str)
parser.add_argument('--objective', action='store', type=str)
parser.add_argument('--name', action='store', type=str)
parser.add_argument('--chkp', action='store', type=int, default=-1)
parsed_args = parser.parse_args()
init_ray(parsed_args.redis_address)
func = FUNCTION_MAP[parsed_args.command]
Expand Down
2 changes: 1 addition & 1 deletion ray_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,4 +511,4 @@ def compute_objective_gradients(self, t, grad_oid):
self.objective_vars_oid = None

self.feed_dict[self.locals.plasma_grad_oid] = grad_oid
self.sess.run(self.locals.plasma_write_grads, self.feed_dict)
self.sess.run(self.locals.plasma_write_grads, self.feed_dict)
4 changes: 2 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def get_store_socket():
:return: socket file path
"""
try:
return ray.worker.global_worker.plasma_client.store_socket_name
return ray.worker.global_worker.node.plasma_store_socket_name
except AttributeError:
return ''

Expand Down Expand Up @@ -222,4 +222,4 @@ def get_vars(scope, trainable_only=True):

def count_vars(scope):
v = get_vars(scope)
return sum([np.prod(var.shape.as_list()) for var in v])
return sum([np.prod(var.shape.as_list()) for var in v])

0 comments on commit c24ada2

Please sign in to comment.