Skip to content
This repository has been archived by the owner on Sep 11, 2023. It is now read-only.

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 17, 2023
1 parent f761ddb commit 41514f5
Show file tree
Hide file tree
Showing 47 changed files with 15 additions and 108 deletions.
1 change: 0 additions & 1 deletion notebooks/2021-08/2021-08-25/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
channel_indexes = [1, 8, 9]
satellite_data = []
for channel_index in channel_indexes:

# renormalize
satellite_data.append(
data["sat_data"][batch_index, :, :, :, channel_index] * SAT_STD.values[channel_index]
Expand Down
1 change: 0 additions & 1 deletion notebooks/2021-08/2021-08-26/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@
channel_indexes = [1, 9, 8]
satellite_data = []
for channel_index in channel_indexes:

# renormalize
satellite_data.append(
data["sat_data"][batch_index, :, :, :, channel_index] * SAT_STD.values[channel_index]
Expand Down
1 change: 0 additions & 1 deletion notebooks/2021-09/2021-09-13/remove_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

for filenames in [train_filenames, validation_filenames]:
for file in train_filenames:

print(file)

filename = file.split("/")[-1]
Expand Down
1 change: 0 additions & 1 deletion notebooks/2021-09/2021-09-14/gsp_centroid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

# for index in range(0, len(shape_data_raw)):
for index in range(140, 150):

# just select the first one
shape_data = shape_data_raw.iloc[index : index + 1]
shapes_dict = json.loads(shape_data["geometry"].to_json())
Expand Down
1 change: 0 additions & 1 deletion notebooks/2021-09/2021-09-14/gsp_duplicated.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
duplicated_raw["Amount"] = range(0, len(duplicated_raw))

for i in range(0, 8, 2):

# just select the first one
duplicated = duplicated_raw.iloc[i : i + 2]
shapes_dict = json.loads(duplicated["geometry"].to_json())
Expand Down
1 change: 0 additions & 1 deletion notebooks/2021-09/2021-09-29/gsp_duplicated.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
duplicated_raw["Amount"] = range(0, len(duplicated_raw))

for i in range(0, 8, 2):

# just select the first one
duplicated = duplicated_raw.iloc[i : i + 2]
shapes_dict = json.loads(duplicated["geometry"].to_json())
Expand Down
2 changes: 0 additions & 2 deletions notebooks/2021-09/2021-09-29/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@


def get_trace(dt):

# plot to check it looks right
return go.Choroplethmapbox(
geojson=shapes_dict,
Expand All @@ -54,7 +53,6 @@ def get_trace(dt):


def get_frame(dt):

# plot to check it looks right
return go.Choroplethmapbox(
z=gps_data[dt],
Expand Down
2 changes: 0 additions & 2 deletions notebooks/2021-10/2021-10-01/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@


class Satellite(BaseModel):

# width: int = Field(..., g=0, description="The width of the satellite image")
# height: int = Field(..., g=0, description="The width of the satellite image")
# num_channels: int = Field(..., g=0, description="The width of the satellite image")
Expand Down Expand Up @@ -49,7 +48,6 @@ class Config:


class Batch(BaseModel):

batch_size: int = Field(
...,
g=0,
Expand Down
1 change: 0 additions & 1 deletion notebooks/2021-10/2021-10-08/xr_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
def get_satellite_xrarray_data_array(
batch_size, seq_length_5, satellite_image_size_pixels, number_sat_channels=10
):

r = np.random.randn(
# self.batch_size,
seq_length_5,
Expand Down
1 change: 0 additions & 1 deletion notebooks/2021-10/2021-10-08/xr_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def v_image_data(cls, v):


class Batch(BaseModel):

batch_size: int = 0
satellite: Satellite

Expand Down
1 change: 0 additions & 1 deletion nowcasting_dataset/data_sources/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def __post_init__(self):
def _get_start_dt(
self, t0_datetime_utc: Union[pd.Timestamp, pd.DatetimeIndex]
) -> Union[pd.Timestamp, pd.DatetimeIndex]:

return t0_datetime_utc - self.history_duration

def _get_end_dt(
Expand Down
1 change: 0 additions & 1 deletion nowcasting_dataset/data_sources/fake/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,6 @@ def topographic_fake(
# make batch of arrays
xr_arrays = []
for i in range(batch_size):

x, y = make_image_coords_osgb(
size_x=image_size_pixels_width,
size_y=image_size_pixels_height,
Expand Down
1 change: 0 additions & 1 deletion nowcasting_dataset/data_sources/gsp/eso.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def get_gsp_shape_from_eso(
shape_gpd["RegionID"] = range(1, len(shape_gpd) + 1)

if save_local_file:

# rename the columns to less than 10 characters
shape_gpd_to_save = shape_gpd.copy()
shape_gpd_to_save.rename(columns=rename_save_columns, inplace=True)
Expand Down
5 changes: 0 additions & 5 deletions nowcasting_dataset/data_sources/gsp/gsp_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def get_all_locations(self, t0_datetimes_utc: pd.DatetimeIndex) -> List[SpaceTim
if total_gsp_nan_count > 0:
assert Exception("There are nans in the GSP data. Can't get locations for all GSPs")
else:

t0_datetimes_utc.name = "t0_datetime_utc"

# get all locations
Expand Down Expand Up @@ -236,7 +235,6 @@ def get_locations(self, t0_datetimes_utc: pd.DatetimeIndex) -> List[SpaceTimeLoc

total_gsp_nan_count = self.gsp_power.isna().sum().sum()
if total_gsp_nan_count == 0:

# get random GSP metadata
indexes = sorted(
list(self.rng.integers(low=0, high=len(self.metadata), size=len(t0_datetimes_utc)))
Expand All @@ -249,7 +247,6 @@ def get_locations(self, t0_datetimes_utc: pd.DatetimeIndex) -> List[SpaceTimeLoc
ids = list(metadata.index)

else:

logger.warning(
"There are some nans in the gsp data, "
"so to get x,y locations we have to do a big loop"
Expand All @@ -262,7 +259,6 @@ def get_locations(self, t0_datetimes_utc: pd.DatetimeIndex) -> List[SpaceTimeLoc
ids = []

for t0_dt in t0_datetimes_utc:

# Choose start and end times
start_dt = self._get_start_dt(t0_dt)
end_dt = self._get_end_dt(t0_dt)
Expand Down Expand Up @@ -290,7 +286,6 @@ def get_locations(self, t0_datetimes_utc: pd.DatetimeIndex) -> List[SpaceTimeLoc

locations = []
for i in range(len(x_centers_osgb)):

locations.append(
SpaceTimeLocation(
t0_datetime_utc=t0_datetimes_utc[i],
Expand Down
1 change: 0 additions & 1 deletion nowcasting_dataset/data_sources/gsp/pvlive.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def load_pv_gsp_raw_data_from_pvlive(
future_tasks = []
with futures.ThreadPoolExecutor(max_workers=1) as executor:
for gsp_id in gsp_ids:

# set the first chunk start and end times
start_chunk = first_start_chunk
end_chunk = first_end_chunk
Expand Down
1 change: 0 additions & 1 deletion nowcasting_dataset/data_sources/metadata/metadata_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def save_to_csv(self, path):
metadata_df = pd.DataFrame(metadata_dict)

else:

metadata_df = pd.read_csv(filename)

metadata_df_extra = pd.DataFrame(metadata_dict)
Expand Down
2 changes: 0 additions & 2 deletions nowcasting_dataset/data_sources/pv/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def get_metadata_from_database(providers: List[str] = None) -> pd.DataFrame:

pv_system_all_df = []
for provider in providers:

logger.debug(f"Get PV systems from database for {provider}")

with db_connection.get_session() as session:
Expand Down Expand Up @@ -136,7 +135,6 @@ def get_pv_power_from_database(
logger.debug(f"Found {len(pv_yields_df)} pv yields")

if len(pv_yields_df) == 0:

data = create_empty_pv_data(end_utc=now, providers=providers, start_utc=start_utc)

return data
Expand Down
3 changes: 1 addition & 2 deletions nowcasting_dataset/data_sources/pv/pv_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def get_data_model_for_batch():
return PV

def _load_metadata(self):

logger.debug(f"Loading PV metadata from {self.files_groups}")

# collect all metadata together
Expand Down Expand Up @@ -155,7 +154,6 @@ def _load_metadata(self):
logger.debug(f"Found {len(pv_metadata)} pv systems")

def _load_pv_power(self):

logger.debug(f"Loading PV Power data from {self.files_groups}")

if not self.is_live:
Expand Down Expand Up @@ -453,6 +451,7 @@ def get_locations(self, t0_datetimes_utc: pd.DatetimeIndex) -> List[SpaceTimeLoc
Returns: x_locations, y_locations. Each has one entry per t0_datetime.
Locations are in OSGB coordinates.
"""

# Set this up as a separate function, so we can cache the result!
@functools.cache # functools.cache requires Python >= 3.9
def _get_pv_system_ids(t0_datetime: pd.Timestamp) -> pd.Int64Dtype:
Expand Down
3 changes: 0 additions & 3 deletions nowcasting_dataset/data_sources/sun/raw_data_load_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,13 @@ def get_azimuth_and_elevation(
names = []
# loop over locations and find azimuth and elevation angles,
with futures.ThreadPoolExecutor() as executor:

logger.debug("Setting up jobs")

# Submit tasks to the executor.
future_azimuth_and_elevation_per_location = []
for i in tqdm(range(len(x_centers))):

name = x_y_to_name(x_centers[i], y_centers[i])
if name not in names:

lat, lon = geospatial.osgb_to_lat_lon(x=x_centers[i], y=y_centers[i])

future_azimuth_and_elevation = executor.submit(
Expand Down
3 changes: 0 additions & 3 deletions nowcasting_dataset/data_sources/sun/sun_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def get_example(self, location: SpaceTimeLocation) -> xr.Dataset:
end_dt = self._get_end_dt(t0_datetime_utc)

if not self.load_live:

# The names of the columns get truncated when saving, therefore we need to look for the
# name of the columns near the location we are looking for
locations = np.array(
Expand All @@ -96,7 +95,6 @@ def get_example(self, location: SpaceTimeLocation) -> xr.Dataset:
elevation = self.elevation.loc[start_dt:end_dt][name]

else:

latitude, longitude = osgb_to_lat_lon(x=x_center_osgb, y=y_center_osgb)

datestamps = pd.date_range(start=start_dt, end=end_dt, freq="5T").tolist()
Expand All @@ -115,7 +113,6 @@ def get_example(self, location: SpaceTimeLocation) -> xr.Dataset:
return sun

def _load(self):

logger.info(f"Loading Sun data from {self.zarr_path}")

if not self.load_live:
Expand Down
2 changes: 0 additions & 2 deletions nowcasting_dataset/dataset/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def load_netcdf(

# loop over data sources
for data_source_name in data_sources_names:

local_netcdf_filename = os.path.join(
local_netcdf_path, data_source_name, get_netcdf_filename(batch_idx)
)
Expand Down Expand Up @@ -193,7 +192,6 @@ def load_netcdf(

# legacy NWP
if "nwp" in batch_dict.keys():

nwp_rename_dict = {
"x_index": "x_osgb_index",
"y_index": "y_osgb_index",
Expand Down
2 changes: 0 additions & 2 deletions nowcasting_dataset/dataset/split/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def split_method(
test_periods = unique_periods[unique_periods["modulo"].isin(test_indexes)]["period"]

elif method == "random":

# randomly sort indexes
rng = np.random.default_rng(seed)
unique_periods_in_dataset = rng.permutation(unique_periods_in_dataset)
Expand All @@ -108,7 +107,6 @@ def split_method(
test_periods = pd.to_datetime(unique_periods_in_dataset[validation_test_split:])

elif method == "specific":

train_periods = unique_periods_in_dataset[
unique_periods_in_dataset.isin(train_test_validation_specific.train)
]
Expand Down
1 change: 0 additions & 1 deletion nowcasting_dataset/filesystem/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def delete_all_files_in_temp_path(path: Union[Path, str], delete_dirs: bool = Fa
else:
# loop over folder structure, but only delete files
for root, dirs, files in filesystem.walk(path):

for f in files:
filesystem.rm(f"{root}/{f}")

Expand Down
3 changes: 0 additions & 3 deletions nowcasting_dataset/manager/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,13 @@ def sample_spatial_and_temporal_locations_for_examples(
shuffled_t0_datetimes = pd.DatetimeIndex(shuffled_t0_datetimes)

if get_all_locations:

# note that the returned 'shuffled_t0_datetimes'
# has duplicate datetimes for each location
locations = self.data_source_which_defines_geospatial_locations.get_all_locations(
t0_datetimes_utc=shuffled_t0_datetimes
)

else:

locations = self.data_source_which_defines_geospatial_locations.get_locations(
shuffled_t0_datetimes
)
Expand Down Expand Up @@ -404,7 +402,6 @@ def create_batches(self, overwrite_batches: bool) -> None:
for worker_id, (data_source_name, data_source) in enumerate(
self.data_sources.items()
):

# Get indexes of first batch and example. And subset locations_for_split.
idx_of_first_batch = first_batches_to_create[split_name][data_source_name]
idx_of_first_example = idx_of_first_batch * self.config.process.batch_size
Expand Down
2 changes: 0 additions & 2 deletions nowcasting_dataset/manager/manager_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def create_batches(self, use_async: Optional[bool] = True) -> None:
async_results_from_create_batches = []
an_error_has_occured = multiprocessing.Event()
for worker_id, (data_source_name, data_source) in enumerate(self.data_sources.items()):

# Get indexes of first batch and example. And subset locations_for_split.
idx_of_first_batch = 0
locations = locations_for_each_example
Expand Down Expand Up @@ -226,7 +225,6 @@ def create_batches(self, use_async: Optional[bool] = True) -> None:
# Sometimes when debuggin it is easy to use non async
data_source.create_batches(**kwargs_for_create_batches)
else:

async_result = pool.apply_async(
data_source.create_batches,
kwds=kwargs_for_create_batches,
Expand Down
2 changes: 2 additions & 0 deletions nowcasting_dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def shutdown(self, wait=True):

def arg_logger(func):
"""A function decorator to log all the args and kwargs passed into a function."""

# Adapted from https://stackoverflow.com/a/23983263/732596
@wraps(func)
def inner_func(*args, **kwargs):
Expand All @@ -191,6 +192,7 @@ def inner_func(*args, **kwargs):

def exception_logger(func):
"""A function decorator to log exceptions thrown by the inner function."""

# Adapted from
# www.blog.pythonlibrary.org/2016/06/09/python-how-to-create-an-exception-logging-decorator
@wraps(func)
Expand Down
20 changes: 12 additions & 8 deletions scripts/generate_raw_data/get_raw_pv_gsp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,24 @@ def fetch_data():
data_df = load_pv_gsp_raw_data_from_pvlive(start=start, end=end, normalize_data=False)

# pivot to index as datetime_gmt, and columns as gsp_id
data_generation_df = data_df.pivot(index="datetime_gmt", columns="gsp_id", values="generation_mw")
data_installedcapacity_df = data_df.pivot(index="datetime_gmt", columns="gsp_id", values="installedcapacity_mwp")
data_generation_df = data_df.pivot(
index="datetime_gmt", columns="gsp_id", values="generation_mw"
)
data_installedcapacity_df = data_df.pivot(
index="datetime_gmt", columns="gsp_id", values="installedcapacity_mwp"
)
data_capacity_df = data_df.pivot(index="datetime_gmt", columns="gsp_id", values="capacity_mwp")
data_updated_gmt_df = data_df.pivot(index="datetime_gmt", columns="gsp_id", values="updated_gmt")
data_updated_gmt_df = data_df.pivot(
index="datetime_gmt", columns="gsp_id", values="updated_gmt"
)
data_xarray = xr.Dataset(
data_vars={
"generation_mw": (("datetime_gmt", "gsp_id"), data_generation_df),
"installedcapacity_mwp": (("datetime_gmt", "gsp_id"), data_installedcapacity_df),
"capacity_mwp": (("datetime_gmt", "gsp_id"), data_capacity_df),
"updated_gmt": (("datetime_gmt", "gsp_id"), data_updated_gmt_df),
},
coords={
"datetime_gmt": data_generation_df.index,
"gsp_id": data_generation_df.columns
},
coords={"datetime_gmt": data_generation_df.index, "gsp_id": data_generation_df.columns},
)

# save config to file
Expand All @@ -71,7 +74,8 @@ def fetch_data():

# Make encoding
encoding = {
var: {"compressor": numcodecs.Blosc(cname="zstd", clevel=5)} for var in data_xarray.data_vars
var: {"compressor": numcodecs.Blosc(cname="zstd", clevel=5)}
for var in data_xarray.data_vars
}

# save data to file
Expand Down
Loading

0 comments on commit 41514f5

Please sign in to comment.