Skip to content

Commit

Permalink
adding: recorder.plot + include set type jsonconversion
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasCouppey committed Dec 9, 2024
1 parent 3af9491 commit 6c12f5f
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 8 deletions.
3 changes: 2 additions & 1 deletion nrv/backend/_NRV_Class.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@ def save(self, save=False, fname="nrv_save.json", blacklist=[], **kwargs) -> dic
key_dic[key] = {}
for i in self.__dict__[key]:
key_dic[key][i] = self.__dict__[key][i].save(**kwargs)

else:
key_dic[key] = deepcopy(self.__dict__[key])
if save:
Expand Down Expand Up @@ -286,6 +285,8 @@ def load(self, data, blacklist={}, **kwargs) -> None:
self.__dict__[key] = load_any(key_dic[key], **kwargs)
elif isinstance(self.__dict__[key], np.ndarray):
self.__dict__[key] = np.array(key_dic[key])
elif isinstance(self.__dict__[key], set):
self.__dict__[key] = set(key_dic[key])
elif is_empty_iterable(key_dic[key]):
self.__dict__[key] = eval(self.__dict__[key].__class__.__name__)()
else:
Expand Down
6 changes: 5 additions & 1 deletion nrv/backend/_file_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def json_load(filename):
results = json.load(file_to_read)
return results


## TODO add NRV_decoder to simplify NRV_class save/load methode
class NRV_Encoder(json.JSONEncoder):
"""
Json encoding class, specific for NRV2 axon
Expand All @@ -183,12 +183,16 @@ def default(self, obj):
result = float(obj)
elif isinstance(obj, np.ndarray):
result = obj.tolist()
elif isinstance(obj, set):
result = list(obj)
else:
# Let the base class Encoder handle the object
result = json.JSONEncoder.default(self, obj)
return result




######################
## DXF related code ##
######################
Expand Down
5 changes: 3 additions & 2 deletions nrv/fmod/FEM/mesh_creator/_NerveMshCreator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,12 +1205,13 @@ def add_CUFF(
contact_thickness = self.Nerve_D * 0.1
if insulator_thickness is None:
insulator_thickness = min(
5 * contact_thickness, 0.4 * (self.Outer_D - self.Nerve_D) / 2
5 * contact_thickness, 0.4 * abs(self.Outer_D - self.Nerve_D) / 2
)
if insulator_thickness < 0:
insulator_thickness = 5 * contact_thickness
if insulator_length is None:
insulator_length = 2 * contact_length
# self.reshape_outerBox(tresholded_res=True)

if ID is not None:
self.electrodes[ID]["kwargs"]["contact_length"] = contact_length
self.electrodes[ID]["kwargs"]["contact_thickness"] = contact_thickness
Expand Down
17 changes: 17 additions & 0 deletions nrv/fmod/_recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import faulthandler

import numpy as np
import matplotlib.pyplot as plt

from ..backend._log_interface import rise_error, rise_warning
from ..backend._MCore import MCH, synchronize_processes
Expand Down Expand Up @@ -795,3 +796,19 @@ def gather_all_recordings(self, results:list[dict]):
for rec in reclist:
point.recording += np.array(rec["recording_points"][i]["recording"])

def plot(self, axes: plt.axes, points:int|np.ndarray|None=None,color: str = "k", **kwgs) -> None:
if self.t is None:
rise_warning("empty recorder canot be ploted")
else:
if points is None:
points = np.arange(len(self.recording_points))
if np.iterable(points):
for i_pts in points:
self.plot(axes=axes, points=i_pts, color=color, **kwgs)
else:
if points > len(self.recording_points):
rise_warning(f"recording point {points} does not exits in recorder, and so canot be ploted")
else:
axes.plot(self.t, self.recording_points[points].recording, color=color,**kwgs)


1 change: 0 additions & 1 deletion nrv/nmod/_fascicles.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,7 +1531,6 @@ def simulate(
"save anyway",
)
self.save_results = True

self.__init_axon_postprocessing()
if len(self.NoR_relative_position) == 0:
self.generate_random_NoR_position()
Expand Down
13 changes: 10 additions & 3 deletions nrv/ui/_axon_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def block(my_dict, position_key=None, t_start=0, t_stop=0):
else:
return False
else:
print("intra_stim_positions is not in dictionnary")
pass_info("intra_stim_positions is not in dictionnary")


def max_spike_position(blocked_spike_positionlist, position_max, spike_begin="down"):
Expand Down Expand Up @@ -1227,8 +1227,10 @@ def sample_keys(results:axon_results, keys_to_sample:str|set[str]={}, t_start_re
axon_results
updated results.
"""
if not results["record_g_mem"]:
rise_error("gmem not recorded nothing will be done")
if isinstance(keys_to_sample, str):
keys_to_sample = {keys_to_sample}
if len(set(keys_to_sample) - set(results.keys())):
rise_error(set(keys_to_sample) - set(results.keys()), "keys are missing to apply postprocessing. Please check simulation parameters")
else:
if t_stop_rec < 0:
t_stop_rec=results.t_sim
Expand Down Expand Up @@ -1267,7 +1269,12 @@ def sample_keys(results:axon_results, keys_to_sample:str|set[str]={}, t_start_re
"Nseg_per_sec",
"axon_path_type",
"t_sim",
"myelinated",
"intra_stim_starts",
"intra_stim_positions",
"recorder"
}

list_keys.update(keys_to_keep)
list_keys.update(keys_to_sample)
if results.ID==0:
Expand Down

0 comments on commit 6c12f5f

Please sign in to comment.