Skip to content

Commit

Permalink
Merge branch 'main' into feature/write_layers
Browse files Browse the repository at this point in the history
  • Loading branch information
jterry64 authored Jun 6, 2024
2 parents 028c467 + cc41db2 commit de9c6e5
Show file tree
Hide file tree
Showing 5 changed files with 1,328 additions and 13 deletions.
38 changes: 33 additions & 5 deletions city_metrix/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,26 @@ def count(self):

def _zonal_stats(self, stats_func):
if box(*self.zones.total_bounds).area <= MAX_TILE_SIZE**2:
return self._zonal_stats_tile(self.zones, [stats_func])[stats_func]
stats = self._zonal_stats_tile(self.zones, [stats_func])
else:
return self._zonal_stats_fishnet(stats_func)
stats = self._zonal_stats_fishnet(stats_func)

if self.layer is not None:
# decode zone and layer value using bit operations
stats["layer"] = stats["zone"].astype("uint32").values >> 16
stats["zone"] = stats["zone"].astype("uint32").values & 65535

# group layer values together into a dictionary per zone
def group_layer_values(df):
layer_values = df.drop(columns="zone").groupby("layer").sum()
layer_dicts = layer_values.to_dict()
return layer_dicts[stats_func]

stats = stats.groupby("zone").apply(group_layer_values)

return stats

return stats[stats_func]

def _zonal_stats_fishnet(self, stats_func):
# fishnet GeoDataFrame into smaller tiles
Expand All @@ -124,31 +141,42 @@ def _zonal_stats_fishnet(self, stats_func):
tile_funcs = get_stats_funcs(stats_func)

# run zonal stats per data frame
print(f"Input covers too much area, splitting into {len(tile_gdfs)} tiles")
tile_stats = pd.concat([
self._zonal_stats_tile(tile_gdf, tile_funcs)
for tile_gdf in tile_gdfs
])

aggregated = tile_stats.groupby("zone").apply(_aggregate_stats, stats_func)
aggregated.name = stats_func

return aggregated
return aggregated.reset_index()

def _zonal_stats_tile(self, tile_gdf, stats_func):
bbox = tile_gdf.total_bounds
aggregate_data = self.aggregate.get_data(bbox)
mask_datum = [mask.get_data(bbox) for mask in self.masks]
layer_data = self.layer.get_data(bbox) if self.layer is not None else None

# align to highest resolution raster, which should be the largest raster
# since all are clipped to the extent
raster_data = [data for data in mask_datum + [aggregate_data] if isinstance(data, xr.DataArray)]
raster_data = [data for data in mask_datum + [aggregate_data] + [layer_data] if isinstance(data, xr.DataArray)]
align_to = sorted(raster_data, key=lambda data: data.size, reverse=True).pop()
aggregate_data = self._align(aggregate_data, align_to)
mask_datum = [self._align(data, align_to) for data in mask_datum]

if self.layer is not None:
layer_data = self._align(layer_data, align_to)

for mask in mask_datum:
aggregate_data = aggregate_data.where(~np.isnan(mask))

zones = self._rasterize(tile_gdf, align_to)

if self.layer is not None:
# encode layer into zones by bitshifting
zones = zones + (layer_data.astype("uint32") << 16)

stats = zonal_stats(zones, aggregate_data, stats_funcs=stats_func)

return stats
Expand Down Expand Up @@ -261,7 +289,7 @@ def get_image_collection(
)

with ProgressBar():
print(f"Extracting layer {name} from Google Earth Engine:")
print(f"Extracting layer {name} from Google Earth Engine for bbox {bbox}:")
data = ds.compute()

# get in rioxarray format
Expand Down
Loading

0 comments on commit de9c6e5

Please sign in to comment.