Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't pickle weakref objects when saving checkpoints #212

Open
tzeitim opened this issue May 7, 2023 · 3 comments
Open

Can't pickle weakref objects when saving checkpoints #212

tzeitim opened this issue May 7, 2023 · 3 comments
Assignees
Labels
enhancement New feature or improvement refactoring Code refactor

Comments

@tzeitim
Copy link

tzeitim commented May 7, 2023

I hope this is not too bleeding-edge but I have no other versioning options due to the combination of the GPU-nodes I have access to and software dependencies.

To do a quick recap

I am pulling cellbender 0.3.0 from the branch sf_dev_0.3.0_postreg_posterior_format_h5 since I had the same two issues raised in PR #193. Following @sjfleming 's suggestion to pull the latest pytorch 2.0.0 from their dev branch for pytorch.

git clone -b sf_dev_0.3.0_postreg_posterior_format_h5 https://github.com/broadinstitute/CellBender.git

Unfortunately the issue about learning rates schedulers persisted even after pulling cellbender's 736d6.

After some digging, I managed to solve the pytorch-pyro scheduler issue by pulling pyro from this commit instead.

pip install git+https://github.com/ilia-kats/pyro/@c9ed43a1f90d2f9a92278c68319eb68962b29013

The main issue

cellbender was able to finish training but it raised a new error when trying to write the final checkpoint (and only checkpoint attempted in this data set, that I am aware).

The code in cellbender's checkpoint.py in it's current form just shows that it failed in an attempt to write the checkpoint when it exits.

'Could not save checkpoint'

I had to remove the try block in order to reveal the real issue.

*** TypeError: cannot pickle 'weakref' object

I dissected the individual lines that would trigger the error on their own.

torch.save(model_obj, filebase + '_model.torch')
torch.save(scheduler, filebase + '_optim.torch')
scheduler.save(filebase + '_optim.pyro')  # use PyroOptim method
pyro.get_param_store().save(filebase + '_params.pyro')

Interestingly the model object can be saved by invoking its .state_dict() method.

torch.save(model_obj.state_dict(), filebase + '_model.torch') 

No .state_dict() exists for the scheduler object, though.

To understand the problem a bit better, I omitted the method objects within the scheduler and then torch.savecould run! This strategy indicated that theweakref in anneal_func was to blame.

I did a little bit of googling with this information and I think that the weakref issue is very similar (maybe identical to) this pytorch issue #42376 .

I have decided to open this issue and documented it here as it has gone beyond my ability to resolve for now.


As a footnote- and just for the record -I wrote this non-fancy routine to eliminate the methods mentioned above

def remove_weakrefs(aa):
    remove_keys = []
    for i in aa.keys():
        for s in aa[i].keys():
            for k in aa[i][s].keys():
                print(f'{isinstance(aa[i], weakref.ReferenceType)} {isinstance(aa[i][s], weakref.ReferenceType)} {isinstance(aa[i][s][k], weakref.ReferenceType)} {i} {s} {k} ')
                print(f'{aa[i].__class__.__name__} {aa[i][s].__class__.__name__} {aa[i][s][k].__class__.__name__}  ')
                if aa[i][s][k].__class__.__name__ == "method":
                    remove_keys.append((i,s,k))

    for i,s,k in remove_keys:
        aa[i][s].pop(k)
    return(aa)

aa = remove_weakrefs(scheduler.get_state())

torch.save(aa, filebase + '_optim.torch')  # this works
@sjfleming
Copy link
Member

@tzeitim thank you very much for writing in with this! This is the same problem I was running into when trying to move to pytorch 2.0.0, and I have not yet been able to figure it out either. You got farther than I did! So thank you, I appreciate it.

Thanks for pointing out the ilia-kats fix for the pyro issue, I had not seen that yet, and that seems promising.

I hadn't run into that optimizer saving issue yet, but I did run into another weakref pickling issue here:
pyro-ppl/pyro#3201

I wonder what the deal is with this weakref stuff in pytorch 2.0.0. They must have refactored some things in a way I don't understand. I wonder why I'm seeing it now with v2.0.0, but never saw it before? But that pytorch issue you linked seems like the right thing.

It seems like this agrees with your fix, and I think I might try this out
https://github.com/numenta/nupic.research/pull/328/files

state_dict = self.lr_scheduler.state_dict()
if "anneal_func" in state_dict:
    del state_dict["anneal_func"]

In my own development work, I am currently still using python 3.7 with pytorch < 2.0.0, for the reasons you pointed out. I will be working on these kinds of fixes on the sf_dev_0.3.0_postreg_python3.8 branch to enable python 3.8 and pytorch 2.0.0 compatibility in the future

@sjfleming sjfleming self-assigned this May 8, 2023
@sjfleming sjfleming added the bug Something isn't working label May 8, 2023
@sjfleming
Copy link
Member

Here's some tracking for this stuff:
#203

@tzeitim
Copy link
Author

tzeitim commented May 8, 2023

Hi @sjfleming - Thanks for your answer and the references. To be honest I was just lucky to find that solution ilia-kats, it was just a couple days old when I found it.

Regarding the source of this issue ... I don't know ... I've spent a lot of time trying to understand the chain of events that lead to a weakref without any success. I'll report of any progress when possible.

I am glad I could help you save some time or identify potential solutions for the future.

Keep up the great work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or improvement refactoring Code refactor
Projects
None yet
Development

No branches or pull requests

2 participants