Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make plotly-express dataframe agnostic via narwhals #4790

Merged
merged 131 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
131 commits
Select commit Hold shift + click to select a range
9873e97
non core changes
FBruzzesi Sep 28, 2024
0389591
_core overhaul
FBruzzesi Sep 28, 2024
ba93236
some _core fixes
FBruzzesi Sep 28, 2024
421fc1d
tests replace sort_index(axis=1)
FBruzzesi Sep 28, 2024
ca5c820
reset_index in concat and allow any object to pandas
FBruzzesi Sep 28, 2024
a6aab24
trendline prep
FBruzzesi Sep 29, 2024
7665f10
WIP Index
FBruzzesi Sep 29, 2024
ec4f250
clean from breakpoints
FBruzzesi Sep 29, 2024
7e0d4c2
some tests fix
FBruzzesi Sep 29, 2024
5543638
hotfix and tests output to pandas
FBruzzesi Sep 29, 2024
cd0dab7
FIX: columns never as index
FBruzzesi Sep 29, 2024
f334b32
getting there with the tests
FBruzzesi Sep 29, 2024
e5eb949
get_column instead of pandas slicing, unix to seconds
FBruzzesi Sep 30, 2024
7747e30
bump narhwals, hierarchy fastpath
FBruzzesi Oct 1, 2024
ac00b36
fix to_unindexed_series
FBruzzesi Oct 1, 2024
da80c5b
fix trendline
FBruzzesi Oct 1, 2024
8a72ba1
rm numpy dep in _core
FBruzzesi Oct 2, 2024
aeff203
fix: _check_dataframe_all_leaves
FBruzzesi Oct 2, 2024
2041bef
(maybe) fix to_unindexed_series
FBruzzesi Oct 2, 2024
71473f1
(maybe) fix to_unindexed_series
FBruzzesi Oct 2, 2024
9f74c38
started tests with constructor
FBruzzesi Oct 2, 2024
28587c9
added constructor to all tests
FBruzzesi Oct 2, 2024
1bb2448
added some comments for fixme
FBruzzesi Oct 2, 2024
f45addf
to_py_scalar and more tests
FBruzzesi Oct 3, 2024
5341759
dealing with exceptions and tests
FBruzzesi Oct 3, 2024
dfc957c
bump version, sort(...,nulls_last=True)
FBruzzesi Oct 4, 2024
90f2667
We did it: no more dups in group by :D
FBruzzesi Oct 4, 2024
fb58d1b
concat_str
FBruzzesi Oct 5, 2024
ddb3b35
fix test_several_dataframes
FBruzzesi Oct 5, 2024
37ce302
dedups customdata
FBruzzesi Oct 5, 2024
4da8768
getting there
FBruzzesi Oct 6, 2024
210e01a
xfail pyarrow chunked-array because name-less
FBruzzesi Oct 6, 2024
c00525e
all green with edge narhwals
FBruzzesi Oct 6, 2024
3486a3e
add pandas nullable constructors in tests
FBruzzesi Oct 7, 2024
c0ce093
bump narwhals and address todos
FBruzzesi Oct 9, 2024
0eb6951
check narwhals installation
FBruzzesi Oct 9, 2024
844a6a9
rm unused comments
FBruzzesi Oct 9, 2024
0c27789
rm unused code
FBruzzesi Oct 9, 2024
0e6ff78
add pyarrow and narwhals to requirements_39_pandas_2_optional
FBruzzesi Oct 9, 2024
c2337c9
requirements, test requirements optional
FBruzzesi Oct 15, 2024
2cc5d7b
refactor tests
FBruzzesi Oct 15, 2024
1b27487
address feedbacks
FBruzzesi Oct 15, 2024
23a23be
typos
FBruzzesi Oct 15, 2024
7968cff
conftest
FBruzzesi Oct 15, 2024
cf76721
merge master
FBruzzesi Oct 15, 2024
91db84b
mock interchange
FBruzzesi Oct 15, 2024
5c6772e
optional requirements
FBruzzesi Oct 15, 2024
9ec3f9e
move conftest in express folder
FBruzzesi Oct 15, 2024
400a624
hotfix and figure_factory hexbin
FBruzzesi Oct 15, 2024
1aa5163
old versions, polars[timezone], hotfix
FBruzzesi Oct 16, 2024
594ded0
fix frame value in hexbin
FBruzzesi Oct 16, 2024
6676061
copy numpy array
FBruzzesi Oct 16, 2024
d7d2884
hotfix hexbin mapbox
FBruzzesi Oct 16, 2024
d6ee676
Merge branch 'master' into plotly-with-narwhals
FBruzzesi Oct 16, 2024
82c114d
fix test
FBruzzesi Oct 17, 2024
0ceabc1
Merge branch 'plotly:master' into plotly-with-narwhals
FBruzzesi Oct 17, 2024
c9b626e
use lazy in process_dataframe_hierarchy
FBruzzesi Oct 17, 2024
87841d1
fix custom sort in process_dataframe_pie
FBruzzesi Oct 18, 2024
ffa7b3b
Merge branch 'master' into plotly-with-narwhals
archmoj Oct 21, 2024
3ba19ae
bump version and adjust core
FBruzzesi Oct 21, 2024
a70146b
use dtype.is_numeric
FBruzzesi Oct 22, 2024
1fa9fe4
Merge branch 'master' into plotly-with-narwhals
FBruzzesi Oct 22, 2024
0103aa6
revert test
FBruzzesi Oct 22, 2024
673d141
Merge branch 'plotly-with-narwhals' of https://github.com/FBruzzesi/p…
FBruzzesi Oct 22, 2024
b858ed8
feedback adjustments
FBruzzesi Oct 23, 2024
bbcf438
Merge branch 'master' into plotly-with-narwhals
FBruzzesi Oct 23, 2024
49efae2
raise if numpy is missing, conftest fix, typo
FBruzzesi Oct 25, 2024
a36bc24
__plotly_n_unique__
FBruzzesi Oct 25, 2024
c119153
Merge branch 'master' into plotly-with-narwhals
FBruzzesi Oct 25, 2024
7416407
format
FBruzzesi Oct 25, 2024
1867f6f
format
FBruzzesi Oct 25, 2024
d3a28c0
feedback adjustments
FBruzzesi Oct 27, 2024
e6e9994
use drop_null_keys, some pandas fastpaths
MarcoGorelli Oct 25, 2024
64b8c70
bump narwhals version
MarcoGorelli Oct 27, 2024
3f6b383
some improvements by Marco
FBruzzesi Oct 27, 2024
755aea8
format and pyspark path
FBruzzesi Oct 27, 2024
6f18021
add narwhals to requirements core
FBruzzesi Oct 27, 2024
4d62e73
Update packages/python/plotly/plotly/express/_core.py
FBruzzesi Oct 28, 2024
a770fd8
refactor checking for df
MarcoGorelli Oct 29, 2024
7d6f7d6
pushdown only for interchange libraries, sort out test
MarcoGorelli Oct 29, 2024
b8c10ec
Update packages/python/plotly/plotly/express/_core.py
MarcoGorelli Oct 29, 2024
490b64a
fixup
MarcoGorelli Oct 29, 2024
f7fd4c9
Merge remote-tracking branch 'origin/plotly-with-narwhals' into plotl…
MarcoGorelli Oct 29, 2024
8753acb
lint
MarcoGorelli Oct 29, 2024
1429e6f
bump narwhals version
MarcoGorelli Oct 29, 2024
878d4db
refactor checking for df and bump version
FBruzzesi Oct 29, 2024
192e0a8
use token in process_dataframe_hierarchy
FBruzzesi Oct 29, 2024
de6761c
Range(label=...) for px.funnel
FBruzzesi Oct 29, 2024
bcfef68
improve error message and in-line comments
FBruzzesi Oct 30, 2024
519cc68
better comments
FBruzzesi Oct 30, 2024
e5520a7
rm unused import and fix typo
FBruzzesi Oct 31, 2024
b855352
Merge branch 'master' into plotly-with-narwhals
FBruzzesi Oct 31, 2024
51e2b23
make sure column + token is unique, replace **{} with .alias()
FBruzzesi Oct 31, 2024
7ef9f28
WIP
FBruzzesi Oct 31, 2024
e9a367d
WIP
FBruzzesi Oct 31, 2024
12fed31
Merge branch 'master' into plotly-with-narwhals
FBruzzesi Nov 1, 2024
27b2996
use nw.get_native_namespace
FBruzzesi Nov 1, 2024
f27f959
Merge branch 'plotly-with-narwhals' of https://github.com/FBruzzesi/p…
FBruzzesi Nov 1, 2024
126a79d
Merge branch 'master' into feat/dataframe-agnostic-data
FBruzzesi Nov 1, 2024
7735366
add narwhals in various requirements
FBruzzesi Nov 1, 2024
b6516b4
docstrings
FBruzzesi Nov 1, 2024
6f1389f
rm type hints, change post_agg to use alias
FBruzzesi Nov 1, 2024
db22268
feedback adjustments
FBruzzesi Nov 1, 2024
b514c01
move imports out, fix pyarrow
FBruzzesi Nov 1, 2024
ce8fb9a
rm unused narwhals wrapper
FBruzzesi Nov 1, 2024
e47827e
comment about stable api
FBruzzesi Nov 1, 2024
9a9283a
update changelog
FBruzzesi Nov 1, 2024
2630a5a
fixup time zone handling
MarcoGorelli Nov 1, 2024
fef6dbe
modin and cudf
FBruzzesi Nov 3, 2024
48c7f62
defensive from_native call
FBruzzesi Nov 4, 2024
18cc11c
typo
FBruzzesi Nov 4, 2024
d94cbf7
fixup timezones
FBruzzesi Nov 4, 2024
c320c46
move from object to datetime dtype in _plotly_utils/test/validators
FBruzzesi Nov 4, 2024
afdb31f
simplify ecdfnorm
MarcoGorelli Nov 4, 2024
68ab52a
Merge pull request #4 from MarcoGorelli/ecdf-mode-perf
FBruzzesi Nov 5, 2024
b8ccec4
Merge branch 'master' into plotly-with-narwhals
FBruzzesi Nov 5, 2024
f102998
rm to_py_scalar call in for loop -> fix Pie performances
FBruzzesi Nov 5, 2024
2df0427
Merge branch 'plotly-with-narwhals' of https://github.com/FBruzzesi/p…
FBruzzesi Nov 5, 2024
55a0178
Merge branch 'master' into feat/dataframe-agnostic-data
FBruzzesi Nov 5, 2024
bb327d5
merge feat/dataframe-agnostic-data
FBruzzesi Nov 5, 2024
7d611fb
use return_type directly when building datasets
FBruzzesi Nov 5, 2024
a22a7be
stocks date to string and test_trendline_on_timeseries fix
FBruzzesi Nov 6, 2024
44a52e5
merge master and rm FIXME comment
FBruzzesi Nov 7, 2024
fc74b2e
do not repeat new_series unnecessarely
FBruzzesi Nov 8, 2024
499e2fa
bump version, use numpy for range
FBruzzesi Nov 8, 2024
d2e1008
trigger ci now that new version is published
FBruzzesi Nov 8, 2024
742b2ec
add narwhals to np2_optional.txt
FBruzzesi Nov 8, 2024
269dea6
version
FBruzzesi Nov 8, 2024
b1dc48d
Merge branch 'master' into plotly-with-narwhals
MarcoGorelli Nov 12, 2024
17fb96f
Merge branch 'master' into plotly-with-narwhals
FBruzzesi Nov 12, 2024
9f2c55b
Merge branch 'master' into plotly-with-narwhals
FBruzzesi Nov 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/python/plotly/optional-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ ipython

