diff --git a/.isort.cfg b/.isort.cfg index bb0418102..9aaf0e2d9 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -2,4 +2,4 @@ line_length = 88 multi_line_output = 3 include_trailing_comma = True -known_third_party = _pytest,aenum,affine,alembic,asgi_lifespan,async_lru,asyncpg,aws_utils,boto3,botocore,click,docker,ee,errors,fastapi,fiona,gdal_utils,geoalchemy2,geojson,gfw_pixetl,gino,gino_starlette,google,httpx,httpx_auth,logger,logging_utils,moto,numpy,orjson,osgeo,pandas,pendulum,pglast,psutil,psycopg2,pydantic,pyproj,pytest,pytest_asyncio,rasterio,shapely,sqlalchemy,sqlalchemy_utils,starlette,tileputty,typer +known_third_party = _pytest,aenum,affine,alembic,asgi_lifespan,async_lru,asyncpg,aws_utils,boto3,botocore,click,docker,ee,errors,fastapi,fiona,gdal_utils,geoalchemy2,geojson,gfw_pixetl,gino,gino_starlette,google,httpx,httpx_auth,logger,logging_utils,moto,numpy,orjson,osgeo,pandas,pendulum,pglast,psutil,psycopg2,pydantic,pyproj,pytest,pytest_asyncio,rasterio,shapely,sqlalchemy,sqlalchemy_utils,starlette,tileputty,tiles_geojson,typer diff --git a/batch/python/apply_colormap.py b/batch/python/apply_colormap.py index a21f886a1..00a6bf877 100644 --- a/batch/python/apply_colormap.py +++ b/batch/python/apply_colormap.py @@ -15,11 +15,12 @@ import rasterio # Use relative imports because these modules get copied into container -from aws_utils import get_s3_client, get_s3_path_parts +from aws_utils import get_aws_files, get_s3_client, get_s3_path_parts, upload_s3 from errors import GDALError, SubprocessKilledError -from gdal_utils import from_vsi_path, run_gdal_subcommand +from gdal_utils import from_vsi_path, run_gdal_subcommand, to_vsi_path from logging_utils import listener_configurer, log_client_configurer, log_listener from pydantic import BaseModel, Extra, Field, StrictInt +from tiles_geojson import generate_geojsons from typer import Option, run NUM_PROCESSES = int( @@ -267,16 +268,37 @@ def apply_symbology( for tile_id in executor.map(create_rgb_tile, process_args): logger.log(logging.INFO, f"Finished processing tile {tile_id}") - # Now run pixetl_prep.create_geojsons to generate a tiles.geojson and - # extent.geojson in the target prefix. But that code appends /geotiff - # to the prefix so remove it first - create_geojsons_prefix = target_prefix.split(f"{dataset}/{version}/")[1].replace( - "/geotiff", "" - ) - logger.log(logging.INFO, "Uploading tiles.geojson to {create_geojsons_prefix}") - from gfw_pixetl.pixetl_prep import create_geojsons + # Now generate a tiles.geojson and extent.geojson and upload to the target prefix. + bucket, _ = get_s3_path_parts(source_uri) + tile_paths = [to_vsi_path(uri) for uri in get_aws_files(bucket, target_prefix)] + + tiles_output_file = "tiles.geojson" + extent_output_file = "extent.geojson" + + logger.log(logging.INFO, "Generating geojsons") + tiles_fc, extent_fc = generate_geojsons(tile_paths, min(16, NUM_PROCESSES)) + logger.log(logging.INFO, "Finished generating geojsons") + + tiles_txt = json.dumps(tiles_fc) + with open(tiles_output_file, "w") as f: + print(tiles_txt, file=f) - create_geojsons(list(), dataset, version, create_geojsons_prefix, True) + extent_txt = json.dumps(extent_fc) + with open(extent_output_file, "w") as f: + print(extent_txt, file=f) + + logger.log(logging.INFO, f"Uploading geojsons to {target_prefix}") + upload_s3( + tiles_output_file, + bucket, + os.path.join(target_prefix, tiles_output_file), + ) + upload_s3( + extent_output_file, + bucket, + os.path.join(target_prefix, extent_output_file), + ) + logger.log(logging.INFO, f"Finished uploading geojsons to {target_prefix}") log_queue.put_nowait(None) listener.join() diff --git a/batch/python/aws_utils.py b/batch/python/aws_utils.py index 32b236fcf..b788bbd4a 100644 --- a/batch/python/aws_utils.py +++ b/batch/python/aws_utils.py @@ -1,5 +1,5 @@ import os -from typing import Tuple +from typing import List, Sequence, Tuple, Dict, Any import boto3 @@ -29,3 +29,33 @@ def exists_in_s3(target_bucket, target_key): for obj in response.get("Contents", []): if obj["Key"] == target_key: return obj["Size"] > 0 + + +def get_aws_files( + bucket: str, prefix: str, extensions: Sequence[str] = (".tif",) +) -> List[str]: + """Get all matching files in S3.""" + files: List[str] = list() + + s3_client = get_s3_client() + paginator = s3_client.get_paginator("list_objects_v2") + + print("get_aws_files") + for page in paginator.paginate(Bucket=bucket, Prefix=prefix): + try: + contents = page["Contents"] + except KeyError: + break + + for obj in contents: + key = str(obj["Key"]) + if any(key.endswith(ext) for ext in extensions): + files.append(f"s3://{bucket}/{key}") + + print("done get_aws_files") + return files + + +def upload_s3(path: str, bucket: str, dst: str) -> Dict[str, Any]: + s3_client = get_s3_client() + return s3_client.upload_file(path, bucket, dst) diff --git a/batch/python/gdal_utils.py b/batch/python/gdal_utils.py index c4fc68c8b..c3b0b8a1b 100644 --- a/batch/python/gdal_utils.py +++ b/batch/python/gdal_utils.py @@ -1,6 +1,7 @@ import os import subprocess from typing import Dict, List, Optional, Tuple +from urllib.parse import urlparse from errors import GDALError, SubprocessKilledError @@ -21,6 +22,18 @@ def from_vsi_path(file_name: str) -> str: return vsi +def to_vsi_path(file_name: str) -> str: + prefix = {"s3": "vsis3", "gs": "vsigs"} + + parts = urlparse(file_name) + try: + path = f"/{prefix[parts.scheme]}/{parts.netloc}{parts.path}" + except KeyError: + raise ValueError(f"Unknown protocol: {parts.scheme}") + + return path + + def run_gdal_subcommand(cmd: List[str], env: Optional[Dict] = None) -> Tuple[str, str]: """Run GDAL as sub command and catch common errors.""" @@ -53,3 +66,10 @@ def run_gdal_subcommand(cmd: List[str], env: Optional[Dict] = None) -> Tuple[str raise GDALError(e) return o, e + + +def from_gdal_data_type(data_type: str) -> str: + if data_type == "Byte": + return "uint8" + else: + return data_type.lower() diff --git a/batch/python/resample.py b/batch/python/resample.py index 732522247..f065688b6 100644 --- a/batch/python/resample.py +++ b/batch/python/resample.py @@ -18,15 +18,21 @@ import rasterio # Use relative imports because these modules get copied into container -from aws_utils import exists_in_s3, get_s3_client, get_s3_path_parts +from aws_utils import ( + exists_in_s3, + get_aws_files, + get_s3_client, + get_s3_path_parts, + upload_s3, +) from errors import SubprocessKilledError -from gdal_utils import from_vsi_path +from gdal_utils import from_vsi_path, to_vsi_path from gfw_pixetl.grids import grid_factory -from gfw_pixetl.pixetl_prep import create_geojsons from logging_utils import listener_configurer, log_client_configurer, log_listener from pyproj import CRS, Transformer from shapely.geometry import MultiPolygon, Polygon, shape from shapely.ops import unary_union +from tiles_geojson import generate_geojsons from typer import Option, run # Use at least 1 process @@ -656,12 +662,38 @@ def resample( for tile_id in executor.map(process_tile, process_tile_args): logger.log(logging.INFO, f"Finished processing tile {tile_id}") - # Now run pixetl_prep.create_geojsons to generate a tiles.geojson and - # extent.geojson in the target prefix. - create_geojsons_prefix = target_prefix.split(f"{dataset}/{version}/")[1] - logger.log(logging.INFO, f"Uploading tiles.geojson to {create_geojsons_prefix}") + # Now generate a tiles.geojson and extent.geojson and upload to the target prefix. + tile_paths = [to_vsi_path(uri) for uri in get_aws_files(bucket, target_prefix)] + + tiles_output_file = "tiles.geojson" + extent_output_file = "extent.geojson" + + logger.log(logging.INFO, "Generating geojsons") + tiles_fc, extent_fc = generate_geojsons(tile_paths, min(16, NUM_PROCESSES)) + logger.log(logging.INFO, "Finished generating geojsons") + + tiles_txt = json.dumps(tiles_fc) + with open(tiles_output_file, "w") as f: + print(tiles_txt, file=f) - create_geojsons(list(), dataset, version, create_geojsons_prefix, True) + extent_txt = json.dumps(extent_fc) + with open(extent_output_file, "w") as f: + print(extent_txt, file=f) + + geojsons_prefix = os.path.join(target_prefix, "geotiff") + + logger.log(logging.INFO, f"Uploading geojsons to {geojsons_prefix}") + upload_s3( + tiles_output_file, + bucket, + os.path.join(geojsons_prefix, tiles_output_file), + ) + upload_s3( + extent_output_file, + bucket, + os.path.join(geojsons_prefix, extent_output_file), + ) + logger.log(logging.INFO, f"Finished uploading geojsons to {geojsons_prefix}") log_queue.put_nowait(None) listener.join() diff --git a/batch/python/tiles_geojson.py b/batch/python/tiles_geojson.py new file mode 100644 index 000000000..ed35360a1 --- /dev/null +++ b/batch/python/tiles_geojson.py @@ -0,0 +1,135 @@ +import json +import math +from concurrent.futures import ProcessPoolExecutor, as_completed +from typing import Any, Dict, List, Tuple + +from geojson import Feature, FeatureCollection +from pyproj import CRS, Transformer +from shapely.geometry import Polygon +from shapely.ops import unary_union + +from errors import GDALError +from gdal_utils import from_gdal_data_type, run_gdal_subcommand + + +def to_4326(crs: CRS, x: float, y: float) -> Tuple[float, float]: + transformer = Transformer.from_crs( + crs, CRS.from_epsg(4326), always_xy=True + ) + return transformer.transform(x, y) + + +def extract_metadata_from_gdalinfo(gdalinfo_json: Dict[str, Any]) -> Dict[str, Any]: + """Extract necessary metadata from the gdalinfo JSON output.""" + corner_coordinates = gdalinfo_json["cornerCoordinates"] + geo_transform = gdalinfo_json["geoTransform"] + + bands = [ + { + "data_type": ( + from_gdal_data_type(band.get("type")) + if band.get("type") is not None + else None + ), + "no_data": ( + "nan" if ( + band.get("noDataValue", None) is not None + and math.isnan(band.get("noDataValue")) + ) + else band.get("noDataValue", None) + ), + "nbits": band.get("metadata", {}).get("IMAGE_STRUCTURE", {}).get("NBITS", None), + "blockxsize": band.get("block", [None])[0], + "blockysize": band.get("block", [None])[1], + "stats": { + "min": band.get("minimum"), + "max": band.get("maximum"), + "mean": band.get("mean"), + "std_dev": band.get("stdDev"), + } if "minimum" in band and "maximum" in band else None, + "histogram": band.get("histogram", None), + } + for band in gdalinfo_json.get("bands", []) + ] + + crs: CRS = CRS.from_string(gdalinfo_json["coordinateSystem"]["wkt"]) + metadata = { + # NOTE: pixetl seems to always write features in tiles.geojson in + # degrees (when the tiles themselves are epsg:3857 I think + # the units should be meters). Reproduce that behavior for + # backwards compatibility. If it ever changes, remove the call to + # to_4326 here. + "extent": [ + *to_4326(crs, *corner_coordinates["lowerLeft"]), + *to_4326(crs, *corner_coordinates["upperRight"]), + ], + "width": gdalinfo_json["size"][0], + "height": gdalinfo_json["size"][1], + "pixelxsize": geo_transform[1], + "pixelysize": abs(geo_transform[5]), + "crs": gdalinfo_json["coordinateSystem"]["wkt"], + "driver": gdalinfo_json.get("driverShortName", None), + "compression": gdalinfo_json.get("metadata", {}).get("IMAGE_STRUCTURE", {}).get("COMPRESSION", None), + "bands": bands, + "name": gdalinfo_json["description"], + } + + return metadata + + +def process_file(file_path: str) -> Dict[str, Any]: + """Run gdalinfo and extract metadata for a single file.""" + print(f"Running gdalinfo on {file_path}") + try: + stdout,stderr = run_gdal_subcommand( + ["gdalinfo", "-json", file_path], + ) + except GDALError as e: + raise RuntimeError(f"Failed to run gdalinfo on {file_path}: {e}") + + gdalinfo_json: Dict = json.loads(stdout) + return extract_metadata_from_gdalinfo(gdalinfo_json) + + +def generate_geojsons( + geotiffs: List[str], + max_workers: int = None +) -> Tuple[FeatureCollection, FeatureCollection]: + """Generate tiles.geojson and extent.geojson files.""" + features = [] + polygons = [] + + with ProcessPoolExecutor(max_workers=max_workers) as executor: + future_to_file = {executor.submit(process_file, file): file for file in geotiffs} + for future in as_completed(future_to_file): + file = future_to_file[future] + try: + metadata = future.result() + extent = metadata["extent"] + # Create a Polygon from the extent + polygon_coords = [ + [extent[0], extent[1]], + [extent[0], extent[3]], + [extent[2], extent[3]], + [extent[2], extent[1]], + [extent[0], extent[1]], + ] + polygon = Polygon(polygon_coords) + + # Add to GeoJSON features + feature = Feature(geometry=polygon.__geo_interface__, properties=metadata) + features.append(feature) + + # Collect for union + polygons.append(polygon) + except Exception as e: + raise RuntimeError(f"Error processing file {file}: {e}") + + tiles_fc = FeatureCollection(features) + + union_geometry = unary_union(polygons) + extent_fc = FeatureCollection([ + Feature(geometry=union_geometry.__geo_interface__, properties={}) + ]) + + return tiles_fc, extent_fc diff --git a/batch/scripts/run_pixetl_prep.sh b/batch/scripts/run_pixetl_prep.sh deleted file mode 100644 index 7dbd5167c..000000000 --- a/batch/scripts/run_pixetl_prep.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash - -set -e - -# requires arguments -# -s | --source -# -d | --dataset -# -v | --version - -# optional arguments -# --overwrite -# --prefix - -ME=$(basename "$0") -. get_arguments.sh "$@" - -echo "Fetch remote GeoTIFFs headers to generate tiles.geojson" - -# Build an array of arguments to pass to pixetl_prep -ARG_ARRAY=$SRC -ARG_ARRAY+=("--dataset" "${DATASET}" "--version" "${VERSION}") - - -if [ -n "${PREFIX}" ]; then - ARG_ARRAY+=("--prefix" "${PREFIX}") -fi - -if [ -z "${OVERWRITE}" ]; then - ARG_ARRAY+=("--merge_existing") -fi - -# Run pixetl_prep with the array of arguments -pixetl_prep "${ARG_ARRAY[@]}" \ No newline at end of file