Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ready for review #15

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions scripts/explore_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def step():
global gridworld, parameter_values, env, agent, running, stepping, terminated, t, state, total, aleph, aleph0, delta, initialMu0, initialMu20, visited_state_alephs, visited_action_alephs
print()
env._fps = values['speed_slider']
action, aleph4action = agent.localPolicy(state, aleph).sample()[0]
# action, aleph4action = agent.localPolicy(state, aleph).sample()[0]
action = agent.act()
aleph4action = agent.last_aleph4action
visited_state_alephs.add((state, aleph))
visited_action_alephs.add((state, action, aleph4action))
if values['lossCoeff4WassersteinTerminalState'] != 0:
Expand All @@ -144,8 +146,10 @@ def step():
if parameter_values['verbose'] or parameter_values['debug']:
print("t:", t, ", last delta:" ,delta, ", total:", total, ", s:", state, ", aleph4s:", aleph, ", a:", action, ", aleph4a:", aleph4action)
nextState, delta, terminated, _, info = env.step(action)
total += delta
aleph = agent.propagateAspiration(state, action, aleph4action, delta, nextState)
agent.observe(nextState, delta, terminated)
total = agent.total # total += delta
aleph = agent.last_aleph4state # agent.propagateAspiration(state, action, aleph4action, delta, nextState)

state = nextState
if terminated:
print("t:",t, ", last delta:",delta, ", final total:", total, ", final s:", state, ", aleph4s:", aleph)
Expand Down Expand Up @@ -183,6 +187,7 @@ def step():
t += 1
if stepping: stepping = False


def reset_env(start=False):
# TODO: only regenerate env if different from before!
global gridworld, parameter_values, env, agent, running, stepping, terminated, t, state, total, aleph, aleph0, delta, initialMu0, initialMu20, visited_state_alephs, visited_action_alephs
Expand All @@ -191,6 +196,7 @@ def reset_env(start=False):
if gridworld != old_gridworld:
env, aleph0 = make_simple_gridworld(gw=gridworld, render_mode="human", fps=values['speed_slider'])
# env = env.get_prolonged_version(5)
agent = AgentMDPPlanning(world=env)
if values['override_aleph_checkbox']:
aleph = (values['aleph0_low'], values['aleph0_high'])
else:
Expand All @@ -201,6 +207,7 @@ def reset_env(start=False):
if parameter_values['lossTemperature'] == 0:
parameter_values['lossTemperature'] = 1e-6
parameter_values.update({
'initialAspiration': aleph,
'verbose': values['verbose_checkbox'],
'debug': values['debug_checkbox'],
'allowNegativeCoeffs': True,
Expand All @@ -209,9 +216,10 @@ def reset_env(start=False):
'wassersteinFromInitial': values['wasserstein_checkbox'],
})
print("\n\nRESTART gridworld", gridworld, parameter_values)
agent.reset(parameter_values)
state, info = env.reset()
agent.observe(state)
print("Initial state:", env.state_embedding(state), ", initial aleph:", aleph)
agent = AgentMDPPlanning(parameter_values, world=env)
# agent.localPolicy(state, aleph) # call it once to precompute tables and save time for later
initialMu0 = list(agent.ETerminalState_state(state, aleph, "default"))
initialMu20 = list(agent.ETerminalState2_state(state, aleph, "default"))
Expand Down
81 changes: 80 additions & 1 deletion src/satisfia/agents/makeMDPAgentSatisfia.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@ class AspirationAgent(ABC):
reachable_states = None
default_transition = None


### Methods for initialization, resetting, clearing caches:

def __init__(self, params):
if params: self.reset(params)

def reset(self, params):
"""
If world is provided, maxAdmissibleQ, minAdmissibleQ, Q, Q2, ..., Q6 are not needed because they are computed from the world. Otherwise, these functions must be provided, e.g. as learned using some reinforcement learning algorithm. Their signature is
- maxAdmissibleQ|minAdmissibleQ: (state, action) -> float
Expand All @@ -41,7 +47,9 @@ def __init__(self, params):
if lossCoeff4StateDistance > 0, referenceState must be provided

"""

