Skip to content

Commit

Permalink
Merge branch 'develop' into uv_at_last
Browse files Browse the repository at this point in the history
  • Loading branch information
dmannarino committed Dec 19, 2024
2 parents ec899dd + dd0616e commit ed5569b
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 54 deletions.
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
line_length = 88
multi_line_output = 3
include_trailing_comma = True
known_third_party = _pytest,aenum,affine,alembic,asgi_lifespan,async_lru,asyncpg,aws_utils,boto3,botocore,click,docker,ee,errors,fastapi,fiona,gdal_utils,geoalchemy2,geojson,gfw_pixetl,gino,gino_starlette,google,httpx,httpx_auth,logger,logging_utils,moto,numpy,orjson,osgeo,pandas,pendulum,pglast,psutil,psycopg2,pydantic,pyproj,pytest,pytest_asyncio,rasterio,shapely,sqlalchemy,sqlalchemy_utils,starlette,tileputty,typer
known_third_party = _pytest,aenum,affine,alembic,asgi_lifespan,async_lru,asyncpg,aws_utils,boto3,botocore,click,docker,ee,errors,fastapi,fiona,gdal_utils,geoalchemy2,geojson,gfw_pixetl,gino,gino_starlette,google,httpx,httpx_auth,logger,logging_utils,moto,numpy,orjson,osgeo,pandas,pendulum,pglast,psutil,psycopg2,pydantic,pyproj,pytest,pytest_asyncio,rasterio,shapely,sqlalchemy,sqlalchemy_utils,starlette,tileputty,tiles_geojson,typer
44 changes: 33 additions & 11 deletions batch/python/apply_colormap.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
import rasterio

# Use relative imports because these modules get copied into container
from aws_utils import get_s3_client, get_s3_path_parts
from aws_utils import get_aws_files, get_s3_client, get_s3_path_parts, upload_s3
from errors import GDALError, SubprocessKilledError
from gdal_utils import from_vsi_path, run_gdal_subcommand
from gdal_utils import from_vsi_path, run_gdal_subcommand, to_vsi_path
from logging_utils import listener_configurer, log_client_configurer, log_listener
from pydantic import BaseModel, Extra, Field, StrictInt
from tiles_geojson import generate_geojsons
from typer import Option, run

