Skip to content

Commit

Permalink
Add testing
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Jul 10, 2024
1 parent 00de859 commit 569dde2
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 5 deletions.
34 changes: 29 additions & 5 deletions src/tdastro/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from astropy.table import Table


def read_grid_data(input_file, format="ascii"):
def read_grid_data(input_file, format="ascii", validate=False):
"""Read 2-d grid data from a text, csv, ecsv, or fits file.
Each line is of the form 'x0 x1 value' where x0 and x1 are the grid
Expand All @@ -16,6 +16,10 @@ def read_grid_data(input_file, format="ascii"):
format : `str`
The file format. Should be one of ascii, csv, ecsv,
or fits.
Default = 'ascii'
validate : `bool`
Perform additional validation on the input data.
Default = False
Returns
-------
Expand All @@ -26,24 +30,44 @@ def read_grid_data(input_file, format="ascii"):
values : `numpy.ndarray`
A 2-d array with the values at each point in the grid with
shape (len(x0), len(x1)).
Raises
------
``ValueError`` if any data validation fails.
"""
data = Table.read(input_file, format=format)
data = Table.read(input_file, format=format, comment=r"\s*#")
if len(data.colnames) != 3:
raise ValueError(
f"Incorrect format for grid data in {input_file} with format {format}. "
f"Expected 3 columns but found {len(data.colnames)}."
)
x0_col = data.colnames[0]
x1_col = data.colnames[1]
v_col = data.colnames[2]

# Get the values along the x0 and x1 dimensions.
x0 = np.sort(np.unique(data[data.colnames[0]].data))
x1 = np.sort(np.unique(data[data.colnames[1]].data))
x0 = np.sort(np.unique(data[x0_col].data))
x1 = np.sort(np.unique(data[x1_col].data))

# Get the array of values.
if len(data) != len(x0) * len(x1):
raise ValueError(
f"Incomplete data for {input_file} with format {format}. Expected "
f"{len(x0) * len(x1)} entries but found {len(data)}."
)
values = data[data.colnames[2]].data.reshape((len(x0), len(x1)))

# If we are validating, loop through the entire table and check that
# the x0 and x1 values are in the expected order.
if validate:
counter = 0
for i in range(len(x0)):
for j in range(len(x1)):
if data[x0_col][counter] != x0[i]:
raise ValueError(f"Incorrect x0 ordering in {input_file} at row={counter}.")
if data[x1_col][counter] != x1[j]:
raise ValueError(f"Incorrect x0 ordering in {input_file} at row={counter}.")

# Build the values matrix.
values = data[v_col].data.reshape((len(x0), len(x1)))

return x0, x1, values
24 changes: 24 additions & 0 deletions tests/tdastro/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os.path

import pytest

DATA_DIR_NAME = "data"
TEST_DIR = os.path.dirname(__file__)


@pytest.fixture
def test_data_dir():
"""Return the base test data directory."""
return os.path.join(TEST_DIR, DATA_DIR_NAME)


@pytest.fixture
def grid_data_good_file(test_data_dir):
"""Return the file path for the good grid input file."""
return os.path.join(test_data_dir, "grid_input_good.ecsv")


@pytest.fixture
def grid_data_bad_file(test_data_dir):
"""Return the file path for the bad grid input file."""
return os.path.join(test_data_dir, "grid_input_bad.txt")
6 changes: 6 additions & 0 deletions tests/tdastro/data/grid_input_bad.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
0.0 1.0 0.0
0.0 1.5 1.0
1.0 1.0 2.0
1.0 1.5 3.0
2.0 1.5 5.0
2.0 1.0 4.0
8 changes: 8 additions & 0 deletions tests/tdastro/data/grid_input_good.ecsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Comment up here.
x0, x1, values
0.0, 1.0, 0.0
0.0, 1.5, 1.0
1.0, 1.0, 2.0
1.0, 1.5, 3.0
2.0, 1.0, 4.0
2.0, 1.5, 5.0
25 changes: 25 additions & 0 deletions tests/tdastro/test_io_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import numpy as np
import pytest
from tdastro.io_utils import read_grid_data


def test_read_grid_data_good(grid_data_good_file):
"""Test that we can read a well formatted grid data file."""
x0, x1, values = read_grid_data(grid_data_good_file, format="ascii.csv")
x0_expected = np.array([0.0, 1.0, 2.0])
x1_expected = np.array([1.0, 1.5])
values_expected = np.array([[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]])

np.testing.assert_allclose(x0, x0_expected, atol=1e-5)
np.testing.assert_allclose(x1, x1_expected, atol=1e-5)
np.testing.assert_allclose(values, values_expected, atol=1e-5)


def test_read_grid_data_bad(grid_data_bad_file):
"""Test that we correctly handle a badly formatted grid data file."""
# We load without a problem is validation is off.
x0, x1, values = read_grid_data(grid_data_bad_file, format="ascii")
assert values.shape == (3, 2)

with pytest.raises(ValueError):
_, _, _ = read_grid_data(grid_data_bad_file, format="ascii", validate=True)

0 comments on commit 569dde2

Please sign in to comment.