Skip to content

Commit

Permalink
Updated ndvi. Moved control of layer-writing test into conftest file.
Browse files Browse the repository at this point in the history
  • Loading branch information
kcartier-wri committed Aug 24, 2024
1 parent 71b4efd commit 4a4b4ad
Show file tree
Hide file tree
Showing 18 changed files with 328 additions and 5,363 deletions.
1 change: 1 addition & 0 deletions city_metrix/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .albedo import Albedo
from .ndvi_sentinel2_gee import NdviSentinel2
from .esa_world_cover import EsaWorldCover, EsaWorldCoverClass
from .land_surface_temperature import LandSurfaceTemperature
from .tree_cover import TreeCover
Expand Down
8 changes: 4 additions & 4 deletions city_metrix/layers/landsat_collection_2.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import odc.stac
import pystac_client
from jupyterlab.utils import deprecated

from .layer import Layer


@deprecated
class LandsatCollection2(Layer):
def __init__(self, bands, start_date="2013-01-01", end_date="2023-01-01", **kwargs):
super().__init__(**kwargs)
Expand All @@ -29,8 +30,7 @@ def get_data(self, bbox):
fail_on_error=False,
)

# TODO: Determine how to output xarray

qa_lst = lc2.where((lc2.qa_pixel & 24) == 0)
return qa_lst.drop_vars("qa_pixel")



43 changes: 27 additions & 16 deletions city_metrix/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@
import shapely.geometry as geometry
import pandas as pd


MAX_TILE_SIZE = 0.5


class Layer:
def __init__(self, aggregate=None, masks=[]):
self.aggregate = aggregate
Expand All @@ -39,6 +37,15 @@ def get_data(self, bbox: Tuple[float]) -> Union[xr.DataArray, gpd.GeoDataFrame]:
"""
...

@abstractmethod
def post_processing_adjustment(self, data, **kwargs) -> Union[xr.DataArray, gpd.GeoDataFrame]:
"""
Applies the standard post-processing adjustment used for rendering of the layer
:param are specific to the layer
:return: A rioxarray-format DataArray or a GeoPandas DataFrame
"""
return data

def mask(self, *layers):
"""
Apply layers as masks
Expand All @@ -56,7 +63,7 @@ def groupby(self, zones, layer=None):
"""
return LayerGroupBy(self.aggregate, zones, layer, self.masks)

def write(self, bbox, output_path, tile_degrees=None):
def write(self, bbox, output_path, tile_degrees=None, **kwargs):
"""
Write the layer to a path. Does not apply masks.
Expand All @@ -76,13 +83,15 @@ def write(self, bbox, output_path, tile_degrees=None):
file_names = []
for tile in tiles["geometry"]:
data = self.aggregate.get_data(tile.bounds)
data = self.post_processing_adjustment(data, **kwargs)

file_name = f"{output_path}/{uuid4()}.tif"
file_names.append(file_name)

write_layer(file_name, data)
else:
data = self.aggregate.get_data(bbox)
data = self.post_processing_adjustment(data, **kwargs)
write_layer(output_path, data)


