Skip to content

Commit

Permalink
Merge pull request #37 from bela127/master
Browse files Browse the repository at this point in the history
fix bug that prevents deepcopy and pickling if object has a parent
  • Loading branch information
MartinBubel authored Jan 21, 2024
2 parents 39d635e + e156caa commit eef6b70
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 11 deletions.
4 changes: 2 additions & 2 deletions paramz/core/gradcheckable.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#===============================================================================
from . import HierarchyError
from .pickleable import Pickleable
from .parentable import Parentable

class Gradcheckable(Pickleable, Parentable):
class Gradcheckable(Parentable):
"""
Adds the functionality for an object to be gradcheckable.
It is just a thin wrapper of a call to the highest parent for now.
Expand Down Expand Up @@ -74,3 +73,4 @@ def _checkgrad(self, param, verbose=0, step=1e-6, tolerance=1e-3, df_tolerance=1
TODO: this can be done more efficiently, when doing it inside here
"""
raise HierarchyError("This parameter is not in a model with a likelihood, and, therefore, cannot be gradient checked!")

5 changes: 5 additions & 0 deletions paramz/core/parameter_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,11 @@ def _name_changed(self, param, old_name):
self._remove_parameter_name(None, old_name)
self._add_parameter_name(param)

def __getstate__(self):
dc = super().__getstate__()
dc.pop('_param_array_', None)
return dc

def __setstate__(self, state):
super(Parameterizable, self).__setstate__(state)
self.logger = logging.getLogger(self.__class__.__name__)
Expand Down
29 changes: 27 additions & 2 deletions paramz/core/parentable.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#===============================================================================
from .pickleable import Pickleable

class Parentable(object):
class Parentable(Pickleable):
"""
Enable an Object to have a parent.
Expand All @@ -37,8 +38,9 @@ class Parentable(object):
"""
_parent_ = None
_parent_index_ = None

def __init__(self, *args, **kwargs):
super(Parentable, self).__init__()
super().__init__(*args, **kwargs)

def has_parent(self):
"""
Expand Down Expand Up @@ -73,3 +75,26 @@ def _notify_parent_change(self):
Dont do anything if in leaf node
"""
pass

def __getstate__(self):
dc = super().__getstate__()
dc.pop('_parent_', None)
return dc

def __deepcopy__(self, memo: dict):
s = self.__new__(self.__class__) # fresh instance
memo[id(self)] = s # be sure to break all cycles --> self is already done
# The above line can cause hard to understand exceptions/bugs, because s is not 'done' its state attributes need to be copied first
# If a state attribute has a link to a parent object, the parent is copied first, using the uninitialized copy s
# thereby throwing an exception when attributes of the child are accessed while copying the parent.
# so a subclass should not link to its parent or handel the link like its done here.
import copy

parent_copy = memo.get(id(self._parent_), None) #get the copy of the parent (it should already be in the memo),
# if the parent is not in the memo the copy process was not started from the parent, and the link should will be removed.
state = self.__getstate__()
updated_state = copy.deepcopy(state, memo) # standard copy
s.__setstate__(updated_state)
s._parent_ = parent_copy
return s

15 changes: 10 additions & 5 deletions paramz/core/pickleable.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def copy(self, memo=None, which=None):
parents = []
if which is None:
which = self
which.traverse_parents(parents.append) # collect parents
which.traverse_parents(parents.append) # collect parents #TODO: Refactor: This is bad, in this class we do not know anything about parentables
for p in parents:
if not id(p) in memo :memo[id(p)] = None # set all parents to be None, so they will not be copied
if not id(self.gradient) in memo:memo[id(self.gradient)] = None # reset the gradient
Expand All @@ -96,13 +96,18 @@ def copy(self, memo=None, which=None):

def __deepcopy__(self, memo):
s = self.__new__(self.__class__) # fresh instance
memo[id(self)] = s # be sure to break all cycles --> self is already done
memo[id(self)] = s # be sure to break all cycles --> self will be done after all children are done
# children should not have a link to parent as parent cant finish before children.
import copy
s.__setstate__(copy.deepcopy(self.__getstate__(), memo)) # standard copy

state = self.__getstate__()
updated_state = copy.deepcopy(state, memo) # standard copy

s.__setstate__(updated_state)
return s

def __getstate__(self):
ignore_list = ['_param_array_', # parameters get set from bottom to top
def __getstate__(self): #TODO: Refactor: This is bad, this class does not know about most of these attributes
ignore_list = [#'_param_array_', # parameters get set from bottom to top
'_gradient_array_', # as well as gradients
'_optimizer_copy_',
'logger',
Expand Down
3 changes: 1 addition & 2 deletions paramz/optimization/verbose_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@


def exponents(fnow, current_grad):
exps = [np.abs(float(fnow)),
1 if current_grad is np.nan else current_grad]
exps = [np.abs(float(fnow)), 1 if current_grad is np.nan else current_grad]
return np.sign(exps) * np.log10(exps).astype(int)


Expand Down

0 comments on commit eef6b70

Please sign in to comment.