Skip to content

Commit

Permalink
Correct VI metadata
Browse files Browse the repository at this point in the history
Fixes #11
  • Loading branch information
chuckwondo committed Jun 14, 2024
1 parent c804467 commit d86f978
Show file tree
Hide file tree
Showing 3 changed files with 410 additions and 56 deletions.
108 changes: 67 additions & 41 deletions hls_vi/generate_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,26 @@
BandData: TypeAlias = Mapping["Band", np.ma.masked_array]
IndexFunction = Callable[[BandData], np.ma.masked_array]

fixed_tags = (
"add_offset",
"ACCODE",
"AREA_OR_POINT",
"cloud_coverage",
"HORIZONTAL_CS_NAME",
"MEAN_SUN_AZIMUTH_ANGLE",
"MEAN_SUN_ZENITH_ANGLE",
"MEAN_VIEW_AZIMUTH_ANGLE",
"MEAN_VIEW_ZENITH_ANGLE",
"NBAR_SOLAR_ZENITH",
"NCOLS",
"NROWS",
"SENSING_TIME",
"spatial_coverage",
"SPATIAL_RESOLUTION",
"ULX",
"ULY",
)


@unique
class Band(Enum):
Expand Down Expand Up @@ -59,11 +79,20 @@ class S30Band(InstrumentBand):


class Instrument(Enum):
L30 = L30Band
S30 = S30Band

def __init__(self, band_type: Type[InstrumentBand]) -> None:
# Example LANDSAT_PRODUCT_ID: LC08_L1TP_069014_20240429_20240430_02_RT
# Pull out satellite number ("08"), convert to an integer, and prepend "L" to get
# the satellite name.
L30 = L30Band, lambda tags: f"L{int(tags.get('LANDSAT_PRODUCT_ID', '')[2:4])}"
# Example PRODUCT_URI: S2B_MSIL1C_...
# Simply pull off the first part of the string before the first underscore to get
# the satellite name.
S30 = S30Band, lambda tags: tags.get("PRODUCT_URI", "").split("_")[0]

def __init__(
self, band_type: Type[InstrumentBand], parse_satellite: Callable[[Tags], str]
) -> None:
self.bands = list(band_type)
self.parse_satellite = parse_satellite

@classmethod
def named(cls, name: str) -> "Instrument":
Expand All @@ -73,6 +102,9 @@ def named(cls, name: str) -> "Instrument":

raise ValueError(f"Invalid instrument name: {name}")

def satellite(self, tags: Tags) -> str:
return self.parse_satellite(tags)


@dataclass
class GranuleId:
Expand Down Expand Up @@ -137,7 +169,7 @@ def read_granule_bands(input_dir: Path) -> Granule:
with rasterio.open(input_dir / tifnames[0]) as tif:
crs = tif.crs
transform = tif.transform
tags = select_tags(tif.tags())
tags = select_tags(id_, tif.tags())

return Granule(id_, crs, transform, tags, dict(zip(harmonized_bands, data)))

Expand All @@ -158,7 +190,7 @@ def apply_fmask(data: np.ndarray, fmask: np.ndarray) -> np.ma.masked_array:
return np.ma.masked_array(data, fmask & cloud_like != 0)


def select_tags(tags: Tags) -> Tags:
def select_tags(granule_id: GranuleId, tags: Tags) -> Tags:
"""
Selects tags from the input tags that are relevant to the HLS VI product.
Expand All @@ -169,19 +201,9 @@ def select_tags(tags: Tags) -> Tags:
Mapping of relevant VI tags.
"""
return {
"ACCODE": tags["ACCODE"],
"cloud_coverage": tags.get("cloud_coverage"),
"HORIZONTAL_CS_NAME": tags.get("HORIZONTAL_CS_NAME"),
"MEAN_SUN_AZIMUTH_ANGLE": tags.get("MEAN_SUN_AZIMUTH_ANGLE"),
"MEAN_SUN_ZENITH_ANGLE": tags.get("MEAN_SUN_ZENITH_ANGLE"),
"MEAN_VIEW_AZIMUTH_ANGLE": tags.get("MEAN_VIEW_AZIMUTH_ANGLE"),
"MEAN_VIEW_ZENITH_ANGLE": tags.get("MEAN_VIEW_ZENITH_ANGLE"),
"NBAR_SOLAR_ZENITH": tags.get("NBAR_SOLAR_ZENITH"),
"SPACECRAFT_NAME": tags.get("SPACECRAFT_NAME"),
"TILE_ID": tags.get("SENTINEL2_TILEID"),
"SENSING_TIME": tags.get("SENSING_TIME"),
"SENSOR": tags.get("SENSOR"),
"spatial_coverage": tags.get("spatial_coverage"),
**{tag: tags.get(tag) for tag in fixed_tags},
"MGRS_TILE_ID": granule_id.tile_id,
"SATELLITE": granule_id.instrument.satellite(tags),
}


