Skip to content

Commit

Permalink
toy benchmark introduction
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Jun 26, 2024
1 parent 55bc194 commit 7a983f7
Show file tree
Hide file tree
Showing 8 changed files with 320 additions and 192 deletions.
34 changes: 0 additions & 34 deletions src/downstream_tasks/base.py

This file was deleted.

133 changes: 0 additions & 133 deletions src/downstream_tasks/subclass_identification.py

This file was deleted.

File renamed without changes.
61 changes: 61 additions & 0 deletions src/toy_benchmarks/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from abc import ABC, abstractmethod


class ToyBenchmark(ABC):
def __init__(self, device: str = "cpu", *args, **kwargs):
"""
I think here it would be nice to pass a general receipt for the downstream task construction.
For example, we could pass
- a dataset constructor that generates the dataset for training from the original
dataset (either by modifying the labels, the data, or removing some samples);
- a metric that generates the final score: it could be either a Metric object from our library, or maybe
accuracy comparison.
:param device:
:param args:
:param kwargs:
"""
self.device = device

@classmethod
@abstractmethod
def generate(cls, *args, **kwargs):
"""
This method should generate all the benchmark components and persist them in the instance.
"""
raise NotImplementedError

@classmethod
@abstractmethod
def load(cls, path: str, *args, **kwargs):
"""
This method should load the benchmark components from a file and persist them in the instance.
"""
raise NotImplementedError

@classmethod
@abstractmethod
def assemble(cls, *args, **kwargs):
"""
This method should assemble the benchmark components from arguments and persist them in the instance.
"""
raise NotImplementedError

@abstractmethod
def save(self, *args, **kwargs):
"""
This method should save the benchmark components to a file/folder.
"""
raise NotImplementedError

@abstractmethod
def evaluate(
self,
*args,
**kwargs,
):
"""
Used to update the metric with new data.
"""

raise NotImplementedError
Loading

0 comments on commit 7a983f7

Please sign in to comment.