Skip to content

Commit

Permalink
Add CUDA device-callable projection API (#1490)
Browse files Browse the repository at this point in the history
Closes #1489

This PR refactors device `cuspatial::detail::pipeline` into the public API via a type alias `cuspatial::device_projection` which can be passed to a CUDA kernel and invoked to transform coordinates.  `cuspatial::projection::get_device_projection(direction)` can be used to get a `device_projection`.

This required changing the direction parameter for `cuspatial::detail::pipeline` to a constructor parameter rather than a template parameter. I benchmarked before and after this change and saw no significant difference.

I have added tests and an example in `README.txt`.

Authors:
  - Mark Harris (https://github.com/harrism)

Approvers:
  - Michael Wang (https://github.com/isVoid)
  - Paul Taylor (https://github.com/trxcllnt)

URL: #1490
  • Loading branch information
harrism authored Dec 2, 2024
1 parent ee70be4 commit ab07122
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 150 deletions.
51 changes: 51 additions & 0 deletions cpp/cuproj/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ Sydney, Australia from WGS84 (lat, lon) coordinates to UTM zone 56S (x, y) coord
#include <cuproj/projection_factories.cuh>
#include <cuproj/vec_2d.hpp>

using T = float;

// Make a projection to convert WGS84 (lat, lon) coordinates to UTM zone 56S (x, y) coordinates
auto proj = cuproj::make_projection<cuproj::vec_2d<T>>("EPSG:4326", "EPSG:32756");

Expand All @@ -36,3 +38,52 @@ thrust::device_vector<cuproj::vec_2d<T>> d_out(d_in.size());
// Convert the coordinates. Works the same with a vector of many coordinates.
proj.transform(d_in.begin(), d_in.end(), d_out.begin(), cuproj::direction::FORWARD);
```
### Projections in CUDA device code
The C++ API also supports transforming coordinate in CUDA device code. Create a
`projection` as above, then get a `device_projection` object from it, which can
be passed to a kernel launch. Here's an example kernel.
```cpp
template <typename T>
using coordinate = typename cuproj::vec_2d<T>;
template <typename T>
using device_projection = cuproj::device_projection<coordinate<T>>;
__global__
void example_kernel(device_projection const d_proj,
cuproj::vec_2d<float> const* in,
cuproj::vec_2d<float>* out,
size_t n)
{
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x;
i < n;
i += gridDim.x * blockDim.x) {
out[i] = d_proj.transform(in[i]);
}
}
```

The corresponding host code:

```cpp
using coordinate = cuproj::vec_2d<float>;

// Make a projection to convert WGS84 (lat, lon) coordinates to
// UTM zone 56S (x, y) coordinates
auto proj = cuproj::make_projection<coordinate>("EPSG:4326", "EPSG:32756");

// Sydney, NSW, Australia
coordinate sydney{-33.858700, 151.214000};
thrust::device_vector<coordinate> d_in{1, sydney};
thrust::device_vector<coordinate> d_out(d_in.size());

auto d_proj = proj->get_device_projection(cuproj::direction::FORWARD);
std::size_t block_size = 256;
std::size_t grid_size = (d_in.size() + block_size - 1) / block_size;
example_kernel<<<grid_size, block_size>>>(
d_proj, d_in.data().get(), d_out.data().get(), d_in.size());
cudaDeviceSynchronize();
```
109 changes: 60 additions & 49 deletions cpp/cuproj/include/cuproj/detail/pipeline.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,18 +16,18 @@

#pragma once

#include <cuproj/assert.cuh>
#include <cuproj/operation/axis_swap.cuh>
#include <cuproj/operation/clamp_angular_coordinates.cuh>
#include <cuproj/operation/degrees_to_radians.cuh>
#include <cuproj/operation/offset_scale_cartesian_coordinates.cuh>
#include <cuproj/operation/operation.cuh>
#include <cuproj/operation/transverse_mercator.cuh>

#include <cuda/std/iterator>
#include <thrust/execution_policy.h>
#include <thrust/for_each.h>

#include <iterator>

namespace cuproj {
namespace detail {

Expand All @@ -39,15 +39,9 @@ namespace detail {
* @tparam dir The direction of the pipeline, FORWARD or INVERSE
* @tparam T the coordinate value type
*/
template <typename Coordinate,
direction dir = direction::FORWARD,
typename T = typename Coordinate::value_type>
template <typename Coordinate, typename T = typename Coordinate::value_type>
class pipeline {
public:
using iterator_type = std::conditional_t<dir == direction::FORWARD,
operation_type const*,
std::reverse_iterator<operation_type const*>>;

/**
* @brief Construct a new pipeline object with the given operations and parameters
*
Expand All @@ -57,62 +51,79 @@ class pipeline {
*/
pipeline(projection_parameters<T> const& params,
operation_type const* ops,
std::size_t num_stages)
: params_(params), d_ops(ops), num_stages(num_stages)
std::size_t num_stages,
direction dir = direction::FORWARD)
: params_(params), d_ops(ops), num_stages(num_stages), dir_(dir)
{
if constexpr (dir == direction::FORWARD) {
first_ = d_ops;
} else {
first_ = std::reverse_iterator(d_ops + num_stages);
}
}

/**
* @brief Apply the pipeline to the given coordinate
* @brief Transform a coordinate using the pipeline
*
* @param c The coordinate to transform
* @return The transformed coordinate
*/
__device__ Coordinate operator()(Coordinate const& c) const
inline __device__ Coordinate operator()(Coordinate const& c) const
{
Coordinate c_out{c};
thrust::for_each_n(thrust::seq, first_, num_stages, [&](auto const& op) {
switch (op) {
case operation_type::AXIS_SWAP: {
auto op = axis_swap<Coordinate>{};
c_out = op(c_out, dir);
break;
}
case operation_type::DEGREES_TO_RADIANS: {
auto op = degrees_to_radians<Coordinate>{};
c_out = op(c_out, dir);
break;
}
case operation_type::CLAMP_ANGULAR_COORDINATES: {
auto op = clamp_angular_coordinates<Coordinate>{params_};
c_out = op(c_out, dir);
break;
}
case operation_type::OFFSET_SCALE_CARTESIAN_COORDINATES: {
auto op = offset_scale_cartesian_coordinates<Coordinate>{params_};
c_out = op(c_out, dir);
break;
}
case operation_type::TRANSVERSE_MERCATOR: {
auto op = transverse_mercator<Coordinate>{params_};
c_out = op(c_out, dir);
break;
}
}
});
// depending on direction, get a forward or reverse iterator to d_ops
if (dir_ == direction::FORWARD) {
auto first = d_ops;
thrust::for_each_n(
thrust::seq, first, num_stages, [&](auto const& op) { c_out = dispatch_op(c_out, op); });
} else {
auto first = cuda::std::reverse_iterator(d_ops + num_stages);
thrust::for_each_n(
thrust::seq, first, num_stages, [&](auto const& op) { c_out = dispatch_op(c_out, op); });
}
return c_out;
}

/**
* @brief Transform a coordinate using the pipeline
*
* @note this is an alias for operator() to allow for a more natural syntax
*
* @param c The coordinate to transform
* @return The transformed coordinate
*/
inline __device__ Coordinate transform(Coordinate const& c) const { return operator()(c); }

private:
projection_parameters<T> params_;
operation_type const* d_ops;
iterator_type first_;
std::size_t num_stages;
direction dir_;

inline __device__ Coordinate dispatch_op(Coordinate const& c, operation_type const& op) const
{
switch (op) {
case operation_type::AXIS_SWAP: {
auto op = axis_swap<Coordinate>{};
return op(c, dir_);
}
case operation_type::DEGREES_TO_RADIANS: {
auto op = degrees_to_radians<Coordinate>{};
return op(c, dir_);
}
case operation_type::CLAMP_ANGULAR_COORDINATES: {
auto op = clamp_angular_coordinates<Coordinate>{params_};
return op(c, dir_);
}
case operation_type::OFFSET_SCALE_CARTESIAN_COORDINATES: {
auto op = offset_scale_cartesian_coordinates<Coordinate>{params_};
return op(c, dir_);
}
case operation_type::TRANSVERSE_MERCATOR: {
auto op = transverse_mercator<Coordinate>{params_};
return op(c, dir_);
}
default: {
cuproj_assert("Invalid operation type");
return c;
}
}
}
};

} // namespace detail
Expand Down
49 changes: 0 additions & 49 deletions cpp/cuproj/include/cuproj/detail/wrap_to_pi.cuh

This file was deleted.

21 changes: 14 additions & 7 deletions cpp/cuproj/include/cuproj/operation/clamp_angular_coordinates.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,8 +24,6 @@

#include <thrust/iterator/transform_iterator.h>

#include <algorithm>

namespace cuproj {

/**
Expand Down Expand Up @@ -62,7 +60,8 @@ class clamp_angular_coordinates : operation<Coordinate> {
* @param dir The direction of the operation
* @return The clamped coordinate
*/
CUPROJ_HOST_DEVICE Coordinate operator()(Coordinate const& coord, direction dir) const
[[nodiscard]] CUPROJ_HOST_DEVICE Coordinate operator()(Coordinate const& coord,
direction dir) const
{
if (dir == direction::FORWARD)
return forward(coord);
Expand All @@ -81,7 +80,7 @@ class clamp_angular_coordinates : operation<Coordinate> {
* @param coord The coordinate to clamp
* @return The clamped coordinate
*/
CUPROJ_HOST_DEVICE Coordinate forward(Coordinate const& coord) const
[[nodiscard]] CUPROJ_HOST_DEVICE Coordinate forward(Coordinate const& coord) const
{
// check for latitude or longitude over-range
T t = (coord.y < 0 ? -coord.y : coord.y) - M_PI_2;
Expand All @@ -92,7 +91,7 @@ class clamp_angular_coordinates : operation<Coordinate> {

/* Clamp latitude to -pi/2..pi/2 degree range */
auto half_pi = static_cast<T>(M_PI_2);
xy.y = std::clamp(xy.y, -half_pi, half_pi);
xy.y = clamp(xy.y, -half_pi, half_pi);

// Distance from central meridian, taking system zero meridian into account
xy.x = (xy.x - prime_meridian_offset_) - lam0_;
Expand All @@ -112,7 +111,7 @@ class clamp_angular_coordinates : operation<Coordinate> {
* @param coord The coordinate to clamp
* @return The clamped coordinate
*/
CUPROJ_HOST_DEVICE Coordinate inverse(Coordinate const& coord) const
[[nodiscard]] inline CUPROJ_HOST_DEVICE Coordinate inverse(Coordinate const& coord) const
{
Coordinate xy = coord;

Expand All @@ -125,6 +124,14 @@ class clamp_angular_coordinates : operation<Coordinate> {
return xy;
}

[[nodiscard]] inline CUPROJ_HOST_DEVICE const T& clamp(const T& val,
const T& low,
const T& high) const
{
CUPROJ_HOST_DEVICE_EXPECTS(!(low < high), "Invalid clamp range");
return val < low ? low : (high < val) ? high : val;
}

T lam0_{}; // central meridian
T prime_meridian_offset_{};
};
Expand Down
Loading

0 comments on commit ab07122

Please sign in to comment.