-
Notifications
You must be signed in to change notification settings - Fork 3
/
apply_glt.py
362 lines (279 loc) · 14.7 KB
/
apply_glt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
"""
Apply a (possibly multi-file) per-pixel spatial reference.
Author: Philip G. Brodrick, [email protected]
"""
import argparse
import numpy as np
import pandas as pd
from osgeo import gdal
from spectral.io import envi
import logging
import ray
from typing import List
import time
import os
import multiprocessing
import emit_utils.common_logs
import emit_utils.file_checks
import emit_utils.multi_raster_info
from emit_utils.file_checks import envi_header
GLT_NODATA_VALUE=-9999
#GLT_NODATA_VALUE=0
CRITERIA_NODATA_VALUE=-9999
def main():
parser = argparse.ArgumentParser(description='Integrate multiple GLTs with a mosaicing rule')
parser.add_argument('glt_file')
parser.add_argument('rawspace_file', help='filename of rawspace source file or, in the case of a mosaic_glt, a text-file list of raw space files')
parser.add_argument('output_filename')
parser.add_argument('-band_numbers', nargs='+', type=int, default=-1, help='list of 0-based band numbers, or -1 for all')
parser.add_argument('-n_cores', type=int, default=-1)
parser.add_argument('-log_file', type=str, default=None)
parser.add_argument('-log_level', type=str, default='INFO')
parser.add_argument('-run_with_missing_files', type=int, default=0, choices=[0,1])
parser.add_argument('-ip_head', type=str)
parser.add_argument('-redis_password', type=str)
parser.add_argument('-one_based_glt', type=int, choices=[0,1], default=0)
parser.add_argument('-mosaic', type=int, choices=[0,1], default=0)
args = parser.parse_args()
# Set up logging per arguments
if args.log_file is None:
logging.basicConfig(format='%(message)s', level=args.log_level)
else:
logging.basicConfig(format='%(message)s', level=args.log_level, filename=args.log_file)
args.one_based_glt = args.one_based_glt == 1
args.run_with_missing_files = args.run_with_missing_files == 1
args.mosaic = args.mosaic == 1
# Log the current time
logging.info('Starting apply_glt, arguments given as: {}'.format(args))
emit_utils.common_logs.logtime()
# Do some checks on input raster files
#emit_utils.file_checks.check_raster_files([args.glt_file], map_space=True)
# Open the GLT dataset
glt_dataset = gdal.Open(args.glt_file, gdal.GA_ReadOnly)
glt = envi.open(envi_header(args.glt_file)).open_memmap(writeable=False, interleave='bip')
if args.mosaic:
rawspace_files = open(args.rawspace_file,'r').readlines()
rawspace_files = [x.strip() for x in rawspace_files]
# TODO: make this check more elegant, should run, catch all files present exception, and proceed
if args.run_with_missing_files is False:
emit_utils.file_checks.check_raster_files(rawspace_files, map_space=False)
# TODO: check that all rawspace files have same number of bands
else:
emit_utils.file_checks.check_raster_files([args.rawspace_file], map_space=False)
rawspace_files = [args.rawspace_file]
# TODO: consider adding check for the right number of rawspace_files - requires
# reading the GLT through, which isn't free
band_names = None
for _ind in range(len(rawspace_files)):
first_file_dataset = gdal.Open(rawspace_files[_ind], gdal.GA_ReadOnly)
if first_file_dataset is not None:
if 'band names' in envi.open(envi_header(rawspace_files[_ind])).metadata.keys():
if args.band_numbers != -1:
band_names = [x for _x, x in enumerate(envi.open(envi_header(rawspace_files[_ind])).metadata['band names']) if _x in args.band_numbers]
else:
band_names = envi.open(envi_header(rawspace_files[_ind])).metadata['band names']
break
else:
band_names = [f'Band {x}' for x in range(first_file_dataset.RasterCount)]
if args.band_numbers != -1:
band_names = [x for _x, x in enumerate(band_names) if _x in args.band_numbers]
if args.band_numbers == -1:
output_bands = np.arange(first_file_dataset.RasterCount)
else:
output_bands = np.array(args.band_numbers)
# Build output dataset
driver = gdal.GetDriverByName('ENVI')
driver.Register()
#TODO: careful about output datatypes / format
outDataset = driver.Create(args.output_filename, glt.shape[1], glt.shape[0],
len(output_bands), gdal.GDT_Float32, options=['INTERLEAVE=BIL'])
outDataset.SetProjection(glt_dataset.GetProjection())
outDataset.SetGeoTransform(glt_dataset.GetGeoTransform())
for _b in range(1, len(output_bands)+1):
outDataset.GetRasterBand(_b).SetNoDataValue(-9999)
if band_names is not None:
outDataset.GetRasterBand(_b).SetDescription(band_names[_b-1])
del outDataset
if args.n_cores == -1:
args.n_cores = multiprocessing.cpu_count()
rayargs = {'address': args.ip_head,
'_redis_password': args.redis_password,
'local_mode': args.n_cores == 1}
if args.n_cores < 40:
rayargs['num_cpus'] = args.n_cores
ray.init(**rayargs)
print(ray.cluster_resources())
jobs = []
for idx_y in range(glt.shape[0]):
jobs.append(apply_mosaic_glt_line.remote(args.glt_file,
args.output_filename,
rawspace_files,
output_bands,
idx_y,
args))
rreturn = [ray.get(jid) for jid in jobs]
ray.shutdown()
#if args.n_cores == -1:
# args.n_cores = multiprocessing.cpu_count()
#pool = multiprocessing.Pool(processes=args.n_cores)
#results = []
#for idx_y in range(glt_dataset.RasterYSize):
# if args.n_cores == 1:
# apply_mosaic_glt_line(args.glt_file, args.output_filename, rawspace_files, output_bands, idx_y)
# else:
# results.append(pool.apply_async(apply_mosaic_glt_line, args=(args.glt_file, args.output_filename, rawspace_files, output_bands, idx_y)))
#if args.n_cores != 1:
# results = [p.get() for p in results]
#for idx_f in range(len(rawspace_files)):
# if args.n_cores == 1:
# apply_mosaic_glt_image(args.glt_file, args.output_filename, rawspace_files[idx_f], idx_f, output_bands)
# else:
# results.append(pool.apply_async(apply_mosaic_glt_image, args=(args.glt_file, args.output_filename, rawspace_files[idx_f], idx_f, output_bands)))
#if args.n_cores != 1:
# results = [p.get() for p in results]
#pool.close()
#pool.join()
# Log final time and exit
logging.info('GLT application complete, output available at: {}'.format(args.output_filename))
emit_utils.common_logs.logtime()
def _write_bil_chunk(dat: np.array, outfile: str, line: int, shape: tuple, dtype: str = 'float32') -> None:
"""
Write a chunk of data to a binary, BIL formatted data cube.
Args:
dat: data to write
outfile: output file to write to
line: line of the output file to write to
shape: shape of the output file
dtype: output data type
Returns:
None
"""
outfile = open(outfile, 'rb+')
outfile.seek(line * shape[1] * shape[2] * np.dtype(dtype).itemsize)
outfile.write(dat.astype(dtype).tobytes())
outfile.close()
@ray.remote
def apply_mosaic_glt_line(glt_filename: str, output_filename: str, rawspace_files: List, output_bands: np.array,
line_index: int, args: List):
"""
Create one line of an output mosaic in mapspace
Args:
glt_filename: pre-built single or mosaic glt
output_filename: output destination, assumed to location where a pre-initialized raster exists
rawspace_files: list of rawspace input locations
output_bands: array-like of bands to use from the rawspace file in the output
line_index: line of the glt to process
Returns:
None
"""
logging.basicConfig(format='%(message)s', level=args.log_level, filename=args.log_file)
glt_dataset = envi.open(envi_header(glt_filename))
glt = glt_dataset.open_memmap(writeable=False, interleave='bip')
if line_index % 100 == 0:
logging.info('Beginning application of line {}/{}'.format(line_index, glt.shape[0]))
#glt_line = glt_dataset.ReadAsArray(0, line_index, glt_dataset.RasterXSize, 1)
#glt_line = glt[0][:,line_index:line_index+1, :]
glt_line = np.squeeze(glt[line_index,...]).copy().astype(int)[...,:3]
valid_glt = np.all(glt_line != GLT_NODATA_VALUE, axis=-1)
glt_line[valid_glt,1] = np.abs(glt_line[valid_glt,1])
glt_line[valid_glt,0] = np.abs(glt_line[valid_glt,0])
glt_line[valid_glt,-1] = glt_line[valid_glt,-1]
if args.one_based_glt:
glt_line[valid_glt,:] = glt_line[valid_glt,:] - 1
if np.sum(valid_glt) == 0:
return
if args.mosaic:
un_file_idx = np.unique(glt_line[valid_glt,-1])
else:
un_file_idx = [0]
output_dat = np.zeros((glt.shape[1],len(output_bands)),dtype=np.float32) - 9999
for _idx in un_file_idx:
if os.path.isfile(rawspace_files[_idx]):
rawspace_dataset = envi.open(envi_header(rawspace_files[_idx]))
rawspace_dat = rawspace_dataset.open_memmap(interleave='bip')
if args.mosaic:
linematch = np.logical_and(glt_line[:,-1] == _idx, valid_glt)
else:
linematch = valid_glt
if np.sum(linematch) > 0:
output_dat[linematch,:] = rawspace_dat[glt_line[linematch,1][:,None], glt_line[linematch,0][:,None],output_bands[None,:]].copy()
_write_bil_chunk(np.transpose(output_dat), output_filename, line_index, (glt.shape[0], len(output_bands), glt.shape[1]))
#for file_index in necessary_file_idxs:
# if glt_line.shape[0] == 3:
# pixel_subset = glt_line[..., -1] == file_index
# else:
# pixel_subset = valid_glt.copy()
# for ind in range(len(glt_line)):
#
# min_glt_y = np.min(glt_line[pixel_subset, 1])
# max_glt_y = np.max(glt_line[pixel_subset, 1])
# # Open up the criteria dataset
# #rawspace_dataset = gdal.Open(rawspace_files[file_index], gdal.GA_ReadOnly)
# if os.path.isfile(rawspace_files[file_index]):
# for
# # Read in the block of data necessary to get the criteria
# rawspace_block = rawspace_dat[np.zeros((len(output_bands), max_glt_y - min_glt_y + 1, rawspace_dataset.RasterXSize))
# for gltindex in range(min_glt_y, max_glt_y+1):
# direct_read = np.squeeze(rawspace_dataset.ReadAsArray(0, gltindex,rawspace_dataset.RasterXSize,1))
# # TODO: account for extra bands, if this isn't a single-band dataset (gdal drops the 0th dimension if it is)
# if rawspace_dataset.RasterCount > 0:
# direct_read = direct_read[output_bands,...]
# # assign data to block
# rawspace_block[:, gltindex - min_glt_y, :] = direct_read
# # convert rawspace to mapspace through lookup
# mapspace_line = rawspace_block[:,glt_line[1, pixel_subset].flatten() - min_glt_y, glt_line[0, pixel_subset].flatten()]
# # write the output
# #TODO: careful about output datatype
# output_memmap = np.memmap(output_filename, mode='r+', shape=(glt[0].shape[1], len(output_bands), glt[0].shape[2]), dtype=np.float32)
# output_memmap[line_index, :, np.squeeze(pixel_subset)] = np.transpose(mapspace_line)
# del output_memmap
def apply_mosaic_glt_image(glt_filename: str, output_filename: str, rawspace_file: str, rawspace_file_index: int, output_bands: np.array):
"""
Apply glt to one files worth of raw-space data, suitable for instances with large numbers of files.
Args:
glt_filename: pre-built single or mosaic glt
output_filename: output destination, assumed to location where a pre-initialized raster exists
rawspace_file: list of rawspace input locations
rawspace_file_index: index of rawspace file in the mosaic_glt
output_bands: array-like of bands to use from the rawspace file in the output
Returns:
None
"""
# Open up the criteria dataset
rawspace_dataset = gdal.Open(rawspace_file, gdal.GA_ReadOnly)
if rawspace_dataset is None:
return
rawspace = np.zeros((rawspace_dataset.RasterCount, rawspace_dataset.RasterYSize, rawspace_dataset.RasterXSize))
rawspace = np.memmap(rawspace_file, mode='r', shape=(rawspace_dataset.RasterYSize, rawspace_dataset.RasterCount, rawspace_dataset.RasterXSize), dtype=np.float32)
#for line_index in range(rawspace.shape[-2]):
# rawspace[:,line_index:line_index+1,:] = rawspace_dataset.ReadAsArray(0,line_index,rawspace.shape[-1],rawspace.shape[-2])
# if (line_index % 100) == 0:
# print('Loading line: {}/{} of file {}'.format(line_index,rawspace.shape[-1],rawspace_file))
logging.info('Successfully complted read of file {}'.format(rawspace_file))
output_y_loc, output_x_loc = np.where(glt[0][2, ...] == rawspace_file_index+1)
print('{} pixel values identified'.format(len(output_y_loc)))
rawspace_y_loc = np.abs(glt[0][0,output_y_loc,output_x_loc])
rawspace_x_loc = np.abs(glt[0][1,output_y_loc,output_x_loc])
#rawspace_y_loc = np.abs(glt[0][1,output_y_loc,output_x_loc]) - 1
#rawspace_x_loc = np.abs(glt[0][0,output_y_loc,output_x_loc]) - 1
order = np.argsort(output_y_loc)
output_y_loc = output_y_loc[order]
output_x_loc = output_x_loc[order]
rawspace_y_loc = rawspace_y_loc[order]
rawspace_x_loc = rawspace_x_loc[order]
line_index = 0
while line_index < len(rawspace_y_loc):
chunk = output_y_loc == output_y_loc[line_index]
output_memmap = np.memmap(output_filename, mode='r+', shape=(glt[0].shape[-2], len(output_bands), glt[0].shape[-1]), dtype=np.float32)
output_memmap[output_y_loc[chunk],:,output_x_loc[chunk]] = rawspace[rawspace_y_loc[chunk],:,rawspace_x_loc[chunk]][:,output_bands]
del output_memmap
if (line_index % int(len(rawspace_y_loc)/10.)) < np.sum(chunk).astype(int):
logging.debug('File {}, {} % written'.format(rawspace_file_index, round(line_index/float(len(rawspace_y_loc))*100.,2) ))
line_index += np.sum(chunk).astype(int)
#logging.debug('Starting write')
#output_memmap = np.memmap(output_filename, mode='r+', shape=(glt[0].shape[-2], len(output_bands), glt[0].shape[-1]), dtype=np.float32)
#output_memmap[output_y_loc,:,output_x_loc] = rawspace[rawspace_y_loc,:,rawspace_x_loc]
#del output_memmap
#logging.debug('Write complete')
if __name__ == "__main__":
main()