Skip to content

Commit

Permalink
Merge pull request #361 from ungarj/gdal_cog
Browse files Browse the repository at this point in the history
use GDAL COG driver if cog is activated
  • Loading branch information
ungarj authored Sep 30, 2021
2 parents 787e000 + 903db1d commit f6992a7
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 120 deletions.
4 changes: 2 additions & 2 deletions mapchete/formats/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ def __init__(self, output_params, readonly=False, **kwargs):
)
self.crs = self.pyramid.crs
self._bucket = None
self._fs = output_params.get("fs") or None
self._fs_kwargs = output_params.get("fs_kwargs") or {}
self.fs = self._fs = output_params.get("fs") or None
self.fs_kwargs = self._fs_kwargs = output_params.get("fs_kwargs") or {}

def is_valid_with_config(self, config):
"""
Expand Down
109 changes: 12 additions & 97 deletions mapchete/formats/default/gtiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,42 +30,36 @@
CCITTFAX3, CCITTFAX4, lzma
"""

from affine import Affine
from contextlib import ExitStack
import logging
import math
import numpy as np
import numpy.ma as ma
import os
import rasterio
import warnings

from affine import Affine
import numpy as np
from numpy import ma
from rasterio.enums import Resampling
from rasterio.io import MemoryFile
from rasterio.profiles import Profile
from rasterio.rio.overview import get_maximum_overview_level
from rasterio.shutil import copy
from rasterio.windows import from_bounds
from shapely.geometry import box
from tempfile import NamedTemporaryFile
from tilematrix import Bounds
import warnings

from mapchete.config import validate_values, snap_bounds, _OUTPUT_PARAMETERS
from mapchete.errors import MapcheteConfigError
from mapchete.formats import base
from mapchete.io import (
fs_from_path,
get_boto3_bucket,
makedirs,
path_exists,
path_is_remote,
)
from mapchete.io.raster import (
write_raster_window,
prepare_array,
memory_file,
read_raster_no_crs,
extract_from_array,
read_raster_window,
rasterio_write,
)
from mapchete.tile import BufferedTile
from mapchete.validate import deprecated_kwargs
Expand Down Expand Up @@ -258,10 +252,6 @@ def _set_attributes(self, output_params):
output_params,
nodata=output_params.get("nodata", GTIFF_DEFAULT_PROFILE["nodata"]),
)
self._bucket = (
self.path.split("/")[2] if self.path.startswith("s3://") else None
)
self._fs = self.output_params.get("fs") or fs_from_path(self.path)


class GTiffTileDirectoryOutputReader(
Expand Down Expand Up @@ -388,9 +378,6 @@ def write(self, process_tile, data):
if data.mask.all():
logger.debug("data empty, nothing to write")
else:
# in case of S3 output, create an boto3 resource
bucket_resource = get_boto3_bucket(self._bucket) if self._bucket else None

# Convert from process_tile to output_tiles and write
for tile in self.pyramid.intersecting(process_tile):
out_path = self.get_path(tile)
Expand All @@ -403,7 +390,7 @@ def write(self, process_tile, data):
out_tile=out_tile,
out_path=out_path,
tags=tags,
bucket_resource=bucket_resource,
fs=self.fs,
)


Expand All @@ -424,8 +411,6 @@ def __init__(self, output_params, **kwargs):
self.zoom = output_params["delimiters"]["zoom"][0]
self.cog = output_params.get("cog", False)
self.in_memory = output_params.get("in_memory", True)
_bucket = self.path.split("/")[2] if self.path.startswith("s3://") else None
self._bucket_resource = get_boto3_bucket(_bucket) if _bucket else None