## pandas deps for some matplotlib functionality ##
pandas
narwhals>=1.9.2

## scipy deps for some FigureFactory functions ##
scipy
Expand Down
9 changes: 0 additions & 9 deletions packages/python/plotly/plotly/express/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,6 @@
`plotly.express` is a terse, consistent, high-level wrapper around `plotly.graph_objects`
for rapid data exploration and figure generation. Learn more at https://plotly.com/python/plotly-express/
"""
from plotly import optional_imports

pd = optional_imports.get_module("pandas")
if pd is None:
raise ImportError(
"""\
Plotly express requires pandas to be installed."""
)

from ._imshow import imshow
from ._chart_types import ( # noqa: F401
scatter,
Expand Down
884 changes: 574 additions & 310 deletions packages/python/plotly/plotly/express/_core.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion packages/python/plotly/plotly/express/_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@
zoom=["int (default `8`)", "Between 0 and 20.", "Sets map zoom level."],
orientation=[
"str, one of `'h'` for horizontal or `'v'` for vertical. ",
"(default `'v'` if `x` and `y` are provided and both continous or both categorical, ",
"(default `'v'` if `x` and `y` are provided and both continuous or both categorical, ",
"otherwise `'v'`(`'h'`) if `x`(`y`) is categorical and `y`(`x`) is continuous, ",
"otherwise `'v'`(`'h'`) if only `x`(`y`) is provided) ",
],
Expand Down
6 changes: 4 additions & 2 deletions packages/python/plotly/plotly/express/_imshow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from _plotly_utils.basevalidators import ColorscaleValidator
from ._core import apply_default_cascade, init_figure, configure_animation_controls
from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types
import pandas as pd
import narwhals.stable.v1 as nw
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
import itertools
from plotly.utils import image_array_to_data_uri
Expand Down Expand Up @@ -321,10 +321,12 @@ def imshow(
aspect = "equal"

# --- Set the value of binary_string (forbidden for pandas)
if isinstance(img, pd.DataFrame):
img = nw.from_native(img, strict=False)
if isinstance(img, nw.DataFrame):
if binary_string:
raise ValueError("Binary strings cannot be used with pandas arrays")
is_dataframe = True
img = img.to_numpy()
else:
is_dataframe = False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
exposed as part of the public API for documentation purposes.
"""

