Skip to content

Commit

Permalink
Merge pull request #77 from CarperAI/2.0
Browse files Browse the repository at this point in the history
Merge latest fixes into 2.0 for pip release
  • Loading branch information
jsuarez5341 authored Sep 2, 2023
2 parents b0f0345 + eca6938 commit a4317f4
Show file tree
Hide file tree
Showing 73 changed files with 2,901 additions and 2,144 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ maps/
runs/*
wandb/*

# local replay file from tests/test_deterministic_replay.py, test_render_save.py
# local replay file from test_render_save.py
tests/replay_local*.pickle
replay*
eval*

.vscode

Expand Down Expand Up @@ -162,3 +163,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

profile.run
4 changes: 2 additions & 2 deletions nmmo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .render.overlay import Overlay, OverlayRegistry
from .core import config, agent, action
from .core.action import Action
from .core.agent import Agent
from .core.agent import Agent, Scripted
from .core.env import Env
from .core.terrain import MapGenerator, Terrain

Expand All @@ -22,7 +22,7 @@
\ \:\ \ \:\ \ \:\ \ \::/ maintained at MIT in
\__\/ \__\/ \__\/ \__\/ Phillip Isola's lab '''

__all__ = ['Env', 'config', 'agent', 'Agent', 'MapGenerator', 'Terrain',
__all__ = ['Env', 'config', 'agent', 'Agent', 'Scripted', 'MapGenerator', 'Terrain',
'action', 'Action', 'material', 'spawn',
'Overlay', 'OverlayRegistry']

Expand Down
133 changes: 44 additions & 89 deletions nmmo/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
# pylint: disable=no-method-argument,unused-argument,no-self-argument,no-member

from enum import Enum, auto
from ordered_set import OrderedSet
import numpy as np
from nmmo.core.observation import Observation

from nmmo.lib import utils
from nmmo.lib.utils import staticproperty
from nmmo.systems.item import Item, Stack
from nmmo.systems.item import Stack
from nmmo.lib.log import EventCode


class NodeType(Enum):
#Tree edges
STATIC = auto() #Traverses all edges without decisions
Expand All @@ -23,7 +24,8 @@ class NodeType(Enum):
class Node(metaclass=utils.IterableNameComparable):
@classmethod
def init(cls, config):
pass
# noop_action is used in some of the N() methods
cls.noop_action = 1 if config.PROVIDE_NOOP_ACTION_TARGET else 0

@staticproperty
def edges():
Expand All @@ -46,12 +48,9 @@ def leaf():
def N(cls, config):
return len(cls.edges)

def deserialize(realm, entity, index):
def deserialize(realm, entity, index, obs: Observation):
return index

def args(stim, entity, config):
return []

class Fixed:
pass

Expand All @@ -75,7 +74,7 @@ def hook(config):
arguments = []
for action in Action.edges(config):
action.init(config)
for args in action.edges:
for args in action.edges: # pylint: disable=not-an-iterable
args.init(config)
if not 'edges' in args.__dict__:
continue
Expand Down Expand Up @@ -105,10 +104,6 @@ def edges(cls, config):
edges.append(Comm)
return edges

def args(stim, entity, config):
raise NotImplementedError


class Move(Node):
priority = 60
nodeType = NodeType.SELECTION
Expand Down Expand Up @@ -139,12 +134,13 @@ def call(realm, entity, direction):
realm.map.tiles[r_new, c_new].add_entity(entity)

# exploration record keeping. moved from entity.py, History.update()
dist_from_spawn = utils.linf(entity.spawn_pos, (r_new, c_new))
if dist_from_spawn > entity.history.exploration:
entity.history.exploration = dist_from_spawn
progress_to_center = realm.map.dist_border_center -\
utils.linf_single(realm.map.center_coord, (r_new, c_new))
if progress_to_center > entity.history.exploration:
entity.history.exploration = progress_to_center
if entity.is_player:
realm.event_log.record(EventCode.GO_FARTHEST, entity,
distance=dist_from_spawn)
distance=progress_to_center)

# CHECK ME: material.Impassible includes void, so this line is not reachable
# Does this belong to Entity/Player.update()?
Expand All @@ -169,18 +165,15 @@ class Direction(Node):
def edges():
return [North, South, East, West, Stay]

def args(stim, entity, config):
return Direction.edges

def deserialize(realm, entity, index):
def deserialize(realm, entity, index, obs: Observation):
return deserialize_fixed_arg(Direction, index)

# a quick helper function
def deserialize_fixed_arg(arg, index):
if isinstance(index, (int, np.int64)):
if index < 0:
return None # so that the action will be discarded
val = min(index-1, len(arg.edges)-1)
val = min(index, len(arg.edges)-1)
return arg.edges[val]

# if index is not int, it's probably already deserialized
Expand All @@ -203,7 +196,6 @@ class West(Node):
class Stay(Node):
delta = (0, 0)


class Attack(Node):
priority = 50
nodeType = NodeType.SELECTION
Expand All @@ -226,7 +218,7 @@ def in_range(entity, stim, config, N):
R, C = stim.shape
R, C = R//2, C//2

rets = OrderedSet([entity])
rets = set([entity])
for r in range(R-N, R+N+1):
for c in range(C-N, C+N+1):
for e in stim[r, c].entities.values():
Expand All @@ -235,14 +227,6 @@ def in_range(entity, stim, config, N):
rets = list(rets)
return rets

# CHECK ME: do we need l1 distance function?
# systems/ai/utils.py also has various distance functions
# which we may want to clean up
# def l1(pos, cent):
# r, c = pos
# r_cent, c_cent = cent
# return abs(r - r_cent) + abs(c - c_cent)

def call(realm, entity, style, target):
if style is None or target is None:
return None
Expand All @@ -256,15 +240,15 @@ def call(realm, entity, style, target):
# Testing a spawn immunity against old agents to avoid spawn camping
immunity = config.COMBAT_SPAWN_IMMUNITY
if entity.is_player and target.is_player and \
target.history.time_alive < immunity < entity.history.time_alive.val:
target.history.time_alive < immunity:
return None

#Check if self targeted
if entity.ent_id == target.ent_id:
return None

#Can't attack out of range
if utils.linf(entity.pos, target.pos) > style.attack_range(config):
if utils.linf_single(entity.pos, target.pos) > style.attack_range(config):
return None

#Execute attack
Expand Down Expand Up @@ -293,28 +277,20 @@ class Style(Node):
def edges():
return [Melee, Range, Mage]

def args(stim, entity, config):
return Style.edges

def deserialize(realm, entity, index):
def deserialize(realm, entity, index, obs: Observation):
return deserialize_fixed_arg(Style, index)


class Target(Node):
argType = None

@classmethod
def N(cls, config):
return config.PLAYER_N_OBS

def deserialize(realm, entity, index: int):
# NOTE: index is the entity id
# CHECK ME: should index be renamed to ent_id?
return realm.entity_or_none(index)
return config.PLAYER_N_OBS + cls.noop_action

def args(stim, entity, config):
#Should pass max range?
return Attack.in_range(entity, stim, config, None)
def deserialize(realm, entity, index: int, obs: Observation):
if index >= len(obs.entities.ids):
return None
return realm.entity_or_none(obs.entities.ids[index])

class Melee(Node):
nodeType = NodeType.ACTION
Expand Down Expand Up @@ -346,27 +322,17 @@ def attack_range(config):
def skill(entity):
return entity.skills.mage


class InventoryItem(Node):
argType = None

@classmethod
def N(cls, config):
return config.INVENTORY_N_OBS

# TODO(kywch): What does args do?
def args(stim, entity, config):
return stim.exchange.items()
return config.INVENTORY_N_OBS + cls.noop_action

def deserialize(realm, entity, index: int):
# NOTE: index is from the inventory, NOT item id
inventory = Item.Query.owned_by(realm.datastore, entity.id.val)

if index >= inventory.shape[0]:
def deserialize(realm, entity, index: int, obs: Observation):
if index >= len(obs.inventory.ids):
return None

item_id = inventory[index, Item.State.attr_name_to_col["id"]]
return realm.items[item_id]
return realm.items.get(obs.inventory.ids[index])

class Use(Node):
priority = 10
Expand Down Expand Up @@ -490,7 +456,6 @@ def call(realm, entity, item, target):

realm.event_log.record(EventCode.GIVE_ITEM, entity)


class GiveGold(Node):
priority = 30

Expand Down Expand Up @@ -528,37 +493,26 @@ def call(realm, entity, amount, target):
if not isinstance(amount, int):
amount = amount.val

if not (amount > 0 and entity.gold.val > 0): # no gold to give
if amount > entity.gold.val: # no gold to give
return

amount = min(amount, entity.gold.val)

entity.gold.decrement(amount)
target.gold.increment(amount)

realm.event_log.record(EventCode.GIVE_GOLD, entity)


class MarketItem(Node):
argType = None

@classmethod
def N(cls, config):
return config.MARKET_N_OBS

# TODO(kywch): What does args do?
def args(stim, entity, config):
return stim.exchange.items()

def deserialize(realm, entity, index: int):
# NOTE: index is from the market, NOT item id
market = Item.Query.for_sale(realm.datastore)
return config.MARKET_N_OBS + cls.noop_action

if index >= market.shape[0]:
def deserialize(realm, entity, index: int, obs: Observation):
if index >= len(obs.market.ids):
return None

item_id = market[index, Item.State.attr_name_to_col["id"]]
return realm.items[item_id]
return realm.items.get(obs.market.ids[index])

class Buy(Node):
priority = 20
Expand Down Expand Up @@ -659,19 +613,24 @@ class Price(Node):
@classmethod
def init(cls, config):
# gold should be > 0
Price.classes = init_discrete(range(1, config.PRICE_N_OBS+1))
cls.price_range = range(1, config.PRICE_N_OBS+1)
Price.classes = init_discrete(cls.price_range)

@classmethod
def index(cls, price):
try:
return cls.price_range.index(price)
except ValueError:
# use the max price, which is config.PRICE_N_OBS
return len(cls.price_range) - 1

@staticproperty
def edges():
return Price.classes

def args(stim, entity, config):
return Price.edges

def deserialize(realm, entity, index):
def deserialize(realm, entity, index, obs: Observation):
return deserialize_fixed_arg(Price, index)


class Token(Node):
argType = Fixed

Expand All @@ -683,13 +642,9 @@ def init(cls, config):
def edges():
return Token.classes

def args(stim, entity, config):
return Token.edges

def deserialize(realm, entity, index):
def deserialize(realm, entity, index, obs: Observation):
return deserialize_fixed_arg(Token, index)


class Comm(Node):
argType = Fixed
priority = 99
Expand Down
14 changes: 13 additions & 1 deletion nmmo/core/agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

class Agent:
policy = 'Neural'

Expand All @@ -11,10 +10,23 @@ def __init__(self, config, idx):
'''
self.config = config
self.iden = idx
self._np_random = None

def __call__(self, obs):
'''Used by scripted agents to compute actions. Override in subclasses.
Args:
obs: Agent observation provided by the environment
'''

def set_rng(self, np_random):
'''Set the random number generator for the agent for reproducibility
Args:
np_random: A numpy random.Generator object
'''
self._np_random = np_random

class Scripted(Agent):
'''Base class for scripted agents'''
policy = 'Scripted'
Loading

0 comments on commit a4317f4

Please sign in to comment.