Skip to content

Commit

Permalink
Merge pull request #4790 from FBruzzesi/plotly-with-narwhals
Browse files Browse the repository at this point in the history
feat: make plotly-express dataframe agnostic via narwhals
  • Loading branch information
emilykl authored Nov 13, 2024
2 parents ffb571b + 9f2c55b commit 5898816
Show file tree
Hide file tree
Showing 38 changed files with 1,689 additions and 834 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ This project adheres to [Semantic Versioning](http://semver.org/).

### Updated

- Updated plotly.py to use base64 encoding of arrays in plotly JSON to improve performance.
- Updated plotly.py to use base64 encoding of arrays in plotly JSON to improve performance.
- Add `subtitle` attribute to all Plotly Express traces
- Make plotly-express dataframe agnostic via Narwhals [#4790](https://github.com/plotly/plotly.py/pull/4790)

## [5.24.1] - 2024-09-12

Expand Down
52 changes: 22 additions & 30 deletions packages/python/plotly/_plotly_utils/basevalidators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import re
import sys
import warnings
import narwhals.stable.v1 as nw

from _plotly_utils.optional_imports import get_module

Expand Down Expand Up @@ -72,8 +73,6 @@ def copy_to_readonly_numpy_array(v, kind=None, force_numeric=False):
"""
np = get_module("numpy")

# Don't force pandas to be loaded, we only want to know if it's already loaded
pd = get_module("pandas", should_load=False)
assert np is not None

# ### Process kind ###
Expand All @@ -93,34 +92,26 @@ def copy_to_readonly_numpy_array(v, kind=None, force_numeric=False):
"O": "object",
}

# Handle pandas Series and Index objects
if pd and isinstance(v, (pd.Series, pd.Index)):
if v.dtype.kind in numeric_kinds:
# Get the numeric numpy array so we use fast path below
v = v.values
elif v.dtype.kind == "M":
# Convert datetime Series/Index to numpy array of datetimes
if isinstance(v, pd.Series):
with warnings.catch_warnings():
warnings.simplefilter("ignore", FutureWarning)
# Series.dt.to_pydatetime will return Index[object]
# https://github.com/pandas-dev/pandas/pull/52459
v = np.array(v.dt.to_pydatetime())
else:
# DatetimeIndex
v = v.to_pydatetime()
elif pd and isinstance(v, pd.DataFrame) and len(set(v.dtypes)) == 1:
dtype = v.dtypes.tolist()[0]
if dtype.kind in numeric_kinds:
v = v.values
elif dtype.kind == "M":
with warnings.catch_warnings():
warnings.simplefilter("ignore", FutureWarning)
# Series.dt.to_pydatetime will return Index[object]
# https://github.com/pandas-dev/pandas/pull/52459
v = [
np.array(row.dt.to_pydatetime()).tolist() for i, row in v.iterrows()
]
# With `pass_through=True`, the original object will be returned if unable to convert
# to a Narwhals DataFrame or Series.
v = nw.from_native(v, allow_series=True, pass_through=True)

if isinstance(v, nw.Series):
if v.dtype == nw.Datetime and v.dtype.time_zone is not None:
# Remove time zone so that local time is displayed
v = v.dt.replace_time_zone(None).to_numpy()
else:
v = v.to_numpy()
elif isinstance(v, nw.DataFrame):
schema = v.schema
overrides = {}
for key, val in schema.items():
if val == nw.Datetime and val.time_zone is not None:
# Remove time zone so that local time is displayed
overrides[key] = nw.col(key).dt.replace_time_zone(None)
if overrides:
v = v.with_columns(**overrides)
v = v.to_numpy()

if not isinstance(v, np.ndarray):
# v has its own logic on how to convert itself into a numpy array
Expand Down Expand Up @@ -193,6 +184,7 @@ def is_homogeneous_array(v):
np
and isinstance(v, np.ndarray)
or (pd and isinstance(v, (pd.Series, pd.Index)))
or (isinstance(v, nw.Series))
):
return True
if is_numpy_convertable(v):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,13 @@ def color_categorical_pandas(request, pandas_type):
def dates_array(request):
return np.array(
[
datetime(year=2013, month=10, day=10),
datetime(year=2013, month=11, day=10),
datetime(year=2013, month=12, day=10),
datetime(year=2014, month=1, day=10),
datetime(year=2014, month=2, day=10),
]
"2013-10-10",
"2013-11-10",
"2013-12-10",
"2014-01-10",
"2014-02-10",
],
dtype="datetime64[ns]",
)


Expand Down Expand Up @@ -183,7 +184,7 @@ def test_data_array_validator_dates_series(
assert isinstance(res, np.ndarray)

# Check dtype
assert res.dtype == "object"
assert res.dtype == "<M8[ns]"

# Check values
np.testing.assert_array_equal(res, dates_array)
Expand All @@ -200,7 +201,7 @@ def test_data_array_validator_dates_dataframe(
assert isinstance(res, np.ndarray)

# Check dtype
assert res.dtype == "object"
assert res.dtype == "<M8[ns]"

# Check values
np.testing.assert_array_equal(res, dates_array.reshape(len(dates_array), 1))
1 change: 1 addition & 0 deletions packages/python/plotly/optional-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ ipython

## pandas deps for some matplotlib functionality ##
pandas
narwhals>=1.13.3

## scipy deps for some FigureFactory functions ##
scipy
Expand Down
Loading

0 comments on commit 5898816

Please sign in to comment.