Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Build lut updates #3

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 164 additions & 131 deletions aggregated_combined_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from isofit.radiative_transfer.modtran import ModtranRT
from isofit.radiative_transfer.six_s import SixSRT
from isofit.radiative_transfer.engines.modtran import ModtranRT
from isofit.radiative_transfer.engines.six_s import SixSRT
from isofit.configs import configs
import argparse
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
import sklearn.metrics
import ray
from isofit.radiative_transfer.radiative_transfer import confPriority



Expand All @@ -39,163 +40,195 @@ def d2_subset(data,ranges):
a = a[:,ranges[1]]
return a

def readModtran(modtran_obj, filename):
try:
solzen = modtran_obj.load_tp6(f"{filename}.tp6")
coszen = np.cos(solzen * np.pi / 180.0)
params = modtran_obj.load_chn(f"{filename}.chn", coszen)

# Remove thermal terms in VSWIR runs to avoid incorrect usage
if modtran_obj.treat_as_emissive is False:
for key in ["thermal_upwelling", "thermal_downwelling"]:
if key in params:
Logger.debug(
f"Deleting key because treat_as_emissive is False: {key}"
)
del params[key]

params["solzen"] = solzen
params["coszen"] = coszen

return params
except:
return None

def readSixs(sixs_obj, filename, wl, multipart_transmittance=False, wl_size=0):

return sixs_obj.parse_file(
file=file,
wl=sixs_obj.wl,
multipart_transmittance=sixs_obj.multipart_transmittance,
wl_size=sixs_obj.wl.size,
)

def main():

# Parse arguments
parser = argparse.ArgumentParser(description="built luts for emulation.")
parser.add_argument('-config_file', type=str, default='templates/isofit_template.json')
parser.add_argument('-keys', type=str, default=['transm', 'rhoatm', 'sphalb'], nargs='+')
parser.add_argument('-keys', type=str, default=['rhoatm', 'sphalb', 'transm_down_dir', 'transm_down_dif', 'transm_up_dir', 'transm_up_dif' ], nargs='+')
parser.add_argument('-munge_dir', type=str, default='munged')
parser.add_argument('-figs_dir', type=str, default=None)
parser.add_argument('-unstruct', type=int, default=0, choices=[0,1])

args = parser.parse_args()

np.random.seed(13)

for key_ind, key in enumerate(args.keys):
munge_file = os.path.join(args.munge_dir, key + '.npz')

if os.path.isfile(munge_file) is False:
config = configs.create_new_config(args.config_file)

# Note - this goes way faster if you comment out the Vector Interpolater build section in each of these
isofit_modtran = ModtranRT(config.forward_model.radiative_transfer.radiative_transfer_engines[0],
config, build_lut = False)
isofit_sixs = SixSRT(config.forward_model.radiative_transfer.radiative_transfer_engines[1],
config, build_lut = False)

sixs_results = get_obj_res(isofit_sixs, key, resample=False)
modtran_results = get_obj_res(isofit_modtran, key)

if os.path.isdir(os.path.dirname(munge_file) is False):
os.mkdir(os.path.dirname(munge_file))

for fn in isofit_modtran.files:
mod_output = isofit_modtran.load_rt(fn)
sol_irr = mod_output['sol']
if np.all(np.isfinite(sol_irr)):
break

np.savez(munge_file, modtran_results=modtran_results, sixs_results=sixs_results, sol_irr=sol_irr)

modtran_results = None
sixs_results = None
for key_ind, key in enumerate(args.keys):
munge_file = os.path.join(args.munge_dir, key + '.npz')

npzf = np.load(munge_file)

dim1 = int(np.product(np.array(npzf['modtran_results'].shape)[:-1]))
dim2 = npzf['modtran_results'].shape[-1]
if modtran_results is None:
modtran_results = np.zeros((dim1,dim2*len(args.keys)))
modtran_results[:,dim2*key_ind:dim2*(key_ind+1)] = npzf['modtran_results']

