Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug that prevents deepcopy and pickling if object has a parent #37

Merged
merged 4 commits into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paramz/core/gradcheckable.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#===============================================================================
from . import HierarchyError
from .pickleable import Pickleable
from .parentable import Parentable

class Gradcheckable(Pickleable, Parentable):
class Gradcheckable(Parentable):
"""
Adds the functionality for an object to be gradcheckable.
It is just a thin wrapper of a call to the highest parent for now.
Expand Down Expand Up @@ -74,3 +73,4 @@ def _checkgrad(self, param, verbose=0, step=1e-6, tolerance=1e-3, df_tolerance=1
TODO: this can be done more efficiently, when doing it inside here
"""
raise HierarchyError("This parameter is not in a model with a likelihood, and, therefore, cannot be gradient checked!")

5 changes: 5 additions & 0 deletions paramz/core/parameter_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,11 @@ def _name_changed(self, param, old_name):
self._remove_parameter_name(None, old_name)
self._add_parameter_name(param)

def __getstate__(self):
dc = super().__getstate__()
dc.pop('_param_array_', None)
return dc

def __setstate__(self, state):
super(Parameterizable, self).__setstate__(state)
self.logger = logging.getLogger(self.__class__.__name__)
Expand Down
29 changes: 27 additions & 2 deletions paramz/core/parentable.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#===============================================================================
from .pickleable import Pickleable

class Parentable(object):
class Parentable(Pickleable):
"""
Enable an Object to have a parent.

Expand All @@ -37,8 +38,9 @@ class Parentable(object):
"""
_parent_ = None
_parent_index_ = None

def __init__(self, *args, **kwargs):
super(Parentable, self).__init__()
super().__init__(*args, **kwargs)

def has_parent(self):
"""
Expand Down Expand Up @@ -73,3 +75,26 @@ def _notify_parent_change(self):
Dont do anything if in leaf node
"""
pass

def __getstate__(self):
dc = super().__getstate__()
dc.pop('_parent_', None)
return dc

def __deepcopy__(self, memo: dict):
s = self.__new__(self.__class__) # fresh instance
memo[id(self)] = s # be sure to break all cycles --> self is already done
# The above line can cause hard to understand exceptions/bugs, because s is not 'done' its state attributes need to be copied first
# If a state attribute has a link to a parent object, the parent is copied first, using the uninitialized copy s
# thereby throwing an exception when attributes of the child are accessed while copying the parent.
# so a subclass should not link to its parent or handel the link like its done here.
MartinBubel marked this conversation as resolved.
Show resolved Hide resolved
import copy

parent_copy = memo.get(id(self._parent_), None) #get the copy of the parent (it should already be in the memo),
# if the parent is not in the memo the copy process was not started from the parent, and the link should will be removed.
state = self.__getstate__()
updated_state = copy.deepcopy(state, memo) # standard copy
s.__setstate__(updated_state)
s._parent_ = parent_copy
return s

15 changes: 10 additions & 5 deletions paramz/core/pickleable.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def copy(self, memo=None, which=None):
parents = []
if which is None:
which = self
which.traverse_parents(parents.append) # collect parents
which.traverse_parents(parents.append) # collect parents #TODO: Refactor: This is bad, in this class we do not know anything about parentables
for p in parents:
if not id(p) in memo :memo[id(p)] = None # set all parents to be None, so they will not be copied
if not id(self.gradient) in memo:memo[id(self.gradient)] = None # reset the gradient
Expand All @@ -96,13 +96,18 @@ def copy(self, memo=None, which=None):

def __deepcopy__(self, memo):
s = self.__new__(self.__class__) # fresh instance
memo[id(self)] = s # be sure to break all cycles --> self is already done
memo[id(self)] = s # be sure to break all cycles --> self will be done after all children are done
# children should not have a link to parent as parent cant finish before children.
import copy
s.__setstate__(copy.deepcopy(self.__getstate__(), memo)) # standard copy

state = self.__getstate__()
updated_state = copy.deepcopy(state, memo) # standard copy

s.__setstate__(updated_state)
return s

def __getstate__(self):
ignore_list = ['_param_array_', # parameters get set from bottom to top
def __getstate__(self): #TODO: Refactor: This is bad, this class does not know about most of these attributes
ignore_list = [#'_param_array_', # parameters get set from bottom to top
'_gradient_array_', # as well as gradients
'_optimizer_copy_',
'logger',
Expand Down
3 changes: 1 addition & 2 deletions paramz/optimization/verbose_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@


def exponents(fnow, current_grad):
exps = [np.abs(float(fnow)),
1 if current_grad is np.nan else current_grad]
exps = [np.abs(float(fnow)), 1 if current_grad is np.nan else current_grad]
return np.sign(exps) * np.log10(exps).astype(int)


Expand Down