import pandas as pd
import numpy as np
import narwhals.stable.v1 as nw

__all__ = ["ols", "lowess", "rolling", "ewm", "expanding"]

Expand All @@ -32,6 +31,8 @@ def ols(trendline_options, x_raw, x, y, x_label, y_label, non_missing):
respect to the base 10 logarithm of the input. Note that this means no zeros can
be present in the input.
"""
import numpy as np

valid_options = ["add_constant", "log_x", "log_y"]
for k in trendline_options.keys():
if k not in valid_options:
Expand Down Expand Up @@ -110,11 +111,22 @@ def lowess(trendline_options, x_raw, x, y, x_label, y_label, non_missing):


def _pandas(mode, trendline_options, x_raw, y, non_missing):
import numpy as np

try:
import pandas as pd
except ImportError:
msg = "Trendline requires pandas to be installed"
raise ImportError(msg)

modes = dict(rolling="Rolling", ewm="Exponentially Weighted", expanding="Expanding")
trendline_options = trendline_options.copy()
function_name = trendline_options.pop("function", "mean")
function_args = trendline_options.pop("function_args", dict())
series = pd.Series(y, index=x_raw)

series = pd.Series(np.copy(y), index=x_raw.to_pandas())
Copy link
Contributor Author

@FBruzzesi FBruzzesi Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MarcoGorelli for old numpy versions, arrow arrays and polars series to_numpy seem to be raise

ValueError: buffer source array is read-only

when then it is used by pandas operations.

I wonder if we should can allow to_numpy(..., writable: bool) in narwhals to return a writable copy

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own curiosity why do we need Pandas still here? Does Narwhals provide a series equivalent?

Copy link
Contributor Author

@FBruzzesi FBruzzesi Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not support rolling, expanding nor ewm at the current stage (see next line TODO comment).

If and when that will be the case, then pandas can be completely optional.
Other way could be to implement these function in pure numpy, but I would rather keep it out of this PR even if that's an option. WDYT?

Copy link
Contributor

@MarcoGorelli MarcoGorelli Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's a way to implement rolling functions in numpy, especially with some of the extra options

We'll add them to Narwhals though, we'd talked about it last month but decided that at the time it wasn't the top priority - but it is on the roadmap

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. This approach makes sense to me -- I opened #4834 to keep track of this as a follow-up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should can allow to_numpy(..., writable: bool) in narwhals to return a writable copy

rather than an option, does it makes sense here to check y.flags['WRITEABLE'], and copy if it's false?


# TODO: If narwhals were to support rolling, ewm and expanding then we could go around these
agg = getattr(series, mode) # e.g. series.rolling
agg_obj = agg(**trendline_options) # e.g. series.rolling(**opts)
function = getattr(agg_obj, function_name) # e.g. series.rolling(**opts).mean
Expand Down
80 changes: 55 additions & 25 deletions packages/python/plotly/plotly/figure_factory/_hexbin_mapbox.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from plotly.express._core import build_dataframe
from plotly.express._doc import make_docstring
from plotly.express._chart_types import choropleth_mapbox, scatter_mapbox
import narwhals.stable.v1 as nw
import numpy as np
import pandas as pd


def _project_latlon_to_wgs84(lat, lon):
Expand Down Expand Up @@ -231,6 +231,7 @@ def _compute_wgs84_hexbin(
nx=None,
agg_func=None,
min_count=None,
native_namespace=None,
):
"""
Computes the lat-lon aggregation at hexagonal bin level.
Expand Down Expand Up @@ -263,7 +264,7 @@ def _compute_wgs84_hexbin(
Lat coordinates of each hexagon (shape M x 6)
np.ndarray
Lon coordinates of each hexagon (shape M x 6)
pd.Series
nw.Series
Unique id for each hexagon, to be used in the geojson data (shape M)
np.ndarray
Aggregated value in each hexagon (shape M)
Expand All @@ -288,7 +289,14 @@ def _compute_wgs84_hexbin(

# Create unique feature id based on hexagon center
centers = centers.astype(str)
hexagons_ids = pd.Series(centers[:, 0]) + "," + pd.Series(centers[:, 1])
hexagons_ids = (
nw.from_dict(
{"x1": centers[:, 0], "x2": centers[:, 1]},
native_namespace=native_namespace,
)
.select(hexagons_ids=nw.concat_str([nw.col("x1"), nw.col("x2")], separator=","))
.get_column("hexagons_ids")
)

return hexagons_lats, hexagons_lons, hexagons_ids, agreggated_value

Expand Down Expand Up @@ -344,22 +352,40 @@ def create_hexbin_mapbox(
Returns a figure aggregating scattered points into connected hexagons
"""
args = build_dataframe(args=locals(), constructor=None)

native_namespace = nw.get_native_namespace(args["data_frame"])
if agg_func is None:
agg_func = np.mean

lat_range = args["data_frame"][args["lat"]].agg(["min", "max"]).values
lon_range = args["data_frame"][args["lon"]].agg(["min", "max"]).values
lat_range = (
args["data_frame"]
.select(
nw.min(args["lat"]).name.suffix("_min"),
nw.max(args["lat"]).name.suffix("_max"),
)
.to_numpy()
.squeeze()
)

lon_range = (
args["data_frame"]
.select(
nw.min(args["lon"]).name.suffix("_min"),
nw.max(args["lon"]).name.suffix("_max"),
)
.to_numpy()
.squeeze()
)

hexagons_lats, hexagons_lons, hexagons_ids, count = _compute_wgs84_hexbin(
lat=args["data_frame"][args["lat"]].values,
lon=args["data_frame"][args["lon"]].values,
lat=args["data_frame"].get_column(args["lat"]).to_numpy(),
lon=args["data_frame"].get_column(args["lon"]).to_numpy(),
lat_range=lat_range,
lon_range=lon_range,
color=None,
nx=nx_hexagon,
agg_func=agg_func,
min_count=min_count,
native_namespace=native_namespace,
)

geojson = _hexagons_to_geojson(hexagons_lats, hexagons_lons, hexagons_ids)
Expand All @@ -381,41 +407,43 @@ def create_hexbin_mapbox(
center = dict(lat=lat_range.mean(), lon=lon_range.mean())

if args["animation_frame"] is not None:
groups = args["data_frame"].groupby(args["animation_frame"]).groups
groups = dict(args["data_frame"].group_by(args["animation_frame"]).__iter__())
else:
groups = {0: args["data_frame"].index}
groups = {(0,): args["data_frame"]}

agg_data_frame_list = []
for frame, index in groups.items():
df = args["data_frame"].loc[index]
for key, df in groups.items():
_, _, hexagons_ids, aggregated_value = _compute_wgs84_hexbin(
lat=df[args["lat"]].values,
lon=df[args["lon"]].values,
lat=df.get_column(args["lat"]).to_numpy(),
lon=df.get_column(args["lon"]).to_numpy(),
lat_range=lat_range,
lon_range=lon_range,
color=df[args["color"]].values if args["color"] else None,
color=df.get_column(args["color"]).to_numpy() if args["color"] else None,
nx=nx_hexagon,
agg_func=agg_func,
min_count=min_count,
native_namespace=native_namespace,
)
agg_data_frame_list.append(
pd.DataFrame(
np.c_[hexagons_ids, aggregated_value], columns=["locations", "color"]
nw.from_dict(
{
"frame": [key[0]] * len(hexagons_ids),
"locations": hexagons_ids,
"color": aggregated_value,
},
native_namespace=native_namespace,
)
)
agg_data_frame = (
pd.concat(agg_data_frame_list, axis=0, keys=groups.keys())
.rename_axis(index=("frame", "index"))
.reset_index("frame")
)

agg_data_frame["color"] = pd.to_numeric(agg_data_frame["color"])
agg_data_frame = nw.concat(agg_data_frame_list, how="vertical").with_columns(
color=nw.col("color").cast(nw.Int64)
)

if range_color is None:
range_color = [agg_data_frame["color"].min(), agg_data_frame["color"].max()]

fig = choropleth_mapbox(
data_frame=agg_data_frame,
data_frame=agg_data_frame.to_native(),
geojson=geojson,
locations="locations",
color="color",
Expand All @@ -440,7 +468,9 @@ def create_hexbin_mapbox(
if show_original_data:
original_fig = scatter_mapbox(
data_frame=(
args["data_frame"].sort_values(by=args["animation_frame"])
args["data_frame"].sort(
by=args["animation_frame"], descending=False, nulls_last=True
)
if args["animation_frame"] is not None
else args["data_frame"]
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4474,7 +4474,7 @@ def test_build_dataframe(self):
lon = np.random.randn(N)
color = np.ones(N)
frame = np.random.randint(0, n_frames, N)
df = pd.DataFrame(
df = pd.DataFrame( # TODO: Test other constructors?
np.c_[lat, lon, color, frame],
columns=["Latitude", "Longitude", "Metric", "Frame"],
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pandas as pd
import polars as pl
import pyarrow as pa
import pytest

from narwhals.typing import IntoDataFrame
from narwhals.utils import parse_version


def pandas_constructor(obj) -> IntoDataFrame:
return pd.DataFrame(obj) # type: ignore[no-any-return]


def pandas_nullable_constructor(obj) -> IntoDataFrame:
return pd.DataFrame(obj).convert_dtypes(dtype_backend="numpy_nullable") # type: ignore[no-any-return]


def pandas_pyarrow_constructor(obj) -> IntoDataFrame:
return pd.DataFrame(obj).convert_dtypes(dtype_backend="pyarrow") # type: ignore[no-any-return]


def polars_eager_constructor(obj) -> IntoDataFrame:
return pl.DataFrame(obj)


def pyarrow_table_constructor(obj) -> IntoDataFrame:
return pa.table(obj) # type: ignore[no-any-return]


constructors = [polars_eager_constructor, pyarrow_table_constructor, pandas_constructor]

if parse_version(pd.__version__) >= parse_version("2.0.0"):
constructors = [
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
pandas_nullable_constructor,
pandas_pyarrow_constructor,
]


@pytest.fixture(params=constructors)
def constructor(request: pytest.FixtureRequest):
return request.param # type: ignore[no-any-return]
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import pandas as pd
import plotly.express as px
from pytest import approx
import pytest
import random


def test_facets():
df = px.data.tips()
def test_facets(constructor):
data = px.data.tips().to_dict(orient="list")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the constructor functions were modified to accept either a Pandas df or a dict? Then if a Pandas df is passed it could call .to_dict() on it. Would save a bit of boilerplate in the tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it makes sense. However let me mention that after our meeting today I started working to allow data(set) module to be dataframe agnostic. Different PR should be ready tomorrow and it may be a quick win to simplify data loading in tests as well. I will ping you there

df = constructor(data)

fig = px.scatter(df, x="total_bill", y="tip")
assert "xaxis2" not in fig.layout
assert "yaxis2" not in fig.layout
Expand Down Expand Up @@ -46,8 +47,9 @@ def test_facets():
assert fig.layout.yaxis4.domain[0] - fig.layout.yaxis.domain[1] == approx(0.08)


def test_facets_with_marginals():
df = px.data.tips()
def test_facets_with_marginals(constructor):
data = px.data.tips().to_dict(orient="list")
df = constructor(data)

fig = px.histogram(df, x="total_bill", facet_col="sex", marginal="rug")
assert len(fig.data) == 4
Expand Down Expand Up @@ -93,12 +95,11 @@ def test_facets_with_marginals():
assert len(fig.data) == 2 # ignore all marginals


@pytest.fixture
def bad_facet_spacing_df():
def bad_facet_spacing_df(constructor_func):
NROWS = 101
NDATA = 1000
categories = [n % NROWS for n in range(NDATA)]
df = pd.DataFrame(
df = constructor_func(
{
"x": [random.random() for _ in range(NDATA)],
"y": [random.random() for _ in range(NDATA)],
Expand All @@ -108,8 +109,8 @@ def bad_facet_spacing_df():
return df


def test_bad_facet_spacing_eror(bad_facet_spacing_df):
df = bad_facet_spacing_df
def test_bad_facet_spacing_error(constructor):
df = bad_facet_spacing_df(constructor_func=constructor)
with pytest.raises(
ValueError, match="Use the facet_row_spacing argument to adjust this spacing."
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
@pytest.mark.parametrize("px_fn", [px.scatter, px.density_heatmap, px.density_contour])
@pytest.mark.parametrize("marginal_x", [None, "histogram", "box", "violin"])
@pytest.mark.parametrize("marginal_y", [None, "rug"])
def test_xy_marginals(px_fn, marginal_x, marginal_y):
df = px.data.tips()
def test_xy_marginals(constructor, px_fn, marginal_x, marginal_y):
data = px.data.tips().to_dict(orient="list")
df = constructor(data)

fig = px_fn(
df, x="total_bill", y="tip", marginal_x=marginal_x, marginal_y=marginal_y
Expand All @@ -17,8 +18,9 @@ def test_xy_marginals(px_fn, marginal_x, marginal_y):
@pytest.mark.parametrize("px_fn", [px.histogram, px.ecdf])
@pytest.mark.parametrize("marginal", [None, "rug", "histogram", "box", "violin"])
@pytest.mark.parametrize("orientation", ["h", "v"])
def test_single_marginals(px_fn, marginal, orientation):
df = px.data.tips()
def test_single_marginals(constructor, px_fn, marginal, orientation):
data = px.data.tips().to_dict(orient="list")
df = constructor(data)

fig = px_fn(
df, x="total_bill", y="total_bill", marginal=marginal, orientation=orientation
Expand Down
Loading