Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
iris-garden committed Aug 1, 2024
1 parent 21ad23a commit a1d2ab6
Show file tree
Hide file tree
Showing 8 changed files with 2,599 additions and 0 deletions.
2,095 changes: 2,095 additions & 0 deletions hail/python/hail/docs/tutorials/10-ggplot2.ipynb

Large diffs are not rendered by default.

190 changes: 190 additions & 0 deletions hail/python/hail/ggplot2/.ipynb_checkpoints/ggplot2-checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from dataclasses import asdict, replace
from textwrap import dedent, indent
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

from altair import X2, Chart, LayerChart, X, Y
from pandas import DataFrame

import hail as hl
from hail import MatrixTable, Table
from hail.ggplot2.utils import typeguard_dataclass


### types ###
Data = Union[Table, MatrixTable]


@typeguard_dataclass
class Mapping:
x: Optional[str]
y: Optional[str]
# TODO add the rest of the supported aesthetic names
color: Optional[str]


Geom = Literal["bar", "line", "circle"]
Stat = Literal["identity", "bin"]


@typeguard_dataclass
class Layer:
mapping: Mapping
data: Optional[Data]
geom: Optional[Geom]
stat: Stat
# FIXME if there's only one type per param name we can make this a typeddict
params: Dict[str, Any]


@typeguard_dataclass
class Plot:
data: Optional[Data]
mapping: Mapping
layers: list[Layer]


### module-level variables ###
_plot_cache: Dict[int, List[Plot]] = {}
_stat_cache: Dict[Tuple[int, ...], Data] = {}


### constructor functions ###
def aes(x: Optional[str] = None, y: Optional[str] = None, color: Optional[str] = None) -> Mapping:
return Mapping(x, y, color)


def geom_histogram(mapping: Mapping = aes(), data: Optional[Data] = None, bins: int = 30) -> Layer:
return Layer(mapping, data, "bar", "bin", {"bins": bins})


def geom_line(mapping: Mapping = aes(), data: Optional[Data] = None) -> Layer:
return Layer(mapping, data, "line", "identity", {})


def geom_point(mapping: Mapping = aes(), data: Optional[Data] = None) -> Layer:
return Layer(mapping, data, "circle", "identity", {})


def ggplot(data: Optional[Data] = None, mapping: Mapping = aes()) -> Plot:
global _plot_cache
new_plot = Plot(data, mapping, [])
_plot_cache |= {id(new_plot): []}
return new_plot


### functionality ###
def extend(plot: Plot, other: Any) -> Plot:
global _plot_cache
kwargs: Optional[Dict[str, Any]] = None
if isinstance(other, Mapping):
kwargs = {
"mapping": replace(
plot.mapping,
**{k: v for k, v in {"x": other.x, "y": other.y, "color": other.color}.items() if v is not None},
)
}
elif isinstance(other, Layer):
kwargs = {"layers": [*plot.layers, other]}

if kwargs is None:
raise ValueError("unsupported addition to plot")

new_plot = replace(plot, **kwargs)
_plot_cache |= {id(new_plot): _plot_cache[id(plot)] + [plot]}
_plot_cache = {k: v for k, v in _plot_cache.items() if k != id(plot)}
return new_plot


setattr(Plot, "__add__", extend)


_altair_configure_mark_keys = {"color"}
_altair_encode_keys = {"x": X, "x2": X2, "y": Y}


