-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trak wrapper #106
Trak wrapper #106
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @gumityolcu, thanks for your work! Outside of my comments, could you add tests and doc strings for TRAK?
|
||
num_params_for_grad = 0 | ||
params_iter = params_ldr if params_ldr is not None else self.model.parameters() | ||
for p in list(params_iter): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we sure this is a robust way to count parameters?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is very nitpicky, but I'm sure there are way more elegant ways to do 43 - 47
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also count parameters of a given parameter loader, not necessarily all the parameters in the model. User may want to only consider final layer etc.
|
||
def explain(self, test, targets): | ||
test = test.to(self.device) | ||
self.traker.start_scoring_checkpoint( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we make exp_name
an argument to init? it can be test
by default.
also I didn't look much into the trak library - but does it make more sense to do start_scoring_checkpoint in init as well? it might be doing something intense that we don't want to repeat every time we explain, for all I know
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am assuming we decided to keep this as is. you can make the change very quickly if you decide to do it.
warnings.warn("Defaulting to BasicProjector.") | ||
projector = "basic" | ||
|
||
projector_cls = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this not outside of the class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, forgot to talk about this during our meeting. For this, I think along these lines: the only reason for that projector_cls dictionary to exist, is that it let's us quickly implement if statements. So if we were looking at it in terms of where it is used, why we have that dictionary, it's only a local thing for this explainer so I wouldn't define it outside the class..
of course, nothing changes practically if we define it outside so i let you decide if you definitely want to keep it outside, or are convinced by my explanation.
Helloooo
This is TRAK wrapper FOR A SINGLE MODEL (explanation below) and fixes about imports. Closes #101
1- TRAK accepts multiple independently trained models and takes average of explanations (they do averaging in a nontrivial way, check paper please). It is not too much of a hassle to accept several checkpoints for TRAK. But then, what would retraining based evaluation correspond to? What about model randomization? Currently, TRAK accepts 1 model that has already been loaded with a checkpoint. So currently we only support post-hoc TRAK.
2- @dilyabareeva I saw you changed the import strategy because isort was complaining. I changed the configurations such that isort ignores init.py files. This solves the isort problem. I corrected all imports with this PR