Skip to content

Commit

Permalink
explict on dtypes on dtype on several location
Browse files Browse the repository at this point in the history
  • Loading branch information
isVoid committed Feb 7, 2024
1 parent 448eb18 commit 558ce57
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 17 deletions.
24 changes: 17 additions & 7 deletions python/cuspatial/cuspatial/core/geoseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,8 @@ def from_points_xy(cls, points_xy):
GeoSeries:
A GeoSeries made of the points.
"""
coords_dtype = "f8" if len(points_xy) == 0 else None
coords_dtype = _check_coords_dtype(points_xy)

return cls(
GeoColumn._from_points_xy(as_column(points_xy, dtype=coords_dtype))
)
Expand All @@ -710,8 +711,9 @@ def from_multipoints_xy(cls, multipoints_xy, geometry_offset):
Parameters
----------
points_xy: array-like
Coordinates of the points, interpreted as interleaved x-y coords.
multipoints_xy: array-like
Coordinates of the multipoints, interpreted as interleaved x-y
coords.
geometry_offset: array-like
Offsets indicating the starting index of the multipoint. Multiply
the index by 2 results in the starting index of the coordinate.
Expand All @@ -732,7 +734,7 @@ def from_multipoints_xy(cls, multipoints_xy, geometry_offset):
1 MULTIPOINT (2.00000 2.00000, 3.00000 3.00000)
dtype: geometry
"""
coords_dtype = "f8" if len(multipoints_xy) == 0 else None
coords_dtype = coords_dtype = _check_coords_dtype(multipoints_xy)
return cls(
GeoColumn._from_multipoints_xy(
as_column(multipoints_xy, dtype=coords_dtype),
Expand All @@ -750,7 +752,8 @@ def from_linestrings_xy(
Parameters
----------
linestrings_xy : array-like
Coordinates of the points, interpreted as interleaved x-y coords.
Coordinates of the linestring, interpreted as interleaved x-y
coords.
geometry_offset : array-like
Offsets of the first coordinate of each geometry. The length of
this array is the number of geometries. Offsets with a difference
Expand All @@ -777,7 +780,7 @@ def from_linestrings_xy(
0 LINESTRING (0 0, 1 1, 2 2, 3 3, 4 4, 5 5)
dtype: geometry
"""
coords_dtype = "f8" if len(linestrings_xy) == 0 else None
coords_dtype = _check_coords_dtype(linestrings_xy)
return cls(
GeoColumn._from_linestrings_xy(
as_column(linestrings_xy, dtype=coords_dtype),
Expand Down Expand Up @@ -827,7 +830,7 @@ def from_polygons_xy(
0 POLYGON (0 0, 1 1, 2 2, 3 3, 4 4, 5 5)
dtype: geometry
"""
coords_dtype = "f8" if len(polygons_xy) == 0 else None
coords_dtype = _check_coords_dtype(polygons_xy)
return cls(
GeoColumn._from_polygons_xy(
as_column(polygons_xy, dtype=coords_dtype),
Expand Down Expand Up @@ -1488,3 +1491,10 @@ def distance(self, other, align=True):
if other_is_scalar:
res.index = self.index
return res


def _check_coords_dtype(coords):
if hasattr(coords, "dtype"):
return coords.dtype
else:
return "f8" if len(coords) == 0 else None
7 changes: 4 additions & 3 deletions python/cuspatial/cuspatial/core/spatial/nearest_points.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import cupy as cp

import cudf
from cudf.core.column import as_column

import cuspatial._lib.nearest_points as nearest_points
Expand Down Expand Up @@ -51,9 +52,9 @@ def pairwise_point_linestring_nearest_points(

if len(points) == 0:
data = {
"point_geometry_id": [],
"linestring_geometry_id": [],
"segment_id": [],
"point_geometry_id": cudf.Series([], dtype="i4"),
"linestring_geometry_id": cudf.Series([], dtype="i4"),
"segment_id": cudf.Series([], dtype="i4"),
"geometry": GeoSeries([]),
}
return GeoDataFrame._from_data(data)
Expand Down
10 changes: 10 additions & 0 deletions python/cuspatial/cuspatial/core/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,16 @@ def trajectory_bounding_boxes(num_trajectories, object_ids, points: GeoSeries):
1 1.0 1.0 3.0 3.0
"""

if len(points) == 0:
return DataFrame(
{
"x_min": Series([], dtype=points.points.x.dtype),
"y_min": Series([], dtype=points.points.x.dtype),
"x_max": Series([], dtype=points.points.x.dtype),
"y_max": Series([], dtype=points.points.x.dtype),
}
)

if len(points) > 0 and not contains_only_points(points):
raise ValueError("`points` must only contain point geometries.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ def test_linestring_bounding_boxes_empty():
result,
cudf.DataFrame(
{
"minx": cudf.Series([]),
"miny": cudf.Series([]),
"maxx": cudf.Series([]),
"maxy": cudf.Series([]),
"minx": cudf.Series([], dtype=np.float64),
"miny": cudf.Series([], dtype=np.float64),
"maxx": cudf.Series([], dtype=np.float64),
"maxy": cudf.Series([], dtype=np.float64),
}
),
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import geopandas as gpd
import pandas as pd
import pytest
import shapely
from geopandas.testing import assert_geodataframe_equal
Expand Down Expand Up @@ -73,9 +74,9 @@ def test_empty_input():
)
expected = gpd.GeoDataFrame(
{
"point_geometry_id": [],
"linestring_geometry_id": [],
"segment_id": [],
"point_geometry_id": pd.Series([], dtype="i4"),
"linestring_geometry_id": pd.Series([], dtype="i4"),
"segment_id": pd.Series([], dtype="i4"),
"geometry": gpd.GeoSeries(),
}
)
Expand Down

0 comments on commit 558ce57

Please sign in to comment.