diff --git a/power_perceiver/time.py b/power_perceiver/time.py index 12668ad..94095b0 100644 --- a/power_perceiver/time.py +++ b/power_perceiver/time.py @@ -226,3 +226,22 @@ def time_periods_to_datetime_index(time_periods: pd.DataFrame, freq: str) -> pd. new_dt_index = single_period_to_datetime_index(time_period, freq=freq) dt_index = dt_index.union(new_dt_index) return dt_index + + +def set_new_sample_period_and_t0_idx_attrs(xr_data, new_sample_period) -> xr.DataArray: + orig_sample_period = xr_data.attrs["sample_period_duration"] + orig_t0_idx = xr_data.attrs["t0_idx"] + new_sample_period = pd.Timedelta(new_sample_period) + assert new_sample_period >= orig_sample_period + new_t0_idx = orig_t0_idx / (new_sample_period / orig_sample_period) + np.testing.assert_almost_equal( + int(new_t0_idx), + new_t0_idx, + err_msg=( + "The original t0_idx must be exactly divisible by" + " (new_sample_period / orig_sample_period)" + ), + ) + xr_data.attrs["sample_period_duration"] = new_sample_period + xr_data.attrs["t0_idx"] = int(new_t0_idx) + return xr_data diff --git a/power_perceiver/transforms/pv.py b/power_perceiver/transforms/pv.py index 8747aeb..27bd6b6 100644 --- a/power_perceiver/transforms/pv.py +++ b/power_perceiver/transforms/pv.py @@ -1,10 +1,11 @@ from dataclasses import dataclass from typing import Optional, Union -import numpy as np import pandas as pd import xarray as xr +from power_perceiver.time import set_new_sample_period_and_t0_idx_attrs + @dataclass class PVPowerRollingWindow: @@ -91,12 +92,5 @@ def __call__(self, xr_data: Union[xr.Dataset, xr.DataArray]) -> Union[xr.Dataset resampled = xr_data # Change the pv_t0_idx and the sample_period_duration attributes: - orig_sample_period = xr_data.attrs["sample_period_duration"] - orig_t0_idx = xr_data.attrs["t0_idx"] - new_sample_period = pd.Timedelta(self.freq) - assert new_sample_period >= orig_sample_period - new_t0_idx = orig_t0_idx / (new_sample_period / orig_sample_period) - np.testing.assert_almost_equal(int(new_t0_idx), new_t0_idx) - resampled.attrs["sample_period_duration"] = new_sample_period - resampled.attrs["t0_idx"] = int(new_t0_idx) + resampled = set_new_sample_period_and_t0_idx_attrs(resampled, new_sample_period=self.freq) return resampled