dim1 = int(np.product(np.array(npzf['sixs_results'].shape)[:-1]))
dim2 = npzf['sixs_results'].shape[-1]
if sixs_results is None:
sixs_results = np.zeros((dim1,dim2*len(args.keys)))
sixs_results[:,dim2*key_ind:dim2*(key_ind+1)] = npzf['sixs_results']

sol_irr = npzf['sol_irr']


config = configs.create_new_config(args.config_file)
isofit_modtran = ModtranRT(config.forward_model.radiative_transfer.radiative_transfer_engines[0],
config, build_lut=False)
isofit_sixs = SixSRT(config.forward_model.radiative_transfer.radiative_transfer_engines[1],
config, build_lut=False)
sixs_names = isofit_sixs.lut_names
modtran_names = isofit_modtran.lut_names

if 'elev' in sixs_names:
sixs_names[sixs_names.index('elev')] = 'GNDALT'
if 'viewzen' in sixs_names:
sixs_names[sixs_names.index('viewzen')] = 'OBSZEN'
if 'viewaz' in sixs_names:
sixs_names[sixs_names.index('viewaz')] = 'TRUEAZ'
if 'alt' in sixs_names:
sixs_names[sixs_names.index('alt')] = 'H1ALT'
if 'AOT550' in sixs_names:
sixs_names[sixs_names.index('AOT550')] = 'AERFRAC_2'

reorder_sixs = [sixs_names.index(x) for x in modtran_names]

points = isofit_modtran.points.copy()
points_sixs = isofit_sixs.points.copy()[:,reorder_sixs]

if 'OBSZEN' in modtran_names:
print('adjusting')
points_sixs[:, modtran_names.index('OBSZEN')] = 180 - points_sixs[:, modtran_names.index('OBSZEN')]


ind = np.lexsort(tuple([points[:,x] for x in range(points.shape[-1])]))
points = points[ind,:]
modtran_results = modtran_results[ind,:]

ind_sixs = np.lexsort(tuple([points_sixs[:,x] for x in range(points_sixs.shape[-1])]))
points_sixs = points_sixs[ind_sixs,:]
sixs_results = sixs_results[ind_sixs,:]
# Note - this goes way faster if you comment out the Vector Interpolater build section in each of these

rt_config = config.forward_model.radiative_transfer
instrument_config = config.forward_model.instrument

good_data = np.all(np.isnan(modtran_results) == False,axis=1)
good_data[np.any(np.isnan(sixs_results),axis=1)] = False

params = {'engine_config': rt_config.radiative_transfer_engines[0]}

params['lut_grid'] = confPriority('lut_grid', [params['engine_config'], instrument_config, rt_config])
params['lut_grid'] = {key: params['lut_grid'][key] for key in params['engine_config'].lut_names.keys()}
params['wavelength_file'] = confPriority('wavelength_file', [params['engine_config'], instrument_config, rt_config])
if args.unstruct:
params['engine_config'].rte_configure_and_exit = True
else:
params['engine_config'].rte_configure_and_exit = False
#params['engine_config'].rt_mode = 'rdn'
isofit_modtran = ModtranRT(**params)

params = {'engine_config' : rt_config.radiative_transfer_engines[1]}

params['lut_grid'] = confPriority('lut_grid', [params['engine_config'], instrument_config, rt_config])
params['lut_grid'] = {key: params['lut_grid'][key] for key in params['engine_config'].lut_names.keys()}
if args.unstruct:
params['engine_config'].rte_configure_and_exit = True
else:
params['engine_config'].rte_configure_and_exit = False
#params['engine_config'].rt_mode = 'rdn'

modtran_results = modtran_results[good_data,:]
sixs_results = sixs_results[good_data,:]
points = points[good_data,...]
points_sixs = points_sixs[good_data,...]
# Get raw 6s return, not wavelength convolved (this is what we'll use for inference too)
sixs_wl = np.arange(350, 2500 + 2.5, 2.5)
sixs_fwhm = np.full(sixs_wl.size, 2.0)
#params['wavelength_file'] = confPriority('wavelength_file', [params['engine_config'], instrument_config, rt_config])
params['wl'] = sixs_wl
params['fwhm'] = sixs_fwhm
isofit_sixs = SixSRT(**params)

