forked from NRCan/geo-deep-learning
-
Notifications
You must be signed in to change notification settings - Fork 1
/
inference_segmentation.py
497 lines (437 loc) · 23.5 KB
/
inference_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
import csv
import itertools
from math import sqrt
from numbers import Number
from tempfile import mkstemp
from typing import List, Union, Sequence
import torch
import torch.nn.functional as F
# import torch should be first. Unclear issue, mentionned here: https://github.com/pytorch/pytorch/issues/2083
import numpy as np
import time
import fiona # keep this import. it sets GDAL_DATA to right value
import rasterio
import ttach as tta
from scipy.special import softmax
from collections import OrderedDict
from fiona.crs import to_string
from ruamel_yaml.comments import CommentedSeq
from tqdm import tqdm
from rasterio import features
from rasterio.windows import Window
from rasterio.plot import reshape_as_image, reshape_as_raster
from pathlib import Path
from omegaconf import OmegaConf, DictConfig, open_dict
from omegaconf.listconfig import ListConfig
from dataset.stacitem import SingleBandItemEO
from utils.aoiutils import aois_from_csv
from utils.geoutils import create_new_raster_from_base
from utils.inference import stretch_heatmap, class_from_heatmap
from utils.logger import get_logger, set_tracker
from models.model_choice import define_model, read_checkpoint
from utils import augmentation
from utils.inference import generate_patch_list
from utils.utils import get_device_ids, get_key_def, \
add_metadata_from_raster_to_sample, set_device
# Set the logging file
logging = get_logger(__name__)
def _pad(arr, chunk_size):
""" Pads img_arr """
w_diff = chunk_size - arr.shape[0]
h_diff = chunk_size - arr.shape[1]
if len(arr.shape) > 2:
padded_arr = np.pad(arr, ((0, w_diff), (0, h_diff), (0, 0)), mode="reflect")
else:
padded_arr = np.pad(arr, ((0, w_diff), (0, h_diff)), mode="reflect")
return padded_arr
def ras2vec(raster_file, output_path):
# Create a generic polygon schema for the output vector file
i = 0
feat_schema = {'geometry': 'Polygon',
'properties': OrderedDict([('value', 'int')])
}
class_value_domain = set()
out_features = []
print(" - Processing raster file: {}".format(raster_file))
with rasterio.open(raster_file, 'r') as src:
raster = src.read(1)
mask = raster != 0
# Vectorize the polygons
polygons = features.shapes(raster, mask, transform=src.transform)
# Create shapely polygon features
for polygon in polygons:
feature = {'geometry': {
'type': 'Polygon',
'coordinates': None},
'properties': OrderedDict([('value', 0)])}
feature['geometry']['coordinates'] = polygon[0]['coordinates']
value = int(polygon[1]) # Pixel value of the class (layer)
class_value_domain.add(value)
feature['properties']['value'] = value
i += 1
out_features.append(feature)
print(" - Writing output vector file: {}".format(output_path))
num_layers = list(class_value_domain) # Number of unique pixel value
for num_layer in num_layers:
polygons = [feature for feature in out_features if feature['properties']['value'] == num_layer]
layer_name = 'vector_' + str(num_layer).rjust(3, '0')
print(" - Writing layer: {}".format(layer_name))
with fiona.open(output_path, 'w',
crs=to_string(src.crs),
layer=layer_name,
schema=feat_schema,
driver='GPKG') as dest:
for polygon in polygons:
dest.write(polygon)
print("")
print("Number of features written: {}".format(i))
def gen_img_samples(src, patch_list, chunk_size, *band_order):
"""
TODO
Args:
src: input image (rasterio object)
patch_list: list of patches index
chunk_size: image tile size
*band_order: ignore
Returns: generator object
"""
for patch in patch_list:
patch_x, patch_y, patch_width, patch_height, hann_window = patch
window = Window.from_slices(slice(patch_y, patch_y + patch_height),
slice(patch_x, patch_x + patch_width))
if band_order:
patch_array = reshape_as_image(src.read(band_order[0], window=window))
else:
patch_array = reshape_as_image(src.read(window=window))
patch_array = _pad(patch_array, chunk_size)
yield patch_array, (patch_y, patch_height), (patch_x, patch_width), hann_window
def sigmoid(x):
return 1/(1+np.exp(-x))
@torch.no_grad()
def segmentation(param,
input_image,
num_classes: int,
model,
chunk_size: int,
use_hanning: bool,
device,
scale: List,
tp_mem,
heatmap_dtype=np.uint16,
debug=False,
):
"""
Args:
param: parameter dict
input_image: opened image (rasterio object)
num_classes: number of classes
model: model weights
chunk_size: image tile size
device: cuda/cpu device
scale: scale range
tp_mem: memory temp file for saving numpy array to disk
debug: True/False
heatmap_dtype:
Output data type for heatmap. Ex.: Uint16 captures more information, but takes more space in memory or disk.
Returns:
"""
sample = {"image": None, "mask": None, 'metadata': None}
start_seg = time.time()
print_log = True if logging.level == 20 else False # 20 is INFO
model.eval() # switch to evaluate mode
# initialize test time augmentation
transforms = tta.aliases.d4_transform()
tf_len = len(transforms)
h_padded, w_padded = input_image.height + chunk_size, input_image.width + chunk_size
patch_list = generate_patch_list(w_padded, h_padded, chunk_size, use_hanning)
fp = np.memmap(tp_mem, dtype='float16', mode='w+', shape=(tf_len, h_padded, w_padded, num_classes))
img_gen = gen_img_samples(src=input_image, patch_list=patch_list, chunk_size=chunk_size)
single_class_mode = False if num_classes > 1 else True
for sub_image, h_idxs, w_idxs, hann_win in tqdm(
img_gen, position=0, leave=True, desc='Inferring on patches',
total=len(patch_list)
):
hann_win = np.expand_dims(hann_win, -1)
image_metadata = add_metadata_from_raster_to_sample(sat_img_arr=sub_image,
raster_handle=input_image,
raster_info={})
sample['metadata'] = image_metadata
totensor_transform = augmentation.compose_transforms(param,
dataset="tst",
scale=scale,
aug_type='totensor',
print_log=print_log)
sample["image"] = sub_image
sample = totensor_transform(sample)
inputs = sample["image"].unsqueeze_(0)
inputs = inputs.to(device)
if inputs.shape[1] == 4 and any("module.modelNIR" in s for s in model.state_dict().keys()):
# Init NIR TODO: make a proper way to read the NIR channel
# and put an option to be able to give the index of the NIR channel
inputs_NIR = inputs[:, -1, ...] # Extract the NIR channel -> [batch size, H, W] since it's only one channel
inputs_NIR.unsqueeze_(1) # add a channel to get the good size -> [:, 1, :, :]
inputs = inputs[:, :-1, ...] # take out the NIR channel and take only the RGB for the inputs
inputs = [inputs, inputs_NIR]
output_lst = []
for transformer in transforms:
# augment inputs
augmented_input = transformer.augment_image(inputs)
with torch.cuda.amp.autocast():
augmented_output = model(augmented_input)
if isinstance(augmented_output, OrderedDict) and 'out' in augmented_output.keys():
augmented_output = augmented_output['out']
logging.debug(f'Shape of augmented output: {augmented_output.shape}')
# reverse augmentation for outputs
deaugmented_output = transformer.deaugment_mask(augmented_output).squeeze(dim=0)
output_lst.append(deaugmented_output)
outputs = torch.stack(output_lst)
outputs = outputs.permute(0, 2, 3, 1).squeeze(dim=0)
outputs = outputs.cpu().numpy() * hann_win
fp[:, h_idxs[0]:h_idxs[0] + h_idxs[1], w_idxs[0]:w_idxs[0] + w_idxs[1], :] += outputs
fp.flush()
del fp
fp = np.memmap(tp_mem, dtype='float16', mode='r', shape=(tf_len, h_padded, w_padded, num_classes))
pred_heatmap = np.zeros((h_padded, w_padded, num_classes), dtype=heatmap_dtype)
for row, col in tqdm(itertools.product(range(0, h_padded, chunk_size),
range(0, w_padded, chunk_size)),
leave=False, position=0, desc="Writing to array"):
arr1 = (fp[:, row:row + chunk_size, col:col + chunk_size, :]).mean(axis=0)
if single_class_mode:
arr1 = sigmoid(arr1)
else:
arr1 = softmax(arr1, axis=-1)
heatmap_max = np.iinfo(heatmap_dtype).max
arr1 = stretch_heatmap(heatmap_arr=arr1, out_max=heatmap_max)
pred_heatmap[row:row + chunk_size, col:col + chunk_size, :] = arr1.astype(heatmap_dtype)
end_seg = time.time() - start_seg
logging.info('Segmentation operation completed in {:.0f}m {:.0f}s'.format(end_seg // 60, end_seg % 60))
if debug:
logging.debug(f'Bin count of final output: {np.unique(pred_heatmap, return_counts=True)}')
input_image.close()
return pred_heatmap[:input_image.height, :input_image.width]
def calc_inference_chunk_size(gpu_devices_dict: dict, max_pix_per_mb_gpu: int = 200, default: int = 512) -> int:
"""
Calculate maximum chunk_size that could fit on GPU during inference based on thumb rule with hardcoded
"pixels per MB of GPU RAM" as threshold. Threshold based on inference with a large model (Deeplabv3_resnet101)
:param gpu_devices_dict: dictionary containing info on GPU devices as returned by lst_device_ids (utils.py)
:param max_pix_per_mb_gpu: Maximum number of pixels that can fit on each MB of GPU (better to underestimate)
:return: returns a downgraded evaluation batch size if the original batch size is considered too high
"""
if not gpu_devices_dict:
return default
# get max ram for smallest gpu
smallest_gpu_ram = min(gpu_info['max_ram'] for _, gpu_info in gpu_devices_dict.items())
# rule of thumb to determine max chunk size based on approximate max pixels a gpu can handle during inference
max_chunk_size = sqrt(max_pix_per_mb_gpu * smallest_gpu_ram)
max_chunk_size_rd = int(max_chunk_size - (max_chunk_size % 256)) # round to the closest multiple of 256
logging.info(f'Data will be split into chunks of {max_chunk_size_rd}')
return max_chunk_size_rd
def override_model_params_from_checkpoint(
params: DictConfig,
checkpoint_params) -> DictConfig:
"""
Overrides model-architecture related parameters from provided checkpoint parameters
@param params: Original parameters as inputted through hydra
@param checkpoint_params: Checkpoint parameters as saved during checkpoint creation when training
@return:
"""
bands = get_key_def('bands', params['dataset'], expected_type=Sequence)
classes = get_key_def('classes_dict', params['dataset'], expected_type=(dict, DictConfig))
clip_limit = get_key_def('clahe_clip_limit', params['tiling'], expected_type=int)
normalization = get_key_def('normalization', params['augmentation'], expected_type=DictConfig)
scale_data = get_key_def('scale_data', params['augmentation'], expected_type=ListConfig)
bands_ckpt = get_key_def('bands', checkpoint_params['dataset'], expected_type=Sequence)
classes_ckpt = get_key_def('classes_dict', checkpoint_params['dataset'], expected_type=(dict, DictConfig))
model_ckpt = get_key_def('model', checkpoint_params, expected_type=(dict, DictConfig))
clip_limit_ckpt = get_key_def('clahe_clip_limit', checkpoint_params['tiling'], expected_type=int)
normalization_ckpt = get_key_def('normalization', checkpoint_params['augmentation'], expected_type=(dict, DictConfig))
# Workaround for "omegaconf.errors.UnsupportedValueType: Value 'CommentedSeq' is not a supported primitive type"
if normalization_ckpt is not None and isinstance(list(normalization_ckpt.values())[0], CommentedSeq):
normalization_ckpt = {k: [float(val) for val in v] for k, v in normalization_ckpt.items()}
scale_data_ckpt = get_key_def('scale_data', checkpoint_params['augmentation'], expected_type=(List, ListConfig))
scale_data_ckpt = list(scale_data_ckpt)
if model_ckpt != params.model or classes_ckpt != classes or bands_ckpt != bands \
or clip_limit != clip_limit_ckpt:
logging.info("\nParameters from checkpoint will override inputted parameters."
f"\n\t\t\t Inputted | Overriden"
f"\nModel:\t\t {params.model} | {model_ckpt}"
f"\nInput bands:\t\t{bands} | {bands_ckpt}"
f"\nOutput classes:\t\t{classes} | {classes_ckpt}"
f"\nNormalization means and stds:\t\t{normalization} | {normalization_ckpt}"
f"\nScale data range:\t\t{scale_data} | {scale_data_ckpt}"
f"\nRaster enhance clip limit:\t\t{clip_limit} | {clip_limit_ckpt}")
with open_dict(params):
params['model'] = model_ckpt
params['dataset']['bands'] = bands_ckpt
params['dataset']['classes_dict'] = classes_ckpt
params['augmentation']['normalization'] = normalization_ckpt
params['augmentation']['scale_data'] = scale_data_ckpt
params['tiling']['clahe_clip_limit'] = clip_limit_ckpt
return params
def stac_input_to_temp_csv(input_stac_item: Union[str, Path]) -> Path:
"""Saves a stac item path or url to a temporary csv"""
_, stac_temp_csv = mkstemp(suffix=".csv")
with open(stac_temp_csv, "w", newline="") as fh:
csv.writer(fh).writerow([str(input_stac_item), None, "inference", Path(input_stac_item).stem])
return Path(stac_temp_csv)
def main(params: Union[DictConfig, dict]) -> None:
"""
Function to manage details about the inference on segmentation task.
1. Read the parameters from the config given.
2. Read and load the state dict from the previous training or the given one.
3. Make the inference on the data specified in the config.
-------
:param params: (dict) Parameters inputted during execution.
"""
# Main params
working_folder = get_key_def('root_dir', params['inference'], default="inference", to_path=True)
working_folder.mkdir(exist_ok=True)
state_dict = get_key_def('state_dict_path', params['inference'], to_path=True,
validate_path_exists=True,
wildcard='*pth.tar')
inference_image = get_key_def(key='output_path', config=params['inference'], to_path=True, expected_type=str)
if inference_image:
inference_image.parent.mkdir(exist_ok=True)
models_dir = get_key_def('checkpoint_dir', params['inference'], default=working_folder / 'checkpoints', to_path=True)
models_dir.mkdir(exist_ok=True)
data_dir = get_key_def('raw_data_dir', params['dataset'], default="data", to_path=True, validate_path_exists=True)
download_data = get_key_def('download_data', params['inference'], default=False, expected_type=bool)
override = get_key_def('override_model_params', params['inference'], default=False, expected_type=bool)
# Override params from checkpoint
checkpoint = read_checkpoint(state_dict, out_dir=models_dir, update=True)
if override:
params = override_model_params_from_checkpoint(params=params,checkpoint_params=checkpoint['params'])
# Dataset params
bands_requested = get_key_def('bands', params['dataset'], default=[1, 2, 3], expected_type=Sequence)
classes_dict = get_key_def('classes_dict', params['dataset'], expected_type=DictConfig)
num_classes = len(classes_dict)
num_classes = num_classes + 1 if num_classes > 1 else num_classes # multiclass account for background
num_bands = len(bands_requested)
# Default input directory based on default output directory
raw_data_csv = get_key_def('raw_data_csv', params['inference'], expected_type=str, to_path=True,
validate_path_exists=True)
input_stac_item = get_key_def('input_stac_item', params['inference'], expected_type=str, to_path=True,
validate_path_exists=True)
prep_data_only = get_key_def('prep_data_only', params['inference'], default=False, expected_type=bool)
# LOGGING PARAMETERS
exper_name = get_key_def('project_name', params['general'], default='gdl-training')
run_name = get_key_def(['tracker', 'run_name'], params, default='gdl')
tracker_uri = get_key_def(['tracker', 'uri'], params, default=None, expected_type=str, to_path=False)
set_tracker(mode='inference', type='mlflow', task='segmentation', experiment_name=exper_name, run_name=run_name,
tracker_uri=tracker_uri, params=params, keys2log=['general', 'dataset', 'model', 'inference'])
# OPTIONAL PARAMETERS
num_devices = get_key_def('gpu', params['inference'], default=0, expected_type=(int, bool))
if num_devices > 1:
logging.warning(f"Inference is not yet implemented for multi-gpu use. Will request only 1 GPU.")
num_devices = 1
max_used_ram = get_key_def('max_used_ram', params['inference'], default=25, expected_type=int)
if not (0 <= max_used_ram <= 100):
raise ValueError(f'\nMax used ram parameter should be a percentage. Got {max_used_ram}.')
max_used_perc = get_key_def('max_used_perc', params['inference'], default=25, expected_type=int)
scale = get_key_def('scale_data', params['augmentation'], default=[0, 1], expected_type=ListConfig)
raster_to_vec = get_key_def('ras2vec', params['inference'], default=False)
debug = get_key_def('debug', params, default=False, expected_type=bool)
if debug:
logging.warning(f'\nDebug mode activated. Some debug features may mobilize extra disk space and '
f'cause delays in execution.')
# list of GPU devices that are available and unused. If no GPUs, returns empty dict
gpu_devices_dict = get_device_ids(num_devices, max_used_ram_perc=max_used_ram, max_used_perc=max_used_perc)
max_pix_per_mb_gpu = get_key_def('max_pix_per_mb_gpu', params['inference'], default=25, expected_type=int)
auto_chunk_size = calc_inference_chunk_size(gpu_devices_dict=gpu_devices_dict,
max_pix_per_mb_gpu=max_pix_per_mb_gpu, default=512)
chunk_size = get_key_def('chunk_size', params['inference'], default=auto_chunk_size, expected_type=int)
device = set_device(gpu_devices_dict=gpu_devices_dict)
clahe_clip_limit = get_key_def('clahe_clip_limit', params['tiling'], expected_type=Number, default=0)
heatmap_dtype = get_key_def('heatmap_dtype', params['inference'], default=np.uint16)
save_heatmap = get_key_def('save_heatmap', params['inference'], default=True, expected_type=bool)
use_hanning = get_key_def('use_hanning', params['inference'], default=True, expected_type=bool)
heatmap_threshold = get_key_def('heatmap_threshold', params['inference'], default=0.5, expected_type=float)
if raw_data_csv and input_stac_item:
raise ValueError(f"Input imagery should be either a csv of stac item. Got inputs from both \"raw_data_csv\" "
f"and \"input stac item\"")
if input_stac_item:
raw_data_csv = stac_input_to_temp_csv(input_stac_item)
if not all([SingleBandItemEO.is_valid_cname(band) for band in bands_requested]):
logging.warning(f"Requested bands are not valid stac item common names. Got: {bands_requested}")
bands_requested = [SingleBandItemEO.band_to_cname(band) for band in bands_requested]
logging.warning(f"Will request: {bands_requested}")
model = define_model(
net_params=params.model,
in_channels=num_bands,
out_classes=num_classes,
main_device=device,
devices=[list(gpu_devices_dict.keys())],
checkpoint_dict=checkpoint,
)
# GET LIST OF INPUT IMAGES FOR INFERENCE
list_aois = aois_from_csv(
csv_path=raw_data_csv,
bands_requested=bands_requested,
download_data=download_data,
data_dir=data_dir,
equalize_clahe_clip_limit=clahe_clip_limit,
)
if len(list_aois) > 1 and inference_image:
raise ValueError(f"\n\"inference.output_path\" should be set for a single inference only. \n"
f"Got {len(list_aois)} AOIs for inference.\n")
if prep_data_only:
logging.info(f"[prep_data_only mode] Data preparation for inference is complete. Exiting...")
exit()
# LOOP THROUGH LIST OF INPUT IMAGES
for aoi in tqdm(list_aois, desc='Inferring from images', position=0, leave=True):
output_path = working_folder / f"{aoi.aoi_id}_pred.tif" if not inference_image else inference_image
inference_heatmap = output_path.parent / f"{output_path.stem}_heatmap.tif"
temp_file = output_path.parent / f"{output_path.stem}_heatmap.dat"
logging.info(f'\nReading image: {aoi.aoi_id}')
inf_meta = aoi.raster.meta
pred_heatmap = segmentation(
param=params,
input_image=aoi.raster,
num_classes=num_classes,
model=model,
chunk_size=chunk_size,
use_hanning=use_hanning,
device=device,
scale=scale,
tp_mem=temp_file,
heatmap_dtype=heatmap_dtype,
debug=debug
)
inf_meta.update({"driver": "GTiff",
"height": pred_heatmap.shape[1],
"width": pred_heatmap.shape[2],
"count": pred_heatmap.shape[0],
"dtype": 'uint8',
"compress": 'lzw'})
logging.info(f'\nSuccessfully inferred on {aoi.aoi_id}\nWriting to file: {output_path}')
pred_img = class_from_heatmap(heatmap_arr=pred_heatmap, heatmap_threshold=heatmap_threshold)
if save_heatmap:
logging.info(f"\nSaving heatmap...")
pred_heatmap = reshape_as_raster(pred_heatmap)
create_new_raster_from_base(
input_raster=aoi.raster,
output_raster=inference_heatmap,
write_array=pred_heatmap,
dtype=heatmap_dtype,
checkpoint_path=state_dict,
classes_dict=classes_dict,
)
logging.info(f'\nSaved heatmap to {inference_heatmap}')
create_new_raster_from_base(
input_raster=aoi.raster,
output_raster=output_path,
write_array=pred_img,
checkpoint_path=state_dict,
classes_dict=classes_dict,
)
del pred_heatmap
try:
temp_file.unlink()
except OSError as e:
logging.warning(f'File Error: {temp_file, e.strerror}')
if raster_to_vec:
start_vec = time.time()
inference_vec = working_folder.joinpath(f"{aoi.aoi_id}_pred.gpkg")
ras2vec(output_path, inference_vec)
end_vec = time.time() - start_vec
logging.info('Vectorization completed in {:.0f}m {:.0f}s'.format(end_vec // 60, end_vec % 60))