Expand Down Expand Up @@ -231,8 +253,10 @@ def write_granule_index(
dst.write(data.filled(), 1)
dst.update_tags(
**granule.tags,
longname=index.value,
long_name=index.long_name,
scale_factor=index.scale_factor,
HLS_VI_PROCESSING_TIME=processing_time,
_FillValue=data.fill_value,
)

# Create browse image using NDVI
Expand All @@ -242,14 +266,14 @@ def write_granule_index(

def evi(data: BandData) -> np.ma.masked_array:
b, r, nir = data[Band.B], data[Band.R], data[Band.NIR]
return 10_000 * 2.5 * (nir - r) / (nir + 6 * r - 7.5 * b + 1) # type: ignore
return 2.5 * (nir - r) / (nir + 6 * r - 7.5 * b + 1) # type: ignore


def msavi(data: BandData) -> np.ma.masked_array:
r, nir = data[Band.R], data[Band.NIR]
sqrt_term = (2 * nir + 1) ** 2 - 8 * (nir - r) # type: ignore

result: np.ma.masked_array = 10_000 * np.ma.where(
result: np.ma.masked_array = np.ma.where(
sqrt_term >= 0,
(2 * nir + 1 - np.sqrt(sqrt_term)) / 2, # type: ignore
np.nan,
Expand All @@ -261,32 +285,32 @@ def msavi(data: BandData) -> np.ma.masked_array:

def nbr(data: BandData) -> np.ma.masked_array:
nir, swir2 = data[Band.NIR], data[Band.SWIR2]
return 10_000 * (nir - swir2) / (nir + swir2) # type: ignore
return (nir - swir2) / (nir + swir2) # type: ignore


def nbr2(data: BandData) -> np.ma.masked_array:
swir1, swir2 = data[Band.SWIR1], data[Band.SWIR2]
return 10_000 * (swir1 - swir2) / (swir1 + swir2) # type: ignore
return (swir1 - swir2) / (swir1 + swir2) # type: ignore


def ndmi(data: BandData) -> np.ma.masked_array:
nir, swir1 = data[Band.NIR], data[Band.SWIR1]
return 10_000 * (nir - swir1) / (nir + swir1) # type: ignore
return (nir - swir1) / (nir + swir1) # type: ignore


def ndvi(data: BandData) -> np.ma.masked_array:
r, nir = data[Band.R], data[Band.NIR]
return 10_000 * (nir - r) / (nir + r) # type: ignore
return (nir - r) / (nir + r) # type: ignore


def ndwi(data: BandData) -> np.ma.masked_array:
g, nir = data[Band.G], data[Band.NIR]
return 10_000 * (g - nir) / (g + nir) # type: ignore
return (g - nir) / (g + nir) # type: ignore


def savi(data: BandData) -> np.ma.masked_array:
r, nir = data[Band.R], data[Band.NIR]
return 10_000 * 1.5 * (nir - r) / (nir + r + 0.5) # type: ignore
return 1.5 * (nir - r) / (nir + r + 0.5) # type: ignore


def tvi(data: BandData) -> np.ma.masked_array:
Expand All @@ -296,28 +320,30 @@ def tvi(data: BandData) -> np.ma.masked_array:


class Index(Enum):
EVI = "Enhanced Vegetation Index"
MSAVI = "Modified Soil-Adjusted Vegetation Index"
NBR = "Normalized Burn Ratio"
NBR2 = "Normalized Burn Ratio 2"
NDMI = "Normalized Difference Moisture Index"
NDVI = "Normalized Difference Vegetation Index"
NDWI = "Normalized Difference Water Index"
SAVI = "Soil-Adjusted Vegetation Index"
TVI = "Triangular Vegetation Index"

def __init__(self, longname: str) -> None:
EVI = ("Enhanced Vegetation Index",)
MSAVI = ("Modified Soil-Adjusted Vegetation Index",)
NBR = ("Normalized Burn Ratio",)
NBR2 = ("Normalized Burn Ratio 2",)
NDMI = ("Normalized Difference Moisture Index",)
NDVI = ("Normalized Difference Vegetation Index",)
NDWI = ("Normalized Difference Water Index",)
SAVI = ("Soil-Adjusted Vegetation Index",)
TVI = ("Triangular Vegetation Index", 1)

def __init__(self, long_name: str, scale_factor: float = 0.0001) -> None:
function_name = self.name.lower()
index_function: Optional[IndexFunction] = globals().get(function_name)

if not index_function or not callable(index_function):
raise ValueError(f"Index function not found: {function_name}")

self.longname = longname
self.long_name = long_name
self.compute_index = index_function
self.scale_factor = scale_factor

def __call__(self, data: BandData) -> np.ma.masked_array:
return np.ma.round(self.compute_index(data)).astype(np.int16)
scaled_index = self.compute_index(data) / self.scale_factor
return np.ma.round(scaled_index).astype(np.int16)


def parse_args() -> Tuple[Path, Path]:
Expand Down
Loading

0 comments on commit d86f978

Please sign in to comment.