diff --git a/benchmarks/benchmark_base.py b/benchmarks/benchmark_base.py index 493573839c7..e87d40c27c0 100644 --- a/benchmarks/benchmark_base.py +++ b/benchmarks/benchmark_base.py @@ -28,7 +28,10 @@ from tardis.transport.montecarlo.packet_collections import ( VPacketCollection, ) -from tardis.transport.montecarlo.packet_trackers import RPacketTracker +from tardis.transport.montecarlo.packet_trackers import ( + RPacketTracker, + generate_rpacket_last_interaction_tracker_list, +) class BenchmarkBase: @@ -239,7 +242,9 @@ def packet(self): @property def verysimple_packet_collection(self): - return self.nb_simulation_verysimple.transport.transport_state.packet_collection + return ( + self.nb_simulation_verysimple.transport.transport_state.packet_collection + ) @property def nb_simulation_verysimple(self): @@ -269,7 +274,9 @@ def verysimple_enable_full_relativity(self): @property def verysimple_disable_line_scattering(self): - return self.nb_simulation_verysimple.transport.montecarlo_configuration.DISABLE_LINE_SCATTERING + return ( + self.nb_simulation_verysimple.transport.montecarlo_configuration.DISABLE_LINE_SCATTERING + ) @property def verysimple_continuum_processes_enabled(self): @@ -277,11 +284,15 @@ def verysimple_continuum_processes_enabled(self): @property def verysimple_tau_russian(self): - return self.nb_simulation_verysimple.transport.montecarlo_configuration.VPACKET_TAU_RUSSIAN + return ( + self.nb_simulation_verysimple.transport.montecarlo_configuration.VPACKET_TAU_RUSSIAN + ) @property def verysimple_survival_probability(self): - return self.nb_simulation_verysimple.transport.montecarlo_configuration.SURVIVAL_PROBABILITY + return ( + self.nb_simulation_verysimple.transport.montecarlo_configuration.SURVIVAL_PROBABILITY + ) @property def static_packet(self): @@ -303,7 +314,9 @@ def set_seed(value): @property def verysimple_3vpacket_collection(self): - spectrum_frequency_grid = self.nb_simulation_verysimple.transport.spectrum_frequency_grid.value + spectrum_frequency_grid = ( + self.nb_simulation_verysimple.transport.spectrum_frequency_grid.value + ) return VPacketCollection( source_rpacket_index=0, spectrum_frequency_grid=spectrum_frequency_grid, @@ -404,3 +417,8 @@ def estimators(self): stim_recomb_cooling_estimator=np.empty((0, 0), dtype=np.float64), photo_ion_estimator_statistics=np.empty((0, 0), dtype=np.int64), ) + + @property + def rpacket_tracker_list(self): + no_of_packets = len(self.transport_state.packet_collection.initial_nus) + return generate_rpacket_last_interaction_tracker_list(no_of_packets) diff --git a/benchmarks/transport_montecarlo_main_loop.py b/benchmarks/transport_montecarlo_main_loop.py index 19f2106e7a3..7d6d07ddb08 100644 --- a/benchmarks/transport_montecarlo_main_loop.py +++ b/benchmarks/transport_montecarlo_main_loop.py @@ -22,6 +22,7 @@ def time_montecarlo_main_loop(self): self.montecarlo_configuration, self.transport_state.radfield_mc_estimators, self.nb_simulation_verysimple.transport.spectrum_frequency_grid.value, + self.rpacket_tracker_list, self.montecarlo_configuration.NUMBER_OF_VPACKETS, iteration=0, show_progress_bars=False, diff --git a/benchmarks/transport_montecarlo_packet_trackers.py b/benchmarks/transport_montecarlo_packet_trackers.py index c4b6c875407..a3f65f66f6d 100644 --- a/benchmarks/transport_montecarlo_packet_trackers.py +++ b/benchmarks/transport_montecarlo_packet_trackers.py @@ -4,6 +4,8 @@ from benchmarks.benchmark_base import BenchmarkBase from tardis.transport.montecarlo.packet_trackers import ( rpacket_trackers_to_dataframe, + generate_rpacket_tracker_list, + generate_rpacket_last_interaction_tracker_list, ) @@ -15,6 +17,20 @@ class BenchmarkTransportMontecarloPacketTrackers(BenchmarkBase): def time_rpacket_trackers_to_dataframe(self): sim = self.simulation_rpacket_tracking_enabled transport_state = sim.transport.transport_state - rpacket_trackers_to_dataframe( - transport_state.rpacket_tracker - ) + rpacket_trackers_to_dataframe(transport_state.rpacket_tracker) + + def time_generate_rpacket_tracker_list(self, no_of_packets, length): + generate_rpacket_tracker_list(no_of_packets, length) + + def time_generate_rpacket_last_interaction_tracker_list( + self, no_of_packets + ): + generate_rpacket_last_interaction_tracker_list(no_of_packets) + + time_generate_rpacket_tracker_list.params = ([1, 10, 50], [1, 10, 50]) + time_generate_rpacket_tracker_list.param_names = ["no_of_packets", "length"] + + time_generate_rpacket_last_interaction_tracker_list.params = [10, 100, 1000] + time_generate_rpacket_last_interaction_tracker_list.param_names = [ + "no_of_packets" + ] diff --git a/tardis/transport/montecarlo/base.py b/tardis/transport/montecarlo/base.py index 8c7559da7fb..c08a695ff53 100644 --- a/tardis/transport/montecarlo/base.py +++ b/tardis/transport/montecarlo/base.py @@ -24,6 +24,8 @@ opacity_state_initialize, ) from tardis.transport.montecarlo.packet_trackers import ( + generate_rpacket_tracker_list, + generate_rpacket_last_interaction_tracker_list, rpacket_trackers_to_dataframe, ) from tardis.util.base import ( @@ -158,12 +160,24 @@ def run( self.transport_state = transport_state number_of_vpackets = self.montecarlo_configuration.NUMBER_OF_VPACKETS + number_of_rpackets = len(transport_state.packet_collection.initial_nus) + + if self.enable_rpacket_tracking: + transport_state.rpacket_tracker = generate_rpacket_tracker_list( + number_of_rpackets, + self.montecarlo_configuration.INITIAL_TRACKING_ARRAY_LENGTH, + ) + else: + transport_state.rpacket_tracker = ( + generate_rpacket_last_interaction_tracker_list( + number_of_rpackets + ) + ) ( v_packets_energy_hist, last_interaction_tracker, vpacket_tracker, - rpacket_trackers, ) = montecarlo_main_loop( transport_state.packet_collection, transport_state.geometry_state, @@ -172,6 +186,7 @@ def run( self.montecarlo_configuration, transport_state.radfield_mc_estimators, self.spectrum_frequency_grid.value, + transport_state.rpacket_tracker, number_of_vpackets, iteration=iteration, show_progress_bars=show_progress_bars, @@ -199,8 +214,6 @@ def run( update_iterations_pbar(1) refresh_packet_pbar() - transport_state.rpacket_tracker = rpacket_trackers - # Need to change the implementation of rpacket_trackers_to_dataframe # Such that it also takes of the case of # RPacketLastInteractionTracker diff --git a/tardis/transport/montecarlo/configuration/base.py b/tardis/transport/montecarlo/configuration/base.py index 3e890f39c57..6a90fce53cf 100644 --- a/tardis/transport/montecarlo/configuration/base.py +++ b/tardis/transport/montecarlo/configuration/base.py @@ -81,6 +81,3 @@ def configuration_initialize(config, transport, number_of_vpackets): montecarlo_globals.ENABLE_RPACKET_TRACKING = ( transport.enable_rpacket_tracking ) - montecarlo_main_loop.ENABLE_RPACKET_TRACKING = ( - transport.enable_rpacket_tracking - ) diff --git a/tardis/transport/montecarlo/montecarlo_main_loop.py b/tardis/transport/montecarlo/montecarlo_main_loop.py index 8460b72461b..e5daa72c89d 100644 --- a/tardis/transport/montecarlo/montecarlo_main_loop.py +++ b/tardis/transport/montecarlo/montecarlo_main_loop.py @@ -10,11 +10,6 @@ consolidate_vpacket_tracker, initialize_last_interaction_tracker, ) -import tardis.transport.montecarlo.montecarlo_main_loop as montecarlo_loop -from tardis.transport.montecarlo.packet_trackers import ( - RPacketTracker, - RPacketLastInteractionTracker, -) from tardis.transport.montecarlo.r_packet import ( PacketStatus, RPacket, @@ -24,8 +19,6 @@ ) from tardis.util.base import update_packet_pbar -ENABLE_RPACKET_TRACKING = False - @njit(**njit_dict) def montecarlo_main_loop( @@ -36,6 +29,7 @@ def montecarlo_main_loop( montecarlo_configuration, estimators, spectrum_frequency_grid, + rpacket_trackers, number_of_vpackets, iteration, show_progress_bars, @@ -77,19 +71,6 @@ def montecarlo_main_loop( # Pre-allocate a list of vpacket collections for later storage vpacket_collections = List() - # Configuring the Tracking for R_Packets - rpacket_trackers = List() - if ENABLE_RPACKET_TRACKING: - for i in range(no_of_packets): - rpacket_trackers.append( - RPacketTracker( - montecarlo_configuration.INITIAL_TRACKING_ARRAY_LENGTH - ) - ) - else: - for i in range(no_of_packets): - rpacket_trackers.append(RPacketLastInteractionTracker()) - for i in range(no_of_packets): vpacket_collections.append( VPacketCollection( @@ -197,7 +178,7 @@ def montecarlo_main_loop( 1, ) - if ENABLE_RPACKET_TRACKING: + if montecarlo_globals.ENABLE_RPACKET_TRACKING: for rpacket_tracker in rpacket_trackers: rpacket_tracker.finalize_array() @@ -205,5 +186,4 @@ def montecarlo_main_loop( v_packets_energy_hist, last_interaction_tracker, vpacket_tracker, - rpacket_trackers, ) diff --git a/tardis/transport/montecarlo/packet_trackers.py b/tardis/transport/montecarlo/packet_trackers.py index f7ff0a50cd5..7a10a992f05 100644 --- a/tardis/transport/montecarlo/packet_trackers.py +++ b/tardis/transport/montecarlo/packet_trackers.py @@ -1,5 +1,6 @@ -from numba import float64, int64 +from numba import float64, int64, njit from numba.experimental import jitclass +from numba.typed import List import numpy as np import pandas as pd @@ -203,3 +204,42 @@ def track(self, r_packet): self.energy = r_packet.energy self.shell_id = r_packet.current_shell_id self.interaction_type = r_packet.last_interaction_type + + # To make it compatible with RPacketTracker + def finalize_array(self): + pass + + +@njit +def generate_rpacket_tracker_list(no_of_packets, length): + """ + Parameters + ---------- + no_of_packets : The count of RPackets that are sent in the ejecta + length : initial length of the tracking array + + Returns + ------- + A list containing RPacketTracker for each RPacket + """ + rpacket_trackers = List() + for i in range(no_of_packets): + rpacket_trackers.append(RPacketTracker(length)) + return rpacket_trackers + + +@njit +def generate_rpacket_last_interaction_tracker_list(no_of_packets): + """ + Parameters + ---------- + no_of_packets : The count of RPackets that are sent in the ejecta + + Returns + ------- + A list containing RPacketLastInteractionTracker for each RPacket + """ + rpacket_trackers = List() + for i in range(no_of_packets): + rpacket_trackers.append(RPacketLastInteractionTracker()) + return rpacket_trackers diff --git a/tardis/transport/montecarlo/tests/test_tracker_utils.py b/tardis/transport/montecarlo/tests/test_tracker_utils.py new file mode 100644 index 00000000000..0d3bc6a4d46 --- /dev/null +++ b/tardis/transport/montecarlo/tests/test_tracker_utils.py @@ -0,0 +1,38 @@ +import pytest +import numpy as np +from numba import typeof + +from tardis.transport.montecarlo.packet_trackers import ( + RPacketTracker, + RPacketLastInteractionTracker, + generate_rpacket_tracker_list, + generate_rpacket_last_interaction_tracker_list, +) + + +def test_generate_rpacket_tracker_list(): + no_of_packets = 10 + length = 10 + random_index = np.random.randint(0, no_of_packets) + + rpacket_tracker_list = generate_rpacket_tracker_list(no_of_packets, length) + + assert len(rpacket_tracker_list) == no_of_packets + assert len(rpacket_tracker_list[random_index].shell_id) == length + assert typeof(rpacket_tracker_list[random_index]) == typeof( + RPacketTracker(length) + ) + + +def test_generate_rpacket_last_interaction_tracker_list(): + no_of_packets = 50 + random_index = np.random.randint(0, no_of_packets) + + rpacket_last_interaction_tracker_list = ( + generate_rpacket_last_interaction_tracker_list(no_of_packets) + ) + + assert len(rpacket_last_interaction_tracker_list) == no_of_packets + assert typeof( + rpacket_last_interaction_tracker_list[random_index] + ) == typeof(RPacketLastInteractionTracker())