Skip to content

Commit

Permalink
Merging 'harvard-edge/dev' branch into farama-dev
Browse files Browse the repository at this point in the history
  • Loading branch information
uchendui committed Jul 27, 2024
2 parents da6f897 + 87aa785 commit 77f80d6
Show file tree
Hide file tree
Showing 81 changed files with 6,607 additions and 913 deletions.
12 changes: 12 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
.git/
.gitignore
node_modules/
npm-debug.log
Dockerfile
.dockerignore
temp/
*.md
*.egg-info/
venv/
env/
.idea/
44 changes: 44 additions & 0 deletions .github/workflows/build-docs-dev.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
name: Deploy Docs
on:
push:
branches: [master]

permissions:
contents: write

jobs:
docs:
name: Generate Website
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'

- name: Install docs dependencies
run: pip install -r docs/requirements.txt

- name: Install <Project>
run: pip install -e .

- name: Run some auxiliary scripts, e.g. build environments docs
run: python docs/_scripts/gen_envs_mds.py

- name: Build
run: sphinx-build -b dirhtml -v docs _build

- name: Move 404
run: mv _build/404/index.html _build/404.html

- name: Update 404 links
run: python docs/_scripts/move_404.py _build/404.html

- name: Remove .doctrees
run: rm -r _build/.doctrees

- name: Upload to GitHub Pages
uses: JamesIves/github-pages-deploy-action@v4
with:
folder: _build
12 changes: 11 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ __pycache__/

# C extensions
*.so

.ssh/
# Distribution / packaging
.Python
build/
Expand Down Expand Up @@ -127,3 +127,13 @@ dmypy.json

# Pyre type checker
.pyre/

# Ignore singularity images
*.simg
*.sif

# Ignore Intellij IDE files
.idea/

# Ignore Logs directory since it's used for running experiments
logs/
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[submodule "web_navigation"]
[submodule "a2perf/domains/web_navigation"]
path = a2perf/domains/web_navigation
url = https://github.com/Farama-Foundation/a2perf-web-nav.git
branch = dev
Expand Down
2 changes: 1 addition & 1 deletion a2perf/a2perf_benchmark_submission
Submodule a2perf_benchmark_submission updated 35 files
+3 −2 README.md
+165 −0 commands.py
+0 −0 configs/docker/QuadrupedLocomotion-v0/bc.gin
+45 −0 configs/docker/QuadrupedLocomotion-v0/ppo.gin
+0 −0 configs/docker/QuadrupedLocomotion-v0/sac.gin
+40 −0 configs/local/CircuitTraining-v0/netlist_ariane_std_cell_placer_mode_dreamplace/bc.gin
+43 −0 configs/local/CircuitTraining-v0/netlist_ariane_std_cell_placer_mode_dreamplace/ddqn.gin
+46 −0 configs/local/CircuitTraining-v0/netlist_ariane_std_cell_placer_mode_dreamplace/ppo.gin
+41 −0 configs/local/CircuitTraining-v0/netlist_toy_macro_stdcell_std_cell_placer_mode_dreamplace/bc.gin
+44 −0 configs/local/CircuitTraining-v0/netlist_toy_macro_stdcell_std_cell_placer_mode_dreamplace/ddqn.gin
+47 −0 configs/local/CircuitTraining-v0/netlist_toy_macro_stdcell_std_cell_placer_mode_dreamplace/ppo.gin
+0 −0 configs/local/QuadrupedLocomotion-v0/bc.gin
+45 −0 configs/local/QuadrupedLocomotion-v0/ppo.gin
+40 −0 configs/local/QuadrupedLocomotion-v0/sac.gin
+44 −0 configs/local/WebNavigation-v0/bc.gin
+46 −0 configs/local/WebNavigation-v0/ddqn.gin
+55 −0 configs/local/WebNavigation-v0/ppo.gin
+129 −19 inference.py
+6 −0 requirements.txt
+294 −10 train.py
+0 −0 train_lib/__init__.py
+1,119 −0 train_lib/agents.py
+0 −0 train_lib/circuit_training/__init__.py
+130 −0 train_lib/circuit_training/fully_connected_model_lib.py
+50 −0 train_lib/circuit_training/static_feature_cache.py
+717 −0 train_lib/collect.py
+268 −0 train_lib/learner_lib.py
+328 −0 train_lib/learners.py
+1,269 −0 train_lib/models.py
+183 −0 train_lib/networks.py
+174 −0 train_lib/reverb_server.py
+834 −0 train_lib/train.py
+41 −0 train_lib/triggers.py
+143 −0 train_lib/vocabulary_manager.py
+1,099 −0 train_lib/web_navigation/networks.py
3 changes: 3 additions & 0 deletions a2perf/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from . import reliability
from . import results
from . import system
186 changes: 186 additions & 0 deletions a2perf/analysis/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import functools
import json
import multiprocessing
import os
from typing import Any
from typing import Dict
from typing import Tuple

from a2perf.analysis.metrics_lib import load_training_system_data
from a2perf.domains import circuit_training
from a2perf.domains import quadruped_locomotion
from a2perf.domains import web_navigation
from a2perf.domains.tfa.suite_gym import create_domain
from a2perf.domains.tfa.utils import load_policy
from a2perf.domains.tfa.utils import perform_rollouts
from absl import app
from absl import flags
from absl import logging
import numpy as np

_NUM_EVAL_EPISODES = flags.DEFINE_integer(
'num_eval_episodes', 100, 'Number of episodes to evaluate the policy.'
)

