Skip to content

Commit

Permalink
Add converting to ndarray (#16)
Browse files Browse the repository at this point in the history
* dont run dev actions on main

* initial ndarray implementation

* fix index feature definitions

* add ndarray and fix gh actions

* make ndarray default

* add reading ndarray with coordinates

* extend ndarray test
  • Loading branch information
Quba1 authored Feb 5, 2024
1 parent 7eb1fe0 commit a2c8f08
Show file tree
Hide file tree
Showing 10 changed files with 217 additions and 19 deletions.
20 changes: 13 additions & 7 deletions .github/workflows/rust-dev.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
name: cargodev

on: [push, pull_request]
on:
push:
branches-ignore:
- main
pull_request:
branches:
- "**"

env:
CARGO_TERM_COLOR: always
Expand All @@ -23,13 +29,13 @@ jobs:
cargo clean
- name: Test with cargo
run: |
RUST_BACKTRACE=full cargo test --features "experimental_index"
RUST_BACKTRACE=full cargo test --features "experimental_index, message_ndarray" -- --include-ignored
- name: Check with clippy
run: |
cargo clippy --features "experimental_index"
cargo clippy --features "experimental_index, message_ndarray"
- name: Build release
run: |
cargo test --features "experimental_index" -- --include-ignored
cargo build --release --features "experimental_index, message_ndarray"
build-macos:

Expand All @@ -45,10 +51,10 @@ jobs:
cargo clean
- name: Test with cargo
run: |
RUST_BACKTRACE=full cargo test --features "experimental_index"
RUST_BACKTRACE=full cargo test --features "experimental_index, message_ndarray" -- --include-ignored
- name: Check with clippy
run: |
cargo clippy --features "experimental_index"
cargo clippy --features "experimental_index, message_ndarray"
- name: Build release
run: |
cargo test --features "experimental_index" -- --include-ignored
cargo build --release --features "experimental_index, message_ndarray"
8 changes: 4 additions & 4 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ jobs:
cargo clean
- name: Build with cargo
run: |
cargo build --release --features "experimental_index"
cargo build --release --features "experimental_index, message_ndarray"
cargo clean
- name: Test with cargo
run: |
cargo test --features "experimental_index"
cargo test --features "experimental_index, message_ndarray"
cargo clean
- name: Benchmark with criterion
run: |
Expand All @@ -53,11 +53,11 @@ jobs:
cargo clean
- name: Build with cargo
run: |
cargo build --release --features "experimental_index"
cargo build --release --features "experimental_index, message_ndarray"
cargo clean
- name: Test with cargo
run: |
cargo test --features "experimental_index"
cargo test --features "experimental_index, message_ndarray"
cargo clean
- name: Benchmark with criterion
run: |
Expand Down
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ num-derive = "0.4.1"
num-traits = "0.2"
fallible-iterator = "0.3"
fallible-streaming-iterator = "0.1.9"
ndarray = { version = "0.15", default-features = false, optional = true, features = ["std"]}

[dev-dependencies]
reqwest = { version = "0.11", features = ["rustls-tls"] }
Expand All @@ -36,10 +37,13 @@ criterion = "0.5"
testing_logger = "0.1"
rand = "0.8"
anyhow = "1.0"
float-cmp = "0.9"

[features]
default = ["message_ndarray"]
docs = ["eccodes-sys/docs"]
experimental_index = []
message_ndarray = ["ndarray"]

[package.metadata.docs.rs]
features = ["docs", "experimental_index"]
Expand Down
Binary file added data/iceland-surface.grib.923a8.idx
Binary file not shown.
25 changes: 25 additions & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ pub enum CodesError {
#[error("ecCodes function returned a non-zero code {0}")]
Internal(#[from] CodesInternal),

#[cfg(feature = "message_ndarray")]
/// Returned when function in `message_ndarray` module cannot convert
/// the message to ndarray. Check [`MessageNdarrayError`] for more details.
#[error("error occured while converting KeyedMessage to ndarray {0}")]
NdarrayConvert(#[from] MessageNdarrayError),

///Returned when one of libc functions returns a non-zero error code.
///Check libc documentation for details of the errors.
///For libc reference check these websites: ([1](https://man7.org/linux/man-pages/index.html))
Expand Down Expand Up @@ -67,6 +73,25 @@ pub enum CodesError {
NullPtr,
}

#[cfg(feature = "message_ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "message_ndarray")))]
#[derive(Error, Debug)]
/// Errors returned by the `message_ndarray` module.
pub enum MessageNdarrayError {
/// Returned when functions converting to ndarray cannot correctly
/// read key necessary for the conversion.
#[error("Requested key {0} has a different type than expected")]
UnexpectedKeyType(String),

/// Returned when length of values array is not equal to
/// product of Ni and Nj keys.
#[error("The length of the values array ({0}) is different than expected ({1})")]
UnexpectedValuesLength(usize, i64),

#[error("Error occured while converting to ndarray: {0}")]
InvalidShape(#[from] ndarray::ShapeError),
}

#[derive(Copy, Eq, PartialEq, Clone, Ord, PartialOrd, Hash, Error, Debug, FromPrimitive)]
///Errors returned by internal ecCodes library functions.
///Copied directly from the ecCodes API.
Expand Down
5 changes: 4 additions & 1 deletion src/intermediate_bindings/codes_handle.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::ptr::{self};

use eccodes_sys::{codes_context, codes_handle, codes_index, CODES_LOCK};
use eccodes_sys::{codes_context, codes_handle};
#[cfg(feature = "experimental_index")]
use eccodes_sys::{codes_index, CODES_LOCK};
use libc::FILE;
use num_traits::FromPrimitive;

Expand Down Expand Up @@ -56,6 +58,7 @@ pub unsafe fn codes_handle_delete(handle: *mut codes_handle) -> Result<(), Codes
Ok(())
}

#[cfg(feature = "experimental_index")]
pub unsafe fn codes_handle_new_from_index(
index: *mut codes_index,
) -> Result<*mut codes_handle, CodesError> {
Expand Down
8 changes: 4 additions & 4 deletions src/intermediate_bindings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ pub use codes_get::{
codes_get_long_array, codes_get_message, codes_get_native_type, codes_get_size,
codes_get_string,
};
pub use codes_handle::{
codes_handle_clone, codes_handle_delete, codes_handle_new_from_file,
codes_handle_new_from_index,
};
#[cfg(feature = "experimental_index")]
pub use codes_handle::codes_handle_new_from_index;
pub use codes_handle::{codes_handle_clone, codes_handle_delete, codes_handle_new_from_file};
#[cfg(feature = "experimental_index")]
pub use codes_index::{
codes_index_add_file, codes_index_delete, codes_index_new, codes_index_read,
codes_index_select_double, codes_index_select_long, codes_index_select_string,
Expand Down
6 changes: 3 additions & 3 deletions src/keys_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl KeyedMessage {
pub fn new_keys_iterator(
&self,
flags: Vec<KeysIteratorFlags>,
namespace: String,
namespace: &str,
) -> Result<KeysIterator, CodesError> {
let flags = flags.iter().map(|f| *f as u32).sum();

Expand Down Expand Up @@ -215,7 +215,7 @@ mod tests {
KeysIteratorFlags::SkipDuplicates, //32
];

let namespace = "geography".to_owned();
let namespace = "geography";

let mut kiter = current_message.new_keys_iterator(flags, namespace)?;

Expand All @@ -238,7 +238,7 @@ mod tests {
KeysIteratorFlags::AllKeys, //0
];

let namespace = "blabla".to_owned();
let namespace = "blabla";

let mut kiter = current_message.new_keys_iterator(flags, namespace)?;

Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ pub mod errors;
mod intermediate_bindings;
pub mod keyed_message;
pub mod keys_iterator;
#[cfg(feature = "message_ndarray")]
#[cfg_attr(docsrs, doc(cfg(feature = "message_ndarray")))]
pub mod message_ndarray;
mod pointer_guard;

pub use codes_handle::{CodesHandle, ProductKind};
Expand Down
157 changes: 157 additions & 0 deletions src/message_ndarray.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
use ndarray::{s, Array2, Array3};

use crate::{errors::MessageNdarrayError, CodesError, KeyType, KeyedMessage};

impl KeyedMessage {
/// Returns [y, x] ([Nj, Ni], [lat, lon]) ndarray from the message,
/// x coordinates are increasing with the i index,
/// y coordinates are decreasing with the j index.
pub fn to_ndarray(&self) -> Result<Array2<f64>, CodesError> {
let ni = if let KeyType::Int(ni) = self.read_key("Ni")?.value {
ni
} else {
return Err(MessageNdarrayError::UnexpectedKeyType("Ni".to_owned()).into());
};

let nj = if let KeyType::Int(nj) = self.read_key("Nj")?.value {
nj
} else {
return Err(MessageNdarrayError::UnexpectedKeyType("Nj".to_owned()).into());
};

let vals = if let KeyType::FloatArray(vals) = self.read_key("values")?.value {
vals
} else {
return Err(MessageNdarrayError::UnexpectedKeyType("values".to_owned()).into());
};

if vals.len() != (ni * nj) as usize {
return Err(MessageNdarrayError::UnexpectedValuesLength(vals.len(), ni * nj).into());
}

let shape = (nj as usize, ni as usize);
let vals = Array2::from_shape_vec(shape, vals).map_err(|e| MessageNdarrayError::from(e))?;

Ok(vals)
}

pub fn to_lons_lats_values(
&self,
) -> Result<(Array2<f64>, Array2<f64>, Array2<f64>), CodesError> {
let ni = if let KeyType::Int(ni) = self.read_key("Ni")?.value {
ni
} else {
return Err(MessageNdarrayError::UnexpectedKeyType("Ni".to_owned()).into());
};

let nj = if let KeyType::Int(nj) = self.read_key("Nj")?.value {
nj
} else {
return Err(MessageNdarrayError::UnexpectedKeyType("Nj".to_owned()).into());
};

let latlonvals = if let KeyType::FloatArray(vals) = self.read_key("latLonValues")?.value {
vals
} else {
return Err(MessageNdarrayError::UnexpectedKeyType("values".to_owned()).into());
};

if latlonvals.len() != (ni * nj * 3) as usize {
return Err(
MessageNdarrayError::UnexpectedValuesLength(latlonvals.len(), ni * nj * 3).into(),
);
}

let shape = (nj as usize, ni as usize, 3_usize);
let mut latlonvals =
Array3::from_shape_vec(shape, latlonvals).map_err(|e| MessageNdarrayError::from(e))?;
let (lats, lons, vals) =
latlonvals
.view_mut()
.multi_slice_move((s![.., .., 0], s![.., .., 1], s![.., .., 2]));

Ok((lons.into_owned(), lats.into_owned(), vals.into_owned()))
}
}

#[cfg(test)]
mod tests {
use float_cmp::assert_approx_eq;

use super::*;
use crate::codes_handle::CodesHandle;
use crate::FallibleStreamingIterator;
use crate::KeyType;
use crate::ProductKind;
use std::path::Path;

#[test]
fn test_to_ndarray() -> Result<(), CodesError> {
let file_path = Path::new("./data/iceland-surface.grib");
let mut handle = CodesHandle::new_from_file(file_path, ProductKind::GRIB)?;

while let Some(msg) = handle.next()? {
if msg.read_key("shortName")?.value == KeyType::Str("2d".to_string()) {
let ndarray = msg.to_ndarray()?;

// values from xarray
assert_approx_eq!(f64, ndarray[[0, 0]], 276.37793, epsilon = 0.000_1);
assert_approx_eq!(f64, ndarray[[0, 48]], 276.65723, epsilon = 0.000_1);
assert_approx_eq!(f64, ndarray[[16, 0]], 277.91113, epsilon = 0.000_1);
assert_approx_eq!(f64, ndarray[[16, 48]], 280.34277, epsilon = 0.000_1);
assert_approx_eq!(f64, ndarray[[5, 5]], 276.03418, epsilon = 0.000_1);
assert_approx_eq!(f64, ndarray[[10, 10]], 277.59082, epsilon = 0.000_1);
assert_approx_eq!(f64, ndarray[[15, 15]], 277.68652, epsilon = 0.000_1);
assert_approx_eq!(f64, ndarray[[8, 37]], 273.2744, epsilon = 0.000_1);

break;
}
}

Ok(())
}

#[test]
fn test_lons_lats() -> Result<(), CodesError> {
let file_path = Path::new("./data/iceland-surface.grib");
let mut handle = CodesHandle::new_from_file(file_path, ProductKind::GRIB)?;

while let Some(msg) = handle.next()? {
if msg.read_key("shortName")?.value == KeyType::Str("2d".to_string()) {
let (lons, lats, vals) = msg.to_lons_lats_values()?;

// values from cfgrib
assert_approx_eq!(f64, vals[[0, 0]], 276.37793, epsilon = 0.000_1);
assert_approx_eq!(f64, vals[[0, 48]], 276.65723, epsilon = 0.000_1);
assert_approx_eq!(f64, vals[[16, 0]], 277.91113, epsilon = 0.000_1);
assert_approx_eq!(f64, vals[[16, 48]], 280.34277, epsilon = 0.000_1);
assert_approx_eq!(f64, vals[[5, 5]], 276.03418, epsilon = 0.000_1);
assert_approx_eq!(f64, vals[[10, 10]], 277.59082, epsilon = 0.000_1);
assert_approx_eq!(f64, vals[[15, 15]], 277.68652, epsilon = 0.000_1);
assert_approx_eq!(f64, vals[[8, 37]], 273.2744, epsilon = 0.000_1);

assert_approx_eq!(f64, lons[[0, 0]], -25.0);
assert_approx_eq!(f64, lons[[0, 48]], -13.0);
assert_approx_eq!(f64, lons[[16, 0]], -25.0);
assert_approx_eq!(f64, lons[[16, 48]], -13.0);
assert_approx_eq!(f64, lons[[5, 5]], -23.75);
assert_approx_eq!(f64, lons[[10, 10]], -22.5);
assert_approx_eq!(f64, lons[[15, 15]], -21.25);
assert_approx_eq!(f64, lons[[8, 37]], -15.75);

assert_approx_eq!(f64, lats[[0, 0]], 67.0);
assert_approx_eq!(f64, lats[[0, 48]], 67.0);
assert_approx_eq!(f64, lats[[16, 0]], 63.0);
assert_approx_eq!(f64, lats[[16, 48]], 63.0);
assert_approx_eq!(f64, lats[[5, 5]], 65.75);
assert_approx_eq!(f64, lats[[10, 10]], 64.5);
assert_approx_eq!(f64, lats[[15, 15]], 63.25);
assert_approx_eq!(f64, lats[[8, 37]], 65.0);

break;
}
}

Ok(())
}
}

0 comments on commit a2c8f08

Please sign in to comment.