print(sixs_results.shape)
print(modtran_results.shape)

tmp = isofit_sixs.load_rt(isofit_sixs.files[0])

np.savez(os.path.join(args.munge_dir, 'combined_training_data.npz'), sixs_results=sixs_results, modtran_results=modtran_results,
points=points, points_sixs=points_sixs, keys=args.keys, point_names=modtran_names, modtran_wavelengths=isofit_modtran.wl,
sixs_wavelengths=isofit_sixs.grid,
sol_irr=sol_irr)


@ray.remote
def read_data_piece(ind, maxind, point, fn, key, resample, obj):
if ind % 100 == 0:
print('{}: {}/{}'.format(key, ind, maxind))
try:
if resample is False:
mod_output = obj.load_rt(fn, resample=False)
else:
mod_output = obj.load_rt(fn)
res = mod_output[key]
except:
res = None
return ind, res


outdict_sixs, outdict_modtran = {}, {}
munge_file = os.path.join(args.munge_dir, 'combined_data.npz')
good_points = np.ones(isofit_sixs.lut['rhoatm'].shape[0]).astype(bool)
for key_ind, key in enumerate(args.keys):

def get_obj_res(obj, key, resample=True):
if args.figs_dir is not None:
point_dims = list(isofit_modtran.lut_grid.keys())
point_dims_s = list(isofit_sixs.lut_grid.keys())
rtm_key = 'transm_down_dir'
for _ind in range(isofit_sixs.lut['rhoatm'].shape[0]):

sp = isofit_sixs.points[_ind,:]
mp = isofit_modtran.points[_ind,:]

std_dir = isofit_sixs.lut['transm_down_dir'][_ind,:]
mtd_dir = isofit_modtran.lut['transm_down_dir'][_ind,:]
std_dif = isofit_sixs.lut['transm_down_dif'][_ind,:]
mtd_dif = isofit_modtran.lut['transm_down_dif'][_ind,:]
stu_dir = isofit_sixs.lut['transm_up_dir'][_ind,:]
mtu_dir = isofit_modtran.lut['transm_up_dir'][_ind,:]
stu_dif = isofit_sixs.lut['transm_up_dif'][_ind,:]
mtu_dif = isofit_modtran.lut['transm_up_dif'][_ind,:]
s_r = isofit_sixs.lut['rhoatm'][_ind,:]
m_r = isofit_modtran.lut['rhoatm'][_ind,:]

name='_'.join([f'{point_dims_s[x]}_{sp[x]}' for x in range(len(sp))])
plt.plot(isofit_sixs.wl, std_dir, color='black', label='t_down_dir')
plt.plot(isofit_sixs.wl, std_dif, color='grey', label='t_down_dif')
plt.plot(isofit_sixs.wl, stu_dir, color='red', label='t_up_dir')
plt.plot(isofit_sixs.wl, stu_dif, color='purple', label='t_up_dif')
plt.plot(isofit_sixs.wl, s_r, color='green', label='rhoatm')

name2='_'.join([f'{point_dims[x]}_{mp[x]}' for x in range(len(mp))])
plt.plot(isofit_modtran.wl, mtd_dir, color='black', ls = '--')
plt.plot(isofit_modtran.wl, mtd_dif, color='grey', ls = '--')
plt.plot(isofit_modtran.wl, mtu_dir, color='red', ls = '--')
plt.plot(isofit_modtran.wl, mtu_dif, color='purple', ls = '--')
plt.plot(isofit_modtran.wl, m_r, color='green', ls = '--')

plt.legend(fontsize=4, loc='lower right')

plt.title(f'S: {name}\n M: {name2}',fontsize=6)

plt.savefig(f'{args.figs_dir}/{name.replace("\n","_")}.png',dpi=100)
plt.clf()