def prepare(self, process_area=None, **kwargs):
bounds = (
Expand Down Expand Up @@ -453,7 +438,7 @@ def prepare(self, process_area=None, **kwargs):
k: v for k, v in self.output_params.items() if k not in _OUTPUT_PARAMETERS
}
self._profile = dict(
GTIFF_DEFAULT_PROFILE,
DefaultGTiffProfile(driver="COG" if self.cog else "GTiff"),
transform=Affine(
self.pyramid.pixel_x_size(self.zoom),
0,
Expand All @@ -466,17 +451,9 @@ def prepare(self, process_area=None, **kwargs):
width=width,
count=self.output_params["bands"],
crs=self.pyramid.crs,
**dict(
{
k: self.output_params.get(k, GTIFF_DEFAULT_PROFILE[k])
for k in GTIFF_DEFAULT_PROFILE.keys()
},
**creation_options,
),
bigtiff=self.output_params.get("bigtiff", "NO"),
**creation_options,
)
logger.debug("single GTiff profile: %s", self._profile)

logger.debug(
get_maximum_overview_level(
width, height, minsize=self._profile["blockxsize"]
Expand Down Expand Up @@ -521,23 +498,9 @@ def prepare(self, process_area=None, **kwargs):
makedirs(os.path.dirname(self.path))
logger.debug("open output file: %s", self.path)
self._ctx = ExitStack()
# (1) use memfile if output is remote or COG
if self.cog or path_is_remote(self.path):
if self.in_memory:
logger.debug("create MemoryFile")
self._memfile = self._ctx.enter_context(MemoryFile())
self.dst = self._ctx.enter_context(self._memfile.open(**self._profile))
else:
# in case output raster is too big, use tempfile on disk
self._tempfile = self._ctx.enter_context(NamedTemporaryFile())
logger.debug(f"create tempfile in {self._tempfile.name}")
self.dst = self._ctx.enter_context(
rasterio.open(self._tempfile.name, "w+", **self._profile)
)
else:
self.dst = self._ctx.enter_context(
rasterio.open(self.path, "w+", **self._profile)
)
self.dst = self._ctx.enter_context(
rasterio_write(self.path, "w+", **self._profile)
)

def read(self, output_tile, **kwargs):
"""
Expand Down Expand Up @@ -670,54 +633,6 @@ def close(self, exc_type=None, exc_value=None, exc_traceback=None):
self.overviews_resampling
].name.upper()
)
# write
if self.cog:
if path_is_remote(self.path):
# remote COG: copy to tempfile and upload to destination
logger.debug("upload to %s", self.path)
# TODO this writes a memoryfile to disk and uploads the file,
# this is inefficient but until we find a solution to copy
# from one memoryfile to another the rasterio way (rasterio needs
# to rearrange the data so the overviews are at the beginning of
# the GTiff in order to be a valid COG).
with NamedTemporaryFile() as tmp_dst:
copy(
self.dst,
tmp_dst.name,
copy_src_overviews=True,
**self._profile,
)
self._bucket_resource.upload_file(
Filename=tmp_dst.name,
Key="/".join(self.path.split("/")[3:]),
)
else:
# local COG: copy to destination
logger.debug("write to %s", self.path)
copy(
self.dst,
self.path,
copy_src_overviews=True,
**self._profile,
)
else:
if path_is_remote(self.path):
# remote GTiff: upload memfile or tempfile to destination
logger.debug("upload to %s", self.path)
if self.in_memory:
self._bucket_resource.put_object(
Body=self._memfile,
Key="/".join(self.path.split("/")[3:]),
)
else:
self._bucket_resource.upload_file(
Filename=self._tempfile.name,
Key="/".join(self.path.split("/")[3:]),
)
else:
# local GTiff: already written, do nothing
pass

finally:
self._ctx.close()

Expand Down
97 changes: 77 additions & 20 deletions mapchete/io/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
from rasterio.vrt import WarpedVRT
from rasterio.warp import reproject
from rasterio.windows import from_bounds
from tempfile import NamedTemporaryFile
from tilematrix import clip_geometry_to_srs_bounds, Shape, Bounds
from types import GeneratorType
import warnings

from mapchete.errors import MapcheteIOError
from mapchete.io import path_is_remote, get_gdal_options, path_exists
from mapchete.io import path_is_remote, get_gdal_options, path_exists, fs_from_path
from mapchete.io._misc import MAPCHETE_IO_RETRY_SETTINGS
from mapchete.tile import BufferedTile
from mapchete.validate import validate_write_window_params
Expand Down Expand Up @@ -428,7 +429,8 @@ def write_raster_window(
out_tile=None,
out_path=None,
tags=None,
bucket_resource=None,
fs=None,
**kwargs,
):
"""
Write a window from a numpy array to an output file.
Expand All @@ -445,7 +447,6 @@ def write_raster_window(
out_path : string
output path to write to
tags : optional tags to be added to GeoTIFF file
bucket_resource : boto3 bucket resource to write to in case of S3 output
"""
if not isinstance(out_path, str):
raise TypeError("out_path must be a string")
Expand Down Expand Up @@ -475,23 +476,10 @@ def write_raster_window(
if window_data.all() is not ma.masked:

try:
if out_path.startswith("s3://"):
with RasterWindowMemoryFile(
in_tile=out_tile,
in_data=window_data,
out_profile=out_profile,
out_tile=out_tile,
tags=tags,
) as memfile:
logger.debug((out_tile.id, "upload tile", out_path))
bucket_resource.put_object(
Key="/".join(out_path.split("/")[3:]), Body=memfile
)
else:
with rasterio.open(out_path, "w", **out_profile) as dst:
logger.debug((out_tile.id, "write tile", out_path))
dst.write(window_data.astype(out_profile["dtype"], copy=False))
_write_tags(dst, tags)
with rasterio_write(out_path, "w", fs=fs, **out_profile) as dst:
logger.debug((out_tile.id, "write tile", out_path))
dst.write(window_data.astype(out_profile["dtype"], copy=False))
_write_tags(dst, tags)
except Exception as e:
logger.exception("error while writing file %s: %s", out_path, e)
raise
Expand All @@ -510,6 +498,75 @@ def _write_tags(dst, tags):
dst.update_tags(**{k: v})


def rasterio_write(path, mode=None, fs=None, in_memory=True, *args, **kwargs):
"""
Wrap rasterio.open() but handle bucket upload if path is remote.
Parameters
----------
path : str
Path to write to.
mode : str
One of the rasterio.open() modes.
fs : fsspec.FileSystem
Target filesystem.
in_memory : bool
On remote output store an in-memory file instead of writing to a tempfile.
args : list
Arguments to be passed on to rasterio.open()
kwargs : dict
Keyword arguments to be passed on to rasterio.open()
Returns
-------
RasterioRemoteWriter if target is remote, otherwise return rasterio.open().
"""
if path.startswith("s3://"):
return RasterioRemoteWriter(path, fs=fs, in_memory=in_memory, *args, **kwargs)
else:
return rasterio.open(path, mode=mode, *args, **kwargs)


class RasterioRemoteWriter:
def __init__(self, path, *args, fs=None, in_memory=True, **kwargs):
logger.debug("open RasterioRemoteWriter for path %s", path)
self.path = path
self.fs = fs or fs_from_path(path)
self.in_memory = in_memory
if self.in_memory:
self._dst = MemoryFile()
else:
self._dst = NamedTemporaryFile(suffix=".tif")
self._open_args = args
self._open_kwargs = kwargs
self._sink = None

def __enter__(self):
if self.in_memory:
self._sink = self._dst.open(*self._open_args, **self._open_kwargs)
else:
self._sink = rasterio.open(
self._dst.name, "w+", *self._open_args, **self._open_kwargs
)
return self._sink

def __exit__(self, *args):
try:
self._sink.close()
if self.in_memory:
logger.debug("write rasterio MemoryFile to %s", self.path)
with self.fs.open(self.path, "wb") as dst:
dst.write(self._dst.getbuffer())
else:
self.fs.put_file(self._dst.name, self.path)
finally:
if self.in_memory:
logger.debug("close rasterio MemoryFile")
else:
logger.debug("close and remove tempfile")
self._dst.close()


def extract_from_array(in_raster=None, in_affine=None, out_tile=None):
"""
Extract raster data window array.
Expand Down
6 changes: 5 additions & 1 deletion test/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,11 @@ def test_convert_mapchete(cleantopo_br, mp_tmpdir):
assert len(job)

job = convert(
cleantopo_br.path, mp_tmpdir, output_pyramid="geodetic", output_metatiling=1
cleantopo_br.path,
mp_tmpdir,
output_pyramid="geodetic",
output_metatiling=1,
zoom=[1, 4],
)
assert len(job)
for zoom, row, col in [(4, 15, 31), (3, 7, 15), (2, 3, 7), (1, 1, 3)]:
Expand Down

0 comments on commit f6992a7

Please sign in to comment.