def show(plot: Plot) -> Union[Chart, LayerChart]:
global _stat_cache
base_chart = None
for layer in plot.layers:
mapping_dict = {}
for mapping in [plot.mapping, layer.mapping]:
mapping_dict = {**mapping_dict, **{k: v for k, v in asdict(mapping).items() if v is not None}}
# TODO should we break the stat stuff out to its own function?
kwargs = {"x": {}, "x2": {}, "y": {}}
cached = _stat_cache.get((id(plot.data), layer.stat), None)
if cached is not None:
data, df = cached
elif layer.stat == "identity":
data = plot.data
df = data.to_pandas()
elif layer.stat == "bin":
# TODO add caching
x = mapping_dict.get("x", None)
if x is None:
raise ValueError("x must be supplied for stat bin")
data = plot.data.aggregate(
hl.agg.hist(
plot.data[x],
plot.data.aggregate(hl.agg.min(plot.data[x])),
plot.data.aggregate(hl.agg.max(plot.data[x])),
layer.params["bins"],
)
)
df = DataFrame([
{x: data["bin_edges"][i], "x2": data["bin_edges"][i + 1], "y": data["bin_freq"][i]}
for i in range(len(data["bin_freq"]))
])
kwargs["x"] = {"bin": "binned"}
mapping_dict["x2"] = "x2"
mapping_dict["y"] = "y"
else:
raise ValueError("unknown stat")
_stat_cache |= {(id(plot.data), layer.stat): (data, df)}
chart = Chart(df)
if layer.geom is not None:
chart = getattr(chart, f"mark_{layer.geom}")(**{
k: v for k, v in mapping_dict.items() if k in _altair_configure_mark_keys
})
chart = chart.encode(**{
k: _altair_encode_keys[k](v, **kwargs[k]) for k, v in mapping_dict.items() if k in _altair_encode_keys
})
base_chart = chart if base_chart is None else base_chart + chart
return base_chart


def undo(plot: Plot, *, depth: int = 1) -> Plot:
global _plot_cache
old_plot = _plot_cache[id(plot)][0 - depth]
_plot_cache |= {id(old_plot): _plot_cache[id(plot)][: 0 - depth]}
_plot_cache = {k: v for k, v in _plot_cache.items() if k != id(plot)}
return old_plot


## introspection ##
def plot_to_string(plot: Plot) -> str:
return dedent(f"""\
Plot(
data = {plot.data},
mapping = {indent_tail(str(plot.mapping), 3)},
layers = {indent_tail(str(plot.layers), 3)},
)""")


def indent_tail(string: str, indent_level: int = 1) -> str:
return "".join([
indent(part, " " * indent_level) if index == 2 else part for index, part in enumerate(string.partition("\n"))
])


setattr(Plot, "__str__", plot_to_string)


def mapping_to_string(mapping: Mapping) -> str:
return dedent(f"""\
Mapping(
x = {mapping.x},
y = {mapping.y},
)""")


setattr(Mapping, "__str__", mapping_to_string)
29 changes: 29 additions & 0 deletions hail/python/hail/ggplot2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typeguard import install_import_hook

install_import_hook("hail.ggplot2")

# These imports need to be placed after the import hook in order for typechecking to work.
# https://typeguard.readthedocs.io/en/stable/userguide.html#using-the-import-hook
from .altair_wrapper import ChartWrapper # noqa: E402
from .ggplot2 import ( # noqa: E402
aes,
extend,
geom_histogram,
geom_line,
geom_point,
ggplot,
show,
undo,
)

__all__ = [
"ChartWrapper",
"aes",
"extend",
"geom_point",
"geom_line",
"geom_histogram",
"ggplot",
"undo",
"show",
]
46 changes: 46 additions & 0 deletions hail/python/hail/ggplot2/altair_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Any, Union

from altair import Chart
from pandas import DataFrame

import hail
from hail import MatrixTable, Table

Data = Union[Table, MatrixTable]


class ChartWrapper:
def __init__(self, data: Data, *args, **kwargs) -> None:
self.chart_args = args
self.chart_kwargs = kwargs
self.data = data

def __setattr__(self, name: str, value: Any) -> None:
super().__setattr__(name, value)
if name == "data":
self.update_data()

def update_data(self) -> None:
self.cache = {}
self.chart = Chart(self.data.to_pandas(), *self.chart_args, **self.chart_kwargs)

def histogram(self, x: str, bins: int = 30) -> None:
if (aggregated := self.cache.get("histogram", None)) is None:
self.cache["histogram"] = (
aggregated := self.data.aggregate(
hail.agg.hist(
self.data[x],
self.data.aggregate(hail.agg.min(self.data[x])),
self.data.aggregate(hail.agg.max(self.data[x])),
bins,
)
)
)
self.chart = Chart(
DataFrame([
{"x": aggregated["bin_edges"][i], "x2": aggregated["bin_edges"][i + 1], "y": aggregated["bin_freq"][i]}
for i in range(len(aggregated["bin_freq"]))
]),
*self.chart_args,
**self.chart_kwargs,
)
Loading

0 comments on commit a1d2ab6

Please sign in to comment.