point_names = isofit_sixs.lut.point.to_index().names
bad_points = np.zeros(isofit_sixs.lut[key].shape[0],dtype=bool)
if 'surface_elevation_km' in point_names and 'observer_altitude_km' in point_names:
bad_points = isofit_sixs.lut['surface_elevation_km'] >= isofit_sixs.lut['observer_altitude_km'] -2
bad_points[np.any(isofit_sixs.lut['transm_down_dif'].data[:,:] > 10,axis=1)] = True
bad_points[np.any(isofit_modtran.lut['transm_down_dif'].data[:,:] > 10,axis=1)] = True
good_points = np.logical_not(bad_points)

outdict_sixs[key] = np.array(isofit_sixs.lut[key])[good_points,:]
outdict_modtran[key] = np.array(isofit_modtran.lut[key])[good_points,:]


# hack at some things to prevent runaway values
outdict_sixs['transm_up_dif'][outdict_sixs['transm_up_dif'][:,:] > 1] = 0
outdict_sixs['transm_up_dif'][outdict_sixs['transm_up_dif'][:,:] < 0] = 0
outdict_modtran['transm_up_dif'][outdict_modtran['transm_up_dif'][:,:] > 1] = 0
outdict_modtran['transm_up_dif'][outdict_modtran['transm_up_dif'][:,:] < 0] = 0

outdict_sixs['sphalb'][:,isofit_sixs.wl > 1200][outdict_sixs['sphalb'][:,isofit_sixs.wl > 1200] > 0.1] = 0
outdict_sixs['sphalb'][outdict_sixs['sphalb'][:,:] < 0] = 0

subs = outdict_modtran['sphalb'][:,isofit_modtran.wl > 1200].copy()
subs[subs > 0.1] = 0
outdict_modtran['sphalb'][:,isofit_modtran.wl > 1200] = subs
print(np.sum(outdict_modtran['sphalb'][:,isofit_modtran.wl > 1200] > 0.1))

outdict_modtran['sphalb'][outdict_modtran['sphalb'][:,:] < 0] = 0



#keys=list(isofit_modtran.lut_grid.keys())
keys=['rhoatm','sphalb','transm_down_dir','transm_down_dif', 'transm_up_dir','transm_up_dif']
#keys=list(isofit_sixs.lut.point.to_index().names)
stacked_modtran = np.zeros((outdict_modtran[keys[0]].shape[0], outdict_modtran[keys[0]].shape[1] * len(keys)))
stacked_sixs = np.zeros((outdict_sixs[keys[0]].shape[0], outdict_sixs[keys[0]].shape[1] * len(keys)))
n_bands_modtran = int(stacked_modtran.shape[-1]/len(keys))
n_bands_sixs = int(stacked_sixs.shape[-1]/len(keys))
for n in range(len(keys)):
stacked_modtran[:,n*n_bands_modtran:(n+1)*n_bands_modtran] = outdict_modtran[keys[n]]
stacked_sixs[:,n*n_bands_sixs:(n+1)*n_bands_sixs] = outdict_sixs[keys[n]]


np.savez(munge_file, modtran_results=stacked_modtran,
sixs_results=stacked_sixs,
points=isofit_sixs.points[good_points,:],
sol_irr=isofit_modtran.lut.solar_irr,
sixs_wavelengths=isofit_sixs.wl,
modtran_wavelengths=isofit_modtran.wl,
point_names=list(isofit_sixs.lut.point.to_index().names),
keys=keys
)

# We don't want the VectorInterpolator, but rather the raw inputs
ray.init(temp_dir='/tmp/ray/brodrick/')
if hasattr(obj,'sixs_ngrid_init'):
results = np.zeros((obj.points.shape[0],obj.sixs_ngrid_init), dtype=float)
else:
results = np.zeros((obj.points.shape[0],obj.n_chan), dtype=float)
objid = ray.put(obj)
jobs = []
for ind, (point, fn) in enumerate(zip(obj.points, obj.files)):
jobs.append(read_data_piece.remote(ind, results.shape[0], point, fn, key, resample, objid))
rreturn = [ray.get(jid) for jid in jobs]
for ind, res in rreturn:
if res is not None:
try:
results[ind,:] = res
except:
results[ind,:] = np.nan
else:
results[ind,:] = np.nan
ray.shutdown()
return results


if __name__ == '__main__':
Expand Down
Loading