-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
55bc194
commit 7a983f7
Showing
8 changed files
with
320 additions
and
192 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class ToyBenchmark(ABC): | ||
def __init__(self, device: str = "cpu", *args, **kwargs): | ||
""" | ||
I think here it would be nice to pass a general receipt for the downstream task construction. | ||
For example, we could pass | ||
- a dataset constructor that generates the dataset for training from the original | ||
dataset (either by modifying the labels, the data, or removing some samples); | ||
- a metric that generates the final score: it could be either a Metric object from our library, or maybe | ||
accuracy comparison. | ||
:param device: | ||
:param args: | ||
:param kwargs: | ||
""" | ||
self.device = device | ||
|
||
@classmethod | ||
@abstractmethod | ||
def generate(cls, *args, **kwargs): | ||
""" | ||
This method should generate all the benchmark components and persist them in the instance. | ||
""" | ||
raise NotImplementedError | ||
|
||
@classmethod | ||
@abstractmethod | ||
def load(cls, path: str, *args, **kwargs): | ||
""" | ||
This method should load the benchmark components from a file and persist them in the instance. | ||
""" | ||
raise NotImplementedError | ||
|
||
@classmethod | ||
@abstractmethod | ||
def assemble(cls, *args, **kwargs): | ||
""" | ||
This method should assemble the benchmark components from arguments and persist them in the instance. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def save(self, *args, **kwargs): | ||
""" | ||
This method should save the benchmark components to a file/folder. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def evaluate( | ||
self, | ||
*args, | ||
**kwargs, | ||
): | ||
""" | ||
Used to update the metric with new data. | ||
""" | ||
|
||
raise NotImplementedError |
Oops, something went wrong.