From bb32aa72ea8869120f6565dfa44bb321ead38a20 Mon Sep 17 00:00:00 2001 From: Ettore Randazzo Date: Fri, 23 Feb 2024 05:39:00 -0800 Subject: [PATCH] Add sexual reproduction to the default configuration. Refactor how mutators generate mutator parameters so that they are contiguous to their related parameters. Add SexualMutators. Add the possibility of performing sparse computation. Add one more agent type (FLOWER_SEXUAL) Now step_f can return metrics. The current metrics are about reproduce_ops. Fix some nomenclature issues and minor bugs. PiperOrigin-RevId: 609700512 --- .../biomakerca/agent_logic.py | 130 +++- .../biomakerca/env_logic.py | 562 +++++++++++++----- .../biomakerca/environments.py | 49 +- .../notebooks/run_configuration.ipynb | 27 +- .../notebooks/run_sexual_configuration.ipynb | 523 ++++++++++++++++ .../biomakerca/extensions/eruption.py | 6 +- .../biomakerca/mutators.py | 99 ++- .../biomakerca/step_maker.py | 51 +- 8 files changed, 1206 insertions(+), 241 deletions(-) create mode 100644 self_organising_systems/biomakerca/examples/notebooks/run_sexual_configuration.ipynb diff --git a/self_organising_systems/biomakerca/agent_logic.py b/self_organising_systems/biomakerca/agent_logic.py index f5f66a6..a9d8238 100644 --- a/self_organising_systems/biomakerca/agent_logic.py +++ b/self_organising_systems/biomakerca/agent_logic.py @@ -82,18 +82,26 @@ def excl_f(self, key: KeyType, perc: PerceivedData, params: AgentProgramType Return a ExclusiveInterface. """ pass - + @abstractmethod def repr_f(self, key: KeyType, perc: PerceivedData, params: AgentProgramType ) -> ReproduceInterface: """Perform a reproduce function. - + params must be only the reproduce params, not all of them. - + Return a ReproduceInterface. """ pass + @abstractmethod + def get_sex(self, params: AgentProgramType): + """Extract the sex parameter from params. + + The sex parameter is a jax int32. + """ + pass + def __str__(self): return stringify_class(self) @@ -127,17 +135,39 @@ class BasicAgentLogic(AgentLogic): The cells can optionally perceive agent_ids. If so, they will never give nutrients to cells with different agent_ids. + + sex_sensitivity is a parameter that is used to initialize and detect the sex + of a DNA. Check get_sex to see how it is used. + I recommend having a sex_sensitivity equal to 1/sd of the mutator used. """ - def __init__(self, config: EnvConfig, perceive_ids=True, minimal_net=False): + def __init__(self, config: EnvConfig, perceive_ids=True, minimal_net=False, + make_asexual_flowers_likely=True, + make_sexual_flowers_likely=True, + sex_sensitivity=1000., init_noise=None): + """Constructor. + + make_asexual_flowers_likely and make_sexual_flowers_likely are checked to + initialize the parameters to encourage a certain kind of flowers to spawn. + + If init_noise is not None, the initialization for !minimal_net will + initialize several parameters with a certain noise, with it being + init_noise * glorot_initialization of weight matrices. + if init_noise is None, the initial network is a homomorphism of a + minimal_nel (same outputs). + """ self.config = config # the types are perceived as one-hot vectors. self.n_types = len(config.etd.types.keys()) - self.n_spec = 4 # specializations of agents + self.n_spec = len(config.etd.agent_types) # specializations of agents # Whether agent ids are perceivable by the agent. # if set to true, agents do not give energy to agents with different ids. self.perceive_ids = perceive_ids self.minimal_net = minimal_net + self.make_asexual_flowers_likely = make_asexual_flowers_likely + self.make_sexual_flowers_likely = make_sexual_flowers_likely + self.sex_sensitivity = sex_sensitivity + self.init_noise = init_noise self.state_clip_val = 3. ## Parallel op: @@ -207,11 +237,16 @@ def split_params_f( params, (n_par_params, n_par_params + self.excl_num_params), axis=-1) + def _cond_init(self, key, shape): + if self.init_noise is not None: + return glorot_normal(batch_axis=0)(key, shape) * self.init_noise + return jp.zeros(shape) + def dsm_init(self, key): if self.minimal_net: # in this case, the agent cannot modify its internal state. return (jp.empty(0),) - + # Set initial effect on state to zero. # a 2 layer NN. # # We look at proportions of neighboring cells and your internal state. @@ -225,8 +260,8 @@ def dsm_init(self, key): ku, key = jr.split(key) dw0 = glorot_normal(batch_axis=0)(ku, (self.n_spec, insize, hsize)) db0 = jp.zeros((self.n_spec, hsize)) - # output is defaulted to zero. - dw1 = jp.zeros((self.n_spec, hsize, outsize)) + key, ku = jr.split(key) + dw1 = self._cond_init(ku, (self.n_spec, hsize, outsize)) db1 = jp.zeros((self.n_spec, outsize)) return (dw0, db0, dw1, db1) @@ -281,18 +316,30 @@ def nsl_init(self, key): w = w.at[:, etd.types.AGENT_FLOWER, spec_idxs.AGENT_UNSPECIALIZED].set( div / 8.0 ) + w = w.at[:, etd.types.AGENT_FLOWER_SEXUAL, + spec_idxs.AGENT_UNSPECIALIZED].set(div / 8.0) w = w.at[:, etd.types.EARTH, spec_idxs.AGENT_UNSPECIALIZED].set( -div / 8.0 ) w = w.at[:, etd.types.AIR, spec_idxs.AGENT_UNSPECIALIZED].set( -div / 8.0 ) - # Flowers only grow if they are surrounded by leaves and some air. - w = w.at[:, etd.types.AGENT_LEAF, spec_idxs.AGENT_FLOWER].set(div / 4.0) - w = w.at[:, etd.types.AIR, spec_idxs.AGENT_FLOWER].set(div / 2.0) + + if self.make_asexual_flowers_likely: + # Flowers only grow if they are surrounded by leaves and some air. + w = w.at[:, etd.types.AGENT_LEAF, spec_idxs.AGENT_FLOWER].set(div / 4.0) + w = w.at[:, etd.types.AIR, spec_idxs.AGENT_FLOWER].set(div / 2.0) + + if self.make_sexual_flowers_likely: + # Sexual flowers start with the same chance of being spawned. + w = w.at[:, etd.types.AGENT_LEAF, spec_idxs.AGENT_FLOWER_SEXUAL].set( + div / 4.0) + w = w.at[:, etd.types.AIR, spec_idxs.AGENT_FLOWER_SEXUAL].set(div / 2.0) # If you are a flower, never change! w = w.at[spec_idxs.AGENT_FLOWER, :, spec_idxs.AGENT_FLOWER].set(div) + w = w.at[spec_idxs.AGENT_FLOWER_SEXUAL, :, + spec_idxs.AGENT_FLOWER_SEXUAL].set(div) if self.minimal_net: return w, b @@ -308,8 +355,8 @@ def nsl_init(self, key): ku, key = jr.split(key) dw0 = glorot_normal(batch_axis=0)(ku, (self.n_spec, insize, hsize)) db0 = jp.zeros((self.n_spec, hsize)) - # output is defaulted to zero. - dw1 = jp.zeros((self.n_spec, hsize, outsize)) + key, ku = jr.split(key) + dw1 = self._cond_init(ku, (self.n_spec, hsize, outsize)) db1 = jp.zeros((self.n_spec, outsize)) return (w, b), (dw0, db0, dw1, db1) @@ -374,8 +421,8 @@ def denm_init(self, key): ku, key = jr.split(key) dw0 = glorot_normal(batch_axis=0)(ku, (self.n_spec, insize, hsize)) db0 = jp.zeros((self.n_spec, hsize)) - # output is defaulted to zero. - dw1 = jp.zeros((self.n_spec, hsize, outsize)) + key, ku = jr.split(key) + dw1 = self._cond_init(ku, (self.n_spec, hsize, outsize)) db1 = jp.zeros((self.n_spec, outsize)) return (b, keep_en, (dw0, db0, dw1, db1)) @@ -407,7 +454,7 @@ def denm_f(self, params, norm_neigh_state, i, neigh_type, neigh_id, self_en): # defaults to 0 if it is not an agent. It doesn't matter since nonagents # are later filtered out. neigh_spec = jax.nn.one_hot( - self.config.etd.get_agent_specialization_idx(neigh_type), 4) + self.config.etd.get_agent_specialization_idx(neigh_type), self.n_spec) def compute_logits_f(t_state, t_spec): inputs = jp.concatenate([norm_self_state, t_state, t_spec], -1) @@ -504,8 +551,8 @@ def excl_init(self, key): ku, key = jr.split(key) dw0 = glorot_normal(batch_axis=0)(ku, (self.n_spec, insize, hsize)) db0 = jp.zeros((self.n_spec, hsize)) - # output is defaulted to zero. - dw1 = jp.zeros((self.n_spec, hsize, outsize)) + key, ku = jr.split(key) + dw1 = self._cond_init(ku, (self.n_spec, hsize, outsize)) db1 = jp.zeros((self.n_spec, outsize)) return minimal_params, (dw0, db0, dw1, db1) @@ -513,11 +560,28 @@ def excl_init(self, key): def repr_init(self, key): """Initialization for repr_f. - Simply create a default requirement of nutrients for triggering the op. + Create a default requirement of nutrients for triggering the op. Note that this value is normalized by the material_nutrient_cap. + This value is different for FLOWER and FLOWER_SEXUAL. + Also, create a sex. + A sex is inferred by doing floor(sex_p * self.sex_sensitivity). """ - return (self.config.reproduce_cost + (self.config.dissipation_per_step * 4) - + self.config.specialize_cost * 2) / self.config.material_nutrient_cap + min_repr_en = ( + self.config.reproduce_cost + (self.config.dissipation_per_step * 4) + + self.config.specialize_cost * 2) / self.config.material_nutrient_cap + # sexual reproduction combines energy of two plants. + min_repr_en_sex = min_repr_en / 2. + + # format params fn has issues with zero dimensional parameters. So as a hack + # I make this a one dimensional array. + sex_m = (jr.uniform(key, (1,)) < 0.5).astype(jp.float32) + sex_p = (0.5 + sex_m) / self.sex_sensitivity + return (min_repr_en, min_repr_en_sex, sex_p) + + def get_sex(self, params: AgentProgramType): + _, _, repr_p = self.split_params_f(params) + _,_, sex_p = self._repr_format_params_fn(repr_p) + return jp.floor(sex_p[..., 0] * self.sex_sensitivity).astype(jp.int32) def initialize(self, key): k1, k2, k3, k4, k5 = jr.split(key, 5) @@ -597,7 +661,6 @@ def par_f(self, key, perc, params): denergy_neigh = self.denm_f( denm_params, norm_neigh_state, spec_idx, neigh_type, neigh_id, self_en) - return ParallelInterface(denergy_neigh, dstate, new_spec_logit) def excl_f(self, key, perc, params): @@ -646,7 +709,8 @@ def excl_f(self, key, perc, params): # This is a probability depending on the average number of neighbors. avg_neigh_types = self.get_avg_neigh_types(neigh_type) avg_agents = avg_neigh_types[ - etd.types.AGENT_UNSPECIALIZED : etd.types.AGENT_UNSPECIALIZED + 4 + etd.types.AGENT_UNSPECIALIZED : + etd.types.AGENT_UNSPECIALIZED + self.n_spec ].sum(-1) ku, key = jax.random.split(key) rand_prob_sp = (jr.uniform(ku) < jax.nn.sigmoid( @@ -712,18 +776,26 @@ def excl_f(self, key, perc, params): def repr_f(self, key, perc, params): """Implementation of repr_f. - - A simple mask that entirely depends on how much energy we have. + + A simple mask that depends on how much energy we have, and whether the + flower is sexual or asexual. """ + min_repr_en, min_repr_en_sex, _ = self._repr_format_params_fn(params) + neigh_type, _, _ = perc + self_type = neigh_type[4] + norm_self_en = (perc.neigh_state[4, evm.EN_ST : evm.EN_ST + 2] / self.config.material_nutrient_cap) - # we could also check whether we are flowers, but it doesn't matter since - # only flowers can reproduce and it gets masked afterwards. + is_flower = self_type == self.config.etd.types.AGENT_FLOWER + is_flower_sex = self_type == self.config.etd.types.AGENT_FLOWER_SEXUAL + flower_mask_logit = is_flower * (norm_self_en > min_repr_en).all().astype( + jp.float32) + flower_sex_mask_logit = is_flower_sex * ( + norm_self_en > min_repr_en_sex).all().astype(jp.float32) # The switch checks whether the output logit is >0. - mask_logit = (norm_self_en > params).all().astype(jp.float32) - return ReproduceInterface(mask_logit) + return ReproduceInterface(flower_mask_logit + flower_sex_mask_logit) def adapt_dna_to_different_basic_logic( diff --git a/self_organising_systems/biomakerca/env_logic.py b/self_organising_systems/biomakerca/env_logic.py index 8b06aff..873149a 100644 --- a/self_organising_systems/biomakerca/env_logic.py +++ b/self_organising_systems/biomakerca/env_logic.py @@ -30,9 +30,10 @@ from collections import namedtuple from functools import partial from typing import Callable, Iterable +import copy import flax -from jax import vmap +from jax import vmap, jit import jax.numpy as jp import jax.random as jr import jax.scipy @@ -152,20 +153,20 @@ def make_empty_parallel_op_cell(config): # ReproduceOp. # flowers can create ReproduceInterfaces that are converted to ReproduceOps. -# Nutrients from the flower get converted into stored_en and we record the +# Nutrients from the flower get converted into stored_en and we record the # original position and agent id. A new seed is then tentatively created # following the logic of the environment decided upon. if "ReproduceOp" not in globals(): - ReproduceOp = namedtuple("ReproduceOp", "mask pos stored_en aid") + ReproduceOp = namedtuple("ReproduceOp", "mask pos stored_en aid is_sexual") # Helpers for making ReproduceOps. EMPTY_REPRODUCE_OP = ReproduceOp( jp.array(0.), jp.zeros([2], dtype=jp.int32), jp.zeros([2]), - jp.array(0, dtype=jp.uint32)) + jp.array(0, dtype=jp.uint32), jp.array(0, dtype=jp.int32)) ### ReproduceInterface # Interface for reproduction of agents. -# If a flower triggers reproduction, all of its energy is converted for the +# If a flower triggers reproduction, all of its energy is converted for the # new seed. However, reproduction has a cost and it may fail. # Arguments: # mask_logit: whether or not to perform reproduction. True if > 0. @@ -180,7 +181,7 @@ def make_empty_parallel_op_cell(config): # No matter the kind of cell, be it a material or an agent, they can only # perceive a limited amount of data. This is the 3x3 neighborhood of the # environment. The difference from Environment is that each cell has 9 values -# per grid. That is, neigh_type will be [w,h,9] as opposed to type_grid: [w,h]. +# per grid. That is, neigh_type will be [h,w,9] as opposed to type_grid: [h,w]. # Also note that perceived data is intended to be passed using vmap2 so that # each cell only perceives their neighbors, that is, their input is of size: # neigh_type:[9], neigh_state:[9,env_state_size], neigh_id:[9]. @@ -191,7 +192,7 @@ def make_empty_parallel_op_cell(config): def perceive_neighbors(env: Environment, etd: EnvTypeDef) -> PerceivedData: - """Return PerceivedData (gridwise, with leading axes of size [w,h]). + """Return PerceivedData (gridwise, with leading axes of size [h,w]). Cells can only perceive their neighbors. Of the neighbors, they can perceive all: type, state and agent_id. @@ -211,7 +212,7 @@ def perceive_neighbors(env: Environment, etd: EnvTypeDef) -> PerceivedData: neigh_state = jax.lax.conv_general_dilated_patches( env.state_grid[None,:], (3, 3), (1, 1), "SAME", dimension_numbers=("NHWC", "OIHW", "NHWC"))[0] - # We want to have [w,h,9,c] so that the indexing is intuitive and consistent + # We want to have [h,w,9,c] so that the indexing is intuitive and consistent # for all neigh vectors. env_state_size = env.state_grid.shape[-1] neigh_state = neigh_state.reshape( @@ -229,21 +230,69 @@ def perceive_neighbors(env: Environment, etd: EnvTypeDef) -> PerceivedData: # These functions make ExclusiveOps and execute them into the environment. def vectorize_cell_exclusive_f( - cell_f: Callable[[KeyType, PerceivedData], ExclusiveOp], w, h): + cell_f: Callable[[KeyType, PerceivedData], ExclusiveOp], h, w): """Vectorizes a cell's exclusive_f to work from 0d to 2d.""" - return lambda key, perc: vmap2(cell_f)(split_2d(key, w, h), perc) + return lambda key, perc: vmap2(cell_f)(split_2d(key, h, w), perc) def vectorize_agent_cell_f( cell_f: Callable[[KeyType, PerceivedData, AgentProgramType], - ExclusiveOp|ParallelOp], w, h): + ExclusiveOp|ParallelOp], h, w): """Vectorizes an agent cell_f to work from 0d to 2d. This works for both ExclusiveOp and ParallelOp. Note that the cell_f *requires* to have the proper argument name 'programs', so this is an informal interface. """ return lambda key, perc, progs: vmap2(partial(cell_f, programs=progs))( - split_2d(key, w, h), perc) + split_2d(key, h, w), perc) + + +def compute_sparse_agent_cell_f( + key: KeyType, + cell_f: (Callable[[KeyType, PerceivedData, AgentProgramType], + ExclusiveOp|ParallelOp] | + Callable[[ + KeyType, PerceivedData, CellPositionType, AgentProgramType], + ReproduceOp]), + perc: PerceivedData, + env_type_grid, programs: AgentProgramType, etd: EnvTypeDef, + n_sparse_max: int, + b_pos=None): + """Compute a sparse agent cell_f. + + This works for ExclusiveOp, ParallelOp and ReproduceOp. + Note that the cell_f *requires* to have the proper argument name 'programs', + so this is an informal interface. + + if it is used for a ReproduceOp, b_pos needs to be set, being a flat list of + (y,x) positions. + """ + # get the args of alive cells. + # note that cell_f will not work by itself if the cell is not an agent. + is_agent_flat = etd.is_agent_fn(env_type_grid.flatten()) + # leftmost are not agents, rightmost are agents. + sorted_agent_idx = jp.argsort(is_agent_flat) + sparse_idx = sorted_agent_idx[-n_sparse_max:] + + # compute, sparsely, cell_f + v_part_cell_f = vmap(partial(cell_f, programs=programs)) + v_keys = jr.split(key, n_sparse_max) + sparse_perc_flat = jax.tree_util.tree_map( + lambda x: x.reshape((-1,)+ x.shape[2:])[sparse_idx], perc) + if b_pos is None: + sparse_output_tree = v_part_cell_f(v_keys, sparse_perc_flat) + else: + # this is a reproduce op, and therefore requires pos too. + sparse_output_tree = v_part_cell_f( + v_keys, sparse_perc_flat, b_pos[sparse_idx]) + + # scatter the result. + # we also need to reshape so that it is [h, w, ...] shape. + return jax.tree_util.tree_map( + lambda x: jax.ops.segment_sum( + x, sparse_idx, num_segments=is_agent_flat.shape[0]).reshape( + env_type_grid.shape + x.shape[1:]), + sparse_output_tree) def make_material_exclusive_interface( @@ -396,33 +445,43 @@ def execute_and_aggregate_exclusive_ops( excl_fs: Iterable[tuple[EnvTypeType, Callable[[KeyType, PerceivedData], ExclusiveOp]]], agent_excl_f: Callable[[ - KeyType, PerceivedData, AgentProgramType], ExclusiveInterface] + KeyType, PerceivedData, AgentProgramType], ExclusiveInterface], + n_sparse_max: int|None = None ) -> ExclusiveOp: """Execute all exclusive functions and aggregate them all into a single ExclusiveOp for each cell. - + This function constructs sanitized interfaces for the input excl_fs and agent_excl_f, making sure that no laws of physics are broken. - It also then executes these operation for each cell in the grid and then - aggregates the resulting ExclusiveOp for each cell. + if n_sparse_max is None, it then executes these operation for each cell in the + grid. If n_sparse_max is an integer, it instead performs a sparse computation + for agent operations, only for alive agent cells. The computation is capped, + so some agents may be skipped if the cap is too low. The agents being chosen + in that case are NOT ensured to be random. Consider changing the code to allow + for a random subset of cells to be run if you want that behavior. + It then aggregates the resulting ExclusiveOp for each cell. Aggregation can be done because *only one function*, at most, will be allowed to output nonzero values for each cell. """ etd = config.etd perc = perceive_neighbors(env, etd) - w, h = env.type_grid.shape + h, w = env.type_grid.shape v_excl_fs = [vectorize_cell_exclusive_f( - make_material_exclusive_interface(t, f, config), w, h) for (t, f) + make_material_exclusive_interface(t, f, config), h, w) for (t, f) in excl_fs] - v_agent_excl_f = vectorize_agent_cell_f( - make_agent_exclusive_interface(agent_excl_f, config), w, h) - - k1, k2, key = jr.split(key, 3) - w, h = env.type_grid.shape - excl_ops = [f(k, perc) for k, f in - zip(jr.split(k1, len(v_excl_fs)), v_excl_fs)] + [ - v_agent_excl_f(k2, perc, programs) - ] + + agent_excl_intf_f = make_agent_exclusive_interface(agent_excl_f, config) + k1, key = jr.split(key) + if n_sparse_max is None: + v_agent_excl_f = vectorize_agent_cell_f(agent_excl_intf_f, h, w) + agent_excl_op = v_agent_excl_f(k1, perc, programs) + else: + agent_excl_op = compute_sparse_agent_cell_f( + k1, agent_excl_intf_f, perc, env.type_grid, programs, etd, n_sparse_max) + + k1, key = jr.split(key) + excl_ops = [f(k, perc) for k, f in + zip(jr.split(k1, len(v_excl_fs)), v_excl_fs)] + [agent_excl_op] excl_op = tree_map_sum_ops(excl_ops) return excl_op @@ -446,7 +505,7 @@ def env_exclusive_decision( key: KeyType, env: Environment, excl_op: ExclusiveOp): """Choose up to one excl_op for each cell and execute them, updating the env. """ - w, h, chn = env.state_grid.shape + h, w, chn = env.state_grid.shape def extract_patches_that_target_cell(x): """Extract all ExclusiveOps subarray of neighbors that target each cell. @@ -456,7 +515,7 @@ def extract_patches_that_target_cell(x): ExclusiveOps that target a given cell, and return them, so that one of them can be chosen to be executed. """ - # input is either (w,h,9) or (w,h,9,c) + # input is either (h,w,9) or (h,w,9,c) ndim = x.ndim x_shape = x.shape old_dtype = x.dtype @@ -468,7 +527,7 @@ def extract_patches_that_target_cell(x): neigh_state = jax.lax.conv_general_dilated_patches( x[None,:], (3, 3), (1, 1), "SAME", dimension_numbers=("NHWC", "OIHW", "NHWC"))[0] - # make it (w, h, k, 9) + # make it (h, w, k, 9) # where k is either 9 or 9*c neigh_state = neigh_state.reshape( neigh_state.shape[:-1] + (x.shape[-1], 9)) @@ -477,7 +536,7 @@ def extract_patches_that_target_cell(x): neigh_state = neigh_state.reshape( neigh_state.shape[:-2] + x_shape[-2:] + (9,)).transpose( (0, 1, 2, 4, 3)) - # now it's either (w,h,9,9) or (w,h,9,9,c). + # now it's either (h,w,9,9) or (h,w,9,9,c). # The leftmost '9' refers to the position of the cell's neighbors. # the rightmost '9' indicates the exclusive op slice of the neighbor, which # in turn targets 9 neighbors. @@ -498,7 +557,7 @@ def extract_patches_that_target_cell(x): # choose a random one from them. rnd_neigh_idx = vmap2(_cell_choose_random_action)( - split_2d(key, w, h), excl_op_neighs) + split_2d(key, h, w), excl_op_neighs) # Do so by zeroing out everything except the chosen update. action_mask = jax.nn.one_hot(rnd_neigh_idx, 9) @@ -531,12 +590,12 @@ def extract_patches_that_target_cell(x): ) def map_to_actor(x): pad_x = jp.pad(x, ((1, 1), (1, 1), (0, 0))) - return jp.stack([jax.lax.dynamic_slice(pad_x, s, (w, h, 9))[:,:,n] for n, s + return jp.stack([jax.lax.dynamic_slice(pad_x, s, (h, w, 9))[:,:,n] for n, s in enumerate(MAP_TO_ACTOR_SLICING)], 2) def map_to_actor_e(x): pad_x = jp.pad(x, ((1, 1), (1, 1), (0, 0), (0, 0))) - return jp.stack([jax.lax.dynamic_slice(pad_x, s+(0,), (w, h, 9, chn))[:,:,n] + return jp.stack([jax.lax.dynamic_slice(pad_x, s+(0,), (h, w, 9, chn))[:,:,n] for n, s in enumerate(MAP_TO_ACTOR_SLICING)], 2) a_upd_mask = map_to_actor(a_upd_mask_from_t).sum(-1) @@ -568,28 +627,33 @@ def env_perform_exclusive_update( excl_fs: Iterable[tuple[EnvTypeType, Callable[[KeyType, PerceivedData], ExclusiveOp]]], agent_excl_f: Callable[[ - KeyType, PerceivedData, AgentProgramType], ExclusiveInterface] + KeyType, PerceivedData, AgentProgramType], ExclusiveInterface], + n_sparse_max: int|None = None ) -> Environment: """Perform exclusive operations in the environment. - + This is the function that should be used for high level step_env design. - + Arguments: key: a jax random number generator. env: the input environment to be modified. programs: params of agents that govern their agent_excl_f. They should be one line for each agent_id allowed in the environment. config: EnvConfig describing the physics of the environment. - excl_fs: list of pairs of (env_type, func) where env_type is the type of + excl_fs: list of pairs of (env_type, func) where env_type is the type of material that triggers the exclusive func. - agent_excl_f: The exclusive function that agents perform. It takes as + agent_excl_f: The exclusive function that agents perform. It takes as input a exclusive program and outputs a ExclusiveInterface. + n_sparse_max: if None, agent_excl_f are performed for the entire grid. If it + is an integer, instead, we perform a sparse computation masked by actual + agent cells. Note that this is capped and if more agents are alive, an + undefined subset of agent cells will be run. Returns: an updated environment. """ k1, key = jr.split(key) excl_op = execute_and_aggregate_exclusive_ops( - k1, env, programs, config, excl_fs, agent_excl_f) + k1, env, programs, config, excl_fs, agent_excl_f, n_sparse_max) key, key1 = jr.split(key) env = env_exclusive_decision(key1, env, excl_op) @@ -685,7 +749,8 @@ def env_perform_parallel_update( key: KeyType, env: Environment, programs: AgentProgramType, config: EnvConfig, par_f: Callable[[ - KeyType, PerceivedData, AgentProgramType], ParallelInterface] + KeyType, PerceivedData, AgentProgramType], ParallelInterface], + n_sparse_max: int|None = None ) -> Environment: """Perform parallel operations in the environment. @@ -699,23 +764,29 @@ def env_perform_parallel_update( config: EnvConfig describing the physics of the environment. par_f: The parallel function that agents perform. It takes as input a parallel program and outputs a ParallelInterface. + n_sparse_max: either an int or None. If set to int, we will use a budget for + the amounts of agent operations allowed at each step. Returns: an updated environment. """ # First compute the ParallelOp for each cell. - w, h = env.type_grid.shape - v_par_f = vectorize_agent_cell_f(make_agent_parallel_interface( - par_f, config), w, h) + h, w = env.type_grid.shape etd = config.etd perc = perceive_neighbors(env, etd) + par_interface_f = make_agent_parallel_interface( par_f, config) k1, key = jr.split(key) - par_op = v_par_f(k1, perc, programs) + if n_sparse_max is None: + v_par_interface_f = vectorize_agent_cell_f(par_interface_f, h, w) + par_op = v_par_interface_f(k1, perc, programs) + else: + par_op = compute_sparse_agent_cell_f( + k1, par_interface_f, perc, env.type_grid, programs, etd, n_sparse_max) # Then process them. mask, denergy_neigh, dstate, new_type = par_op - w, h = env.type_grid.shape + h, w = env.type_grid.shape MAP_TO_NEIGH_SLICING = ( (2, 2, 0), # indexing (-1, -1) @@ -730,7 +801,7 @@ def env_perform_parallel_update( ) def map_to_neigh(x): pad_x = jp.pad(x, ((1, 1), (1, 1), (0, 0), (0, 0))) - return jp.stack([jax.lax.dynamic_slice(pad_x, s+(0,), (w, h, 9, 2))[:, :, n] + return jp.stack([jax.lax.dynamic_slice(pad_x, s+(0,), (h, w, 9, 2))[:, :, n] for n, s in enumerate(MAP_TO_NEIGH_SLICING)], 2) denergy = map_to_neigh(denergy_neigh).sum(2) @@ -757,9 +828,9 @@ def map_to_neigh(x): # Functions for making and processing ReproduceOps. -def vectorize_reproduce_f(repr_f, w, h): +def vectorize_reproduce_f(repr_f, h, w): return lambda key, perc, pos, progs: vmap2(partial(repr_f, programs=progs))( - split_2d(key, w, h), perc, pos) + split_2d(key, h, w), perc, pos) def _convert_to_reproduce_op( @@ -781,20 +852,27 @@ def _convert_to_reproduce_op( # must want to reproduce want_to_repr_m = (mask_logit > 0.0).astype(jp.float32) - # must be a flower - is_flower_m = (self_type == env_config.etd.types.AGENT_FLOWER).astype( - jp.float32) + # must be a flower. + # depending on the kind, flag it as either sexual or asexual reproduction. + is_flower_m = self_type == env_config.etd.types.AGENT_FLOWER + is_flower_sexual_m = self_type == env_config.etd.types.AGENT_FLOWER_SEXUAL + + is_flower_type = is_flower_m | is_flower_sexual_m # must have enough energy. + min_repr_cost = env_config.reproduce_cost * ( + 1 *is_flower_m + 0.5 * is_flower_sexual_m) has_enough_en_m = ( - (self_en >= env_config.reproduce_cost).all().astype(jp.float32) + (self_en >= min_repr_cost).all().astype(jp.float32) ) - mask = want_to_repr_m * is_flower_m * has_enough_en_m + mask = want_to_repr_m * is_flower_type * has_enough_en_m stored_en = (self_en - env_config.reproduce_cost) * mask aid = (self_id * mask).astype(jp.uint32) - return ReproduceOp(mask, pos, stored_en, aid) + # don't use booleans. + return ReproduceOp(mask, pos, stored_en, aid, + is_flower_sexual_m.astype(jp.int32)) def make_agent_reproduce_interface( @@ -816,8 +894,11 @@ def make_agent_reproduce_interface( # Note that we need to pass the extra information of the position of the cell. # this is used in the conversion to ReproduceOp, not by the cell. def f(key, perc, pos, programs): - # it has to be a flower! - is_correct_type = config.etd.types.AGENT_FLOWER == perc.neigh_type[4] + # it has to be a flower type (sexual or asexual) + is_correct_type = ( + perc.neigh_type[4] == jp.array( + [config.etd.types.AGENT_FLOWER, + config.etd.types.AGENT_FLOWER_SEXUAL], dtype=jp.uint32)).any(-1) curr_agent_id = perc.neigh_id[4] program = programs[curr_agent_id] @@ -874,23 +955,23 @@ def _select_random_position_for_seed_within_range( return chosen_idx, n_available > 0 -def env_perform_one_reproduce_op( - key: KeyType, env: Environment, repr_op: ReproduceOp, config: EnvConfig): - """Perform one single ReproduceOp. +def env_try_place_one_seed( + key: KeyType, env: Environment, op_info, config: EnvConfig): + """Try to place one seed in the environment. - For a ReproduceOp to be successful, fertile soil in the neighborhood must be + For this op to be successful, fertile soil in the neighborhood must be found. If it is, a new seed (two unspecialized cells) are placed in the environment. Their age is reset to zero, and they may have a different agent_id than their parent, if mutation was set to true. """ - mask, pos, stored_en, aid = repr_op + mask, pos, stored_en, aid = op_info etd = config.etd def true_fn(env): best_idx_per_column, column_m = find_fertile_soil(env.type_grid, etd) t_column, column_valid = _select_random_position_for_seed_within_range( - key, pos[1], config.reproduce_min_dist, config.reproduce_max_dist, + key, pos[1], config.reproduce_min_dist, config.reproduce_max_dist, column_m) def true_fn2(env): @@ -904,28 +985,34 @@ def true_fn2(env): return jax.lax.cond(mask, true_fn, lambda env: env, env) -def env_reproduce_operations(key, env, b_repr_op, config): - """Perform a batch of ReproduceOps. +def env_try_place_seeds(key, env, b_op_info, config): + """Try to place seeds in the environment. These are performed sequentially. Note that some ops may be masked and therefore be noops. """ - def body_f(carry, repr_op): + def body_f(carry, op_info): env, key = carry key, ku = jr.split(key) - env = env_perform_one_reproduce_op(ku, env, repr_op, config) + env = env_try_place_one_seed(ku, env, op_info, config) return (env, key), 0 - (env, key), _ = jax.lax.scan(body_f, (env, key), b_repr_op) + (env, key), _ = jax.lax.scan(body_f, (env, key), b_op_info) return env -def _select_subset_of_reproduce_ops(key, b_repr_op, neigh_type, config): +def _select_subset_of_reproduce_ops( + key, b_repr_op, neigh_type, config, select_sexual_repr): # Only a small subset of possible ReproduceOps are selected at each step. - # This is config dependent (config.n_reproduce_per_step). b_pos = b_repr_op.pos mask_flat = b_repr_op.mask.flatten() b_pos_flat = b_pos.reshape((-1, 2)) + n_repr_ops = (config.n_sexual_reproduce_per_step * 2 if select_sexual_repr + else config.n_reproduce_per_step) + + # Because sexual and asexual reproductions are fundamentally different, use + # select_sexual_repr to decide which kinds of reproduce_op to extract. + mask_flat *= (select_sexual_repr == b_repr_op.is_sexual).flatten() p_logits = mask_flat # Moreover, flowers can only reproduce if in contact with air. The more air, @@ -939,7 +1026,7 @@ def _select_subset_of_reproduce_ops(key, b_repr_op, neigh_type, config): k1, key = jr.split(key) selected_pos = jr.choice( k1, b_pos_flat, p=p_logits/p_logits.sum().clip(1), - shape=(config.n_reproduce_per_step,), replace=False) + shape=(n_repr_ops,), replace=False) return selected_pos @@ -952,11 +1039,21 @@ def env_perform_reproduce_update( mutate_programs=False, programs: (AgentProgramType | None) = None, mutate_f: (Callable[[KeyType, AgentProgramType], AgentProgramType] | None - ) = None): + ) = None, + enable_asexual_reproduction=True, + enable_sexual_reproduction=True, + does_sex_matter=True, + sexual_mutate_f: (Callable[[KeyType, AgentProgramType, AgentProgramType], + AgentProgramType] | None + ) = None, + split_mutator_params_f = None, + get_sex_f: (Callable[[AgentProgramType], AgentProgramType] | None) = None, + n_sparse_max: int|None = None, + return_metrics=False): """Perform reproduce operations in the environment. - + This is the function that should be used for high level step_env design. - + Arguments: key: a jax random number generator. env: the input environment to be modified. @@ -974,37 +1071,102 @@ def env_perform_reproduce_update( if mutate_programs is set to True. mutate_f: The mutation function to mutate 'programs'. This is intended to be the method 'mutate' of a Mutator class. + enable_asexual_reproduction: if set to True, asexual reproduction is + enabled. At least enable_asexual_reproduction or + enable_sexual_reproduction must be set to True. + enable_sexual_reproduction: if set to True, sexual reproduction is enabled. + does_sex_matter: if set to True, and enable_sexual_reproduction is True, + then sexual reproduction can only occur between different sex entities. + sexual_mutate_f: The sexual mutation function to mutate 'programs'. This is + intended to be the method 'mutate' of a SexualMutator class. + split_mutator_params_f: a mutator function that explains how to split the + parameters of the mutator. + get_sex_f: The function that extracts the sex of the agent from its params. + n_sparse_max: either an int or None. If set to int, we will use a budget for + the amounts of agent operations allowed at each step. + return_metrics: if True, return metrics about whether reproduction occurred, + and who are the parents and children. Returns: an updated environment. if mutate_programs is True, it also returns the updated programs. """ + assert enable_asexual_reproduction or enable_sexual_reproduction etd = config.etd perc = perceive_neighbors(env, etd) - k1, key = jr.split(key) - w, h = env.type_grid.shape + h, w = env.type_grid.shape - v_repr_f = vectorize_reproduce_f( - make_agent_reproduce_interface(repr_f, config), w, h) - - b_pos = jp.stack(jp.meshgrid(jp.arange(w), jp.arange(h), indexing="ij"), -1) - b_repr_op = v_repr_f(k1, perc, b_pos, repr_programs) + b_pos = jp.stack(jp.meshgrid(jp.arange(h), jp.arange(w), indexing="ij"), -1) + repr_interface_f = make_agent_reproduce_interface(repr_f, config) + k1, key = jr.split(key) + if n_sparse_max is None: + v_repr_f = vectorize_reproduce_f(repr_interface_f, h, w) + b_repr_op = v_repr_f(k1, perc, b_pos, repr_programs) + else: + b_repr_op = compute_sparse_agent_cell_f( + k1, repr_interface_f, perc, env.type_grid, repr_programs, etd, + n_sparse_max, b_pos.reshape((h*w, 2))) # Only a small subset of possible ReproduceOps are selected at each step. - # This is config dependent (config.n_reproduce_per_step). - selected_pos = _select_subset_of_reproduce_ops( - k1, b_repr_op, perc.neigh_type, config) - spx, spy = selected_pos[:, 0], selected_pos[:, 1] - - selected_mask = b_repr_op.mask[spx, spy] - selected_stored_en = b_repr_op.stored_en[spx, spy] - selected_aid = b_repr_op.aid[spx, spy] - - if mutate_programs: + # sexual and asexual reproductions are treated independently. + if enable_asexual_reproduction: + ## First, asexual reproduction. + # This is config dependent (config.n_reproduce_per_step). + k1, key = jr.split(key) + selected_pos = _select_subset_of_reproduce_ops( + k1, b_repr_op, perc.neigh_type, config, 0) + spx, spy = selected_pos[:, 0], selected_pos[:, 1] + + selected_mask = b_repr_op.mask[spx, spy] + selected_stored_en = b_repr_op.stored_en[spx, spy] + selected_aid = b_repr_op.aid[spx, spy] + + if enable_sexual_reproduction: + ## Sexual reproduction + # This is config dependent (2 * config.n_sexual_reproduce_per_step). + # twice, because we need 2 flowers per sexual operation. + k1, key = jr.split(key) + selected_pos_sx = _select_subset_of_reproduce_ops( + k1, b_repr_op, perc.neigh_type, config, 1) + spx_sx, spy_sx = selected_pos_sx[:, 0], selected_pos_sx[:, 1] + selected_aid_sx = b_repr_op.aid[spx_sx, spy_sx] + selected_aid_sx_1 = selected_aid_sx[::2] + selected_aid_sx_2 = selected_aid_sx[1::2] + + selected_mask_sx = b_repr_op.mask[spx_sx, spy_sx] + # to reproduce, selected_mask_sx of both parent cells to be True. Representing + # whether there actually are sexual flowers that triggered a reproduce op. + pair_repr_mask_sx = selected_mask_sx[::2] * selected_mask_sx[1::2] + if does_sex_matter: + # moreover, their sex has to be different. Done to prevent selfing. + # we could consider moving this to the sexual mutator as a responsibility. + selected_programs_sx = programs[selected_aid_sx] + logic_params, _ = vmap(split_mutator_params_f)(selected_programs_sx) + sex_of_selected = vmap(get_sex_f)(logic_params) + pair_repr_mask_sx = pair_repr_mask_sx * ( + sex_of_selected[::2] != sex_of_selected[1::2]) + + selected_stored_en_sx = b_repr_op.stored_en[spx_sx, spy_sx] + # the total energy is the sum of the two flowers. + pair_stored_en_sx = selected_stored_en_sx[::2] + selected_stored_en_sx[1::2] + + if not mutate_programs: + if enable_asexual_reproduction: + repr_aid = selected_aid + # with sexual reproduction, we assume that programs are mutated. + # but, as a failsafe, we randomly pick a parent's ID otherwise. + if enable_sexual_reproduction: + k1, key = jr.split(key) + parent_m = ( + jr.uniform(k1, [config.n_sexual_reproduce_per_step]) < 0.5).astype( + jp.uint32)[..., None] + repr_aid_sx = (selected_aid_sx_1 * parent_m + + selected_aid_sx_2 * (1 - parent_m)) + else: # Logic: # look into the pool of programs and see if some of them are not used ( # there are no agents alive with such program). - # if that is the case, create a new program with mutate_f, then modify the - # corresponding 'selected_aid'. + # if that is the case, create a new program with mutate_f or + # sexual_mutate_f, then modify the corresponding 'selected_aid'. # If there is no space, set the mask to zero instead. # get n agents per id. @@ -1015,51 +1177,136 @@ def env_perform_reproduce_update( is_agent_flat, env_aid_flat, num_segments=programs.shape[0]) sorted_na_idx = jp.argsort(n_agents_in_env).astype(jp.uint32) - # we only care about the first few indexes - sorted_na_chosen_idx = sorted_na_idx[:config.n_reproduce_per_step] - sorted_na_chosen_mask = (n_agents_in_env[sorted_na_chosen_idx] == 0 - ).astype(jp.float32) - - # assume that the number of selected reproductions is LESS than the total - # number of programs. - to_mutate_programs = programs[selected_aid] - mutated_programs = vmap(mutate_f)( - jr.split(key, config.n_reproduce_per_step), to_mutate_programs) - - mutation_mask = (selected_mask * sorted_na_chosen_mask) - mutation_mask_e = mutation_mask[:, None] - n_mutation_mask_e = 1. - mutation_mask_e - - # substitute the programs - programs = programs.at[sorted_na_chosen_idx].set( - mutation_mask_e * mutated_programs - + n_mutation_mask_e * programs[sorted_na_chosen_idx]) - # update the aid and the mask - selected_mask = mutation_mask - selected_aid = sorted_na_chosen_idx - - # these positions (if mask says yes) are then selected to reproduce. - # A seed is spawned if possible. - env = env_reproduce_operations( - key, env, - ReproduceOp(selected_mask, selected_pos, selected_stored_en, - selected_aid), - config) - # The flower is destroyed, regardless of whether the operation succeeds. - n_selected_mask = 1 - selected_mask - n_selected_mask_uint = n_selected_mask.astype(jp.uint32) - env = Environment( - env.type_grid.at[spx, spy].set( - n_selected_mask_uint * env.type_grid[spx, spy]), # 0 is VOID - env.state_grid.at[spx, spy].set( - n_selected_mask[..., None] * env.state_grid[spx, spy] - ), # set everything to zero - env.agent_id_grid.at[spx, spy].set( - n_selected_mask_uint * env.agent_id_grid[spx, spy]) # default id 0. - ) - if mutate_programs: - return env, programs - return env + + if enable_asexual_reproduction: + ## Asexual reproduction + # we only care about the first few indexes + sorted_na_chosen_idx = sorted_na_idx[:config.n_reproduce_per_step] + sorted_na_chosen_mask = (n_agents_in_env[sorted_na_chosen_idx] == 0 + ).astype(jp.float32) + + # assume that the number of selected reproductions is LESS than the total + # number of programs. + to_mutate_programs = programs[selected_aid] + ku, key = jr.split(key) + mutated_programs = vmap(mutate_f)( + jr.split(ku, config.n_reproduce_per_step), to_mutate_programs) + + mutation_mask = (selected_mask * sorted_na_chosen_mask) + mutation_mask_e = mutation_mask[:, None] + n_mutation_mask_e = 1. - mutation_mask_e + + # substitute the programs + programs = programs.at[sorted_na_chosen_idx].set( + mutation_mask_e * mutated_programs + + n_mutation_mask_e * programs[sorted_na_chosen_idx]) + # update the aid and the mask + selected_mask = mutation_mask + repr_aid = sorted_na_chosen_idx + + if enable_sexual_reproduction: + ## Sexual reproduction + # we only care about some few indexes + sorted_na_chosen_idx_sx = sorted_na_idx[ + config.n_reproduce_per_step: + config.n_reproduce_per_step+config.n_sexual_reproduce_per_step] + sorted_na_chosen_mask_sx = ( + n_agents_in_env[sorted_na_chosen_idx_sx] == 0).astype(jp.float32) + + # assume that the number of selected reproductions is LESS than the total + # number of programs. + to_mutate_programs_sx = programs[selected_aid_sx] + ku, key = jr.split(key) + mutated_programs_sx = vmap(sexual_mutate_f)( + jr.split(ku, config.n_sexual_reproduce_per_step), + to_mutate_programs_sx[::2], to_mutate_programs_sx[1::2]) + # To assess whether we can perform a sexual reproduction, we need, for + # each chosen position: + # 1. sorted_na_chosen_mask_sx to be True, representing whether there is + # space for new programs. + # 2. pair_repr_mask_sx to be True. + pair_repr_mask_sx = (pair_repr_mask_sx * sorted_na_chosen_mask_sx) + + pair_repr_mask_sx_e = pair_repr_mask_sx[:, None] + n_pair_repr_mask_sx_e = 1. - pair_repr_mask_sx_e + + # substitute the programs + programs = programs.at[sorted_na_chosen_idx_sx].set( + pair_repr_mask_sx_e * mutated_programs_sx + + n_pair_repr_mask_sx_e * programs[sorted_na_chosen_idx_sx]) + + # save the new ids for the children. + repr_aid_sx = sorted_na_chosen_idx_sx + + if return_metrics: + metrics = {} + + if enable_asexual_reproduction: + ## Asexual reproduction + # these positions (if mask says yes) are then selected to reproduce. + # A seed is spawned if possible. + k1, key = jr.split(key) + env = env_try_place_seeds( + k1, env, + (selected_mask, selected_pos, selected_stored_en, repr_aid), + config) + # The flower is destroyed, regardless of whether the operation succeeds. + n_selected_mask = 1 - selected_mask + n_selected_mask_uint = n_selected_mask.astype(jp.uint32) + env = Environment( + env.type_grid.at[spx, spy].set( + n_selected_mask_uint * env.type_grid[spx, spy]), # 0 is VOID + env.state_grid.at[spx, spy].set( + n_selected_mask[..., None] * env.state_grid[spx, spy] + ), # set everything to zero + env.agent_id_grid.at[spx, spy].set( + n_selected_mask_uint * env.agent_id_grid[spx, spy]) # default id 0. + ) + + if return_metrics: + metrics["asexual_reproduction"] = (selected_mask, selected_aid, repr_aid) + + if enable_sexual_reproduction: + ## Sexual reproduction + # seeds will spawn in a neighborhood of one of the two parents, chosen at + # random. + k1, key = jr.split(key) + pos_m = (jr.uniform(k1, [config.n_sexual_reproduce_per_step]) < 0.5).astype( + jp.int32)[..., None] + repr_pos_sx = (selected_pos_sx[::2] * pos_m + + selected_pos_sx[1::2] * (1 - pos_m)) + + k1, key = jr.split(key) + env = env_try_place_seeds( + k1, env, + (pair_repr_mask_sx, repr_pos_sx, pair_stored_en_sx, repr_aid_sx), + config) + # The flower is destroyed, regardless of whether the operation succeeds. + # update the selected_mask_sx, since some flowers may not have been actually + # selected. + # mutation_mask_sx is the outcome of a pair, so you need to repeat it twice. + destroy_mask_sx = jp.repeat(pair_repr_mask_sx, 2, axis=-1) + n_destroy_mask_sx = 1 - destroy_mask_sx + n_destroy_mask_sx_uint = n_destroy_mask_sx.astype(jp.uint32) + env = Environment( + env.type_grid.at[spx_sx, spy_sx].set( + n_destroy_mask_sx_uint * env.type_grid[spx_sx, spy_sx]), # 0 is VOID + env.state_grid.at[spx_sx, spy_sx].set( + n_destroy_mask_sx[..., None] * env.state_grid[spx_sx, spy_sx] + ), # set everything to zero + env.agent_id_grid.at[spx_sx, spy_sx].set( # default id is 0 + n_destroy_mask_sx_uint * env.agent_id_grid[spx_sx, spy_sx]) + ) + + if return_metrics: + metrics["sexual_reproduction"] = ( + pair_repr_mask_sx, selected_aid_sx_1, selected_aid_sx_2, + repr_aid_sx) + + result = (env, programs) if mutate_programs else env + if return_metrics: + return result, metrics + return result def intercept_reproduce_ops( @@ -1077,6 +1324,9 @@ def intercept_reproduce_ops( env_perform_reproduce_update). I chose to do that to avoid making the latter too complex. But I might refactor it eventually. + TODO: this was not modified for sexual reproduction! Double check it can be + used anyway. + Arguments: key: a jax random number generator. env: the input environment to be modified. @@ -1096,17 +1346,17 @@ def intercept_reproduce_ops( etd = config.etd perc = perceive_neighbors(env, etd) k1, key = jr.split(key) - w, h = env.type_grid.shape + h, w = env.type_grid.shape v_repr_f = vectorize_reproduce_f( - make_agent_reproduce_interface(repr_f, config), w, h) + make_agent_reproduce_interface(repr_f, config), h, w) - b_pos = jp.stack(jp.meshgrid(jp.arange(w), jp.arange(h), indexing="ij"), -1) + b_pos = jp.stack(jp.meshgrid(jp.arange(h), jp.arange(w), indexing="ij"), -1) b_repr_op = v_repr_f(k1, perc, b_pos, repr_programs) # Only a small subset of possible ReproduceOps are selected at each step. # This is config dependent (config.n_reproduce_per_step). selected_pos = _select_subset_of_reproduce_ops( - k1, b_repr_op, perc.neigh_type, config) + k1, b_repr_op, perc.neigh_type, config, False) spx, spy = selected_pos[:, 0], selected_pos[:, 1] selected_mask = b_repr_op.mask[spx, spy] @@ -1136,7 +1386,7 @@ def intercept_reproduce_ops( ### Gravity logic. -def _line_gravity(env, x, h, etd): +def _line_gravity(env, x, w, etd): type_grid, state_grid, agent_id_grid = env env_state_size = state_grid.shape[-1] # self needs to be affected by gravity: @@ -1152,25 +1402,25 @@ def _line_gravity(env, x, h, etd): vmap(lambda ctype: (ctype != etd.structural_mats).all())(type_grid[x])) swap_mask = (is_gravity_mat & is_down_intangible_mat & is_crumbling ).astype(jp.float32) - # [h] -> [1,h] + # [w] -> [1,w] swap_mask_e = swap_mask[None,:] swap_mask_uint_e = swap_mask_e.astype(jp.uint32) - idx_swap_x = jp.repeat(jp.array([x, x+1]), h) - idx_swap_y = jp.concatenate([jp.arange(0, h, dtype=jp.int32), - jp.arange(0, h, dtype=jp.int32)], 0) - type_slice = jax.lax.dynamic_slice(type_grid, (x, 0), (2, h)) + idx_swap_x = jp.repeat(jp.array([x, x+1]), w) + idx_swap_y = jp.concatenate([jp.arange(0, w, dtype=jp.int32), + jp.arange(0, w, dtype=jp.int32)], 0) + type_slice = jax.lax.dynamic_slice(type_grid, (x, 0), (2, w)) type_upd = (type_slice * (1 - swap_mask_uint_e) + type_slice[::-1] * swap_mask_uint_e).reshape(-1) new_type_grid = type_grid.at[idx_swap_x, idx_swap_y].set(type_upd) swap_mask_ee = swap_mask_e[..., None] state_slice = jax.lax.dynamic_slice( - state_grid, (x, 0, 0), (2, h, env_state_size)) + state_grid, (x, 0, 0), (2, w, env_state_size)) state_upd = (state_slice * (1. - swap_mask_ee) + state_slice[::-1] * swap_mask_ee).reshape([-1, env_state_size]) new_state_grid = state_grid.at[idx_swap_x, idx_swap_y].set(state_upd) # agent ids - id_slice = jax.lax.dynamic_slice(agent_id_grid, (x, 0), (2, h)) + id_slice = jax.lax.dynamic_slice(agent_id_grid, (x, 0), (2, w)) id_upd = (id_slice * (1 - swap_mask_uint_e) + id_slice[::-1] * swap_mask_uint_e).reshape(-1) new_agent_id_grid = agent_id_grid.at[idx_swap_x, idx_swap_y].set(id_upd) @@ -1188,11 +1438,11 @@ def env_process_gravity(env: Environment, etd: EnvTypeDef) -> Environment: Create a new env by applying gravity on every line, from bottom to top. Nit: right now, you can't fall off, so we start from the second to bottom. """ - w, h = env.type_grid.shape + h, w = env.type_grid.shape env, _ = jax.lax.scan( - partial(_line_gravity, h=h, etd=etd), + partial(_line_gravity, w=w, etd=etd), env, - jp.arange(w-2, -1, -1)) + jp.arange(h-2, -1, -1)) return env @@ -1349,7 +1599,7 @@ def process_energy(env: Environment, config: EnvConfig) -> Environment: neigh_asking_nutrients = jax.lax.conv_general_dilated_patches( asking_nutrients[None,:], (3, 3), (1, 1), "SAME", dimension_numbers=("NHWC", "OIHW", "NHWC"))[0] - # we want to have [w,h,9,c] so that the indexing is intuitive and consistent + # we want to have [h,w,9,c] so that the indexing is intuitive and consistent # for both neigh vectors. neigh_asking_nutrients = neigh_asking_nutrients.reshape( neigh_asking_nutrients.shape[:2] + (2, 9)).transpose((0, 1, 3, 2)) @@ -1363,7 +1613,7 @@ def process_energy(env: Environment, config: EnvConfig) -> Environment: absorb_perc = jax.lax.conv_general_dilated_patches( perc_to_give[None,:], (3, 3), (1, 1), "SAME", dimension_numbers=("NHWC", "OIHW", "NHWC"))[0] - # we want to have [w,h,9,c] so that the indexing is intuitive and consistent + # we want to have [h,w,9,c] so that the indexing is intuitive and consistent # for both neigh vectors. absorb_perc = absorb_perc.reshape( absorb_perc.shape[:2] + (2, 9)).transpose((0, 1, 3, 2)) @@ -1376,7 +1626,7 @@ def process_energy(env: Environment, config: EnvConfig) -> Environment: is_agent_grid = etd.is_agent_fn(env.type_grid) is_agent_grid_e_f = is_agent_grid.astype(jp.float32)[..., None] ag_spec_idx = etd.get_agent_specialization_idx(env.type_grid) - # [w,h,3] @ [3,2] -> [w,h,2] + # [h,w,3] @ [3,2] -> [h,w,2] agent_dissipation_rate = etd.dissipation_rate_per_spec[ag_spec_idx] dissipated_energy = (config.dissipation_per_step * agent_dissipation_rate * is_agent_grid_e_f) diff --git a/self_organising_systems/biomakerca/environments.py b/self_organising_systems/biomakerca/environments.py index db082b7..17852d5 100644 --- a/self_organising_systems/biomakerca/environments.py +++ b/self_organising_systems/biomakerca/environments.py @@ -247,9 +247,12 @@ def __str__(self): "AGENT_ROOT", # Leaf: Capable of absorbing air nutrients. "AGENT_LEAF", - # Flower: Capable of performing a reproduce operation. They *tend* to - # consume more nutrients. + # Flower: Capable of performing an asexual reproduce operation. They *tend* + # to consume more nutrients. "AGENT_FLOWER", + # Sexual Flower: Capable of performing a sexual reproduce operation. They + # *tend* to consume more nutrients. + "AGENT_FLOWER_SEXUAL", ] # indexed by the type, it tells how much structure decays. @@ -265,16 +268,18 @@ def __str__(self): "AGENT_ROOT": 5, "AGENT_LEAF": 5, "AGENT_FLOWER": 5, + "AGENT_FLOWER_SEXUAL": 5, } # A modifier of the dissipation based on the agent specialization. -# the first element is for the earth nutrients, the second element is for the +# the first element is for the earth nutrients, the second element is for the # air nutrients. DEFAULT_DISSIPATION_RATE_PER_SPEC_DICT = { "AGENT_UNSPECIALIZED": jp.array([0.5, 0.5]), "AGENT_ROOT": jp.array([1.0, 1.0]), "AGENT_LEAF": jp.array([1.0, 1.0]), "AGENT_FLOWER": jp.array([1.2, 1.2]), + "AGENT_FLOWER_SEXUAL": jp.array([1.2, 1.2]), } # Colors for visualising the default types. @@ -289,6 +294,7 @@ def __str__(self): "AGENT_ROOT": jp.array([0.52, 0.39, 0.14]), # RGB: 133,99,36 "AGENT_LEAF": jp.array([0.16, 0.49, 0.10]), # RGB: 41,125,26 "AGENT_FLOWER": jp.array([1., 0.42, 0.71]), # RGB: 255,107,181 + "AGENT_FLOWER_SEXUAL": jp.array([1., 0.749, 0.]), # RGB: 255,191,0 } def convert_string_dict_to_type_array(d, types): @@ -339,6 +345,7 @@ def __init__( types = self.types # setup material specific properties. If you are subclassing this, consider # changing these values manually. + #TODO make these uint32 self.intangible_mats = jp.array([types.VOID, types.AIR], dtype=jp.int32) self.gravity_mats = jp.concatenate([ jp.array([types.EARTH], dtype=jp.int32), self.agent_types], 0) @@ -411,11 +418,18 @@ class EnvConfig: seed can be placed. reproduce_max_dist: the maximum distance from a reproducing flower where a seed can be placed. - n_reproduce_per_step: how many reproduce ops can be selected per step to be - executed. In effect, this means that flowers may not execute reproduce ops - as soon as they desire, but they may wait, depending on how many other - flowers are asking the same in the environment. See that competition over - a scarse external resource (bees, for instance). + n_reproduce_per_step: how many asexual reproduce ops can be selected per + step to be executed. In effect, this means that flowers may not execute + reproduce ops as soon as they desire, but they may wait, depending on how + many other flowers are asking the same in the environment. See that + competition over a scarse external resource (bees, for instance). + n_sexual_reproduce_per_step: how many sexual reproduce ops can be selected + per step to be executed. Note that since 2 ops are needed per each sexual + reproduction, we will perform n_sexual_reproduce_per_step by getting twice + as many reproduce ops. In effect, this means that flowers may not execute + reproduce ops as soon as they desire, but they may wait, depending on how + many other flowers are asking the same in the environment. See that + competition over a scarse external resource (bees, for instance). nutrient_cap: maximum value of nutrients for agents. material_nutrient_cap: maximum value of nutrients for other materials. max_lifetime: the maximum lifetime of an organism. Agents age at every step. @@ -446,6 +460,7 @@ def __init__(self, reproduce_min_dist=5, reproduce_max_dist=15, n_reproduce_per_step=2, + n_sexual_reproduce_per_step=2, nutrient_cap=DEFAULT_NUTRIENT_CAP, material_nutrient_cap=DEFAULT_MATERIAL_NUTRIENT_CAP, max_lifetime=int(1e6), @@ -462,6 +477,7 @@ def __init__(self, self.reproduce_min_dist = reproduce_min_dist self.reproduce_max_dist = reproduce_max_dist self.n_reproduce_per_step = n_reproduce_per_step + self.n_sexual_reproduce_per_step = n_sexual_reproduce_per_step self.nutrient_cap = nutrient_cap self.material_nutrient_cap = material_nutrient_cap self.max_lifetime = max_lifetime @@ -596,6 +612,21 @@ def create_default_environment(config, h, w, with_earth=True, return env +def create_multiseed_environment(h, w, config, with_earth=True, + init_nutrient_perc=0.2, place_every=10): + """Create a default environment that contains several different seeds. + + Each seed has a different agent id. + """ + env = create_default_environment(config, h, w, with_earth, init_nutrient_perc) + + aid = 0 + for i in range(1, w//place_every): + env = place_seed(env, place_every*i, config, aid=aid) + aid += 1 + + return env + def infer_width(h, width_type): """Infer the width of the environment. @@ -791,9 +822,7 @@ def hsl_to_rgb(h,s,l): """ l_lt_05 = jp.array(l < 0.5).astype(jp.float32) q = l_lt_05 * l * (1 + s) + (1. - l_lt_05) * (l + s - l * s) - print(q) p = 2 * l - q - print(p) return jp.stack( [hue_to_rgb(p,q,h+1/3), hue_to_rgb(p,q,h), hue_to_rgb(p,q,h-1/3)], -1) diff --git a/self_organising_systems/biomakerca/examples/notebooks/run_configuration.ipynb b/self_organising_systems/biomakerca/examples/notebooks/run_configuration.ipynb index 0b382bb..3042e6e 100644 --- a/self_organising_systems/biomakerca/examples/notebooks/run_configuration.ipynb +++ b/self_organising_systems/biomakerca/examples/notebooks/run_configuration.ipynb @@ -73,8 +73,6 @@ "from self_organising_systems.biomakerca.step_maker import step_env\n", "from self_organising_systems.biomakerca.display_utils import zoom\n", "from self_organising_systems.biomakerca.custom_ipython_display import display\n", - "from self_organising_systems.biomakerca.env_logic import ReproduceOp\n", - "from self_organising_systems.biomakerca.env_logic import env_perform_one_reproduce_op\n", "\n", "import cv2\n", "import numpy as np\n", @@ -222,10 +220,6 @@ "# How many unique programs (organisms) are allowed in the simulation.\n", "N_MAX_PROGRAMS = 128\n", "\n", - "# if True, every 50 steps we check whether the agents go extinct. If they did,\n", - "# we replace a seed in the environment.\n", - "replace_if_extinct = False\n", - "\n", "# The number of frames of the video. This is NOT the number of steps.\n", "# The total number of steps depend on the number of steps per frame, which can\n", "# vary over time.\n", @@ -276,25 +270,6 @@ " env, programs = step_env(\n", " ku, env, config, agent_logic, programs, do_reproduction=True,\n", " mutate_programs=True, mutator=mutator)\n", - " if replace_if_extinct and step % 50 == 0:\n", - " # check if there is no alive cell.\n", - " any_alive = jit(lambda type_grid: evm.is_agent_fn(type_grid).sum() \u003e 0)(env.type_grid)\n", - " if not any_alive:\n", - " # Then place a new seed.\n", - " agent_init_nutrients = (config.dissipation_per_step * 4 +\n", - " config.specialize_cost)\n", - " ku, key = jr.split(key)\n", - " rpos = jp.stack([0, jr.randint(ku, (),minval=0, maxval=env.type_grid.shape[1])],0)\n", - " ku, key = jr.split(key)\n", - " raid = jr.randint(ku, (),minval=0, maxval=N_MAX_PROGRAMS).astype(jp.uint32)\n", - " repr_op = ReproduceOp(1., rpos, agent_init_nutrients*2, raid)\n", - " ku, key = jr.split(key)\n", - " env = jit(partial(env_perform_one_reproduce_op, config=config))(ku, env, repr_op)\n", - "\n", - " # show it, though\n", - " frame = make_frame(env, step, steps_per_frame)\n", - " for stop_i in range(10):\n", - " video.add_image(frame)\n", "\n", " video.add_image(make_frame(env, step, steps_per_frame))\n", "\n", @@ -375,7 +350,7 @@ "else:\n", " # Extract a living program from the final environment.\n", " aid_flat = env.agent_id_grid.flatten()\n", - " is_agent_flat = evm.is_agent_fn(env.type_grid).flatten().astype(jp.float32)\n", + " is_agent_flat = config.etd.is_agent_fn(env.type_grid).flatten().astype(jp.float32)\n", " n_alive_per_id = jax.ops.segment_sum(is_agent_flat, aid_flat, num_segments=N_MAX_PROGRAMS)\n", " alive_programs = programs[n_alive_per_id\u003e0]\n", " print(\"Extracted {} programs.\".format(alive_programs.shape[0]))\n", diff --git a/self_organising_systems/biomakerca/examples/notebooks/run_sexual_configuration.ipynb b/self_organising_systems/biomakerca/examples/notebooks/run_sexual_configuration.ipynb new file mode 100644 index 0000000..fd93f06 --- /dev/null +++ b/self_organising_systems/biomakerca/examples/notebooks/run_sexual_configuration.ipynb @@ -0,0 +1,523 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "LpQvs2DVfkVy" + }, + "source": [ + "# Biomaker CA: performing advanced runs on a configuration\n", + "\n", + "In this colab we show how to run models on a configuration and how to evaluate them.\n", + "\n", + "This colab allows to choose whether to perform sexual and/or asexual reproduction.\n", + "It also allows for sparse computations of agent logics.\n", + "\n", + "Copyright 2023 Google LLC\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + " https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iuhJcrwNgIEF" + }, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "287RU0sup5J2" + }, + "outputs": [], + "source": [ + "#@title install selforg package\n", + "# install the package locally\n", + "!pip install --upgrade -e git+https://github.com/google-research/self-organising-systems.git#egg=self_organising_systems\u0026subdirectory=biomakerca\n", + "# activate the locally installed package (otherwise a runtime restart is required)\n", + "import pkg_resources\n", + "import importlib\n", + "# Reload the resources because we uninstalled and reinstalled some packages.\n", + "importlib.reload(pkg_resources)\n", + "pkg_resources.get_distribution(\"self_organising_systems\").activate()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DY7J0cz7NUCV" + }, + "outputs": [], + "source": [ + "#@title imports \u0026 notebook utilities\n", + "\n", + "from self_organising_systems.biomakerca import environments as evm\n", + "from self_organising_systems.biomakerca.agent_logic import BasicAgentLogic\n", + "from self_organising_systems.biomakerca.mutators import BasicMutator\n", + "from self_organising_systems.biomakerca.mutators import RandomlyAdaptiveMutator\n", + "from self_organising_systems.biomakerca.mutators import CrossOverSexualMutator\n", + "from self_organising_systems.biomakerca.step_maker import step_env\n", + "from self_organising_systems.biomakerca.display_utils import zoom, tile2d, add_text_to_img, imshow\n", + "from self_organising_systems.biomakerca.custom_ipython_display import display\n", + "from self_organising_systems.biomakerca.env_logic import env_perform_multi_world_reproduce_update\n", + "\n", + "import cv2\n", + "import numpy as np\n", + "import jax.random as jr\n", + "import jax.numpy as jp\n", + "from jax import vmap\n", + "from jax import jit\n", + "import jax\n", + "import time\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import tqdm\n", + "import mediapy as media\n", + "from functools import partial\n", + "\n", + "\n", + "def pad_text(img, text):\n", + " font = cv2.FONT_HERSHEY_SIMPLEX\n", + " orgin = (5, 15)\n", + " fontScale = 0.5\n", + " color = (0, 0, 0)\n", + " thickness = 1\n", + "\n", + " # ensure to preserve even size (assumes the input size was even.\n", + " new_h = img.shape[0]//15\n", + " new_h = new_h if new_h % 2 == 0 else new_h + 1\n", + " img = np.concatenate([np.ones([new_h, img.shape[1], img.shape[2]]), img], 0)\n", + " img = cv2.putText(img, text, orgin, font, fontScale, color, thickness, cv2.LINE_AA)\n", + " return img" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aRRMQ1aNhqv6" + }, + "source": [ + "## Select the configuration, the agent logic and the mutator\n", + "\n", + "Set soil_unbalance_limit to 0 to reproduce the original environment. Set it to 1/3 for having self-balancing environments (recommended)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dtc32MAfOBTK" + }, + "outputs": [], + "source": [ + "ec_id = \"persistence\" #@param ['persistence', 'pestilence', 'collaboration', 'sideways']\n", + "env_width_type = \"wide\" #@param ['wide', 'landscape', 'square', 'petri', '10x', '20x']\n", + "soil_unbalance_limit = 1/3 #@param [0, \"1/3\"] {type:\"raw\"}\n", + "\n", + "h = 72\n", + "if env_width_type == \"10x\":\n", + " env_width_type = h * 10\n", + "if env_width_type == \"20x\":\n", + " env_width_type = h * 20\n", + "else:\n", + " env_width_type = evm.infer_width(h, env_width_type)\n", + "\n", + "env_and_config = evm.get_env_and_config(ec_id, width_type=env_width_type, h=h)\n", + "_, config = env_and_config\n", + "\n", + "st_env = evm.create_multiseed_environment(h, env_width_type, config)\n", + "\n", + "config.soil_unbalance_limit = soil_unbalance_limit\n", + "reproduction_type = \"asexual\" #@param ['both', 'asexual', 'sexual']\n", + "does_sex_matter = True #@param ['False', 'True'] {type:\"raw\"}\n", + "\n", + "enable_asexual_reproduction = reproduction_type != \"sexual\"\n", + "enable_sexual_reproduction = reproduction_type != \"asexual\"\n", + "\n", + "sex_sensitivity = 1000 #@param [1, 10, 100, 1000] {type:\"raw\"}\n", + "\n", + "agent_model = \"extended\" #@param ['minimal', 'extended']\n", + "agent_logic = BasicAgentLogic(config, minimal_net=agent_model==\"minimal\",\n", + " make_asexual_flowers_likely=enable_asexual_reproduction,\n", + " make_sexual_flowers_likely=enable_sexual_reproduction,\n", + " init_noise=0.001, sex_sensitivity=sex_sensitivity)\n", + "\n", + "n_sparse_max = 2**13 #@param ['None', '2**13', '2**12', '2**11', '2**10'] {type:\"raw\"}\n", + "\n", + "mutator_type = \"randomly_adaptive\" #@param ['basic', 'randomly_adaptive']\n", + "sd = 1e-3\n", + "mutator = (BasicMutator(sd=sd, change_perc=0.2) if mutator_type == \"basic\"\n", + " else RandomlyAdaptiveMutator(init_sd=sd, change_perc=0.2))\n", + "sexual_mutator = CrossOverSexualMutator(mutator, n_frequencies=64)\n", + "\n", + "exp_id = \"{}_sm^{}_sens^{}_w^{}\".format(reproduction_type, does_sex_matter, sex_sensitivity, env_width_type)\n", + "print(exp_id)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "y3p9hXhKuSk2" + }, + "source": [ + "## Optionally, modify the config for custom configurations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cKshhtYDvE81" + }, + "outputs": [], + "source": [ + "print(\"Current config:\")\n", + "print('\\n'.join(\"%s: %s\" % item for item in vars(config).items()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "H1wR-EQium-X" + }, + "outputs": [], + "source": [ + "## Examples for modifying the config\n", + "## Uncomment relevant lines or do like them.\n", + "\n", + "## Regardless, to trigger the recomputation of step_env and similar,\n", + "## config needs to be a new object! So, first, we create a new copy.\n", + "import copy\n", + "config = copy.copy(config)\n", + "\n", + "## Change simple isolated parameters (most of them)\n", + "# config.struct_integrity_cap = 100\n", + "# config.max_lifetime = 500\n", + "## Vectors can be modified either by writing new vectors:\n", + "# config.dissipation_per_step = jp.array([0.02, 0.02])\n", + "## Or by multiplying previous values. Note that they are immutable!\n", + "# config.dissipation_per_step = config.dissipation_per_step * 2\n", + "\n", + "## agent_state_size is trickier, because it influences env_state_size.\n", + "## So you can either create a new config:\n", + "## Note that you would have to insert all values that you don't want to take\n", + "## default initializations.\n", + "# config = evm.EnvConfig(agent_state_size=4)\n", + "## Or you can just modify env_state_size as well.\n", + "## (env_state_size = agent_state_size + 4) for now.\n", + "# config.agent_state_size = 4\n", + "# config.env_state_size = config.agent_state_size + 4\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Y-xLUg_Uh14k" + }, + "source": [ + "## Perform a simulation\n", + "\n", + "Consider modifying the code to vary the extent of the simulation and video configs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8tj0U8naJDC3" + }, + "outputs": [], + "source": [ + "# create auxiliary frames that show interesting counters.\n", + "\n", + "def make_aux_frame(n_asex, n_sex):\n", + " img = np.ones([130 if reproduction_type == \"both\" else 65, 550, 3])\n", + " yorigin = 50\n", + " if enable_asexual_reproduction:\n", + " img = add_text_to_img(\n", + " img, \"Asexual reproductions: {}\".format(n_asex),\n", + " origin=(20, yorigin), color=\"black\")\n", + " yorigin = 100\n", + " if enable_sexual_reproduction:\n", + " img = add_text_to_img(\n", + " img, \"Sexual reproductions: {}\".format(n_sex),\n", + " origin=(20, yorigin), color=\"black\")\n", + " return img\n", + "\n", + "imshow(make_aux_frame(100000, 212122))\n", + "\n", + "def make_nsexes_frame(n_sexes):\n", + " img = np.ones([65, 300, 3])\n", + " yorigin = 50\n", + " img = add_text_to_img(\n", + " img, \"Num sexes: {}\".format(n_sexes),\n", + " origin=(20, yorigin), color=\"black\")\n", + " return img\n", + "\n", + "imshow(make_nsexes_frame(200))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6OMyOYUUHodo" + }, + "outputs": [], + "source": [ + "\n", + "@partial(jit, static_argnames=[\"config\", \"n_max_programs\"])\n", + "def get_alive_programs_mask(env, config, n_max_programs):\n", + " aid_flat = env.agent_id_grid.flatten()\n", + " is_agent_flat = config.etd.is_agent_fn(env.type_grid).flatten().astype(jp.float32)\n", + " n_alive_per_id = jax.ops.segment_sum(is_agent_flat, aid_flat, num_segments=n_max_programs)\n", + "\n", + " has_alive = n_alive_per_id \u003e 0\n", + " return has_alive\n", + "\n", + "@jit\n", + "def get_alive_and_sexes(env, programs):\n", + " has_alive = get_alive_programs_mask(env, config, N_MAX_PROGRAMS)\n", + " all_sexes = vmap(agent_logic.get_sex)(mutator.split_params(programs)[0])\n", + " return has_alive, all_sexes\n", + "\n", + "def get_num_sexes(env, programs):\n", + " has_alive, all_sexes = get_alive_and_sexes(env, programs)\n", + " has_alive = np.array(has_alive)\n", + " all_sexes = np.array(all_sexes)\n", + "\n", + " sexes = all_sexes[has_alive]\n", + " return len(np.unique(sexes))\n", + "\n", + "def run_env(\n", + " key, programs, env, n_steps, step_f,\n", + " curr_asexual_repr = 0, curr_sexual_repr = 0,\n", + " zoom_sz=12,\n", + " steps_per_frame=2, when_to_double_speed=[100, 500, 1000, 2000, 5000]):\n", + "\n", + " fps = 20\n", + " def make_frame(env):\n", + " return zoom(evm.grab_image_from_env(env, config),zoom_sz)\n", + "\n", + " frame = make_frame(env)\n", + "\n", + " # remember that the metrics are per step, right now, and that are 'inbetween'\n", + " # frames, at best.\n", + " n_asexual_repr_log = []\n", + " n_sexual_repr_log = []\n", + " n_sexes_log = []\n", + "\n", + " aux_frames = [make_aux_frame(0, 0)]\n", + " num_sexes_frames = [make_nsexes_frame(get_num_sexes(env, programs))]\n", + "\n", + " out_file = \"video.mp4\"\n", + " with media.VideoWriter(out_file, shape=frame.shape[:2], fps=fps, crf=18\n", + " ) as video:\n", + " video.add_image(frame)\n", + " for i in tqdm.trange(n_steps):\n", + " if i in when_to_double_speed:\n", + " steps_per_frame *= 2\n", + "\n", + " key, ku = jr.split(key)\n", + " (env, programs), metrics = step_f(ku, env, programs=programs)\n", + "\n", + " if enable_asexual_reproduction:\n", + " step_asex_repr = int(metrics[\"asexual_reproduction\"][0].sum())\n", + " n_asexual_repr_log.append(step_asex_repr)\n", + " curr_asexual_repr += step_asex_repr\n", + " if enable_sexual_reproduction:\n", + " step_sexual_repr = int(metrics[\"sexual_reproduction\"][0].sum())\n", + " n_sexual_repr_log.append(step_sexual_repr)\n", + " curr_sexual_repr += step_sexual_repr\n", + "\n", + " # get sexes alive\n", + " num_sexes = get_num_sexes(env, programs)\n", + " n_sexes_log.append(num_sexes)\n", + " if i % steps_per_frame == 0:\n", + " video.add_image(make_frame(env))\n", + " aux_frames.append(make_aux_frame(curr_asexual_repr, curr_sexual_repr))\n", + " num_sexes_frames.append(make_nsexes_frame(num_sexes))\n", + "\n", + "\n", + " media.show_video(media.read_video(out_file))\n", + " media.show_video(aux_frames, fps=fps)\n", + " media.show_video(num_sexes_frames, fps=fps)\n", + " ret_metrics = {'n_asexual_repr_log': n_asexual_repr_log,\n", + " 'n_sexual_repr_log': n_sexual_repr_log,\n", + " 'n_sexes_log': n_sexes_log,\n", + " 'curr_asexual_repr': curr_asexual_repr,\n", + " 'curr_sexual_repr': curr_sexual_repr}\n", + " return programs, env, ret_metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "65WC-fp1aI1Y" + }, + "outputs": [], + "source": [ + "key = jr.PRNGKey(42)\n", + "\n", + "# How many unique programs (organisms) are allowed in the simulation.\n", + "N_MAX_PROGRAMS = 512\n", + "\n", + "# for 20x environments, you need shorter videos.\n", + "n_steps = 15000\n", + "\n", + "# on what FRAME to double speed.\n", + "when_to_double_speed = [100, 200, 300, 400, 500]\n", + "# on what FRAME to reset speed.\n", + "when_to_reset_speed = []\n", + "fps = 20\n", + "# this affects the size of the image. If this number is not even, the resulting\n", + "# video *may* not be supported by all renderers.\n", + "zoom_sz = 4\n", + "\n", + "ku, key = jr.split(key)\n", + "programs = vmap(agent_logic.initialize)(jr.split(ku, N_MAX_PROGRAMS))\n", + "ku, key = jr.split(key)\n", + "programs = vmap(mutator.initialize)(jr.split(ku, programs.shape[0]), programs)\n", + "\n", + "env = st_env\n", + "\n", + "step_f = partial(step_env, config=config, agent_logic=agent_logic, do_reproduction=True,\n", + " enable_asexual_reproduction=enable_asexual_reproduction,\n", + " enable_sexual_reproduction=enable_sexual_reproduction,\n", + " does_sex_matter=does_sex_matter,\n", + " mutate_programs=True, mutator=mutator, sexual_mutator=sexual_mutator,\n", + " n_sparse_max=n_sparse_max, return_metrics=True)\n", + "\n", + "step = 0\n", + "# how many steps per frame we start with. This gets usually doubled many times\n", + "# during the simulation.\n", + "# In the article, we usually use 2 or 4 as the starting value, sometimes 1.\n", + "steps_per_frame = 2\n", + "\n", + "ku, key = jr.split(key)\n", + "programs, env, metrics = run_env(\n", + " ku, programs, env, n_steps, step_f, zoom_sz=6)\n", + "\n", + "\n", + "\n", + "def running_average(a, n):\n", + " a = np.concatenate([np.full([n], a[0]), a], axis=0)\n", + " return np.convolve(a, np.ones(n)/n, mode=\"valid\")\n", + "\n", + "if enable_asexual_reproduction:\n", + " plt.plot(running_average(metrics['n_asexual_repr_log'], 100), label=\"n_asexual_repr_log\")\n", + "if enable_sexual_reproduction:\n", + " plt.plot(running_average(metrics['n_sexual_repr_log'], 100), label=\"n_sexual_repr_log\")\n", + "plt.legend()\n", + "plt.show()\n", + "\n", + "plt.plot(running_average(metrics['n_sexes_log'], 100), label=\"n_sexes_log\")\n", + "plt.legend()\n", + "plt.show()\n", + "\n", + "aid_flat = env.agent_id_grid.flatten()\n", + "is_agent_flat = config.etd.is_agent_fn(env.type_grid).flatten().astype(jp.float32)\n", + "n_alive_per_id = jax.ops.segment_sum(is_agent_flat, aid_flat, num_segments=N_MAX_PROGRAMS)\n", + "alive_programs = programs[n_alive_per_id\u003e0]\n", + "print(\"Extracted {} programs.\".format(alive_programs.shape[0]))\n", + "print(\"sexes:\", vmap(agent_logic.get_sex)(mutator.split_params(alive_programs)[0]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oBggxzAZm4lM" + }, + "outputs": [], + "source": [ + "# continue...\n", + "n_steps = 25000\n", + "ku, key = jr.split(key)\n", + "\n", + "programs, env, metrics = run_env(\n", + " ku, programs, env, n_steps, step_f, zoom_sz=6,\n", + " steps_per_frame=64, when_to_double_speed=[],\n", + " curr_asexual_repr=metrics['curr_asexual_repr'],\n", + " curr_sexual_repr=metrics['curr_sexual_repr'])\n", + "\n", + "\n", + "if enable_asexual_reproduction:\n", + " plt.plot(running_average(metrics['n_asexual_repr_log'], 100), label=\"n_asexual_repr_log\")\n", + "if enable_sexual_reproduction:\n", + " plt.plot(running_average(metrics['n_sexual_repr_log'], 100), label=\"n_sexual_repr_log\")\n", + "plt.legend()\n", + "plt.show()\n", + "\n", + "plt.plot(running_average(metrics['n_sexes_log'], 100), label=\"n_sexes_log\")\n", + "plt.legend()\n", + "plt.show()\n", + "\n", + "aid_flat = env.agent_id_grid.flatten()\n", + "is_agent_flat = config.etd.is_agent_fn(env.type_grid).flatten().astype(jp.float32)\n", + "n_alive_per_id = jax.ops.segment_sum(is_agent_flat, aid_flat, num_segments=N_MAX_PROGRAMS)\n", + "alive_programs = programs[n_alive_per_id\u003e0]\n", + "print(\"Extracted {} programs.\".format(alive_programs.shape[0]))\n", + "print(\"sexes:\", vmap(agent_logic.get_sex)(mutator.split_params(alive_programs)[0]))\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "V100", + "last_runtime": { + "build_target": "//learning/grp/tools/ml_python:ml_notebook", + "kind": "private" + }, + "machine_shape": "hm", + "name": "run_sexual_configuration.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1rgSP9D1gT3S4Aljk-jLys8dFbla2InCu", + "timestamp": 1708355883667 + }, + { + "file_id": "/piper/depot/google3/third_party/py/self_organising_systems/biomakerca/examples/notebooks/run_configuration.ipynb?workspaceId=etr:biomaker::citc", + "timestamp": 1701254768987 + }, + { + "file_id": "1ADfcMRj-JmfN6VUIcuqU-3bTMGrWSkj_", + "timestamp": 1688723295778 + }, + { + "file_id": "1XY102qIEc9MY9hd-Jb6Oirmyw7ga2LZL", + "timestamp": 1688637712371 + } + ], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/self_organising_systems/biomakerca/extensions/eruption.py b/self_organising_systems/biomakerca/extensions/eruption.py index b1e029f..dac3e7f 100644 --- a/self_organising_systems/biomakerca/extensions/eruption.py +++ b/self_organising_systems/biomakerca/extensions/eruption.py @@ -79,7 +79,7 @@ def is_burnable_fn(t, etd): """Return True if t is any of (LEAF,FLOWER,UNSPECIALIZED) agent types.""" burnable_types = jp.array([ etd.types.AGENT_LEAF, etd.types.AGENT_FLOWER, - etd.types.AGENT_UNSPECIALIZED]) + etd.types.AGENT_FLOWER_SEXUAL, etd.types.AGENT_UNSPECIALIZED]) return (t == burnable_types).any(axis=-1) @@ -255,7 +255,7 @@ def get_eruption_config(): return evm.get_env_and_config( "persistence", width_type="petri", h=10, etd=EruptionTypeDef()).config -### Eruption environment +### Original Eruption environment def create_eruption_env(h, config): """Create the Eruption environment. @@ -408,7 +408,7 @@ def test_freq_lava(key, st_env, config, init_program, agent_logic, mutator, def make_frame(env, lava_perc): return add_text_to_img( zoom(evm.grab_image_from_env(env, config), zoom_sz), - "LAVA FREQUENCY: {:.3f}%".format(lava_perc*100), + "LAVA FREQUENCY: {:.3f}%".format(lava_perc*100), origin=(5, 35), fontScale=1.) diff --git a/self_organising_systems/biomakerca/mutators.py b/self_organising_systems/biomakerca/mutators.py index d82bf25..3705f08 100644 --- a/self_organising_systems/biomakerca/mutators.py +++ b/self_organising_systems/biomakerca/mutators.py @@ -21,14 +21,15 @@ limitations under the License. """ from abc import ABC, abstractmethod +import math import jax.numpy as jp import jax.random as jr from self_organising_systems.biomakerca.utils import stringify_class class Mutator(ABC): - """Interface of all mutators. - + """Interface of all asexual mutators. + The abstract methods need to be implemented in order to allow for in-environment mutations, through the method step_maker.step_env. """ @@ -36,9 +37,11 @@ class Mutator(ABC): @abstractmethod def initialize(self, key, p): """Initialize mutation parameters. - + p must be one-dimensional. Return a concatenation of p and the mutation parameters. + Advice: it is better if related params are contiguous to one another. This + is because SexualMutators likely use crossing over as a mechanism. """ pass @@ -124,14 +127,19 @@ def __init__(self, init_sd, change_perc: float | None, meta_sd_perc=0.1, self.min_sd = min_sd self.max_sd = max_sd + def join_params(self, p, sd): + return jp.ravel(jp.column_stack((p, sd))) + def initialize(self, key, p): - # works also if batched. + # Does not work if batched. (use vmap) sd = jp.full_like(p, self.init_sd) - return jp.concatenate([p, sd], -1) + return self.join_params(p, sd) def split_params(self, p): # works also if batched. - return jp.split(p, 2, axis=-1) + p_e, sd_e = jp.split( + p.reshape(p.shape[:-1]+ (p.shape[-1]//2, 2)), 2, axis=-1) + return p_e[..., 0], sd_e[..., 0] def mutate(self, key, p): # not batched. @@ -154,4 +162,81 @@ def mutate(self, key, p): new_mu = mu + dmu new_sd = (sd + dsd).clip(self.min_sd, self.max_sd) - return jp.concatenate([new_mu, new_sd], -1) + return self.join_params(new_mu, new_sd) + +### Sexual mutators + + +class SexualMutator(ABC): + """Interface of all sexual mutators. + + The abstract methods need to be implemented in order to allow for + in-environment sexual mutations, through the method step_maker.step_env. + + For now, a SexualMutator does not accept variable parameters. + This is because it uses a Mutator for extra variation. + Ideally, the Mutator used is the same of asexual reproduction, but you can + put custom ones here as well. + """ + + def __init__(self, mutator: Mutator): + """Set the mutator for the general variation of parameters. + """ + self.mutator = mutator + + @abstractmethod + def mutate(self, key, p1, p2): + """Mutate params. + + The inputs p1 and p2 must be all params, as generated by 'initialize'. + This method can mutate mutation params too. + p must be one-dimensional. + """ + pass + + def __str__(self): + return stringify_class(self) + + +def valuenoise1d(key, f, n, interpolation="linear"): + """Create continous noise, useful for make a XOver mask.""" + ku, key = jr.split(key) + grads = jr.uniform(ku, (f,), minval=-1, maxval=1) + n_repeats = int(math.ceil(n/(f-1))) + prev = jp.repeat(grads[:-1], n_repeats, axis=0)[:n] + next = jp.repeat(grads[1:], n_repeats, axis=0)[:n] + # fraction of where you are + t = jp.tile(jp.linspace(0, 1, n_repeats+1)[:-1], f-1)[:n] + + # linear interpolation because we don't need anything more complex. + # if you want it cubic, do this: + if interpolation == "cubic": + t = t*t*t*(t*(t*6 - 15) + 10) + return prev + t * (next - prev) + + +class CrossOverSexualMutator(SexualMutator): + """Sexual mutator that performs a classic Crossing over recombination mutation. + + This mutator also performs an asexual mutation through the mutator input. + """ + + def __init__(self, mutator: Mutator, n_frequencies=16): + super().__init__(mutator) + self.n_frequencies = n_frequencies + + def mutate(self, key, p1, p2): + """Mutate params. + + The inputs p1 and p2 must be all params, as generated by 'initialize'. + This method can mutate mutation params too. + p must be one-dimensional. + """ + ku, key = jr.split(key) + xo_mask = valuenoise1d(ku, self.n_frequencies, p1.shape[0], "linear") > 0 + + new_p = xo_mask * p1 + (1. - xo_mask) * p2 + + ku, key = jr.split(key) + return self.mutator.mutate(ku, new_p) + diff --git a/self_organising_systems/biomakerca/step_maker.py b/self_organising_systems/biomakerca/step_maker.py index 0ce57c2..e25da18 100644 --- a/self_organising_systems/biomakerca/step_maker.py +++ b/self_organising_systems/biomakerca/step_maker.py @@ -44,13 +44,15 @@ from self_organising_systems.biomakerca.env_logic import process_structural_integrity_n_times from self_organising_systems.biomakerca.environments import EnvConfig from self_organising_systems.biomakerca.environments import Environment -from self_organising_systems.biomakerca.mutators import Mutator +from self_organising_systems.biomakerca.mutators import Mutator, SexualMutator @partial(jit, static_argnames=[ - "config", "agent_logic", "excl_fs", "do_reproduction", "mutate_programs", - "mutator", "intercept_reproduction"]) + "config", "agent_logic", "excl_fs", "do_reproduction", + "enable_asexual_reproduction", "enable_sexual_reproduction", + "does_sex_matter", "mutate_programs", "mutator", "sexual_mutator", + "intercept_reproduction", "n_sparse_max", "return_metrics"]) def step_env( key: KeyType, env: Environment, config: EnvConfig, agent_logic: AgentLogic, @@ -59,10 +61,16 @@ def step_env( tuple[EnvTypeType, Callable[[KeyType, PerceivedData], ExclusiveOp]]] = None, do_reproduction=True, + enable_asexual_reproduction=True, + enable_sexual_reproduction=False, + does_sex_matter=True, mutate_programs=False, mutator: (Mutator | None) = None, + sexual_mutator: (SexualMutator | None) = None, intercept_reproduction=False, - min_repr_energy_requirement=None): + min_repr_energy_requirement=None, + n_sparse_max: (int | None ) = None, + return_metrics=False): """Perform one step for the environment. There are several different settings for performing a step. The most important @@ -80,6 +88,12 @@ def step_env( excl_fs: the exclusive logic of materials. Defaults to AIR spreading through VOID, and EARTH acting like falling-sand. do_reproduction: whether reproduction is enabled. + enable_asexual_reproduction: if set to True, asexual reproduction is + enabled. if do_reproduction is enabled, at least one of + enable_asexual_reproduction and enable_sexual_reproduction must be True. + enable_sexual_reproduction: if set to True, sexual reproduction is enabled. + does_sex_matter: if set to True, and enable_sexual_reproduction is True, + then sexual reproduction can only occur between different sex entities. mutate_programs: relevant only if do_reproduction==True. In that case, determines whether reproduction is performed with or without mutation. Beware! Reproduction *without* mutation creates agents with identical @@ -88,6 +102,8 @@ def step_env( If set to true, 'mutator' must be a valid Mutator class. mutator: relevant only if we reproduce with mutation. In that case, mutator determines how to extract parameters and how to modify them. + sexual_mutator: relevant only if we reproduce with mutation. In that case, + it determines how to perform sexual reproduction. intercept_reproduction: useful for petri-dish-like experiments. If set to true, whenever a ReproduceOp triggers, instead of creating a new seed, we simply destroy the flower and record its occurrence. We consider it a @@ -98,14 +114,20 @@ def step_env( min_repr_energy_requirement: relevant only if intercepting reproductions. Determines whether the intercepted seed would have had enough energy to count as a successful reproduction. + n_sparse_max: either an int or None. If set to int, we will use a budget for + the amounts of agent operations allowed at each step. + return_metrics: if True, returns reproduction metrics. May be extended to + return more metrics in the future. Returns: an updated environment. If intercept_reproduction is True, returns also the number of successful reproductions intercepted. """ + assert (not do_reproduction or + (enable_asexual_reproduction or enable_sexual_reproduction)) etd = config.etd if excl_fs is None: excl_fs = ((etd.types.AIR, air_cell_op), (etd.types.EARTH, earth_cell_op)) - + if mutate_programs: agent_params, _ = vmap(mutator.split_params)(programs) else: @@ -135,17 +157,25 @@ def step_env( else: ku, key = jr.split(key) if mutate_programs: - env, programs = env_perform_reproduce_update( + # metrics are supported only here. + repr_result = env_perform_reproduce_update( ku, env, repr_programs, config, repr_f, mutate_programs, programs, - mutator.mutate) + mutator.mutate, enable_asexual_reproduction, + enable_sexual_reproduction, does_sex_matter, sexual_mutator.mutate, + mutator.split_params, agent_logic.get_sex, n_sparse_max, + return_metrics) + if return_metrics: + (env, programs), metrics = repr_result + else: + env, programs = repr_result else: env = env_perform_reproduce_update( - ku, env, repr_programs, config, repr_f) + ku, env, repr_programs, config, repr_f, n_sparse_max=n_sparse_max) # parallel updates k1, key = jr.split(key) env = env_perform_parallel_update( - k1, env, par_programs, config, agent_logic.par_f) + k1, env, par_programs, config, agent_logic.par_f, n_sparse_max) # energy absorbed and generated by materials. env = process_energy(env, config) @@ -153,11 +183,12 @@ def step_env( # exclusive updates k1, key = jr.split(key) env = env_perform_exclusive_update( - k1, env, excl_programs, config, excl_fs, agent_logic.excl_f) + k1, env, excl_programs, config, excl_fs, agent_logic.excl_f, n_sparse_max) # increase age. env = env_increase_age(env, etd) rval = (env, programs) if mutate_programs else env rval = (rval, n_successful_repr) if intercept_reproduction else rval + rval = (rval, metrics) if return_metrics else rval return rval