Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trak wrapper #106

Closed
wants to merge 8 commits into from
Closed

Trak wrapper #106

wants to merge 8 commits into from

Conversation

gumityolcu
Copy link
Collaborator

@gumityolcu gumityolcu commented Aug 13, 2024

Helloooo

This is TRAK wrapper FOR A SINGLE MODEL (explanation below) and fixes about imports. Closes #101

1- TRAK accepts multiple independently trained models and takes average of explanations (they do averaging in a nontrivial way, check paper please). It is not too much of a hassle to accept several checkpoints for TRAK. But then, what would retraining based evaluation correspond to? What about model randomization? Currently, TRAK accepts 1 model that has already been loaded with a checkpoint. So currently we only support post-hoc TRAK.

2- @dilyabareeva I saw you changed the import strategy because isort was complaining. I changed the configurations such that isort ignores init.py files. This solves the isort problem. I corrected all imports with this PR

Copy link
Owner

@dilyabareeva dilyabareeva left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @gumityolcu, thanks for your work! Outside of my comments, could you add tests and doc strings for TRAK?

pyproject.toml Show resolved Hide resolved
quanda/explainers/wrappers/trak_wrapper.py Outdated Show resolved Hide resolved

num_params_for_grad = 0
params_iter = params_ldr if params_ldr is not None else self.model.parameters()
for p in list(params_iter):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we sure this is a robust way to count parameters?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is very nitpicky, but I'm sure there are way more elegant ways to do 43 - 47

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also count parameters of a given parameter loader, not necessarily all the parameters in the model. User may want to only consider final layer etc.

quanda/explainers/wrappers/trak_wrapper.py Outdated Show resolved Hide resolved
quanda/explainers/wrappers/trak_wrapper.py Outdated Show resolved Hide resolved
quanda/explainers/wrappers/trak_wrapper.py Outdated Show resolved Hide resolved

def explain(self, test, targets):
test = test.to(self.device)
self.traker.start_scoring_checkpoint(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we make exp_namean argument to init? it can be test by default.

also I didn't look much into the trak library - but does it make more sense to do start_scoring_checkpoint in init as well? it might be doing something intense that we don't want to repeat every time we explain, for all I know

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am assuming we decided to keep this as is. you can make the change very quickly if you decide to do it.

warnings.warn("Defaulting to BasicProjector.")
projector = "basic"

projector_cls = {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this not outside of the class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, forgot to talk about this during our meeting. For this, I think along these lines: the only reason for that projector_cls dictionary to exist, is that it let's us quickly implement if statements. So if we were looking at it in terms of where it is used, why we have that dictionary, it's only a local thing for this explainer so I wouldn't define it outside the class..

of course, nothing changes practically if we define it outside so i let you decide if you definitely want to keep it outside, or are convinced by my explanation.

@gumityolcu gumityolcu closed this Aug 16, 2024
@gumityolcu gumityolcu mentioned this pull request Aug 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

TRAK wrapper
2 participants