diff --git a/main_run_test.py b/main_run_test.py index fa37531..76403c6 100644 --- a/main_run_test.py +++ b/main_run_test.py @@ -118,8 +118,7 @@ def run_sequential_test(args, logger): "reward": {"vshape": (1,)}, "terminated": {"vshape": (1,), "dtype": th.uint8}, } - if args.mac=="maven_mac": - scheme["noise"] = {"vshape": (args.noise_dim,)} + groups = { "agents": args.n_agents } diff --git a/pymarl/config/algs/maven.yaml b/pymarl/config/algs/maven.yaml deleted file mode 100644 index f4bc2a3..0000000 --- a/pymarl/config/algs/maven.yaml +++ /dev/null @@ -1,57 +0,0 @@ -# --- Maven specific parameters --- - -# use epsilon greedy action selector -action_selector: "epsilon_greedy" -epsilon_start: 1.0 -epsilon_finish: 0.05 -epsilon_anneal_time: 200000 - -runner: "episode" -#batch_size_run: 6 - -buffer_size: 5000 - - -target_update_interval: 200 - - -agent_output_type: "q" -learner: "maven_learner" -double_q: True -mixer: "qmix" -mixing_embed_dim: 32 -skip_connections: False -hyper_initialization_nonzeros: 0 - -mac: "maven_mac" -agent: "maven_rnn" -noise_dim: 16 -noise_embedding_dim: 32 - -mi_loss: 0.001 -rnn_discrim: True -rnn_agg_size: 32 - -discrim_size: 64 -discrim_layers: 3 - -noise_bandit: True -noise_bandit_lr: 0.1 -noise_bandit_epsilon: 0.2 - -mi_intrinsic: False -mi_scaler: 0.1 -entropy_scaling: 0.001 -hard_qs: False - -bandit_epsilon: 0.1 -bandit_iters: 8 -bandit_batch: 64 -bandit_buffer: 512 -bandit_reward_scaling: 20 -bandit_use_state: True -bandit_policy: True - -name: "maven" -use_cuda: False -use_tensorboard: True \ No newline at end of file diff --git a/pymarl/config/default.yaml b/pymarl/config/default.yaml index 696ba6b..5a1a5b3 100644 --- a/pymarl/config/default.yaml +++ b/pymarl/config/default.yaml @@ -45,6 +45,3 @@ obs_last_action: True # Include the agent's last action (one_hot) in the observa # --- Experiment running params --- repeat_id: 1 label: "default_label" - -# maven -noise_bandit: False \ No newline at end of file diff --git a/pymarl/controllers/__init__.py b/pymarl/controllers/__init__.py index 0533cb1..764ba87 100644 --- a/pymarl/controllers/__init__.py +++ b/pymarl/controllers/__init__.py @@ -2,8 +2,6 @@ from .basic_controller import BasicMAC from .is_controller import ISMAC -from .maven_controller import MavenMAC REGISTRY["basic_mac"] = BasicMAC -REGISTRY["maven_mac"] = MavenMAC REGISTRY["is_mac"] = ISMAC diff --git a/pymarl/controllers/maven_controller.py b/pymarl/controllers/maven_controller.py deleted file mode 100755 index 027bf8f..0000000 --- a/pymarl/controllers/maven_controller.py +++ /dev/null @@ -1,105 +0,0 @@ -from pymarl.modules.agents import REGISTRY as agent_REGISTRY -from pymarl.components.action_selectors import REGISTRY as action_REGISTRY -import torch as th - - -# This multi-agent controller shares parameters between agents -class MavenMAC: - def __init__(self, scheme, groups, args): - self.n_agents = args.n_agents - self.args = args - input_shape = self._get_input_shape(scheme) - self._build_agents(input_shape) - self.agent_output_type = args.agent_output_type - - self.action_selector = action_REGISTRY[args.action_selector](args) - - self.hidden_states = None - - def select_actions(self, ep_batch, t_ep, t_env, bs=slice(None), test_mode=False): - # Only select actions for the selected batch elements in bs - avail_actions = ep_batch["avail_actions"][:, t_ep] - agent_outputs = self.forward(ep_batch, t_ep, test_mode=test_mode) - chosen_actions = self.action_selector.select_action(agent_outputs[bs], avail_actions[bs], t_env, test_mode=test_mode) - return chosen_actions - - def forward(self, ep_batch, t, test_mode=False): - agent_inputs = self._build_inputs(ep_batch, t) - avail_actions = ep_batch["avail_actions"][:, t] - noise_vector = ep_batch["noise"][:, 0] - if th.sum(noise_vector) ==0: - print("noise_vector", noise_vector) - print("ERREUR NOISE = 0") - exit() - agent_outs, self.hidden_states = self.agent(agent_inputs, self.hidden_states, noise_vector) - - if self.agent_output_type == "pi_logits": - - if getattr(self.args, "mask_before_softmax", True): - # Make the logits for unavailable actions very negative to minimise their affect on the softmax - reshaped_avail_actions = avail_actions.reshape(ep_batch.batch_size * self.n_agents, -1) - agent_outs[reshaped_avail_actions == 0] = -1e10 - - agent_outs = th.nn.functional.softmax(agent_outs, dim=-1) - if not test_mode: - # Epsilon floor - epsilon_action_num = agent_outs.size(-1) - if getattr(self.args, "mask_before_softmax", True): - # With probability epsilon, we will pick an available action uniformly - epsilon_action_num = reshaped_avail_actions.sum(dim=1, keepdim=True).float() - - agent_outs = ((1 - self.action_selector.epsilon) * agent_outs - + th.ones_like(agent_outs) * self.action_selector.epsilon/epsilon_action_num) - - if getattr(self.args, "mask_before_softmax", True): - # Zero out the unavailable actions - agent_outs[reshaped_avail_actions == 0] = 0.0 - - return agent_outs.view(ep_batch.batch_size, self.n_agents, -1) - - def init_hidden(self, batch_size): - self.hidden_states = self.agent.init_hidden().unsqueeze(0).expand(batch_size, self.n_agents, -1) # bav - - def parameters(self): - return self.agent.parameters() - - def load_state(self, other_mac): - self.agent.load_state_dict(other_mac.agent.state_dict()) - - def cuda(self): - self.agent.cuda() - - def save_models(self, path): - th.save(self.agent.state_dict(), "{}/agent.th".format(path)) - - def load_models(self, path): - self.agent.load_state_dict(th.load("{}/agent.th".format(path), map_location=lambda storage, loc: storage)) - - def _build_agents(self, input_shape): - self.agent = agent_REGISTRY[self.args.agent](input_shape, self.args) - - def _build_inputs(self, batch, t): - # Assumes homogenous agents with flat observations. - # Other MACs might want to e.g. delegate building inputs to each agent - bs = batch.batch_size - inputs = [] - inputs.append(batch["obs"][:, t]) # b1av - if self.args.obs_last_action: - if t == 0: - inputs.append(th.zeros_like(batch["actions_onehot"][:, t])) - else: - inputs.append(batch["actions_onehot"][:, t-1]) - if self.args.obs_agent_id: - inputs.append(th.eye(self.n_agents, device=batch.device).unsqueeze(0).expand(bs, -1, -1)) - - inputs = th.cat([x.reshape(bs*self.n_agents, -1) for x in inputs], dim=1) - return inputs - - def _get_input_shape(self, scheme): - input_shape = scheme["obs"]["vshape"] - if self.args.obs_last_action: - input_shape += scheme["actions_onehot"]["vshape"][0] - if self.args.obs_agent_id: - input_shape += self.n_agents - - return input_shape diff --git a/pymarl/learners/__init__.py b/pymarl/learners/__init__.py index 809e839..e94c827 100644 --- a/pymarl/learners/__init__.py +++ b/pymarl/learners/__init__.py @@ -4,7 +4,6 @@ from .coma_learner import COMALearner from .comaIS_learner import COMAISLearner from .ddmac_learner import DDMACLearner -from .maven_learner import MavenLearner REGISTRY = {} REGISTRY["q_learner"] = QLearner diff --git a/pymarl/learners/maven_learner.py b/pymarl/learners/maven_learner.py deleted file mode 100644 index 6e5d7a1..0000000 --- a/pymarl/learners/maven_learner.py +++ /dev/null @@ -1,323 +0,0 @@ -import copy -from pymarl.components.episode_buffer import EpisodeBatch -from pymarl.modules.mixers.vdn import VDNMixer -from pymarl.modules.mixers.qmix import QMixer -from pymarl.modules.mixers.maven_mixer import MavenMixer -import torch as th -from torch.optim import RMSprop -import numpy as np - - -class MavenLearner: - def __init__(self, mac, scheme, logger, args): - self.args = args - self.mac = mac - self.logger = logger - - self.params = list(mac.parameters()) - self.last_target_update_episode = 0 - - self.mixer = None - if args.mixer is not None: - if args.mixer == "vdn": - self.mixer = VDNMixer() - elif args.mixer == "qmix": - self.mixer = MavenMixer(args) - else: - raise ValueError("Mixer {} not recognised.".format(args.mixer)) - self.params += list(self.mixer.parameters()) - self.target_mixer = copy.deepcopy(self.mixer) - - discrim_input = np.prod(self.args.state_shape) + self.args.n_agents * self.args.n_actions - - if self.args.rnn_discrim: - self.rnn_agg = RNNAggregator(discrim_input, args) - self.discrim = Discrim(args.rnn_agg_size, self.args.noise_dim, args) - self.params += list(self.discrim.parameters()) - self.params += list(self.rnn_agg.parameters()) - else: - self.discrim = Discrim(discrim_input, self.args.noise_dim, args) - self.params += list(self.discrim.parameters()) - self.discrim_loss = th.nn.CrossEntropyLoss(reduction="none") - - self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) - - self.target_mac = copy.deepcopy(mac) - - self.log_stats_t = -self.args.learner_log_interval - 1 - - def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): - # Get the relevant quantities - rewards = batch["reward"][:, :-1] - actions = batch["actions"][:, :-1] - terminated = batch["terminated"][:, :-1].float() - mask = batch["filled"][:, :-1].float() - mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) - avail_actions = batch["avail_actions"] - noise = batch["noise"][:, 0].unsqueeze(1).repeat(1,rewards.shape[1],1) - - # Calculate estimated Q-Values - mac_out = [] - self.mac.init_hidden(batch.batch_size) - for t in range(batch.max_seq_length): - agent_outs = self.mac.forward(batch, t=t) - mac_out.append(agent_outs) - mac_out = th.stack(mac_out, dim=1) # Concat over time - - # Pick the Q-Values for the actions taken by each agent - chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) # Remove the last dim - - # Calculate the Q-Values necessary for the target - target_mac_out = [] - self.target_mac.init_hidden(batch.batch_size) - for t in range(batch.max_seq_length): - target_agent_outs = self.target_mac.forward(batch, t=t) - target_mac_out.append(target_agent_outs) - - # We don't need the first timesteps Q-Value estimate for calculating targets - target_mac_out = th.stack(target_mac_out[1:], dim=1) # Concat across time - - # Mask out unavailable actions - target_mac_out[avail_actions[:, 1:] == 0] = -9999999 # From OG deepmarl - - # Max over target Q-Values - if self.args.double_q: - # Get actions that maximise live Q (for double q-learning) - mac_out[avail_actions == 0] = -9999999 - cur_max_actions = mac_out[:, 1:].max(dim=3, keepdim=True)[1] - target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) - else: - target_max_qvals = target_mac_out.max(dim=3)[0] - - # Mix - if self.mixer is not None: - chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1], noise) - target_max_qvals = self.target_mixer(target_max_qvals, batch["state"][:, 1:], noise) - - # Discriminator - mac_out[avail_actions == 0] = -9999999 - q_softmax_actions = th.nn.functional.softmax(mac_out[:, :-1], dim=3) - - if self.args.hard_qs: - maxs = th.max(mac_out[:, :-1], dim=3, keepdim=True)[1] - zeros = th.zeros_like(q_softmax_actions) - zeros.scatter_(dim=3, index=maxs, value=1) - q_softmax_actions = zeros - - q_softmax_agents = q_softmax_actions.reshape(q_softmax_actions.shape[0], q_softmax_actions.shape[1], -1) - - states = batch["state"][:, :-1] - state_and_softactions = th.cat([q_softmax_agents, states], dim=2) - - if self.args.rnn_discrim: - h_to_use = th.zeros(size=(batch.batch_size, self.args.rnn_agg_size)).to(states.device) - hs = th.ones_like(h_to_use) - for t in range(batch.max_seq_length - 1): - hs = self.rnn_agg(state_and_softactions[:, t], hs) - for b in range(batch.batch_size): - if t == batch.max_seq_length - 2 or (mask[b, t] == 1 and mask[b, t+1] == 0): - # This is the last timestep of the sequence - h_to_use[b] = hs[b] - s_and_softa_reshaped = h_to_use - else: - s_and_softa_reshaped = state_and_softactions.reshape(-1, state_and_softactions.shape[-1]) - - if self.args.mi_intrinsic: - s_and_softa_reshaped = s_and_softa_reshaped.detach() - - discrim_prediction = self.discrim(s_and_softa_reshaped) - - # Cross-Entropy - target_repeats = 1 - if not self.args.rnn_discrim: - target_repeats = q_softmax_actions.shape[1] - discrim_target = batch["noise"][:, 0].long().detach().max(dim=1)[1].unsqueeze(1).repeat(1, target_repeats).reshape(-1) - discrim_loss = self.discrim_loss(discrim_prediction, discrim_target) - - if self.args.rnn_discrim: - averaged_discrim_loss = discrim_loss.mean() - else: - masked_discrim_loss = discrim_loss * mask.reshape(-1) - averaged_discrim_loss = masked_discrim_loss.sum() / mask.sum() - self.logger.log_stat("discrim_loss", averaged_discrim_loss.item(), t_env) - - - # Calculate 1-step Q-Learning targets - targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals - if self.args.mi_intrinsic: - assert self.args.rnn_discrim is False - targets = targets + self.args.mi_scaler * discrim_loss.view_as(rewards) - - # Td-error - td_error = (chosen_action_qvals - targets.detach()) - - mask = mask.expand_as(td_error) - - # 0-out the targets that came from padded data - masked_td_error = td_error * mask - - # Normal L2 loss, take mean over actual data - loss = (masked_td_error ** 2).sum() / mask.sum() - - loss = loss + self.args.mi_loss * averaged_discrim_loss - - # Optimise - self.optimiser.zero_grad() - loss.backward() - grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) - self.optimiser.step() - - if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0: - self._update_targets() - self.last_target_update_episode = episode_num - - if t_env - self.log_stats_t >= self.args.learner_log_interval: - self.logger.log_stat("loss", loss.item(), t_env) - self.logger.log_stat("grad_norm", grad_norm, t_env) - mask_elems = mask.sum().item() - self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item()/mask_elems), t_env) - self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item()/(mask_elems * self.args.n_agents), t_env) - self.logger.log_stat("target_mean", (targets * mask).sum().item()/(mask_elems * self.args.n_agents), t_env) - self.log_stats_t = t_env - - def stats(self, batch, t_env): - # Get the relevant quantities - rewards = batch["reward"][:, :-1] - actions = batch["actions"][:, :-1] - terminated = batch["terminated"][:, :-1].float() - mask = batch["filled"][:, :-1].float() - mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) - avail_actions = batch["avail_actions"] - noise = batch["noise"][:, 0].unsqueeze(1).repeat(1, rewards.shape[1], 1) - - # Calculate estimated Q-Values - mac_out = [] - self.mac.init_hidden(batch.batch_size) - for t in range(batch.max_seq_length): - agent_outs = self.mac.forward(batch, t=t) - mac_out.append(agent_outs) - mac_out = th.stack(mac_out, dim=1) # Concat over time - - # Pick the Q-Values for the actions taken by each agent - chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) # Remove the last dim - chosen_action_qvals_copy = chosen_action_qvals.clone().detach() - # Calculate the Q-Values necessary for the target - target_mac_out = [] - self.target_mac.init_hidden(batch.batch_size) - for t in range(batch.max_seq_length): - target_agent_outs = self.target_mac.forward(batch, t=t) - target_mac_out.append(target_agent_outs) - target_mac_out_for_log = th.stack(target_mac_out, dim=1).clone().detach() - # We don't need the first timesteps Q-Value estimate for calculating targets - target_mac_out = th.stack(target_mac_out[1:], dim=1) # Concat across time - - # Mask out unavailable actions - target_mac_out[avail_actions[:, 1:] == 0] = -9999999 # From OG deepmarl - - # Max over target Q-Values - if self.args.double_q: - # Get actions that maximise live Q (for double q-learning) - mac_out[avail_actions == 0] = -9999999 - cur_max_actions = mac_out[:, 1:].max(dim=3, keepdim=True)[1] - target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) - else: - target_max_qvals = target_mac_out.max(dim=3)[0] - - target_max_qvals_copy = target_max_qvals.clone().detach() - max_q_indiv = mac_out[:, :-1].max(dim=-1)[0] - max_target_mac_out_training = target_mac_out_for_log[:, :-1].max(dim=-1)[0] - - # Mix - if self.mixer is not None: - chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1], noise) - target_max_qvals = self.target_mixer(target_max_qvals, batch["state"][:, 1:], noise) - - real_discounted_sum = rewards.clone().detach() - t = rewards.size()[1] - 1 # t max - - real_discounted_sum[:, t, :] = rewards[:, t, :] - while t>0: - t-=1 - real_discounted_sum[:, t, :] = rewards[:, t, :] + self.args.gamma * real_discounted_sum[:, t+1, :] - - mask_elems = mask.sum().item() - - self.logger.log_stat("chosen_q_indiv_mean", - (chosen_action_qvals_copy * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) - self.logger.log_stat("chosen_target_q_indiv_mean", - (target_max_qvals_copy * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) - - if self.mixer is not None: - self.logger.log_stat("chosen_q_mix_mean", - (chosen_action_qvals * mask).sum().item() / mask_elems, t_env) - self.logger.log_stat("target_q_mix_mean", - (target_max_qvals * mask).sum().item() / mask_elems, t_env) - - self.logger.log_stat("max_q_indiv_mean", - (max_q_indiv * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) - self.logger.log_stat("max_target_q_indiv_mean", - (max_target_mac_out_training * mask).sum().item() / (mask_elems * self.args.n_agents), - t_env) - self.logger.log_stat("real_discounted_per_state_mean", - (real_discounted_sum * mask).sum().item() / mask_elems, - t_env) - self.log_stats_t = t_env - - - def _update_targets(self): - self.target_mac.load_state(self.mac) - if self.mixer is not None: - self.target_mixer.load_state_dict(self.mixer.state_dict()) - self.logger.console_logger.info("Updated target network") - - def cuda(self): - self.mac.cuda() - self.target_mac.cuda() - self.discrim.cuda() - if self.args.rnn_discrim: - self.rnn_agg.cuda() - if self.mixer is not None: - self.mixer.cuda() - self.target_mixer.cuda() - - def save_models(self, path): - self.mac.save_models(path) - if self.mixer is not None: - th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) - th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) - - def load_models(self, path): - self.mac.load_models(path) - self.target_mac.load_models(path) - if self.mixer is not None: - self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) - self.optimiser.load_state_dict(th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage)) - - -class Discrim(th.nn.Module): - - def __init__(self, input_size, output_size, args): - super().__init__() - self.args = args - layers = [th.nn.Linear(input_size, self.args.discrim_size), th.nn.ReLU()] - for _ in range(self.args.discrim_layers - 1): - layers.append(th.nn.Linear(self.args.discrim_size, self.args.discrim_size)) - layers.append(th.nn.ReLU()) - layers.append(th.nn.Linear(self.args.discrim_size, output_size)) - self.model = th.nn.Sequential(*layers) - - def forward(self, x): - return self.model(x) - - -class RNNAggregator(th.nn.Module): - - def __init__(self, input_size, args): - super().__init__() - self.args = args - self.input_size = input_size - output_size = args.rnn_agg_size - self.rnn = th.nn.GRUCell(input_size, output_size) - - def forward(self, x, h): - return self.rnn(x, h) diff --git a/pymarl/modules/agents/maven_rnn_agent.py b/pymarl/modules/agents/maven_rnn_agent.py deleted file mode 100644 index e41ba57..0000000 --- a/pymarl/modules/agents/maven_rnn_agent.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch as th -import torch.nn as nn -import torch.nn.functional as F - -class MavenRNNAgent(nn.Module): - def __init__(self, input_shape, args): - super(MavenRNNAgent, self).__init__() - self.args = args - - self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim) - self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim) - self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions) - - self.noise_fc1 = nn.Linear(args.noise_dim + args.n_agents, args.noise_embedding_dim) - self.noise_fc2 = nn.Linear(args.noise_embedding_dim, args.noise_embedding_dim) - self.noise_fc3 = nn.Linear(args.noise_embedding_dim, args.n_actions) - - self.hyper = True - self.hyper_noise_fc1 = nn.Linear(args.noise_dim + args.n_agents, args.rnn_hidden_dim * args.n_actions) - - def init_hidden(self): - # make hidden states on same device as model - return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_() - - def forward(self, inputs, hidden_state, noise): - agent_ids = th.eye(self.args.n_agents, device=inputs.device).repeat(noise.shape[0], 1) - noise_repeated = noise.repeat(1, self.args.n_agents).reshape(agent_ids.shape[0], -1) - - x = F.relu(self.fc1(inputs)) - h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim) - h = self.rnn(x, h_in) - q = self.fc2(h) - - noise_input = th.cat([noise_repeated, agent_ids], dim=-1) - - if self.hyper: - W = self.hyper_noise_fc1(noise_input).reshape(-1, self.args.n_actions, self.args.rnn_hidden_dim) - wq = th.bmm(W, h.unsqueeze(2)) - else: - z = F.tanh(self.noise_fc1(noise_input)) - z = F.tanh(self.noise_fc2(z)) - wz = self.noise_fc3(z) - - wq = q * wz - - return wq, h diff --git a/pymarl/modules/agents/noise_rnn_agent.py b/pymarl/modules/agents/noise_rnn_agent.py deleted file mode 100644 index 4dba473..0000000 --- a/pymarl/modules/agents/noise_rnn_agent.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch as th -import torch.nn as nn -import torch.nn.functional as F - -class RNNAgent(nn.Module): - def __init__(self, input_shape, args): - super(RNNAgent, self).__init__() - self.args = args - - self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim) - self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim) - self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions) - - self.noise_fc1 = nn.Linear(args.noise_dim + args.n_agents, args.noise_embedding_dim) - self.noise_fc2 = nn.Linear(args.noise_embedding_dim, args.noise_embedding_dim) - self.noise_fc3 = nn.Linear(args.noise_embedding_dim, args.n_actions) - - self.hyper = True - self.hyper_noise_fc1 = nn.Linear(args.noise_dim + args.n_agents, args.rnn_hidden_dim * args.n_actions) - - def init_hidden(self): - # make hidden states on same device as model - return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_() - - def forward(self, inputs, hidden_state, noise): - agent_ids = th.eye(self.args.n_agents, device=inputs.device).repeat(noise.shape[0], 1) - noise_repeated = noise.repeat(1, self.args.n_agents).reshape(agent_ids.shape[0], -1) - - x = F.relu(self.fc1(inputs)) - h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim) - h = self.rnn(x, h_in) - q = self.fc2(h) - - noise_input = th.cat([noise_repeated, agent_ids], dim=-1) - - if self.hyper: - W = self.hyper_noise_fc1(noise_input).reshape(-1, self.args.n_actions, self.args.rnn_hidden_dim) - wq = th.bmm(W, h.unsqueeze(2)) - else: - z = F.tanh(self.noise_fc1(noise_input)) - z = F.tanh(self.noise_fc2(z)) - wz = self.noise_fc3(z) - - wq = q * wz - - return wq, h diff --git a/pymarl/modules/bandits/__init__.py b/pymarl/modules/bandits/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/pymarl/modules/bandits/const_lr.py b/pymarl/modules/bandits/const_lr.py deleted file mode 100644 index 9ac2151..0000000 --- a/pymarl/modules/bandits/const_lr.py +++ /dev/null @@ -1,37 +0,0 @@ -import numpy as np -import torch as th - - -class Constant_Lr: - - def __init__(self, args): - self.args = args - self.lr = args.noise_bandit_lr - self.returns = [0 for _ in range(self.args.noise_dim)] - self.epsilon = args.noise_bandit_epsilon - self.noise_dim = self.args.noise_dim - - def sample(self, test_mode): - noise_vector = [] - for _ in range(self.args.batch_size_run): - noise = 0 - # During training we are epsilon greedy. - # During testing we are uniform so that we can gather info about all noise seeds - if test_mode or np.random.random() < self.epsilon: - noise = np.random.randint(self.noise_dim) - else: - noise = np.argmax(self.returns) - one_hot_noise = th.zeros(self.noise_dim) - one_hot_noise[noise] = 1 - noise_vector.append(one_hot_noise) - return th.stack(noise_vector) - - def update_returns(self, noise, returns, test_mode): - if test_mode: - return # Only update the returns for training. - for n, r in zip(noise, returns): - # n is onehot - n_idx = np.argmax(n) - self.returns[n_idx] = self.lr * r + (1 - self.lr) * self.returns[n_idx] - - diff --git a/pymarl/modules/bandits/reinforce_hierarchial.py b/pymarl/modules/bandits/reinforce_hierarchial.py deleted file mode 100644 index 967b091..0000000 --- a/pymarl/modules/bandits/reinforce_hierarchial.py +++ /dev/null @@ -1,115 +0,0 @@ -# Categorical policy for discrete z - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import numpy as np -from collections import deque - - -class Policy(nn.Module): - def __init__(self, args): - super(Policy, self).__init__() - self.args = args - self.affine1 = nn.Linear(args.state_shape, 128) - self.affine2 = nn.Linear(128, args.noise_dim) - - def forward(self, x): - x = x.view(-1, self.args.state_shape) - x = self.affine1(x) - x = F.relu(x) - action_scores = self.affine2(x) - return F.softmax(action_scores, dim=1) - - -class Z_agent: - def __init__(self, args): - self.args = args - self.lr = args.lr - self.noise_dim = self.args.noise_dim - # size of state vector - self.state_shape = self.args.state_shape - self.policy = Policy(args) - self.optimizer = optim.Adam(self.policy.parameters(), lr=self.lr) - - def sample(self, state): - probs = self.policy(state) - m = torch.distributions.one_hot_categorical.OneHotCategorical(probs) - action = m.sample() - return action - - def update_returns(self, states, actions, returns, test_mode): - if test_mode: - return - probs = self.policy(states) - m = torch.distributions.one_hot_categorical.OneHotCategorical(probs) - log_probs = m.log_prob(actions) - self.optimizer.zero_grad() - policy_loss = -torch.dot(log_probs, returns) - policy_loss.backward() - self.optimizer.step() - -# Max entropy Z agent -class EZ_agent: - def __init__(self, args, logger): - self.args = args - self.lr = args.lr - self.noise_dim = self.args.noise_dim - # size of state vector - self.state_shape = self.args.state_shape - self.policy = Policy(args) - self.optimizer = optim.Adam(self.policy.parameters(), lr=self.lr) - # Scaling factor for entropy, would roughly be similar to MI scaling - self.entropy_scaling = args.entropy_scaling - self.uniform_distrib = torch.distributions.one_hot_categorical.OneHotCategorical(torch.tensor([1/self.args.noise_dim for _ in range(self.args.noise_dim)]).repeat(self.args.batch_size_run, 1)) - - self.buffer = deque(maxlen=self.args.bandit_buffer) - self.epsilon_floor = args.bandit_epsilon - - self.logger = logger - - def sample(self, state, test_mode): - # During testing we just sample uniformly - if test_mode: - return self.uniform_distrib.sample() - else: - probs = self.policy(state) - m = torch.distributions.one_hot_categorical.OneHotCategorical(probs) - action = m.sample().cpu() - return action - - def update_returns(self, states, actions, returns, test_mode, t): - if test_mode: - return - - for s,a,r in zip(states, actions, returns): - self.buffer.append((s,a,torch.tensor(r, dtype=torch.float))) - - for _ in range(self.args.bandit_iters): - idxs = np.random.randint(0, len(self.buffer), size=self.args.bandit_batch) - batch_elems = [self.buffer[i] for i in idxs] - states_ = torch.stack([x[0] for x in batch_elems]).to(states.device) - actions_ = torch.stack([x[1] for x in batch_elems]).to(states.device) - returns_ = torch.stack([x[2] for x in batch_elems]).to(states.device) - - probs = self.policy(states_) - m = torch.distributions.one_hot_categorical.OneHotCategorical(probs) - log_probs = m.log_prob(actions_.to(probs.device)) - self.optimizer.zero_grad() - policy_loss = -torch.dot(log_probs, torch.tensor(returns_, device=log_probs.device).float()) + self.entropy_scaling * log_probs.sum() - policy_loss.backward() - self.optimizer.step() - - mean_entropy = m.entropy().mean() - self.logger.log_stat("bandit_entropy", mean_entropy.item(), t) - - - def cuda(self): - self.policy.cuda() - - def save_model(self, path): - torch.save(self.policy.state_dict(), "{}/ez_bandit_policy.th".format(path)) - - def load_model(self, path): - self.policy.load_state_dict(torch.load("{}/ez_bandit_policy.th".format(path), map_location=lambda storage, loc:storage)) diff --git a/pymarl/modules/bandits/returns_bandit.py b/pymarl/modules/bandits/returns_bandit.py deleted file mode 100644 index 1f13f18..0000000 --- a/pymarl/modules/bandits/returns_bandit.py +++ /dev/null @@ -1,97 +0,0 @@ -# Categorical policy for discrete z - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from collections import deque -import numpy as np - - -class Net(nn.Module): - def __init__(self, args): - super(Net, self).__init__() - self.args = args - self.affine1 = nn.Linear(args.state_shape, 256) - self.affine2 = nn.Linear(256, 256) - self.affine3 = nn.Linear(256, args.noise_dim) - self.output_scale = self.args.bandit_reward_scaling - - def forward(self, x): - x = x.view(-1, self.args.state_shape) - x = self.affine1(x) - x = F.relu(x) - x = self.affine2(x) - x = F.relu(x) - returns = self.affine3(x) - return returns * self.output_scale - - -class ReturnsBandit: - def __init__(self, args, logger): - self.args = args - self.lr = args.lr - self.logger = logger - self.noise_dim = self.args.noise_dim - # size of state vector - self.state_shape = self.args.state_shape - self.net = Net(args) - self.optimizer = optim.RMSprop(self.net.parameters(), lr=self.lr) - - self.buffer = deque(maxlen=self.args.bandit_buffer) - self.epsilon_floor = args.bandit_epsilon - - self.uniform_noise = torch.distributions.one_hot_categorical.OneHotCategorical(torch.tensor([1/self.args.noise_dim for _ in range(self.args.noise_dim)]).repeat(self.args.batch_size_run, 1)) - - def sample(self, state, test_mode): - if test_mode: - return self.uniform_noise.sample() - else: - estimated_returns = self.net(state) - probs = F.softmax(estimated_returns, dim=-1) - probs_eps = (1 - self.epsilon_floor) * probs + self.epsilon_floor / self.noise_dim - m = torch.distributions.one_hot_categorical.OneHotCategorical(probs_eps) - action = m.sample().cpu() - return action - - def update_returns(self, states, actions, returns, test_mode, t): - if test_mode: - return - - for s,a,r in zip(states, actions, returns): - self.buffer.append((s,a,torch.tensor(r, dtype=torch.float))) - - for _ in range(self.args.bandit_iters): - idxs = np.random.randint(0, len(self.buffer), size=self.args.bandit_batch) - batch_elems = [self.buffer[i] for i in idxs] - states_ = torch.stack([x[0] for x in batch_elems]).to(states.device) - actions_ = torch.stack([x[1] for x in batch_elems]).to(states.device) - returns_ = torch.stack([x[2] for x in batch_elems]).to(states.device) - - if not self.args.bandit_use_state: - states_ = torch.ones_like(states_) - - estimated_returns_all = self.net(states_) - estimated_returns = (estimated_returns_all * actions_).sum(dim=1) - loss = (returns_ - estimated_returns).pow(2).mean() - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() - - # Log info about the last iteration - self.logger.log_stat("bandit_loss", loss.item(), t) - action_distrib = torch.distributions.OneHotCategorical(F.softmax(estimated_returns_all, dim=1)) - mean_entropy = action_distrib.entropy().mean() - self.logger.log_stat("bandit_entropy", mean_entropy.item(), t) - mins = estimated_returns_all.min(dim=1)[0].mean().item() - maxs = estimated_returns_all.max(dim=1)[0].mean().item() - means = estimated_returns_all.mean().item() - self.logger.log_stat("min_returns", mins, t) - self.logger.log_stat("max_returns", maxs, t) - self.logger.log_stat("mean_returns", means, t) - - def cuda(self): - self.net.cuda() - - def save_model(self, path): - torch.save(self.net.state_dict(), "{}/returns_bandit_net.th".format(path)) diff --git a/pymarl/modules/bandits/uniform.py b/pymarl/modules/bandits/uniform.py deleted file mode 100644 index e59b6c0..0000000 --- a/pymarl/modules/bandits/uniform.py +++ /dev/null @@ -1,14 +0,0 @@ -import torch as th - - -class Uniform: - - def __init__(self, args): - self.args = args - self.noise_distrib = th.distributions.one_hot_categorical.OneHotCategorical(th.tensor([1/self.args.noise_dim for _ in range(self.args.noise_dim)]).repeat(self.args.batch_size_run, 1)) - - def sample(self, state, test_mode): - return self.noise_distrib.sample() - - def update_returns(self, state, noise, returns, test_mode, t): - pass \ No newline at end of file diff --git a/pymarl/modules/mixers/maven_mixer.py b/pymarl/modules/mixers/maven_mixer.py deleted file mode 100644 index 1fdf152..0000000 --- a/pymarl/modules/mixers/maven_mixer.py +++ /dev/null @@ -1,66 +0,0 @@ -import torch as th -import torch.nn as nn -import torch.nn.functional as F -import numpy as np - - -class MavenMixer(nn.Module): - def __init__(self, args): - super(MavenMixer, self).__init__() - - self.args = args - self.n_agents = args.n_agents - self.state_dim = int(np.prod(args.state_shape)) + args.noise_dim - - self.embed_dim = args.mixing_embed_dim - - self.hyper_w_1 = nn.Linear(self.state_dim, self.embed_dim * self.n_agents) - self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim) - - # Initialise the hyper networks with a fixed variance, if specified - if self.args.hyper_initialization_nonzeros > 0: - std = self.args.hyper_initialization_nonzeros ** -0.5 - self.hyper_w_1.weight.data.normal_(std=std) - self.hyper_w_1.bias.data.normal_(std=std) - self.hyper_w_final.weight.data.normal_(std=std) - self.hyper_w_final.bias.data.normal_(std=std) - - # Initialise the hyper-network of the skip-connections, such that the result is close to VDN - if self.args.skip_connections: - self.skip_connections = nn.Linear(self.state_dim, self.args.n_agents, bias=True) - self.skip_connections.bias.data.fill_(1.0) # bias produces initial VDN weights - - # State dependent bias for hidden layer - self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim) - - # V(s) instead of a bias for the last layers - self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim), - nn.ReLU(), - nn.Linear(self.embed_dim, 1)) - - def forward(self, agent_qs, states, noise): - bs = agent_qs.size(0) - states = th.cat([states, noise], dim=-1) - states = states.reshape(-1, self.state_dim) - agent_qs = agent_qs.view(-1, 1, self.n_agents) - # First layer - w1 = th.abs(self.hyper_w_1(states)) - b1 = self.hyper_b_1(states) - w1 = w1.view(-1, self.n_agents, self.embed_dim) - b1 = b1.view(-1, 1, self.embed_dim) - hidden = F.elu(th.bmm(agent_qs, w1) + b1) - # Second layer - w_final = th.abs(self.hyper_w_final(states)) - w_final = w_final.view(-1, self.embed_dim, 1) - # State-dependent bias - v = self.V(states).view(-1, 1, 1) - # Skip connections - s = 0 - if self.args.skip_connections: - ws = th.abs(self.skip_connections(states)).view(-1, self.n_agents, 1) - s = th.bmm(agent_qs, ws) - # Compute final output - y = th.bmm(hidden, w_final) + v + s - # Reshape and return - q_tot = y.view(bs, -1, 1) - return q_tot diff --git a/pymarl/run.py b/pymarl/run.py index 39dbfe8..558c0f2 100644 --- a/pymarl/run.py +++ b/pymarl/run.py @@ -93,8 +93,6 @@ def run_sequential(args, logger): "reward": {"vshape": (1,)}, "terminated": {"vshape": (1,), "dtype": th.uint8}, } - if args.mac=="maven_mac": - scheme["noise"] = {"vshape": (args.noise_dim,)} if args.mac == "is_mac": scheme["behavior"] = {"vshape": (env_info["n_actions"],), diff --git a/pymarl/runners/episode_runner.py b/pymarl/runners/episode_runner.py index 7d07782..64589f4 100644 --- a/pymarl/runners/episode_runner.py +++ b/pymarl/runners/episode_runner.py @@ -1,14 +1,7 @@ -import time - from pymarl.envs import REGISTRY as env_REGISTRY from functools import partial from pymarl.components.episode_buffer import EpisodeBatch import numpy as np -from pymarl.modules.bandits.const_lr import Constant_Lr -from pymarl.modules.bandits.uniform import Uniform -from pymarl.modules.bandits.reinforce_hierarchial import EZ_agent as enza -from pymarl.modules.bandits.returns_bandit import ReturnsBandit as RBandit - class EpisodeRunner: @@ -38,20 +31,6 @@ def setup(self, scheme, groups, preprocess, mac): device=self.args.device) self.mac = mac - # Setup the noise distribution sampler - if self.args.mac == "maven_mac": - if self.args.noise_bandit: - if self.args.bandit_policy: - self.noise_distrib = enza(self.args, logger=self.logger) - else: - self.noise_distrib = RBandit(self.args, logger=self.logger) - else: - self.noise_distrib = Uniform(self.args) - - self.noise_returns = {} - self.noise_test_won = {} - self.noise_train_won = {} - def get_env_info(self): return self.env.get_env_info() @@ -72,12 +51,7 @@ def run(self, test_mode=False): terminated = False episode_return = 0 self.mac.init_hidden(batch_size=self.batch_size) - if self.args.mac == "maven_mac": - # Sample the noise at the beginning of the episode - self.noise = self.noise_distrib.sample(self.batch['state'][:, 0], - test_mode) - self.batch.update({"noise": self.noise}, ts=0) while not terminated: pre_transition_data = { @@ -168,12 +142,10 @@ def run(self, test_mode=False): return self.batch def save_models(self, path): - if self.args.noise_bandit: - self.noise_distrib.save_model(path) + pass def load_models(self, path): - if self.args.mac == "maven_mac" and self.args.noise_bandit: - self.noise_distrib.load_model(path) + pass def _log(self, returns, stats, prefix): self.logger.log_stat(prefix + "return_mean", np.mean(returns), diff --git a/pymarl/runners/parallel_runner.py b/pymarl/runners/parallel_runner.py index 6e8f684..840d402 100644 --- a/pymarl/runners/parallel_runner.py +++ b/pymarl/runners/parallel_runner.py @@ -3,12 +3,6 @@ from pymarl.components.episode_buffer import EpisodeBatch from multiprocessing import Pipe, Process import numpy as np -import torch as th -from pymarl.modules.bandits.const_lr import Constant_Lr -from pymarl.modules.bandits.uniform import Uniform -from pymarl.modules.bandits.reinforce_hierarchial import EZ_agent as enza -from pymarl.modules.bandits.returns_bandit import ReturnsBandit as RBandit -import time # Based (very) heavily on SubprocVecEnv from OpenAI Baselines # https://github.com/openai/baselines/blob/master/baselines/common/vec_env/subproc_vec_env.py @@ -50,8 +44,7 @@ def __init__(self, args, logger): self.log_train_stats_t = -100000 def cuda(self): - if self.args.noise_bandit: - self.noise_distrib.cuda() + pass def setup(self, scheme, groups, preprocess, mac): self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, @@ -63,20 +56,6 @@ def setup(self, scheme, groups, preprocess, mac): self.groups = groups self.preprocess = preprocess - # Setup the noise distribution sampler - if self.args.mac == "maven_mac": - if self.args.noise_bandit: - if self.args.bandit_policy: - self.noise_distrib = enza(self.args, logger=self.logger) - else: - self.noise_distrib = RBandit(self.args, logger=self.logger) - else: - self.noise_distrib = Uniform(self.args) - - self.noise_returns = {} - self.noise_test_won = {} - self.noise_train_won = {} - def get_env_info(self): return self.env_info @@ -108,13 +87,6 @@ def reset(self, test_mode=False): self.batch.update(pre_transition_data, ts=0) - if self.args.mac == "maven_mac": - # Sample the noise at the beginning of the episode - self.noise = self.noise_distrib.sample(self.batch['state'][:, 0], - test_mode) - - self.batch.update({"noise": self.noise}, ts=0) - self.t = 0 self.env_steps_this_run = 0 @@ -243,13 +215,6 @@ def run(self, test_mode=False): cur_returns.extend(episode_returns) - if self.args.mac == "maven_mac": - self._update_noise_returns(episode_returns, self.noise, - final_env_infos, test_mode) - self.noise_distrib.update_returns(self.batch['state'][:, 0], - self.noise, episode_returns, - test_mode, self.t_env) - n_test_runs = max(1, self.args.test_nepisode // self.batch_size)* self.batch_size if test_mode and (len(self.test_returns) == n_test_runs): self._log(cur_returns, cur_stats, log_prefix) @@ -276,36 +241,11 @@ def _log(self, returns, stats, prefix): v / stats["n_episodes"], self.t_env) stats.clear() - def _update_noise_returns(self, returns, noise, stats, test_mode): - for n, r in zip(noise, returns): - n = int(np.argmax(n)) - if n in self.noise_returns: - self.noise_returns[n].append(r) - else: - self.noise_returns[n] = [r] - if test_mode: - noise_won = self.noise_test_won - else: - noise_won = self.noise_train_won - - if stats != [] and "battle_won" in stats[0]: - for n, info in zip(noise, stats): - if "battle_won" not in info: - continue - bw = info["battle_won"] - n = int(np.argmax(n)) - if n in noise_won: - noise_won[n].append(bw) - else: - noise_won[n] = [bw] - def save_models(self, path): - if self.args.noise_bandit: - self.noise_distrib.save_model(path) + pass def load_models(self, path): - if self.args.mac == "maven_mac" and self.args.noise_bandit: - self.noise_distrib.load_model(path) + pass def env_worker(remote, env_fn):