Skip to content

Commit

Permalink
Deprecating ElementType and doc fixes (#12)
Browse files Browse the repository at this point in the history
* deprecate ElemenType basically everywhere

* - removed tqdm and prints from the repo because they break the docs (consider removing the dependency)
- minor fixes to text in notebooks
  • Loading branch information
guybuk authored Aug 20, 2024
1 parent 68ce3c5 commit 17de994
Show file tree
Hide file tree
Showing 30 changed files with 259 additions and 293 deletions.
17 changes: 7 additions & 10 deletions bridge/display/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pandas as pd
import panel as pn

from bridge.primitives.element.element_type import ElementType
from bridge.primitives.sample.singular_sample import SingularSample

if TYPE_CHECKING:
Expand All @@ -14,9 +13,9 @@

class TextClassification(DisplayEngine[SingularDataset, SingularSample]):
def show_element(self, element: Element, element_plot_kwargs: Dict[str, Any] | None = None):
if element.etype == ElementType.class_label:
if element.etype == "class_label":
return pn.pane.Markdown(element.to_pd_series().to_frame().T.to_markdown())
elif element.etype == ElementType.text:
elif element.etype == "text":
return pn.pane.Markdown(element.data)
else:
raise NotImplementedError()
Expand All @@ -27,9 +26,7 @@ def show_sample(
element_plot_kwargs: Dict[str, Any] | None = None,
sample_plot_kwargs: Dict[str, Any] | None = None,
):
annotations_md = pd.DataFrame(
[ann.to_pd_series() for ann in sample.annotations[ElementType.class_label]]
).to_markdown()
annotations_md = pd.DataFrame([ann.to_pd_series() for ann in sample.annotations["class_label"]]).to_markdown()
text_display = pn.pane.Markdown(sample.data)
return pn.Column("# Sample Text:", text_display, "# Annotations Data:", annotations_md)

Expand All @@ -52,9 +49,9 @@ def plot_sample_by_widget(sample_id):

# class Panel(DisplayEngine):
# def show_element(self, element: Element, element_plot_kwargs: Dict[str, Any] | None = None):
# if element.etype == ElementType.class_label:
# if element.etype == "class_label":
# return self._show_class_label(element, element_plot_kwargs)
# elif element.etype == ElementType.text:
# elif element.etype == "text":
# return self._show_text(element, element_plot_kwargs)
# else:
# raise NotImplementedError()
Expand All @@ -66,9 +63,9 @@ def plot_sample_by_widget(sample_id):
# sample_plot_kwargs: Dict[str, Any] | None = None,
# ):
# annotations_md = pd.DataFrame(
# [ann.to_pd_series() for ann in sample.elements[ElementType.class_label]]
# [ann.to_pd_series() for ann in sample.elements["class_label"]]
# ).to_markdown()
# text_display = self.show_element(sample.elements[ElementType.text][0])
# text_display = self.show_element(sample.elements["text"][0])
# return pn.Column("# Sample Text:", text_display, "# Annotations Data:", annotations_md)
#
# def show_dataset(
Expand Down
27 changes: 13 additions & 14 deletions bridge/display/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pandas as pd

from bridge.display.display_engine import DisplayEngine
from bridge.primitives.element.element_type import ElementType
from bridge.primitives.sample import Sample
from bridge.utils import optional_dependencies

Expand All @@ -24,9 +23,9 @@ def __init__(self, bbox_format: str = "xyxy") -> None:

def show_element(self, element: Element, element_plot_kwargs: Dict[str, Any] | None = None):
etype = element.etype
if etype == ElementType.image:
if etype == "image":
plot = self._plot_single_image(element)
elif etype == ElementType.bbox:
elif etype == "bbox":
plot = self._plot_single_bbox(element)
else:
raise NotImplementedError(f"Invalid etype: {etype}")
Expand All @@ -42,13 +41,13 @@ def show_sample(
):
import holoviews as hv

imgs = [self._plot_single_image(element) for element in sample.elements[ElementType.image]]
if ElementType.bbox in sample.elements:
bboxes = self._plot_list_of_bbox_or_class_labels(sample.elements[ElementType.bbox])
imgs = [self._plot_single_image(element) for element in sample.elements["image"]]
if "bbox" in sample.elements:
bboxes = self._plot_list_of_bbox_or_class_labels(sample.elements["bbox"])
else:
bboxes = hv.Overlay()
if ElementType.class_label in sample.elements:
class_labels = self._plot_list_of_bbox_or_class_labels(sample.elements[ElementType.class_label])
if "class_label" in sample.elements:
class_labels = self._plot_list_of_bbox_or_class_labels(sample.elements["class_label"])
else:
class_labels = hv.Overlay()
for i in range(len(imgs)):
Expand Down Expand Up @@ -104,10 +103,10 @@ def _plot_list_of_bbox_or_class_labels(self, elements: List[Element]):
rectangle_list = []
for element in elements:
data: BoundingBox = element.data
if element.etype == ElementType.bbox:
if element.etype == "bbox":
xyxy = self._extract_bbox_coords(data)
cl = data.class_label
elif element.etype == ElementType.class_label: # assume cls
elif element.etype == "class_label": # assume cls
xyxy = [np.nan, np.nan, np.nan, np.nan]
cl = data
else:
Expand All @@ -122,7 +121,7 @@ def _plot_list_of_bbox_or_class_labels(self, elements: List[Element]):
for i, group in hv_df.groupby("class"):
p = hv.Rectangles(group, label=i)
plots.append(p)
plots = hv.Overlay(plots).opts(hv.opts.Rectangles(**self._default_kwargs(ElementType.bbox)))
plots = hv.Overlay(plots).opts(hv.opts.Rectangles(**self._default_kwargs("bbox")))
return plots

def _extract_bbox_coords(self, data):
Expand All @@ -143,12 +142,12 @@ def _extract_bbox_coords(self, data):
return xyxy

@staticmethod
def _default_kwargs(etype: ElementType) -> Dict[str, Any]:
def _default_kwargs(etype: str) -> Dict[str, Any]:
import holoviews as hv

if etype == ElementType.image:
if etype == "image":
return dict(aspect="equal", invert_yaxis=True, legend_position="left", xaxis=None, yaxis=None)
elif etype == ElementType.bbox:
elif etype == "bbox":
return dict(fill_alpha=0.0, line_width=3, line_color=hv.Cycle("Category20"))

@staticmethod
Expand Down
14 changes: 6 additions & 8 deletions bridge/primitives/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, Iterable, Iterator, List, Sequence

import pandas as pd
from tqdm.contrib import tmap
from typing_extensions import Self

from bridge.primitives.dataset.sample_api import SampleAPI
from bridge.primitives.dataset.table_api import TableAPI
from bridge.primitives.element.element_type import ElementType
from bridge.primitives.sample import Sample
from bridge.utils.constants import ELEMENT_COLS, INDICES
from bridge.utils.helper import Displayable
Expand All @@ -27,7 +25,7 @@ def __init__(
self,
elements: pd.DataFrame,
display_engine: DisplayEngine = None,
cache_mechanisms: Dict[ElementType, CacheMechanism | None] | None = None,
cache_mechanisms: Dict[str, CacheMechanism | None] | None = None,
):
self._elements = elements
self._display_engine = display_engine
Expand Down Expand Up @@ -59,7 +57,7 @@ def merge(
self,
other: "Dataset",
display_engine: DisplayEngine | None = None,
cache_mechanisms: Dict[ElementType, CacheMechanism | None] | None = None,
cache_mechanisms: Dict[str, CacheMechanism | None] | None = None,
) -> "Dataset":
self_element_ids = self.elements.index.get_level_values(ELEMENT_COLS.ID)
other_element_ids = other.elements.index.get_level_values(ELEMENT_COLS.ID)
Expand All @@ -86,8 +84,8 @@ def get(self, sample_id: Hashable) -> Sample:
def transform_samples(
self,
transform: SampleTransform,
map_fn=tmap,
cache_mechanisms: Dict[ElementType, CacheMechanism] | None = None,
map_fn=map,
cache_mechanisms: Dict[str, CacheMechanism] | None = None,
display_engine: DisplayEngine | None = None,
) -> Self:
fn = functools.partial(
Expand All @@ -99,7 +97,7 @@ def transform_samples(
elements = [element for sample in samples for e_list in sample.elements.values() for element in e_list]
return Dataset.from_elements(elements, display_engine=display_engine)

def map_samples(self, function: Callable[[Sample], Any], map_fn=tmap):
def map_samples(self, function: Callable[[Sample], Any], map_fn=map):
outputs = map_fn(function, self)
if isinstance(outputs, GeneratorType):
return list(outputs)
Expand Down Expand Up @@ -130,7 +128,7 @@ def from_elements(
cls,
elements: Iterable[Element],
display_engine: DisplayEngine = None,
cache_mechanisms: Dict[ElementType, CacheMechanism | None] | None = None,
cache_mechanisms: Dict[str, CacheMechanism | None] | None = None,
) -> Self:
element_records = [e.to_pd_series() for e in elements]
elements_df = pd.DataFrame(element_records).set_index(INDICES)
Expand Down
8 changes: 3 additions & 5 deletions bridge/primitives/dataset/sample_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
import abc
from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, Iterable, Iterator, Sequence

from tqdm.contrib import tmap
from typing_extensions import Self

if TYPE_CHECKING:
from bridge.display.display_engine import DisplayEngine
from bridge.primitives.element.data.cache_mechanism import CacheMechanism
from bridge.primitives.element.element import Element
from bridge.primitives.element.element_type import ElementType
from bridge.primitives.sample import Sample
from bridge.primitives.sample.transform import SampleTransform

Expand All @@ -28,14 +26,14 @@ def get(self, sample_id: Hashable) -> Sample:
def transform_samples(
self,
transform: SampleTransform,
map_fn=tmap,
cache_mechanisms: Dict[ElementType, CacheMechanism] | None = None,
map_fn=map,
cache_mechanisms: Dict[str, CacheMechanism] | None = None,
display_engine: DisplayEngine | None = None,
) -> Self:
pass

@abc.abstractmethod
def map_samples(self, function: Callable[[Sample], Any], map_fn=tmap) -> Sequence[Sample]:
def map_samples(self, function: Callable[[Sample], Any], map_fn=map) -> Sequence[Sample]:
pass

@abc.abstractmethod
Expand Down
10 changes: 4 additions & 6 deletions bridge/primitives/dataset/singular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING, Callable, Dict, Hashable, List, Sequence

import pandas as pd
from tqdm.contrib import tmap
from typing_extensions import Self

from bridge.primitives.dataset.dataset import Dataset
Expand All @@ -14,7 +13,6 @@
from bridge.display import DisplayEngine
from bridge.primitives.element.data.cache_mechanism import CacheMechanism
from bridge.primitives.element.element import Element
from bridge.primitives.element.element_type import ElementType
from bridge.primitives.sample.transform import SampleTransform


Expand All @@ -31,7 +29,7 @@ def __init__(
samples: pd.DataFrame,
annotations: pd.DataFrame,
display_engine: DisplayEngine = None,
cache_mechanisms: Dict[ElementType, CacheMechanism | None] | None = None,
cache_mechanisms: Dict[str, CacheMechanism | None] | None = None,
):
assert (
len(
Expand Down Expand Up @@ -132,8 +130,8 @@ def sort_annotations(self, by: str, ascending: bool = True):
def transform_samples(
self,
transform: SampleTransform,
map_fn=tmap,
cache_mechanisms: Dict[ElementType, CacheMechanism] | None = None,
map_fn=map,
cache_mechanisms: Dict[str, CacheMechanism] | None = None,
display_engine: DisplayEngine | None = None,
) -> Self:
ds = super().transform_samples(
Expand All @@ -157,7 +155,7 @@ def from_lists(
samples_list: List[Element],
annotations_list: List[Element],
display_engine: DisplayEngine = None,
cache_mechanisms: Dict[ElementType, CacheMechanism | None] | None = None,
cache_mechanisms: Dict[str, CacheMechanism | None] | None = None,
) -> Self:
sample_records = [s.to_dict() for s in samples_list]
annotation_records = [a.to_dict() for a in annotations_list]
Expand Down
5 changes: 2 additions & 3 deletions bridge/primitives/element/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from bridge.display import DisplayEngine
from bridge.primitives.element.data.cache_mechanism import CacheMechanism
from bridge.primitives.element.element_data_type import ELEMENT_DATA_TYPE
from bridge.primitives.element.element_type import ElementType


class Element(Displayable):
Expand All @@ -22,7 +21,7 @@ class Element(Displayable):
def __init__(
self,
element_id: Hashable,
etype: ElementType,
etype: str,
load_mechanism: LoadMechanism,
sample_id: Hashable,
display_engine: DisplayEngine | None = None,
Expand Down Expand Up @@ -55,7 +54,7 @@ def _data_impl(self):
return data

@property
def etype(self) -> ElementType:
def etype(self) -> str:
return self._etype

@property
Expand Down
12 changes: 0 additions & 12 deletions bridge/primitives/element/element_type.py

This file was deleted.

21 changes: 10 additions & 11 deletions bridge/primitives/sample/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from bridge.primitives.element.data.cache_mechanism import CacheMechanism
from bridge.primitives.element.element import Element
from bridge.primitives.element.element_type import ElementType
from bridge.utils.constants import ELEMENT_COLS, INDICES
from bridge.utils.helper import Displayable

Expand All @@ -22,7 +21,7 @@ class Sample(Displayable):

def __init__(
self,
elements: List[Element] | Dict[ElementType, List[Element]],
elements: List[Element] | Dict[str, List[Element]],
display_engine: DisplayEngine | None = None,
):
if isinstance(elements, dict):
Expand All @@ -41,11 +40,11 @@ def id(self) -> Hashable:
return e_list[0].sample_id

@property
def elements(self) -> Dict[ElementType, List[Element]]:
def elements(self) -> Dict[str, List[Element]]:
return self._elements

@property
def data(self) -> Dict[ElementType, List[ELEMENT_DATA_TYPE]]:
def data(self) -> Dict[str, List[ELEMENT_DATA_TYPE]]:
data_dict = defaultdict(list)
for etype, elist in self._elements.items():
data_dict[etype].extend([e.data for e in elist])
Expand All @@ -57,7 +56,7 @@ def show(self, **kwargs: Any):
def transform(
self,
transform: SampleTransform,
cache_mechanisms: Dict[ElementType, CacheMechanism] | None = None,
cache_mechanisms: Dict[str, CacheMechanism] | None = None,
display_engine: DisplayEngine | None = None,
) -> "Sample":
cache_mechanisms = self._get_cache_mechanisms_for_transform(self, cache_mechanisms)
Expand All @@ -71,7 +70,7 @@ def from_pd_dataframe(
cls,
elements_df: pd.DataFrame,
display_engine: DisplayEngine | None,
cache_mechanisms: Dict[ElementType, CacheMechanism | None],
cache_mechanisms: Dict[str, CacheMechanism | None],
):
def fast_to_dict_records(df):
data = df.values.tolist()
Expand All @@ -87,7 +86,7 @@ def fast_to_dict_records(df):
for element_row in fast_to_dict_records(elements_df):
etype = element_row[ELEMENT_COLS.ETYPE]

element_type = ElementType(etype)
element_type = str(etype)
elements.append(
Element.from_dict(
element_row,
Expand All @@ -108,23 +107,23 @@ def __len__(self) -> int:
return sum(map(len, self._elements.values()))

@staticmethod
def _assert_valid_elements(elements: Dict[ElementType, List[Element]]):
def _assert_valid_elements(elements: Dict[str, List[Element]]):
sample_ids_from_elements = set([e.sample_id for e_list in elements.values() for e in e_list])
assert len(sample_ids_from_elements) == 1, (
f"All elements must contain a single sample id,"
f" got {len(sample_ids_from_elements)}: {sample_ids_from_elements}"
)

@staticmethod
def _convert_elements_list_to_dict(elements: List[Element]) -> Dict[ElementType, List[Element]]:
elements_by_type: Dict[ElementType, List[Element]] = defaultdict(list)
def _convert_elements_list_to_dict(elements: List[Element]) -> Dict[str, List[Element]]:
elements_by_type: Dict[str, List[Element]] = defaultdict(list)
for element in elements:
elements_by_type[element.etype].append(element)
d = dict(elements_by_type)
return d

@staticmethod
def _get_cache_mechanisms_for_transform(sample: Sample, cache_mechanisms: Dict[ElementType, CacheMechanism] | None):
def _get_cache_mechanisms_for_transform(sample: Sample, cache_mechanisms: Dict[str, CacheMechanism] | None):
if cache_mechanisms is None:
cache_mechanisms = {}

Expand Down
Loading

0 comments on commit 17de994

Please sign in to comment.