_MAX_PARALLEL_ENVS = flags.DEFINE_integer(
'max_parallel_envs', 1, 'Maximum number of parallel environments to use.'
)
_ROOT_DIR = flags.DEFINE_string(
'root_dir',
None,
'Root directory of the environment. If not set, the ROOT_DIR environment '
'variable is used.',
)

_ENV_NAME = flags.DEFINE_string(
'env_name', 'CartPole-v0', 'The name of the environment to evaluate.'
)
_POLICY_NAME = flags.DEFINE_string(
'policy_name', 'policy', 'The name of the policy to evaluate.'
)


def load_policy_and_perform_rollouts(
checkpoint_path: str,
env_name: str,
policy_path: str,
num_episodes: int,
root_dir: str = None,
) -> Dict[str, Any]:
try:
policy = load_policy(policy_path, checkpoint_path)
if env_name == 'CircuitTraining-v0' or env_name == 'WebNavigation-v0':
env = create_domain(env_name, root_dir=root_dir)
else:
env = create_domain(env_name)
episode_returns = perform_rollouts(policy, env, num_episodes)

eval_dict = {
checkpoint_path: {
'mean': np.mean(episode_returns).astype(float),
'std': np.std(episode_returns).astype(float),
'min': np.min(episode_returns).astype(float),
'max': np.max(episode_returns).astype(float),
'median': np.median(episode_returns).astype(float),
'count': int(episode_returns.size),
'rollout_returns': [float(v) for v in episode_returns],
}
}

logging.info('Evaluation results for %s:', checkpoint_path)
logging.info('\t%s', eval_dict[checkpoint_path])
return eval_dict
except Exception as e:
import traceback

logging.error('Error evaluating checkpoint %s: %s', checkpoint_path, e)
traceback.print_exc()
return {}


def add_training_energy_cost(
checkpoint_item: Tuple[str, Dict[str, Any]], total_energy_kwh
) -> Tuple[str, Dict[str, Any]]:
checkpoint_path = checkpoint_item[0]
checkpoint_dict = checkpoint_item[1]

policy_checkpoint_name = os.path.basename(checkpoint_path)
policy_checkpoint_number = int(policy_checkpoint_name.split('_')[-1])

checkpoint_dict.update({
'total_training_energy_kwh': total_energy_kwh,
'training_energy_kwh': total_energy_kwh * policy_checkpoint_number,
'checkpoint_number': policy_checkpoint_number,
})
return checkpoint_path, checkpoint_dict


def main(_):
multiprocessing.set_start_method('spawn', force=False)
saved_model_path = os.path.join(
_ROOT_DIR.value, 'policies', _POLICY_NAME.value
)
checkpoints_path = os.path.join(_ROOT_DIR.value, 'policies', 'checkpoints')

# Get absolute paths of all checkpoints
all_checkpoints_paths = [
os.path.join(checkpoints_path, checkpoint)
for checkpoint in os.listdir(checkpoints_path)
]

# Create a partial function that has all the fixed parameters set
partial_func = functools.partial(
load_policy_and_perform_rollouts,
root_dir=_ROOT_DIR.value,
env_name=_ENV_NAME.value,
policy_path=saved_model_path,
num_episodes=_NUM_EVAL_EPISODES.value,
)

with multiprocessing.Pool(_MAX_PARALLEL_ENVS.value) as pool:
episode_returns = pool.map(partial_func, all_checkpoints_paths)
pool.close()
pool.join()

all_episode_returns = {k: v for d in episode_returns for k, v in d.items()}

# Use the experiment path to get the system metrics for this training run
experiment_path = os.path.abspath(
os.path.join(_ROOT_DIR.value, os.pardir, os.pardir, os.pardir, os.pardir)
)
logging.debug('Experiment path: %s', experiment_path)
training_system_df = load_training_system_data(
base_dir=os.path.abspath(os.path.join(experiment_path, os.pardir)),
experiment_ids=[os.path.basename(experiment_path)],
)

# For each run-id, take the last `energy_consumed` entry and sum them together
total_training_energy_kwh = (
training_system_df.groupby(
['domain', 'algo', 'task', 'experiment', 'seed']
)['energy_consumed']
.last()
.sum()
)

# Add the training sample cost to the evaluation results
with multiprocessing.Pool(_MAX_PARALLEL_ENVS.value) as pool:
all_episode_returns = pool.map(
functools.partial(
add_training_energy_cost, total_energy_kwh=total_training_energy_kwh
),
all_episode_returns.items(),
)
pool.close()
pool.join()

# Turn all_episode_returns back into a dictionary
all_episode_returns = {k: v for (k, v) in all_episode_returns}

maximum_checkpoint_number = max(
[int(v['checkpoint_number']) for v in all_episode_returns.values()]
)
logging.info('Maximum checkpoint number: %d', maximum_checkpoint_number)

for checkpoint_path, checkpoint_dict in all_episode_returns.items():
# Adjusting the training energy cost such that earlier checkpoints are
# associated with less energy usage
checkpoint_dict['training_energy_kwh'] = (
checkpoint_dict['training_energy_kwh'] / maximum_checkpoint_number
)

# Make sure that the energy usage for the final checkpoint is the same as
# the total energy usage
if checkpoint_dict['checkpoint_number'] == maximum_checkpoint_number:
assert checkpoint_dict['training_energy_kwh'] == total_training_energy_kwh

# Save as JSON
evaluation_save_path = os.path.join(
_ROOT_DIR.value, 'policies', 'evaluation.json'
)
with open(evaluation_save_path, 'w') as f:
json.dump(all_episode_returns, f, indent=2)


if __name__ == '__main__':
app.run(main)
Loading

0 comments on commit 77f80d6

Please sign in to comment.