Skip to content

Commit

Permalink
Simplify the code of AIR and EARTH logic; add functionalities to the …
Browse files Browse the repository at this point in the history
…DefaultTypeDef to be easier to extend.

PiperOrigin-RevId: 572171916
  • Loading branch information
oteret authored and Selforg Gardener committed Oct 10, 2023
1 parent d9caae8 commit 750332c
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 84 deletions.
143 changes: 68 additions & 75 deletions self_organising_systems/biomakerca/cells_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
limitations under the License.
"""

import jax
import jax.numpy as jp
import jax.random as jr

Expand All @@ -32,11 +31,11 @@
from self_organising_systems.biomakerca.env_logic import EMPTY_UPD_MASK
from self_organising_systems.biomakerca.env_logic import EMPTY_UPD_TYPE
from self_organising_systems.biomakerca.env_logic import KeyType
from self_organising_systems.biomakerca.env_logic import make_empty_exclusive_op_cell
from self_organising_systems.biomakerca.env_logic import make_empty_upd_state
from self_organising_systems.biomakerca.env_logic import PerceivedData
from self_organising_systems.biomakerca.env_logic import UpdateOp
from self_organising_systems.biomakerca.environments import EnvConfig
from self_organising_systems.biomakerca.utils import conditional_update


### AIR
Expand All @@ -48,34 +47,30 @@ def air_cell_op(key: KeyType, perc: PerceivedData, config: EnvConfig
AIR simply spreads through neighboring VOID cells.
"""
neigh_type = perc[0]
etd = config.etd
# choose a random neighbor.
k1, key = jr.split(key)
neigh_idx = jr.choice(k1, jp.array([0, 1, 2, 3, 5, 6, 7, 8]))

def action_fn(neigh_idx):
t_upd_mask = EMPTY_UPD_MASK.at[neigh_idx].set(1.0)
a_upd_mask = EMPTY_UPD_MASK
t_upd_type = EMPTY_UPD_TYPE.at[neigh_idx].set(etd.types.AIR)
a_upd_type = EMPTY_UPD_TYPE
t_upd_state = make_empty_upd_state(config)
a_upd_state = make_empty_upd_state(config)
t_upd_id = EMPTY_UPD_ID
a_upd_id = EMPTY_UPD_ID

return ExclusiveOp(
UpdateOp(t_upd_mask, t_upd_type, t_upd_state, t_upd_id),
UpdateOp(a_upd_mask, a_upd_type, a_upd_state, a_upd_id),
)