Expand Down Expand Up @@ -301,21 +310,23 @@ def get_image_collection(

return data


def write_layer(path, data):
if isinstance(data, xr.DataArray):
# for rasters, need to write to locally first then copy to cloud storage
if path.startswith("s3://"):
tmp_path = f"{uuid4()}.tif"
data.rio.to_raster(raster_path=tmp_path, driver="COG")

s3 = boto3.client('s3')
s3.upload_file(tmp_path, path.split('/')[2], '/'.join(path.split('/')[3:]))

os.remove(tmp_path)
else:
data.rio.to_raster(raster_path=path, driver="COG")
write_dataarray(path, data)
elif isinstance(data, gpd.GeoDataFrame):
data.to_file(path, driver="GeoJSON")
else:
raise NotImplementedError("Can only write DataArray or GeoDataFrame")
raise NotImplementedError("Can only write DataArray, Dataset, or GeoDataFrame")

def write_dataarray(path, data):
# for rasters, need to write to locally first then copy to cloud storage
if path.startswith("s3://"):
tmp_path = f"{uuid4()}.tif"
data.rio.to_raster(raster_path=tmp_path, driver="COG")

s3 = boto3.client('s3')
s3.upload_file(tmp_path, path.split('/')[2], '/'.join(path.split('/')[3:]))

os.remove(tmp_path)
else:
data.rio.to_raster(raster_path=path, driver="COG")
65 changes: 65 additions & 0 deletions city_metrix/layers/ndvi_sentinel2_gee.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import ee
from tools.xarray_tools import convert_ratio_to_percentage
from .layer import Layer, get_image_collection

class NdviSentinel2(Layer):
""""
NDVI = Sentinel-2 Normalized Difference Vegetation Index
param: year: The satellite imaging year.
return: a rioxarray-format DataArray
Author of associated Jupyter notebook: [email protected]
Notebook: https://github.com/wri/cities-cities4forests-indicators/blob/dev-eric/scripts/extract-VegetationCover.ipynb
Reference: https://en.wikipedia.org/wiki/Normalized_difference_vegetation_index
"""
def __init__(self, year=None, **kwargs):
super().__init__(**kwargs)
self.year = year

def get_data(self, bbox):
if self.year is None:
raise Exception('NdviSentinel2.get_data() requires a year value')

start_date = "%s-01-01" % self.year
end_date = "%s-12-31" % self.year

# Compute NDVI for each image
def calculate_ndvi(image):
ndvi = (image
.normalizedDifference(['B8', 'B4'])
.rename('NDVI'))
return image.addBands(ndvi)

s2 = ee.ImageCollection("COPERNICUS/S2_HARMONIZED")
ndvi = (s2
.filterBounds(ee.Geometry.BBox(*bbox))
.filterDate(start_date, end_date)
.map(calculate_ndvi)
.select('NDVI')
)

ndvi_mosaic = ndvi.qualityMosaic('NDVI')

ic = ee.ImageCollection(ndvi_mosaic)
ndvi_data = get_image_collection(ic, bbox, 10, "NDVI")

xdata = ndvi_data.to_dataarray()

return xdata

def post_processing_adjustment(self, data, ndvi_threshold=0.4, convert_to_percentage=True, **kwargs):
"""
Applies the standard post-processing adjustment used for rendering of NDVI including masking
to a threshold and conversion to percentage values.
:param ndvi_threshold: (float) minimum threshold for keeping values
:param convert_to_percentage: (bool) controls whether NDVI values are converted to a percentage
:return: A rioxarray-format DataArray
"""
# Remove values less than the specified threshold
if ndvi_threshold is not None:
data = data.where(data >= ndvi_threshold)

# Convert to percentage in byte data_type
if convert_to_percentage is True:
data = convert_ratio_to_percentage(data)

return data
7 changes: 7 additions & 0 deletions city_metrix/layers/open_street_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,10 @@ def get_data(self, bbox):
osm_feature = osm_feature.reset_index()[keep_col]

return osm_feature

def write(self, output_path):
self.data['bbox'] = str(self.data.total_bounds)
self.data['osm_class'] = str(self.osm_class.value)

# Write to a GeoJSON file
self.data.to_file(output_path, driver='GeoJSON')
5 changes: 4 additions & 1 deletion city_metrix/layers/sentinel_2_level_2.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import odc.stac
import pystac_client
from jupyterlab.utils import deprecated

from .layer import Layer


@deprecated
class Sentinel2Level2(Layer):
def __init__(self, bands, start_date="2013-01-01", end_date="2023-01-01", **kwargs):
super().__init__(**kwargs)
Expand Down Expand Up @@ -50,4 +51,6 @@ def get_data(self, bbox):
cloud_masked = s2.where(s2 != 0).where(s2.scl != 3).where(s2.scl != 8).where(s2.scl != 9).where(
s2.scl != 10)

# TODO: Determine how to output as an xarray

return cloud_masked.drop_vars("scl")
17 changes: 15 additions & 2 deletions tests/resources/bbox_constants.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
# File defines bboxes using in the test code


BBOX_BR_LAURO_DE_FREITAS_1 = (
BBOX_BRA_LAURO_DE_FREITAS_1 = (
-38.35530428121955,
-12.821710300686393,
-38.33813814352424,
-12.80363249765361,
)
)

BBOX_BRA_SALVADOR_ADM4 = (
-38.647320153390055,
-13.01748678217598787,
-38.3041637148564007,
-12.75607703449720631
)

BBOX_SMALL_TEST = (
-38.43864,-12.97987,
-38.39993,-12.93239
)

67 changes: 67 additions & 0 deletions tests/resources/layer_dumps_for_br_lauro_de_freitas/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import tempfile
import pytest
import os
import shutil
from collections import namedtuple

from tests.resources.bbox_constants import BBOX_BRA_LAURO_DE_FREITAS_1
from tools.general_tools import create_target_folder, is_valid_path

# RUN_DUMPS is the master control for whether the writes and tests are executed
# Setting RUN_DUMPS to True turns on code execution.
# Values should normally be set to False in order to avoid unnecessary execution.
RUN_DUMPS = True

# Specify None to write to a temporary default folder otherwise specify a valid custom target path.
CUSTOM_DUMP_DIRECTORY = None

# Both the tests and QGIS file are implemented for the same bounding box in Brazil.
COUNTRY_CODE_FOR_BBOX = 'BRA'
BBOX = BBOX_BRA_LAURO_DE_FREITAS_1

def pytest_configure(config):
qgis_project_file = 'layers_for_br_lauro_de_freitas.qgz'

source_folder = os.path.dirname(__file__)
target_folder = get_target_folder_path()
create_target_folder(target_folder, True)

source_qgis_file = os.path.join(source_folder, qgis_project_file)
target_qgis_file = os.path.join(target_folder, qgis_project_file)
shutil.copyfile(source_qgis_file, target_qgis_file)

print("\n\033[93m QGIS project file and layer files written to folder %s.\033[0m\n" % target_folder)

@pytest.fixture
def target_folder():
return get_target_folder_path()

@pytest.fixture
def bbox_info():
bbox = namedtuple('bbox', ['bounds', 'country'])
bbox_instance = bbox(bounds=BBOX, country=COUNTRY_CODE_FOR_BBOX)
return bbox_instance

def get_target_folder_path():
if CUSTOM_DUMP_DIRECTORY is not None:
if is_valid_path(CUSTOM_DUMP_DIRECTORY) is False:
raise ValueError(f"The custom path '%s' is not valid. Stopping." % CUSTOM_DUMP_DIRECTORY)
else:
output_dir = CUSTOM_DUMP_DIRECTORY
else:
sub_directory_name = 'test_result_tif_files'
scratch_dir_name = tempfile.TemporaryDirectory(ignore_cleanup_errors=True).name
dir_path = os.path.dirname(scratch_dir_name)
output_dir = os.path.join(dir_path, sub_directory_name)

return output_dir

def prep_output_path(output_folder, file_name):
file_path = os.path.join(output_folder, file_name)
if os.path.isfile(file_path):
os.remove(file_path)
return file_path

def verify_file_is_populated(file_path):
is_populated = True if os.path.getsize(file_path) > 0 else False
return is_populated
Loading

0 comments on commit 4a4b4ad

Please sign in to comment.