defaults = {
"initialAspiration": None,
# admissibility parameters:
"maxLambda": 1, # upper bound on local relative aspiration in each step (must be minLambda...1) # TODO: rename to lambdaHi
"minLambda": 0, # lower bound on local relative aspiration in each step (must be 0...maxLambda) # TODO: rename to lambdaLo
Expand Down Expand Up @@ -103,6 +111,15 @@ def __init__(self, params):
self.params.update(params)
# TODO do I need to add params_.options

self.clear_caches()
self.last_state = None
self.last_aleph4state = params["initialAspiration"]
self.last_action = None
self.last_aleph4action = None
self.last_delta = None
self.terminated = False
self.total = None

self.stateActionPairsSet = set()

assert self.params["lossTemperature"] > 0, "lossTemperature must be > 0"
Expand Down Expand Up @@ -216,9 +233,23 @@ def deltaVar(s, a, al4s, al4a, p):
→ aspiration4state
→ simulate (RECURSION)"""

def clear_caches(self):
"""Clear all function caches (called by reset())"""
# loop through all parent classes:
for cls in self.__class__.__mro__:
# loop through all attributes of the class:
for key, value in cls.__dict__.items():
# check if the attribute is a cached function:
if callable(value):
if hasattr(value, "cache_clear"):
value.cache_clear()

def __getitem__(self, name):
return self.params[name]


### Methods for computing feasibility sets / reference simplices:

@cache
def maxAdmissibleV(self, state): # recursive
if self.verbose or self.debug:
Expand Down Expand Up @@ -259,6 +290,9 @@ def admissibility4state(self, state):
def admissibility4action(self, state, action):
return self.minAdmissibleQ(state, action), self.maxAdmissibleQ(state, action)


# Methods for computing aspirations:

# When in state, we can get any expected total in the interval
# [minAdmissibleV(state), maxAdmissibleV(state)].
# So when having aspiration aleph, we can still fulfill it in expectation if it lies in the interval.
Expand Down Expand Up @@ -324,6 +358,9 @@ def aspiration4action(self, state, action, aleph4state):
print(pad(state),"| | ╰ aspiration4action, state",prettyState(state),"action",action,"aleph4state",aleph4state,":",res,"(steadfast)")
return res


### Methods for computing loss components independent of actual policy:

@cache
def disorderingPotential_state(self, state): # recursive
if self.debug or self.verbose:
Expand Down Expand Up @@ -375,6 +412,9 @@ def X(other_state):
print(pad(state),"| | | | ╰ agency_state", prettyState(state), ":", res)
return res


# Methods for computing the policy, propagating aspirations, acting, and observing:

# Based on the admissibility information computed above, we can now construct the policy,
# which is a mapping taking a state and an aspiration interval as input and returning
# a categorical distribution over (action, aleph4action) pairs.
Expand Down Expand Up @@ -538,6 +578,39 @@ def propagateAspiration(self, state, action, aleph4action, Edel, nextState):
including Edel in the formula.
"""

def observe(self, state, delta=None, terminated=False):
"""Called after env.reset() or env.step()"""
self.last_delta = delta
if delta is not None:
if self.total is None:
self.total = delta
else:
self.total += delta
self.terminated = terminated
if not terminated:
if self.last_state is not None:
# propagate the aspiration:
self.last_aleph4state = self.propagateAspiration(self.last_state, self.last_action, self.last_aleph4action, delta, state)
# otherwise it was set in reset()
self.last_state = state
if self.verbose or self.debug:
print("observed state", prettyState(state), ", delta", delta, " (terminated", terminated, "); resulting total", self.total, ", aleph4state", self.last_aleph4state)

def act(self):
"""Choose an action based on current state and aspiration"""
assert not self.terminated, "cannot act after termination"
state, aleph4state = self.last_state, self.last_aleph4state
assert state is not None, "cannot act without having observed a state"
action, aleph4action = self.localPolicy(state, aleph4state).sample()[0]
# TODO later: potentially consult with the principal and change the aleph4state action and/or and/or aleph4action
self.last_action, self.last_aleph4action = action, aleph4action
if self.verbose or self.debug:
print("acting in state", prettyState(state), ", choosing action", action, ", aleph4action", aleph4action)
return action


### Methods for computing loss components dependent on actual policy:

@cache
def V(self, state, aleph4state): # recursive
if self.debug:
Expand Down Expand Up @@ -858,6 +931,9 @@ def X(actionAndAleph):
def randomTieBreaker(self, state, action):
return random.random()


### Methods for computing overall safety loss:

# now we can combine all of the above quantities to a combined (safety) loss function:

# state, action, aleph4state, aleph4action, estActionProbability
Expand All @@ -882,6 +958,9 @@ def getData(self): # FIXME: still needed?
"locs": [state.loc for state in states],
}


### Abstract methods that need to be implemented by subclasses:

@abstractmethod
def maxAdmissibleQ(self, state, action): pass
@abstractmethod
Expand Down Expand Up @@ -979,7 +1058,7 @@ def __init__(self, params, maxAdmissibleQ=None, minAdmissibleQ=None,
self.possible_actions = possible_actions

class AgentMDPPlanning(AspirationAgent):
def __init__(self, params, world=None):
def __init__(self, params=None, world=None):
self.world = world
super().__init__(params)

Expand Down