Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pass params as a dict into policy and state update functions #303

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions cadCAD/engine/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def env_composition(target_field, state_dict, target_value):
# mech_step
def partial_state_update(
self,
sweep_dict: Dict[str, List[Any]],
sweep_dict: Dict[str, Any],
sub_step: int,
sL,
sH,
Expand Down Expand Up @@ -149,7 +149,7 @@ def transfer_missing_fields(source, destination):
# mech_pipeline - state_update_block
def state_update_pipeline(
self,
sweep_dict: Dict[str, List[Any]],
sweep_dicts: List[Dict[str, Any]],
simulation_list,
configs: List[Tuple[List[Callable], List[Callable]]],
env_processes: Dict[str, Callable],
Expand All @@ -170,7 +170,7 @@ def state_update_pipeline(
states_list: List[Dict[str, Any]] = [genesis_states]

sub_step += 1
for [s_conf, p_conf] in configs:
for [s_conf, p_conf], sweep_dict in zip(configs, sweep_dicts):
states_list: List[Dict[str, Any]] = self.partial_state_update(
sweep_dict, sub_step, states_list, simulation_list, s_conf, p_conf, env_processes, time_step, run,
additional_objs
Expand All @@ -184,7 +184,7 @@ def state_update_pipeline(
# state_update_pipeline
def run_pipeline(
self,
sweep_dict: Dict[str, List[Any]],
sweep_dicts: List[Dict[str, Any]],
states_list: List[Dict[str, Any]],
configs: List[Tuple[List[Callable], List[Callable]]],
env_processes: Dict[str, Callable],
Expand All @@ -197,7 +197,7 @@ def run_pipeline(

for time_step in time_seq:
pipe_run: List[Dict[str, Any]] = self.state_update_pipeline(
sweep_dict, simulation_list, configs, env_processes, time_step, run, additional_objs
sweep_dicts, simulation_list, configs, env_processes, time_step, run, additional_objs
)
_, *pipe_run = pipe_run
simulation_list.append(pipe_run)
Expand All @@ -206,7 +206,7 @@ def run_pipeline(

def simulation(
self,
sweep_dict: Dict[str, List[Any]],
sweep_dicts: List[Dict[str, Any]],
states_list: List[Dict[str, Any]],
configs,
env_processes: Dict[str, Callable],
Expand All @@ -222,7 +222,7 @@ def simulation(
run += 1
subset_window.appendleft(subset_id)

def execute_run(sweep_dict, states_list, configs, env_processes, time_seq, _run) -> List[Dict[str, Any]]:
def execute_run(sweep_dicts, states_list, configs, env_processes, time_seq, _run) -> List[Dict[str, Any]]:
def generate_init_sys_metrics(genesis_states_list, sim_id, _subset_id, _run, _subset_window):
for D in genesis_states_list:
d = deepcopy(D)
Expand All @@ -235,14 +235,14 @@ def generate_init_sys_metrics(genesis_states_list, sim_id, _subset_id, _run, _su
)

first_timestep_per_run: List[Dict[str, Any]] = self.run_pipeline(
sweep_dict, states_list_copy, configs, env_processes, time_seq, run, additional_objs
sweep_dicts, states_list_copy, configs, env_processes, time_seq, run, additional_objs
)
del states_list_copy

return first_timestep_per_run

pipe_run = flatten(
[execute_run(sweep_dict, states_list, configs, env_processes, time_seq, run)]
[execute_run(sweep_dicts, states_list, configs, env_processes, time_seq, run)]
)

return pipe_run