Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRAK wrapper #101

Closed
dilyabareeva opened this issue Aug 6, 2024 · 3 comments · Fixed by #120
Closed

TRAK wrapper #101

dilyabareeva opened this issue Aug 6, 2024 · 3 comments · Fixed by #120
Assignees

Comments

@dilyabareeva
Copy link
Owner

No description provided.

@gumityolcu
Copy link
Collaborator

Hello.

Some issues about the current state of the TRAK wrapper

TRAKer object (the underlying explainer we are interfacing with) works as follows:

1 - We "featurize" the training data given a model with model_id. This is done by iterating through checkpoints we 1-) load_checkpoint(), 2-) featurize() training dataset and 3-)finalize_features().
2- We then "score" the test data: Iterating through checkpoints, 1-)start_scoring_checkpoint() 2-score() test batch and 3-) finalize_scores()

once you finalize_scores(), you can not use score() until you start_scoring_checkpoint(). You can not get explanations without finalizing scores, unless you make a subclass of the underlying TRAK explainer.

Basically it wants you to go through all the test dataset, and get explanations in the end, and it is being clever by caching everything.

This causes the problem that:

  • every batch will return the same explanations, becasue it is batched and the whole process is started from scratch with each explanation call

  • when you destroy an object, if you create a new one with the same model id and cache folder, it will use cached explanations

so we either need some garbage collection and creating new cache folders with new "experiment_name"s. Any straightforward solutions i tried failed: delete corresponding cache file or small changes in the wrapper logic.

@gumityolcu
Copy link
Collaborator

That's why I closed the PR #106

@dilyabareeva
Copy link
Owner Author

@gumityolcu I have investigated the caching issue in more detail. The issue stems from TRAK using a memory-mapped numpy array saver. When we call start_scoring_checkpoint a new sample indices count is initiated and the results are saved to a specific memory address (on disk) that only depends on those indices. When we call start_scoring_checkpoint anew, the saving address still remains the same as for the previously calculated batches. So the old results are overwritten on disk. The explanations are being returned in this memory-mapped format, referring directly to a disk memory address, which leads to our issues.

Luckily, this is easily resolved if we allocate new memory for the explanations by calling copy.deepcopy 😄

TRAK Library allows a saver of their AbstractSaver type to be passed to a TRAKer instance. So I think a long-term better solution is to write our own saver implementation, which is not memory-mapped. I will open an issue for that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants