Skip to content

Commit

Permalink
starting to get second example working. adding functionality to save …
Browse files Browse the repository at this point in the history
…DarkNews tables and objects. still some issues in the pickling
  • Loading branch information
nickkamp1 committed Oct 24, 2024
1 parent 840da1f commit 6210de4
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 21 deletions.
35 changes: 35 additions & 0 deletions python/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,16 @@

import time
from siren import dataclasses as _dataclasses
from siren import math as _math
from siren.interactions import DarkNewsCrossSection,DarkNewsDecay
import numpy as np
import awkward as ak
import h5py
import pickle
try:
from DarkNews.nuclear_tools import NuclearTarget
except:
pass

THIS_DIR = os.path.abspath(os.path.dirname(__file__))

Expand Down Expand Up @@ -966,3 +973,31 @@ def SaveEvents(events,
# Load events from the custom SIREN event format
def LoadEvents(filename):
return _dataclasses.LoadInteractionTrees(filename)

def SaveDarkNewsProcesses(table_dir,
primary_processes,
primary_ups_keys,
secondary_processes,
secondary_dec_keys,
pickles=True):
for primary in primary_processes.keys():
for xs,ups_key in zip(primary_processes[primary],primary_ups_keys[primary]):
subdir = "_".join(["CrossSection"] + [str(x) if type(x)!=NuclearTarget else str(x.name) for x in ups_key])
table_subdir = os.path.join(table_dir, subdir)
os.makedirs(table_subdir,exist_ok=True)
print("Saving cross section table at %s" % table_subdir)
xs.FillInterpolationTables()
xs.save_to_table(table_subdir)
# if pickles:
# with open(os.path.join(table_subdir, "xs_object.pkl"),"wb") as f:
# pickle.dump(xs,f)
for secondary in secondary_processes.keys():
for dec,dec_key in zip(secondary_processes[secondary],secondary_dec_keys[secondary]):
subdir = "_".join(["Decay"] + [str(x) if type(x)!=NuclearTarget else str(x.name) for x in dec_key])
table_subdir = os.path.join(table_dir, subdir)
os.makedirs(table_subdir,exist_ok=True)
print("Saving decay object at %s" % table_subdir)
dec.save_to_table(table_subdir)
if pickles:
with open(os.path.join(table_subdir, "dec_object.pkl"),"wb") as f:
pickle.dump(dec,f)
17 changes: 13 additions & 4 deletions resources/examples/example2/DipolePortal_CCM.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import siren
from siren import utilities
from siren._util import GenerateEvents,SaveEvents
from siren._util import GenerateEvents,SaveEvents,get_processes_model_path,SaveDarkNewsProcesses

# Define a DarkNews model
model_kwargs = {
Expand All @@ -18,7 +18,7 @@
}

# Number of events to inject
events_to_inject = 100
events_to_inject = 1

# Experiment to run
experiment = "CCM"
Expand All @@ -27,12 +27,14 @@
# Particle to inject
primary_type = siren.dataclasses.Particle.ParticleType.NuMu

table_name = f"DarkNewsTables-v{siren.utilities.darknews_version()}"
table_name = f"DarkNewsTables-v{siren.utilities.darknews_version()}/"
table_name += "Dipole_M%2.2e_mu%2.2e"%(model_kwargs["m4"],model_kwargs["mu_tr_mu4"])
table_dir = os.path.join(get_processes_model_path("DarkNewsTables"),table_name)
os.makedirs(table_dir,exist_ok=True)


# Load DarkNews processes
primary_processes, secondary_processes = utilities.load_processes(
primary_processes, secondary_processes, primary_ups_keys, secondary_dec_keys = utilities.load_processes(
"DarkNewsTables",
primary_type=primary_type,
detector_model = detector_model,
Expand Down Expand Up @@ -118,6 +120,13 @@ def stop(datum, i):

SaveEvents(events,weighter,gen_times,output_filename="output/CCM_Dipole")

# save cross section tables
SaveDarkNewsProcesses(table_dir,
primary_processes,
primary_ups_keys,
secondary_processes,
secondary_dec_keys)



weights = [weighter(event) for event in events]
Expand Down
6 changes: 3 additions & 3 deletions resources/processes/DarkNewsTables/DarkNewsDecay.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,11 @@ def load_from_table(self, table_dir):
self.decay_norm, self.decay_integrator = pickle.load(f)

def save_to_table(self, table_dir):
with open(os.path.join(table_dir, "decay.pkl")) as f:
pickle.dump(f, {
with open(os.path.join(table_dir, "decay.pkl"),'wb') as f:
pickle.dump({
"decay_integrator": self.decay_integrator,
"decay_norm": self.decay_norm
})
}, f)

# serialization method
def get_representation(self):
Expand Down
35 changes: 21 additions & 14 deletions resources/processes/DarkNewsTables/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
siren._util.load_module("logger", logger_file)

from siren.DNModelContainer import ModelContainer
from DarkNews.nuclear_tools import NuclearTarget

# Import PyDarkNewsDecay and PyDarkNewsCrossSection
decay_file = os.path.join(base_path, "DarkNewsDecay.py")
Expand Down Expand Up @@ -99,8 +100,8 @@ def load_cross_section_from_table(
interp_tolerance=5e-2,
always_interpolate=True,
):
subdir = "_".join(["CrossSection"] + [str(x) for x in upscattering_key])
table_subdir = os.path.join(table_dir, subdir)
# subdir = "_".join(["CrossSection"] + [str(x) if type(x)!=NuclearTarget else str(x.name) for x in upscattering_key])
# table_subdir = os.path.join(table_dir, subdir)

cross_section = load_cross_section(
model_container,
Expand All @@ -109,7 +110,7 @@ def load_cross_section_from_table(
interp_tolerance=interp_tolerance,
always_interpolate=always_interpolate,
)
cross_section.load_from_table(table_subdir)
cross_section.load_from_table(table_dir)
return cross_section


Expand All @@ -121,8 +122,8 @@ def load_cross_section_from_pickle(
always_interpolate=True,
):
import pickle
subdir = "_".join(["CrossSection"] + [str(x) for x in upscattering_key])
table_subdir = os.path.join(table_dir, subdir)
# subdir = "_".join(["CrossSection"] + [str(x) if type(x)!=NuclearTarget else str(x.name) for x in upscattering_key])
# table_subdir = os.path.join(table_dir, subdir)
fname = os.path.join(table_dir, "xs_object.pkl")
with open(fname, "rb") as f:
xs_obj = pickle.load(f)
Expand Down Expand Up @@ -165,7 +166,7 @@ def attempt_to_load_cross_section(
if len(preferences) == 0:
raise ValueError("preferences must have at least one entry")

subdir = "_".join(["CrossSection"] + [str(x) for x in ups_key])
subdir = "_".join(["CrossSection"] + [str(x) if type(x)!=NuclearTarget else str(x.name) for x in ups_key])
loaded = False
cross_section = None
for p in preferences:
Expand Down Expand Up @@ -240,9 +241,9 @@ def load_cross_sections(
if table_dir is None:
table_dir = ""

cross_sections = []
cross_sections = {}
for ups_key, ups_case in models.ups_cases.items():
cross_sections.append(
cross_sections[ups_key] = (
attempt_to_load_cross_section(models, ups_key,
table_dir,
preferences,
Expand Down Expand Up @@ -365,9 +366,9 @@ def load_decays(
if table_dir is None:
table_dir = ""

decays = []
decays = {}
for decay_key, dec_case in models.dec_cases.items():
decays.append(attempt_to_load_decay(models, decay_key, table_dir, preferences))
decays[decay_key] = attempt_to_load_decay(models, decay_key, table_dir, preferences)

return decays

Expand Down Expand Up @@ -435,6 +436,8 @@ def load_processes(

if nuclear_targets is None:
nuclear_targets = GetDetectorModelTargets(detector_model)[1]
model_kwargs["nuclear_targets"] = list(nuclear_targets)
if target_types: model_kwargs["nuclear_targets"]+=list(target_types)

base_path = os.path.dirname(os.path.abspath(__file__))
table_dir = os.path.join(base_path, table_name)
Expand All @@ -456,7 +459,7 @@ def load_processes(
table_dir=table_dir,
)

cross_sections = [xs for xs in cross_sections if len([s for s in xs.GetPossibleSignatures() if s.primary_type == primary_type])>0]
cross_sections = {k:xs for k,xs in cross_sections.items() if len([s for s in xs.GetPossibleSignatures() if s.primary_type == primary_type])>0}

if fill_tables_at_start:
if Emax is None:
Expand All @@ -468,25 +471,29 @@ def load_processes(
cross_section.FillInterpolationTables(Emax=Emax)

primary_processes = collections.defaultdict(list)
primary_ups_keys = collections.defaultdict(list)
# Loop over available cross sections and save those which match primary type
for cross_section in cross_sections:
for ups_key,cross_section in cross_sections.items():
if primary_type == siren.dataclasses.Particle.ParticleType(
cross_section.ups_case.nu_projectile.pdgid
):
primary_processes[primary_type].append(cross_section)
primary_ups_keys[primary_type].append(ups_key)

secondary_processes = collections.defaultdict(list)
secondary_dec_keys = collections.defaultdict(list)
# Loop over available decays, group by parent type
for decay in decays:
for dec_key,decay in decays.items():
secondary_type = siren.dataclasses.Particle.ParticleType(
decay.dec_case.nu_parent.pdgid
)
secondary_processes[secondary_type].append(decay)
secondary_dec_keys[secondary_type].append(dec_key)


#holder = Holder()
#holder.primary_processes = primary_processes
#holder.secondary_processes = secondary_processes

return dict(primary_processes), dict(secondary_processes)
return dict(primary_processes), dict(secondary_processes), dict(primary_ups_keys), dict(secondary_dec_keys)

0 comments on commit 6210de4

Please sign in to comment.