NUM_PROCESSES = int(
Expand Down Expand Up @@ -267,16 +268,37 @@ def apply_symbology(
for tile_id in executor.map(create_rgb_tile, process_args):
logger.log(logging.INFO, f"Finished processing tile {tile_id}")

# Now run pixetl_prep.create_geojsons to generate a tiles.geojson and
# extent.geojson in the target prefix. But that code appends /geotiff
# to the prefix so remove it first
create_geojsons_prefix = target_prefix.split(f"{dataset}/{version}/")[1].replace(
"/geotiff", ""
)
logger.log(logging.INFO, "Uploading tiles.geojson to {create_geojsons_prefix}")
from gfw_pixetl.pixetl_prep import create_geojsons
# Now generate a tiles.geojson and extent.geojson and upload to the target prefix.
bucket, _ = get_s3_path_parts(source_uri)
tile_paths = [to_vsi_path(uri) for uri in get_aws_files(bucket, target_prefix)]

tiles_output_file = "tiles.geojson"
extent_output_file = "extent.geojson"

logger.log(logging.INFO, "Generating geojsons")
tiles_fc, extent_fc = generate_geojsons(tile_paths, min(16, NUM_PROCESSES))
logger.log(logging.INFO, "Finished generating geojsons")

tiles_txt = json.dumps(tiles_fc)
with open(tiles_output_file, "w") as f:
print(tiles_txt, file=f)

create_geojsons(list(), dataset, version, create_geojsons_prefix, True)
extent_txt = json.dumps(extent_fc)
with open(extent_output_file, "w") as f:
print(extent_txt, file=f)

logger.log(logging.INFO, f"Uploading geojsons to {target_prefix}")
upload_s3(
tiles_output_file,
bucket,
os.path.join(target_prefix, tiles_output_file),
)
upload_s3(
extent_output_file,
bucket,
os.path.join(target_prefix, extent_output_file),
)
logger.log(logging.INFO, f"Finished uploading geojsons to {target_prefix}")

log_queue.put_nowait(None)
listener.join()
Expand Down
32 changes: 31 additions & 1 deletion batch/python/aws_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Tuple
from typing import List, Sequence, Tuple, Dict, Any

import boto3

Expand Down Expand Up @@ -29,3 +29,33 @@ def exists_in_s3(target_bucket, target_key):
for obj in response.get("Contents", []):
if obj["Key"] == target_key:
return obj["Size"] > 0


def get_aws_files(
bucket: str, prefix: str, extensions: Sequence[str] = (".tif",)
) -> List[str]:
"""Get all matching files in S3."""
files: List[str] = list()

s3_client = get_s3_client()
paginator = s3_client.get_paginator("list_objects_v2")

print("get_aws_files")
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
try:
contents = page["Contents"]
except KeyError:
break

for obj in contents:
key = str(obj["Key"])
if any(key.endswith(ext) for ext in extensions):
files.append(f"s3://{bucket}/{key}")

print("done get_aws_files")
return files


def upload_s3(path: str, bucket: str, dst: str) -> Dict[str, Any]:
s3_client = get_s3_client()
return s3_client.upload_file(path, bucket, dst)
20 changes: 20 additions & 0 deletions batch/python/gdal_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import subprocess
from typing import Dict, List, Optional, Tuple
from urllib.parse import urlparse

from errors import GDALError, SubprocessKilledError

Expand All @@ -21,6 +22,18 @@ def from_vsi_path(file_name: str) -> str:
return vsi


def to_vsi_path(file_name: str) -> str:
prefix = {"s3": "vsis3", "gs": "vsigs"}

parts = urlparse(file_name)
try:
path = f"/{prefix[parts.scheme]}/{parts.netloc}{parts.path}"
except KeyError:
raise ValueError(f"Unknown protocol: {parts.scheme}")

return path


def run_gdal_subcommand(cmd: List[str], env: Optional[Dict] = None) -> Tuple[str, str]:
"""Run GDAL as sub command and catch common errors."""

Expand Down Expand Up @@ -53,3 +66,10 @@ def run_gdal_subcommand(cmd: List[str], env: Optional[Dict] = None) -> Tuple[str
raise GDALError(e)

return o, e


def from_gdal_data_type(data_type: str) -> str:
if data_type == "Byte":
return "uint8"
else:
return data_type.lower()
48 changes: 40 additions & 8 deletions batch/python/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,21 @@
import rasterio

# Use relative imports because these modules get copied into container
from aws_utils import exists_in_s3, get_s3_client, get_s3_path_parts
from aws_utils import (
exists_in_s3,
get_aws_files,
get_s3_client,
get_s3_path_parts,
upload_s3,
)
from errors import SubprocessKilledError
from gdal_utils import from_vsi_path
from gdal_utils import from_vsi_path, to_vsi_path
from gfw_pixetl.grids import grid_factory
from gfw_pixetl.pixetl_prep import create_geojsons
from logging_utils import listener_configurer, log_client_configurer, log_listener
from pyproj import CRS, Transformer
from shapely.geometry import MultiPolygon, Polygon, shape
from shapely.ops import unary_union
from tiles_geojson import generate_geojsons
from typer import Option, run

# Use at least 1 process
Expand Down Expand Up @@ -656,12 +662,38 @@ def resample(
for tile_id in executor.map(process_tile, process_tile_args):
logger.log(logging.INFO, f"Finished processing tile {tile_id}")

# Now run pixetl_prep.create_geojsons to generate a tiles.geojson and
# extent.geojson in the target prefix.
create_geojsons_prefix = target_prefix.split(f"{dataset}/{version}/")[1]
logger.log(logging.INFO, f"Uploading tiles.geojson to {create_geojsons_prefix}")
# Now generate a tiles.geojson and extent.geojson and upload to the target prefix.
tile_paths = [to_vsi_path(uri) for uri in get_aws_files(bucket, target_prefix)]

tiles_output_file = "tiles.geojson"
extent_output_file = "extent.geojson"

logger.log(logging.INFO, "Generating geojsons")
tiles_fc, extent_fc = generate_geojsons(tile_paths, min(16, NUM_PROCESSES))
logger.log(logging.INFO, "Finished generating geojsons")

tiles_txt = json.dumps(tiles_fc)
with open(tiles_output_file, "w") as f:
print(tiles_txt, file=f)

create_geojsons(list(), dataset, version, create_geojsons_prefix, True)
extent_txt = json.dumps(extent_fc)
with open(extent_output_file, "w") as f:
print(extent_txt, file=f)

geojsons_prefix = os.path.join(target_prefix, "geotiff")

logger.log(logging.INFO, f"Uploading geojsons to {geojsons_prefix}")
upload_s3(
tiles_output_file,
bucket,
os.path.join(geojsons_prefix, tiles_output_file),
)
upload_s3(
extent_output_file,
bucket,
os.path.join(geojsons_prefix, extent_output_file),
)
logger.log(logging.INFO, f"Finished uploading geojsons to {geojsons_prefix}")

log_queue.put_nowait(None)
listener.join()
Expand Down
135 changes: 135 additions & 0 deletions batch/python/tiles_geojson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import json
import math
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Any, Dict, List, Tuple

from geojson import Feature, FeatureCollection
from pyproj import CRS, Transformer
from shapely.geometry import Polygon
from shapely.ops import unary_union

from errors import GDALError
from gdal_utils import from_gdal_data_type, run_gdal_subcommand


def to_4326(crs: CRS, x: float, y: float) -> Tuple[float, float]:
transformer = Transformer.from_crs(
crs, CRS.from_epsg(4326), always_xy=True
)
return transformer.transform(x, y)


def extract_metadata_from_gdalinfo(gdalinfo_json: Dict[str, Any]) -> Dict[str, Any]:
"""Extract necessary metadata from the gdalinfo JSON output."""
corner_coordinates = gdalinfo_json["cornerCoordinates"]
geo_transform = gdalinfo_json["geoTransform"]

bands = [
{
"data_type": (
from_gdal_data_type(band.get("type"))
if band.get("type") is not None
else None
),
"no_data": (
"nan" if (
band.get("noDataValue", None) is not None
and math.isnan(band.get("noDataValue"))
)
else band.get("noDataValue", None)
),
"nbits": band.get("metadata", {}).get("IMAGE_STRUCTURE", {}).get("NBITS", None),
"blockxsize": band.get("block", [None])[0],
"blockysize": band.get("block", [None])[1],
"stats": {
"min": band.get("minimum"),
"max": band.get("maximum"),
"mean": band.get("mean"),
"std_dev": band.get("stdDev"),
} if "minimum" in band and "maximum" in band else None,
"histogram": band.get("histogram", None),
}
for band in gdalinfo_json.get("bands", [])
]

crs: CRS = CRS.from_string(gdalinfo_json["coordinateSystem"]["wkt"])
metadata = {
# NOTE: pixetl seems to always write features in tiles.geojson in
# degrees (when the tiles themselves are epsg:3857 I think
# the units should be meters). Reproduce that behavior for
# backwards compatibility. If it ever changes, remove the call to
# to_4326 here.
"extent": [
*to_4326(crs, *corner_coordinates["lowerLeft"]),
*to_4326(crs, *corner_coordinates["upperRight"]),
],
"width": gdalinfo_json["size"][0],
"height": gdalinfo_json["size"][1],
"pixelxsize": geo_transform[1],
"pixelysize": abs(geo_transform[5]),
"crs": gdalinfo_json["coordinateSystem"]["wkt"],
"driver": gdalinfo_json.get("driverShortName", None),
"compression": gdalinfo_json.get("metadata", {}).get("IMAGE_STRUCTURE", {}).get("COMPRESSION", None),
"bands": bands,
"name": gdalinfo_json["description"],
}

return metadata


def process_file(file_path: str) -> Dict[str, Any]:
"""Run gdalinfo and extract metadata for a single file."""
print(f"Running gdalinfo on {file_path}")
try:
stdout,stderr = run_gdal_subcommand(
["gdalinfo", "-json", file_path],
)
except GDALError as e:
raise RuntimeError(f"Failed to run gdalinfo on {file_path}: {e}")

gdalinfo_json: Dict = json.loads(stdout)
return extract_metadata_from_gdalinfo(gdalinfo_json)


def generate_geojsons(
geotiffs: List[str],
max_workers: int = None
) -> Tuple[FeatureCollection, FeatureCollection]:
"""Generate tiles.geojson and extent.geojson files."""
features = []
polygons = []

with ProcessPoolExecutor(max_workers=max_workers) as executor:
future_to_file = {executor.submit(process_file, file): file for file in geotiffs}
for future in as_completed(future_to_file):
file = future_to_file[future]
try:
metadata = future.result()
extent = metadata["extent"]
# Create a Polygon from the extent
polygon_coords = [
[extent[0], extent[1]],
[extent[0], extent[3]],
[extent[2], extent[3]],
[extent[2], extent[1]],
[extent[0], extent[1]],
]
polygon = Polygon(polygon_coords)

# Add to GeoJSON features
feature = Feature(geometry=polygon.__geo_interface__, properties=metadata)
features.append(feature)

# Collect for union
polygons.append(polygon)
except Exception as e:
raise RuntimeError(f"Error processing file {file}: {e}")

tiles_fc = FeatureCollection(features)

union_geometry = unary_union(polygons)
extent_fc = FeatureCollection([
Feature(geometry=union_geometry.__geo_interface__, properties={})
])

return tiles_fc, extent_fc
33 changes: 0 additions & 33 deletions batch/scripts/run_pixetl_prep.sh

This file was deleted.

0 comments on commit ed5569b

Please sign in to comment.