From 3181c16ae53bfae15032fc9489ac6ee86a752a32 Mon Sep 17 00:00:00 2001 From: raj1701 Date: Sun, 1 Oct 2023 14:38:35 +0000 Subject: [PATCH] Refactored __eq__ functions --- hnn_core/cell.py | 40 ++++++++++++++++++++++----------------- hnn_core/extracellular.py | 19 ++++++++++++++----- hnn_core/network.py | 14 +++++++++++--- 3 files changed, 48 insertions(+), 25 deletions(-) diff --git a/hnn_core/cell.py b/hnn_core/cell.py index cd1ce35b8..e3bb6e84b 100644 --- a/hnn_core/cell.py +++ b/hnn_core/cell.py @@ -241,15 +241,16 @@ def __eq__(self, other): if np.testing.assert_almost_equal(self_end_pt, other_end_pt, 5) is not None: return False + + all_attrs = dir(self) + attrs_to_ignore = [x for x in all_attrs if x.startswith('_')] + attrs_to_ignore.extend(['end_pts', 'mechs', 'to_dict']) + attrs_to_check = [x for x in all_attrs if x not in attrs_to_ignore] # Check all other attributes - if (self.L != other.L or - self.diam != other.diam or - self.Ra != other.Ra or - self.cm != other.cm or - self.nseg != other.nseg or - self.syns != other.syns): - return False + for attr in attrs_to_check: + if getattr(self, attr) != getattr(other, attr): + return False return True @@ -411,16 +412,21 @@ def __repr__(self): def __eq__(self, other): if not isinstance(other, Cell): return NotImplemented - if not (self.name == other.name and - self.pos == other.pos and - self.synapses == other.synapses and - self.cell_tree == other.cell_tree and - self.sect_loc == other.sect_loc and - self.dipole_pp == other.dipole_pp and - self.vsec == other.vsec and - self.isec == other.isec and - self.tonic_biases == other.tonic_biases): - return False + + all_attrs = dir(self) + attrs_to_ignore = [x for x in all_attrs if x.startswith('_')] + attrs_to_ignore.extend(['build', 'copy', 'create_tonic_bias', + 'define_shape', 'distance_section', 'gid', + 'list_IClamp', 'modify_section', + 'parconnect_from_src', 'plot_morphology', + 'record', 'sections', 'setup_source_netcon', + 'syn_create', 'to_dict']) + attrs_to_check = [x for x in all_attrs if x not in attrs_to_ignore] + + # Check all other attributes + for attr in attrs_to_check: + if getattr(self, attr) != getattr(other, attr): + return False if not (self.sections.keys() == other.sections.keys()): return False diff --git a/hnn_core/extracellular.py b/hnn_core/extracellular.py index fc3fc9db1..c9cc81cf8 100644 --- a/hnn_core/extracellular.py +++ b/hnn_core/extracellular.py @@ -357,11 +357,20 @@ def __len__(self): def __eq__(self, other): if not isinstance(other, ExtracellularArray): return NotImplemented - if not (self.positions == other.positions and - self.conductivity == other.conductivity and - self.method == other.method and - self.min_distance == other.min_distance and - (self.times == other.times).all() and + + all_attrs = dir(self) + attrs_to_ignore = [x for x in all_attrs if x.startswith('_')] + attrs_to_ignore.extend(['conductivity', 'copy', 'n_contacts', + 'plot_csd', 'plot_lfp', 'sfreq', 'smooth', + 'voltages', 'to_dict', 'times', 'voltages']) + attrs_to_check = [x for x in all_attrs if x not in attrs_to_ignore] + + # Check all other attributes + for attr in attrs_to_check: + if getattr(self, attr) != getattr(other, attr): + return False + + if not ((self.times == other.times).all() and (self.voltages == other.voltages).all()): return False diff --git a/hnn_core/network.py b/hnn_core/network.py index bc89f5746..1fb32b91b 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -434,11 +434,19 @@ def __eq__(self, other): receptor=conn['receptor']) if len(match_conns) == 0: return False + # print(dir(self)) + all_attrs = dir(self) + attrs_to_ignore = [x for x in all_attrs if x.startswith('_')] + attrs_to_ignore.extend(['add_bursty_drive', 'add_connection', + 'add_electrode_array', 'add_evoked_drive', + 'add_poisson_drive', 'add_tonic_bias', + 'clear_connectivity', 'clear_drives', + 'connectivity', 'copy', 'gid_to_type', + 'plot_cells', 'set_cell_positions', 'write']) + attrs_to_check = [x for x in all_attrs if x not in attrs_to_ignore] # Check all other attributes - for attr in ['cell_types', 'gid_ranges', 'pos_dict', 'cell_response', - 'external_drives', 'external_biases', 'rec_arrays', - 'threshold', 'delay']: + for attr in attrs_to_check: if getattr(self, attr) != getattr(other, attr): return False