# needs to be in bounds and only spreads through void.
result = jax.lax.cond(
perc.neigh_type[neigh_idx] == etd.types.VOID,
action_fn, # true
lambda _: make_empty_exclusive_op_cell(config), # false
neigh_idx,

# look for a random neighbor and see if it is VOID
rnd_idx = jr.choice(key, jp.array([0, 1, 2, 3, 5, 6, 7, 8]))
is_void_f = (neigh_type[rnd_idx] == etd.types.VOID).astype(jp.float32)
is_void_i = is_void_f.astype(jp.int32)

# if the target is VOID, we create a new AIR cell there.
t_upd_mask = EMPTY_UPD_MASK.at[rnd_idx].set(is_void_f)
# note that we dont update the actor, so we don't need to fill anything here.
a_upd_mask = EMPTY_UPD_MASK
t_upd_type = EMPTY_UPD_TYPE.at[rnd_idx].set(etd.types.AIR * is_void_i)
# likewise here, if we update a, it is because fire is becoming void.
a_upd_type = EMPTY_UPD_TYPE
t_upd_state = make_empty_upd_state(config)
a_upd_state = make_empty_upd_state(config)
t_upd_id = EMPTY_UPD_ID
a_upd_id = EMPTY_UPD_ID

return ExclusiveOp(
UpdateOp(t_upd_mask, t_upd_type, t_upd_state, t_upd_id),
UpdateOp(a_upd_mask, a_upd_type, a_upd_state, a_upd_id),
)
return result


# EARTH
Expand All @@ -94,60 +89,58 @@ def earth_cell_op(key: KeyType, perc: PerceivedData, config: EnvConfig
neigh_type, neigh_state, neigh_id = perc
etd = config.etd

# first check if you can fall.
# Create the output (and modify it based on conditions)
t_upd_mask = EMPTY_UPD_MASK
a_upd_mask = EMPTY_UPD_MASK
t_upd_type = EMPTY_UPD_TYPE
a_upd_type = EMPTY_UPD_TYPE
t_upd_state = make_empty_upd_state(config)
a_upd_state = make_empty_upd_state(config)
t_upd_id = EMPTY_UPD_ID
a_upd_id = EMPTY_UPD_ID

# First, check if you can fall.
# for now, you can't fall out of bounds.
can_fall = jp.logical_and(
can_fall_i = jp.logical_and(
neigh_type[7] != etd.types.OUT_OF_BOUNDS,
(neigh_type[7] == etd.intangible_mats).any(),
)

# if you can fall, do nothing. Gravity will take care of it.

def execute_move_f(side_idx):
t_upd_mask = EMPTY_UPD_MASK.at[side_idx].set(1.0)
a_upd_mask = EMPTY_UPD_MASK.at[side_idx].set(1.0)

# switch the types, states and ids
t_upd_type = EMPTY_UPD_TYPE.at[side_idx].set(etd.types.EARTH)
a_upd_type = EMPTY_UPD_TYPE.at[side_idx].set(neigh_type[side_idx])
t_upd_state = make_empty_upd_state(config).at[side_idx].set(neigh_state[4])
a_upd_state = (
make_empty_upd_state(config).at[side_idx].set(neigh_state[side_idx])
)
t_upd_id = EMPTY_UPD_ID.at[side_idx].set(neigh_id[4])
a_upd_id = EMPTY_UPD_ID.at[side_idx].set(neigh_id[side_idx])

return ExclusiveOp(
UpdateOp(t_upd_mask, t_upd_type, t_upd_state, t_upd_id),
UpdateOp(a_upd_mask, a_upd_type, a_upd_state, a_upd_id),
)
done_i = can_fall_i

# Else, check if you can fall on a random side.
# both the side and below need to be free.
def side_walk_f(key):
rnd_idx = jr.choice(key, jp.array([3, 5]))
down_rnd_idx = rnd_idx + 3

can_fall = (
(neigh_type[rnd_idx] != etd.types.OUT_OF_BOUNDS)
& (neigh_type[down_rnd_idx] != etd.types.OUT_OF_BOUNDS)
& ((neigh_type[rnd_idx] == etd.intangible_mats).any())
& ((neigh_type[down_rnd_idx] == etd.intangible_mats).any())
)

return jax.lax.cond(
can_fall,
execute_move_f,
lambda _: make_empty_exclusive_op_cell(config),
rnd_idx,
)

k1, key = jr.split(key)
result = jax.lax.cond(
can_fall,
lambda k: make_empty_exclusive_op_cell(config), # true
side_walk_f, # false
k1,
# if yes, do that.
key, ku = jr.split(key)
side_idx = jr.choice(ku, jp.array([3, 5]))
down_side_idx = side_idx + 3

can_fall_to_side_i = (1 - done_i) * (
(neigh_type[side_idx] != etd.types.OUT_OF_BOUNDS)
& (neigh_type[down_side_idx] != etd.types.OUT_OF_BOUNDS)
& ((neigh_type[side_idx] == etd.intangible_mats).any())
& ((neigh_type[down_side_idx] == etd.intangible_mats).any())
)

return result
# update the outputs if true.
t_upd_mask = conditional_update(t_upd_mask, side_idx, 1.0, can_fall_to_side_i)
a_upd_mask = conditional_update(a_upd_mask, side_idx, 1.0, can_fall_to_side_i)
# switch the types, states and ids
t_upd_type = conditional_update(
t_upd_type, side_idx, neigh_type[4], can_fall_to_side_i)
a_upd_type = conditional_update(
a_upd_type, side_idx, neigh_type[side_idx], can_fall_to_side_i)
t_upd_state = conditional_update(
t_upd_state, side_idx, neigh_state[4], can_fall_to_side_i)
a_upd_state = conditional_update(
a_upd_state, side_idx, neigh_state[side_idx], can_fall_to_side_i)
t_upd_id = conditional_update(
t_upd_id, side_idx, neigh_id[4], can_fall_to_side_i.astype(jp.uint32))
a_upd_id = conditional_update(
a_upd_id, side_idx, neigh_id[side_idx],
can_fall_to_side_i.astype(jp.uint32))

return ExclusiveOp(
UpdateOp(t_upd_mask, t_upd_type, t_upd_state, t_upd_id),
UpdateOp(a_upd_mask, a_upd_type, a_upd_state, a_upd_id),
)
39 changes: 30 additions & 9 deletions self_organising_systems/biomakerca/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,16 +303,37 @@ def convert_string_dict_to_type_array(d, types):


class DefaultTypeDef(EnvTypeDef):
"""Example implementation of EnvTypeDef.
"""Default implementation of EnvTypeDef.
This etd is used in the original Biomaker CA paper.
This etd, with its default values, is used in the original Biomaker CA paper.
If you are subclassing this ETD, remember to call super() during init, and
override the relevant properties of mats, such as gravity_mats.
"""

def __init__(self):
def __init__(
self, materials=DEFAULT_MATERIALS,
agent_types=DEFAULT_AGENT_TYPES,
structure_decay_mats_dict=DEFAULT_STRUCTURE_DECAY_MATS_DICT,
dissipation_rate_per_spec_dict=DEFAULT_DISSIPATION_RATE_PER_SPEC_DICT,
type_color_dict=DEFAULT_TYPE_COLOR_DICT):
"""Initialization of DefaultTypeDef.
Args:
materials: List of strings of material types.
agent_types: List of strings of agent types.
structure_decay_mats_dict: dictionary of agent type strings and structural
decay values.
dissipation_rate_per_spec_dict: dictionary of agent type strings and
modifiers of the dissipation based on the agent specialization.
type_color_dict: dictionary of agent type strings and rgb colors for
visualising them.
"""
# initialize types, specialization_idxs, agent_types, materials_list
super().__create_types__(DEFAULT_MATERIALS, DEFAULT_AGENT_TYPES)
super().__create_types__(materials, agent_types)
types = self.types
# setup material specific properties.
# setup material specific properties. If you are subclassing this, consider
# changing these values manually.
self.intangible_mats = jp.array([types.VOID, types.AIR], dtype=jp.int32)
self.gravity_mats = jp.concatenate([
jp.array([types.EARTH], dtype=jp.int32), self.agent_types], 0)
Expand All @@ -322,14 +343,14 @@ def __init__(self):
self.agent_spawnable_mats = jp.array([
types.VOID, types.AIR, types.EARTH], dtype=jp.int32)
self.structure_decay_mats = convert_string_dict_to_type_array(
DEFAULT_STRUCTURE_DECAY_MATS_DICT, types)
structure_decay_mats_dict, types)
self.aging_mats = self.agent_types

self.dissipation_rate_per_spec = convert_string_dict_to_type_array(
DEFAULT_DISSIPATION_RATE_PER_SPEC_DICT, self.specialization_idxs)
dissipation_rate_per_spec_dict, self.specialization_idxs)

self.type_color_map = convert_string_dict_to_type_array(
DEFAULT_TYPE_COLOR_DICT, types)
type_color_dict, types)

# Class abstraction checks for attributes.
super().__post_init__()
Expand Down
6 changes: 6 additions & 0 deletions self_organising_systems/biomakerca/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,9 @@ class dotdict(dict):

def split_2d(key, w, h):
return vmap(lambda k: jr.split(k, h))(jr.split(key, w))


def conditional_update(arr, idx, val, cond):
"""Update arr[idx] to val if cond is True."""
return arr.at[idx].set((1 - cond) * arr[idx] + cond * val)

0 comments on commit 750332c

Please sign in to comment.