-
Notifications
You must be signed in to change notification settings - Fork 249
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
21ad23a
commit a1d2ab6
Showing
8 changed files
with
2,599 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
190 changes: 190 additions & 0 deletions
190
hail/python/hail/ggplot2/.ipynb_checkpoints/ggplot2-checkpoint.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
from dataclasses import asdict, replace | ||
from textwrap import dedent, indent | ||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union | ||
|
||
from altair import X2, Chart, LayerChart, X, Y | ||
from pandas import DataFrame | ||
|
||
import hail as hl | ||
from hail import MatrixTable, Table | ||
from hail.ggplot2.utils import typeguard_dataclass | ||
|
||
|
||
### types ### | ||
Data = Union[Table, MatrixTable] | ||
|
||
|
||
@typeguard_dataclass | ||
class Mapping: | ||
x: Optional[str] | ||
y: Optional[str] | ||
# TODO add the rest of the supported aesthetic names | ||
color: Optional[str] | ||
|
||
|
||
Geom = Literal["bar", "line", "circle"] | ||
Stat = Literal["identity", "bin"] | ||
|
||
|
||
@typeguard_dataclass | ||
class Layer: | ||
mapping: Mapping | ||
data: Optional[Data] | ||
geom: Optional[Geom] | ||
stat: Stat | ||
# FIXME if there's only one type per param name we can make this a typeddict | ||
params: Dict[str, Any] | ||
|
||
|
||
@typeguard_dataclass | ||
class Plot: | ||
data: Optional[Data] | ||
mapping: Mapping | ||
layers: list[Layer] | ||
|
||
|
||
### module-level variables ### | ||
_plot_cache: Dict[int, List[Plot]] = {} | ||
_stat_cache: Dict[Tuple[int, ...], Data] = {} | ||
|
||
|
||
### constructor functions ### | ||
def aes(x: Optional[str] = None, y: Optional[str] = None, color: Optional[str] = None) -> Mapping: | ||
return Mapping(x, y, color) | ||
|
||
|
||
def geom_histogram(mapping: Mapping = aes(), data: Optional[Data] = None, bins: int = 30) -> Layer: | ||
return Layer(mapping, data, "bar", "bin", {"bins": bins}) | ||
|
||
|
||
def geom_line(mapping: Mapping = aes(), data: Optional[Data] = None) -> Layer: | ||
return Layer(mapping, data, "line", "identity", {}) | ||
|
||
|
||
def geom_point(mapping: Mapping = aes(), data: Optional[Data] = None) -> Layer: | ||
return Layer(mapping, data, "circle", "identity", {}) | ||
|
||
|
||
def ggplot(data: Optional[Data] = None, mapping: Mapping = aes()) -> Plot: | ||
global _plot_cache | ||
new_plot = Plot(data, mapping, []) | ||
_plot_cache |= {id(new_plot): []} | ||
return new_plot | ||
|
||
|
||
### functionality ### | ||
def extend(plot: Plot, other: Any) -> Plot: | ||
global _plot_cache | ||
kwargs: Optional[Dict[str, Any]] = None | ||
if isinstance(other, Mapping): | ||
kwargs = { | ||
"mapping": replace( | ||
plot.mapping, | ||
**{k: v for k, v in {"x": other.x, "y": other.y, "color": other.color}.items() if v is not None}, | ||
) | ||
} | ||
elif isinstance(other, Layer): | ||
kwargs = {"layers": [*plot.layers, other]} | ||
|
||
if kwargs is None: | ||
raise ValueError("unsupported addition to plot") | ||
|
||
new_plot = replace(plot, **kwargs) | ||
_plot_cache |= {id(new_plot): _plot_cache[id(plot)] + [plot]} | ||
_plot_cache = {k: v for k, v in _plot_cache.items() if k != id(plot)} | ||
return new_plot | ||
|
||
|
||
setattr(Plot, "__add__", extend) | ||
|
||
|
||
_altair_configure_mark_keys = {"color"} | ||
_altair_encode_keys = {"x": X, "x2": X2, "y": Y} | ||
|
||
|
||
def show(plot: Plot) -> Union[Chart, LayerChart]: | ||
global _stat_cache | ||
base_chart = None | ||
for layer in plot.layers: | ||
mapping_dict = {} | ||
for mapping in [plot.mapping, layer.mapping]: | ||
mapping_dict = {**mapping_dict, **{k: v for k, v in asdict(mapping).items() if v is not None}} | ||
# TODO should we break the stat stuff out to its own function? | ||
kwargs = {"x": {}, "x2": {}, "y": {}} | ||
cached = _stat_cache.get((id(plot.data), layer.stat), None) | ||
if cached is not None: | ||
data, df = cached | ||
elif layer.stat == "identity": | ||
data = plot.data | ||
df = data.to_pandas() | ||
elif layer.stat == "bin": | ||
# TODO add caching | ||
x = mapping_dict.get("x", None) | ||
if x is None: | ||
raise ValueError("x must be supplied for stat bin") | ||
data = plot.data.aggregate( | ||
hl.agg.hist( | ||
plot.data[x], | ||
plot.data.aggregate(hl.agg.min(plot.data[x])), | ||
plot.data.aggregate(hl.agg.max(plot.data[x])), | ||
layer.params["bins"], | ||
) | ||
) | ||
df = DataFrame([ | ||
{x: data["bin_edges"][i], "x2": data["bin_edges"][i + 1], "y": data["bin_freq"][i]} | ||
for i in range(len(data["bin_freq"])) | ||
]) | ||
kwargs["x"] = {"bin": "binned"} | ||
mapping_dict["x2"] = "x2" | ||
mapping_dict["y"] = "y" | ||
else: | ||
raise ValueError("unknown stat") | ||
_stat_cache |= {(id(plot.data), layer.stat): (data, df)} | ||
chart = Chart(df) | ||
if layer.geom is not None: | ||
chart = getattr(chart, f"mark_{layer.geom}")(**{ | ||
k: v for k, v in mapping_dict.items() if k in _altair_configure_mark_keys | ||
}) | ||
chart = chart.encode(**{ | ||
k: _altair_encode_keys[k](v, **kwargs[k]) for k, v in mapping_dict.items() if k in _altair_encode_keys | ||
}) | ||
base_chart = chart if base_chart is None else base_chart + chart | ||
return base_chart | ||
|
||
|
||
def undo(plot: Plot, *, depth: int = 1) -> Plot: | ||
global _plot_cache | ||
old_plot = _plot_cache[id(plot)][0 - depth] | ||
_plot_cache |= {id(old_plot): _plot_cache[id(plot)][: 0 - depth]} | ||
_plot_cache = {k: v for k, v in _plot_cache.items() if k != id(plot)} | ||
return old_plot | ||
|
||
|
||
## introspection ## | ||
def plot_to_string(plot: Plot) -> str: | ||
return dedent(f"""\ | ||
Plot( | ||
data = {plot.data}, | ||
mapping = {indent_tail(str(plot.mapping), 3)}, | ||
layers = {indent_tail(str(plot.layers), 3)}, | ||
)""") | ||
|
||
|
||
def indent_tail(string: str, indent_level: int = 1) -> str: | ||
return "".join([ | ||
indent(part, " " * indent_level) if index == 2 else part for index, part in enumerate(string.partition("\n")) | ||
]) | ||
|
||
|
||
setattr(Plot, "__str__", plot_to_string) | ||
|
||
|
||
def mapping_to_string(mapping: Mapping) -> str: | ||
return dedent(f"""\ | ||
Mapping( | ||
x = {mapping.x}, | ||
y = {mapping.y}, | ||
)""") | ||
|
||
|
||
setattr(Mapping, "__str__", mapping_to_string) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from typeguard import install_import_hook | ||
|
||
install_import_hook("hail.ggplot2") | ||
|
||
# These imports need to be placed after the import hook in order for typechecking to work. | ||
# https://typeguard.readthedocs.io/en/stable/userguide.html#using-the-import-hook | ||
from .altair_wrapper import ChartWrapper # noqa: E402 | ||
from .ggplot2 import ( # noqa: E402 | ||
aes, | ||
extend, | ||
geom_histogram, | ||
geom_line, | ||
geom_point, | ||
ggplot, | ||
show, | ||
undo, | ||
) | ||
|
||
__all__ = [ | ||
"ChartWrapper", | ||
"aes", | ||
"extend", | ||
"geom_point", | ||
"geom_line", | ||
"geom_histogram", | ||
"ggplot", | ||
"undo", | ||
"show", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from typing import Any, Union | ||
|
||
from altair import Chart | ||
from pandas import DataFrame | ||
|
||
import hail | ||
from hail import MatrixTable, Table | ||
|
||
Data = Union[Table, MatrixTable] | ||
|
||
|
||
class ChartWrapper: | ||
def __init__(self, data: Data, *args, **kwargs) -> None: | ||
self.chart_args = args | ||
self.chart_kwargs = kwargs | ||
self.data = data | ||
|
||
def __setattr__(self, name: str, value: Any) -> None: | ||
super().__setattr__(name, value) | ||
if name == "data": | ||
self.update_data() | ||
|
||
def update_data(self) -> None: | ||
self.cache = {} | ||
self.chart = Chart(self.data.to_pandas(), *self.chart_args, **self.chart_kwargs) | ||
|
||
def histogram(self, x: str, bins: int = 30) -> None: | ||
if (aggregated := self.cache.get("histogram", None)) is None: | ||
self.cache["histogram"] = ( | ||
aggregated := self.data.aggregate( | ||
hail.agg.hist( | ||
self.data[x], | ||
self.data.aggregate(hail.agg.min(self.data[x])), | ||
self.data.aggregate(hail.agg.max(self.data[x])), | ||
bins, | ||
) | ||
) | ||
) | ||
self.chart = Chart( | ||
DataFrame([ | ||
{"x": aggregated["bin_edges"][i], "x2": aggregated["bin_edges"][i + 1], "y": aggregated["bin_freq"][i]} | ||
for i in range(len(aggregated["bin_freq"])) | ||
]), | ||
*self.chart_args, | ||
**self.chart_kwargs, | ||
) |
Oops, something went wrong.