diff --git a/.github/workflows/test-B01_SL_load_single_file.yml b/.github/workflows/test-B01_SL_load_single_file.yml index 4ef0d1fa..216f2b9c 100644 --- a/.github/workflows/test-B01_SL_load_single_file.yml +++ b/.github/workflows/test-B01_SL_load_single_file.yml @@ -24,8 +24,27 @@ jobs: - name: List dependencies run: pip list - name: test icesat2waves app - run: icesat2waves --help - - name: test command for step 1 - run: load_single_file --help + run: | + icesat2waves --help + icesat2waves load-file --help + icesat2waves make-spectra --help + icesat2waves plot-spectra --help + icesat2waves make-iowaga-threads-prior --help + icesat2waves make-b04-angle --help # prelim name + icesat2waves define-angle --help + icesat2waves correct-separate --help # prelim name + - name: first step B01_SL_load_single_file - run: python src/icesat2_tracks/analysis_db/B01_SL_load_single_file.py --track-name 20190502052058_05180312_005_01 --batch-key SH_testSLsinglefile2 --output-dir ./work + run: load-file --track-name 20190502052058_05180312_005_01 --batch-key SH_testSLsinglefile2 --output-dir ./work + - name: second step make_spectra + run: make-spectra --track-name SH_20190502_05180312 --batch-key SH_testSLsinglefile2 --output-dir ./work + - name: third step plot_spectra + run: plot-spectra --track-name SH_20190502_05180312 --batch-key SH_testSLsinglefile2 --output-dir ./work + - name: fouth step IOWAGA threads + run: make-iowaga-threads-prior --track-name SH_20190502_05180312 --batch-key SH_testSLsinglefile2 --output-dir ./work + - name: fifth step B04_angle + run: make-b04-angle --track-name SH_20190502_05180312 --batch-key SH_testSLsinglefile2 --output-dir ./work + - name: sixth step B04_define_angle + run: define-angle --track-name SH_20190502_05180312 --batch-key SH_testSLsinglefile2 --output-dir ./work + - name: seventh step B06_correct_separate + run: correct-separate --track-name SH_20190502_05180312 --batch-key SH_testSLsinglefile2 --output-dir ./work \ No newline at end of file diff --git a/.gitignore b/.gitignore index a6f0b6af..3574230a 100644 --- a/.gitignore +++ b/.gitignore @@ -38,13 +38,14 @@ analysis_db/support_files/ *.egg-info/ .installed.cfg *.egg - +logs/ *__pycache__/ *__pycache__/* # Environments .env .venv +.venv39/ env/ venv/ ENV/ @@ -57,4 +58,4 @@ dist/ #visual code .vscode/ -*.h5 \ No newline at end of file +*.h5 diff --git a/pyproject.toml b/pyproject.toml index ccedb2f6..e340d58d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,8 +134,16 @@ dependencies = [ # Optional # Similar to `dependencies` above, these must be valid existing # projects. [project.optional-dependencies] # Optional -dev = ["check-manifest","black","icesat2-tracks[test]"] -test = ["coverage", "pytest >=7.4.4, <8.0.0", "pytest-xdist >=3.5.0, <4.0.0"] +dev = [ + "check-manifest", + "black", + "icesat2-tracks[test]" +] +test = [ + "coverage", + "pytest >=7.4.4, <8.0.0", + "pytest-xdist >=3.5.0, <4.0.0" +] # List URLs that are relevant to your project # @@ -157,7 +165,13 @@ test = ["coverage", "pytest >=7.4.4, <8.0.0", "pytest-xdist >=3.5.0, <4.0.0"] [project.scripts] # Optional #TODO: ADD ANY SCRIPTS WE WANT TO HAVE download = "icesat2_tracks.icesat2_tools_scripts.nsidc_icesat2_associated2:main" -load_single_file = "icesat2_tracks.analysis_db.B01_SL_load_single_file:step1app" +load-file = "icesat2_tracks.analysis_db.B01_SL_load_single_file:load_file_app" +make-spectra = "icesat2_tracks.analysis_db.B02_make_spectra_gFT:make_spectra_app" +plot-spectra = "icesat2_tracks.analysis_db.B03_plot_spectra_ov:plot_spectra" +make-iowaga-threads-prior = "icesat2_tracks.analysis_db.A02c_IOWAGA_thredds_prior:make_iowaga_threads_prior_app" +make-b04-angle = "icesat2_tracks.analysis_db.B04_angle:make_b04_angle_app" +define-angle = "icesat2_tracks.analysis_db.B05_define_angle:define_angle_app" +correct-separate = "icesat2_tracks.analysis_db.B06_correct_separate_var:correct_separate_app" icesat2waves = "icesat2_tracks.app:app" diff --git a/src/icesat2_tracks/ICEsat2_SI_tools/iotools.py b/src/icesat2_tracks/ICEsat2_SI_tools/iotools.py index e92ee495..45444490 100644 --- a/src/icesat2_tracks/ICEsat2_SI_tools/iotools.py +++ b/src/icesat2_tracks/ICEsat2_SI_tools/iotools.py @@ -39,7 +39,7 @@ def init_from_input(arguments): batch_key = arguments[2] # $(hemisphere) $(coords) $(config) - print("read vars from file: " + str(arguments[1])) + # print("read vars from file: " + str(arguments[1])) if len(arguments) >= 4: if arguments[3] == "True": @@ -49,13 +49,16 @@ def init_from_input(arguments): else: test_flag = arguments[3] - print("test_flag found, test_flag= " + str(test_flag)) + # print("test_flag found, test_flag= " + str(test_flag)) else: test_flag = False - print(track_name) - print("----- batch =" + batch_key) - print("----- test_flag: " + str(test_flag)) + # TODO: print statements to be handled with logger + # # print(track_name) + + # print("----- batch =" + batch_key) + # print("----- test_flag: " + str(test_flag)) + return track_name, batch_key, test_flag diff --git a/src/icesat2_tracks/analysis_db/A02c_IOWAGA_thredds_prior.py b/src/icesat2_tracks/analysis_db/A02c_IOWAGA_thredds_prior.py index 1fa8068a..7cf10b3c 100644 --- a/src/icesat2_tracks/analysis_db/A02c_IOWAGA_thredds_prior.py +++ b/src/icesat2_tracks/analysis_db/A02c_IOWAGA_thredds_prior.py @@ -1,13 +1,17 @@ -import sys +#!/usr/bin/env python3 + import datetime +from pathlib import Path import h5py import pandas as pd import numpy as np +import matplotlib import matplotlib.pyplot as plt from matplotlib.gridspec import GridSpec import xarray as xr from siphon.catalog import TDSCatalog +import typer from icesat2_tracks.config.IceSAT2_startup import mconfig import icesat2_tracks.ICEsat2_SI_tools.iotools as io @@ -17,527 +21,632 @@ from icesat2_tracks.config.IceSAT2_startup import color_schemes from icesat2_tracks.config.IceSAT2_startup import font_for_print -color_schemes.colormaps2(21) +from icesat2_tracks.clitools import ( + echo, + validate_batch_key, + validate_output_dir, + suppress_stdout, + update_paths_mconfig, + report_input_parameters, + validate_track_name_steps_gt_1, + makeapp, +) -track_name, batch_key, ID_flag = io.init_from_input(sys.argv) +color_schemes.colormaps2(21) +matplotlib.use("Agg") # prevent plot windows from opening -track_name_short = track_name[0:-16] +dtime = 4 # in hours -ID, track_names, hemis, batch = io.init_data( - track_name, batch_key, ID_flag, mconfig["paths"]["work"] -) +# IOWAGA constants +data_url = "https://tds3.ifremer.fr/thredds/IOWAGA-WW3-FORECAST/IOWAGA-WW3-FORECAST_GLOBMULTI_GLOB-30M.xml" +dataset_key = "IOWAGA-WW3-FORECAST_GLOBMULTI_GLOB-30M_FIELD_NC_MARC_WW3-GLOB-30M" + + +def get_iowaga(data_url=data_url, dataset_key=dataset_key): + ## load WW3 data + # ECMWF hindcast + # data_url = 'https://tds3.ifremer.fr/thredds/IOWAGA-WW3-HINDCAST/IOWAGA-GLOBAL_ECMWF-WW3-HINDCAST_FULL_TIME_SERIE.xml' + + # CFSR hindcast + # data_url = 'https://tds3.ifremer.fr/thredds/IOWAGA-WW3-HINDCAST/IOWAGA-GLOBAL_CFSR-WW3-HINDCAST_FULL_TIME_SERIE.xml' + cat = TDSCatalog(data_url) + ncss = cat.datasets[dataset_key].remote_access(use_xarray=True) + + var_list = [ + "dir", + "dp", + "fp", + "hs", + "ice", + "spr", + "t01", + "t02", + "plp0", + "pdir0", + "pdir1", + "pdir2", + "pdir3", + "pdir4", + "pdir5", + "pspr0", + "pspr1", + "pspr2", + "pspr3", + "pspr4", + "pspr5", + "ptp0", + "ptp1", + "ptp2", + "ptp3", + "ptp4", + "ptp5", + "phs0", + "phs1", + "phs2", + "phs3", + "phs4", + "phs5", + ] + names_map = { + "pdir0": "pdp0", + "pdir1": "pdp1", + "pdir2": "pdp2", + "pdir3": "pdp3", + "pdir4": "pdp4", + "pdir5": "pdp5", + } + IOWAGA = ncss[var_list] + IOWAGA["time"] = np.array([np.datetime64(k0) for k0 in IOWAGA.time.data]).astype( + "M8[h]" + ) + IOWAGA = IOWAGA.rename(name_dict=names_map) -hemis, batch = batch_key.split("_") -ATlevel = "ATL03" + return IOWAGA -save_path = mconfig["paths"]["work"] + batch_key + "/A02_prior/" -plot_path = ( - mconfig["paths"]["plot"] + "/" + hemis + "/" + batch_key + "/" + track_name + "/" -) -save_name = "A02_" + track_name -plot_name = "A02_" + track_name_short -MT.mkdirs_r(plot_path) -MT.mkdirs_r(save_path) -bad_track_path = mconfig["paths"]["work"] + "bad_tracks/" + batch_key + "/" -all_beams = mconfig["beams"]["all_beams"] -high_beams = mconfig["beams"]["high_beams"] -low_beams = mconfig["beams"]["low_beams"] +def sel_data(Ibeam, lon_range, lat_range, time_range, timestamp=None): + """ + this method returns the selected data in the lon-lat box at an interpolated timestamp + """ + # TODO: refactor to avoid code duplication + lon_flag = (lon_range[0] < Ibeam.longitude.data) & ( + Ibeam.longitude.data < lon_range[1] + ) + lat_flag = (lat_range[0] < Ibeam.latitude.data) & ( + Ibeam.latitude.data < lat_range[1] + ) + time_flag = (time_range[0] < Ibeam.time.data) & (Ibeam.time.data < time_range[1]) -load_path_WAVE_GLO = mconfig["paths"]["work"] + "/GLOBMULTI_ERA5_GLOBCUR_01/" -file_name_base = "LOPS_WW3-GLOB-30M_" + if timestamp is None: + Ibeam = Ibeam.isel(latitude=lat_flag, longitude=lon_flag) + else: + Ibeam = ( + Ibeam.isel(latitude=lat_flag, longitude=lon_flag, time=time_flag) + .sortby("time") + .interp(time=np.datetime64(timestamp)) + ) + return Ibeam -load_path = mconfig["paths"]["work"] + batch_key + "/B01_regrid/" -Gd = h5py.File(load_path + "/" + track_name + "_B01_binned.h5", "r") +def build_prior(Tend: pd.DataFrame): + """ + Build a prior dictionary from the Tend DataFrame + Args: + Tend (pd.DataFrame): DataFrame containing the mean and standard deviation of the wave parameters -G1 = dict() -for b in all_beams: - Gi = io.get_beam_hdf_store(Gd[b]) - G1[b] = Gi.iloc[abs(Gi["lats"]).argmin()] + Returns: + dict: A dictionary containing the prior wave parameters + """ + Prior = dict() + key_mapping = { + "incident_angle": "dp", + "spread": "spr", + "Hs": "hs", + "center_lon": "lon", + "center_lat": "lat", + } -Gd.close() -G1 = pd.DataFrame.from_dict(G1).T + # populate Prior + for key, tend_key in key_mapping.items(): + Prior[key] = { + "value": Tend["mean"][tend_key].astype("float"), + "name": Tend["name"][tend_key], + } + # Handle "peak_period" separately + Prior["peak_period"] = { + "value": 1 / Tend["mean"]["fp"].astype("float"), + "name": "1/" + Tend["name"]["fp"], + } -if hemis == "SH": - # DEFINE SEARCH REGION AND SIZE OF BOXES FOR AVERAGES - dlon_deg = 1 # degree range aroud 1st point - dlat_deg = 30, 5 # degree range aroud 1st point - dlat_deg_prior = 2, 1 # degree range aroud 1st point + return Prior - dtime = 4 # in hours +def calculate_ranges(hemis, G1, dlon_deg, dlat_deg, dlat_deg_prior): lon_range = G1["lons"].min() - dlon_deg, G1["lons"].max() + dlon_deg - lat_range = np.sign(G1["lats"].min()) * 78, G1["lats"].max() + dlat_deg[1] + + if hemis == "SH": + lat_range = np.sign(G1["lats"].min()) * 78, G1["lats"].max() + dlat_deg[1] + else: + lat_range = G1["lats"].min() - dlat_deg[0], G1["lats"].max() + dlat_deg[1] + lat_range_prior = ( G1["lats"].min() - dlat_deg_prior[0], G1["lats"].max() + dlat_deg_prior[1], ) -else: - # DEFINE SEARCH REGION AND SIZE OF BOXES FOR AVERAGES - dlon_deg = 2 # lon degree range aroud 1st point - dlat_deg = 20, 20 # lat degree range aroud 1st point - dlat_deg_prior = 2, 1 # degree range aroud 1st point + return lon_range, lat_range, lat_range_prior - dtime = 4 # in hours - lon_range = G1["lons"].min() - dlon_deg, G1["lons"].max() + dlon_deg - lat_range = G1["lats"].min() - dlat_deg[0], G1["lats"].max() + dlat_deg[1] - lat_range_prior = ( - G1["lats"].min() - dlat_deg_prior[0], - G1["lats"].max() + dlat_deg_prior[1], - ) +def define_timestamp_and_time_range(ID, dtime): + timestamp = pd.to_datetime(ID["pars"]["start"]["delta_time"], unit="s") + time_range = np.datetime64(timestamp) - np.timedelta64(dtime, "h"), np.datetime64( + timestamp + ) + np.timedelta64(dtime, "h") + return timestamp, time_range -timestamp = pd.to_datetime(ID["pars"]["start"]["delta_time"], unit="s") -time_range = np.datetime64(timestamp) - np.timedelta64(dtime, "h"), np.datetime64( - timestamp -) + np.timedelta64(dtime, "h") +def run_A02c_IOWAGA_thredds_prior( + track_name: str = typer.Option(..., callback=validate_track_name_steps_gt_1), + batch_key: str = typer.Option(..., callback=validate_batch_key), + ID_flag: bool = True, + data_url: str = typer.Option(data_url), + dataset_key: str = typer.Option(dataset_key), + output_dir: str = typer.Option(..., callback=validate_output_dir), + verbose: bool = False, +): + """ + TODO: add docstring + """ -## load WW3 data -# ECMWF hindcast -# data_url = 'https://tds3.ifremer.fr/thredds/IOWAGA-WW3-HINDCAST/IOWAGA-GLOBAL_ECMWF-WW3-HINDCAST_FULL_TIME_SERIE.xml' + track_name, batch_key, _ = io.init_from_input( + [ + None, + track_name, + batch_key, + ID_flag, + ] # init_from_input expects sys.argv with 4 elements + ) -# CFSR hindcast -# data_url = 'https://tds3.ifremer.fr/thredds/IOWAGA-WW3-HINDCAST/IOWAGA-GLOBAL_CFSR-WW3-HINDCAST_FULL_TIME_SERIE.xml' + kwargs = { + "track_name": track_name, + "batch_key": batch_key, + "ID_flag": ID_flag, + "output_dir": output_dir, + } + report_input_parameters(**kwargs) -# ECMWF forecast -data_url = "https://tds3.ifremer.fr/thredds/IOWAGA-WW3-FORECAST/IOWAGA-WW3-FORECAST_GLOBMULTI_GLOB-30M.xml" + track_name_short = track_name[0:-16] -cat = TDSCatalog(data_url) - -ncss = cat.datasets[ - "IOWAGA-WW3-FORECAST_GLOBMULTI_GLOB-30M_FIELD_NC_MARC_WW3-GLOB-30M" -].remote_access(use_xarray=True) - -var_list = [ - "dir", - "dp", - "fp", - "hs", - "ice", - "spr", - "t01", - "t02", - "plp0", - "pdir0", - "pdir1", - "pdir2", - "pdir3", - "pdir4", - "pdir5", - "pspr0", - "pspr1", - "pspr2", - "pspr3", - "pspr4", - "pspr5", - "ptp0", - "ptp1", - "ptp2", - "ptp3", - "ptp4", - "ptp5", - "phs0", - "phs1", - "phs2", - "phs3", - "phs4", - "phs5", -] - - -# chunk data -IOWAGA = ncss[var_list] -IOWAGA["time"] = np.array([np.datetime64(k0) for k0 in IOWAGA.time.data]).astype( - "M8[h]" -) -IOWAGA = IOWAGA.rename( - name_dict={ - "pdir0": "pdp0", - "pdir1": "pdp1", - "pdir2": "pdp2", - "pdir3": "pdp3", - "pdir4": "pdp4", - "pdir5": "pdp5", + workdir, plotsdir = update_paths_mconfig(output_dir, mconfig) + + ID, track_names, hemis, batch = io.init_data( + str(track_name), str(batch_key), str(ID_flag), str(workdir) + ) # TODO: clean up application of str() to all arguments + + kwargs = { + "ID": ID, + "track_names": track_names, + "hemis": hemis, + "batch": batch, + "mconfig": workdir, + "heading": "** Revised input parameters:", } -) + report_input_parameters(**kwargs) + with suppress_stdout(verbose): + hemis, batch = batch_key.split("_") -def sel_data(I, lon_range, lat_range, timestamp=None): - """ - this method returns the selected data in the lon-lat box at an interpolated timestamp - """ - lon_flag = (lon_range[0] < I.longitude.data) & (I.longitude.data < lon_range[1]) - lat_flag = (lat_range[0] < I.latitude.data) & (I.latitude.data < lat_range[1]) - time_flag = (time_range[0] < I.time.data) & (I.time.data < time_range[1]) - if timestamp is None: - I = I.isel(latitude=lat_flag, longitude=lon_flag) - else: - I = ( - I.isel(latitude=lat_flag, longitude=lon_flag, time=time_flag) - .sortby("time") - .interp(time=np.datetime64(timestamp)) - ) - return I + save_path = Path(workdir, batch_key, "A02_prior") + plot_path = Path(plotsdir, hemis, batch_key, track_name) + save_name = "A02_" + track_name + plot_name = "A02_" + track_name_short + plot_path.mkdir(parents=True, exist_ok=True) + save_path.mkdir(parents=True, exist_ok=True) + all_beams = mconfig["beams"]["all_beams"] -try: - G_beam = sel_data(IOWAGA, lon_range, lat_range, timestamp).load() - G_prior = sel_data(G_beam, lon_range, lat_range_prior) + load_path = Path(workdir, batch_key, "B01_regrid") + with h5py.File(load_path / (track_name + "_B01_binned.h5"), "r") as Gd: - if hemis == "SH": - # create Ice mask - ice_mask = (G_beam.ice > 0) | np.isnan(G_beam.ice) - - lats = list(ice_mask.latitude.data) - lats.sort(reverse=True) - - # find 1st latitude that is completely full with sea ice. - ice_lat_pos = next( - ( - i - for i, j in enumerate( - (ice_mask.sum("longitude") == ice_mask.longitude.size).sel( - latitude=lats - ) - ) - if j - ), - None, - ) - # recreate lat mask based on this criteria - lat_mask = lats < lats[ice_lat_pos] - lat_mask = xr.DataArray( - lat_mask.repeat(ice_mask.longitude.size).reshape(ice_mask.shape), - dims=ice_mask.dims, - coords=ice_mask.coords, - ) - lat_mask["latitude"] = lats + # Select the beam with the minimum absolute latitude for each beam in the dataset Gd + # and store the corresponding row as a dictionary in G1? CP + G1 = { + b: io.get_beam_hdf_store(Gd[b]).iloc[ + abs(io.get_beam_hdf_store(Gd[b])["lats"]).argmin() + ] + for b in all_beams + } - # combine ice mask and new lat mask - ice_mask = ice_mask + lat_mask + G1 = pd.DataFrame.from_dict(G1).T - else: - ice_mask = np.isnan(G_beam.ice) - lats = ice_mask.latitude - - # find closed latituyde with with non-nan data - ice_lat_pos = ( - abs( - lats.where(ice_mask.sum("longitude") > 4, np.nan) - - np.array(lat_range).mean() + if hemis == "SH": + lon_range, lat_range, lat_range_prior = calculate_ranges( + hemis=hemis, G1=G1, dlon_deg=1, dlat_deg=(30, 5), dlat_deg_prior=(2, 1) + ) + else: + lon_range, lat_range, lat_range_prior = calculate_ranges( + hemis=hemis, G1=G1, dlon_deg=2, dlat_deg=(20, 20), dlat_deg_prior=(2, 1) ) - .argmin() - .data - ) - # redefine lat-range - lat_range = lats[ice_lat_pos].data - 2, lats[ice_lat_pos].data + 2 - lat_flag2 = (lat_range[0] < lats.data) & (lats.data < lat_range[1]) + IOWAGA = get_iowaga(data_url=data_url, dataset_key=dataset_key) + + timestamp, time_range = define_timestamp_and_time_range(ID, dtime) + + # TODO: refactor this try-except block -- too much complexity within. CP + try: + G_beam = sel_data( + Ibeam=IOWAGA, + lon_range=lon_range, + lat_range=lat_range, + time_range=time_range, + timestamp=timestamp, + ).load() + G_prior = sel_data( + Ibeam=G_beam, + lon_range=lon_range, + lat_range=lat_range_prior, + time_range=time_range, + ) - lat_mask = xr.DataArray( - lat_flag2.repeat(ice_mask.longitude.size).reshape(ice_mask.shape), - dims=ice_mask.dims, - coords=ice_mask.coords, - ) - lat_mask["latitude"] = lats - - # plot 1st figure - def draw_range(lon_range, lat_range, *args, **kargs): - plt.plot( - [lon_range[0], lon_range[1], lon_range[1], lon_range[0], lon_range[0]], - [lat_range[0], lat_range[0], lat_range[1], lat_range[1], lat_range[0]], - *args, - **kargs, - ) + if hemis == "SH": + # create Ice mask + ice_mask = (G_beam.ice > 0) | np.isnan(G_beam.ice) + + lats = list(ice_mask.latitude.data) + lats.sort(reverse=True) + + # find 1st latitude that is completely full with sea ice. + ice_lat_pos = next( + ( + i + for i, j in enumerate( + (ice_mask.sum("longitude") == ice_mask.longitude.size).sel( + latitude=lats + ) + ) + if j + ), + None, + ) + # recreate lat mask based on this criteria + lat_mask = lats < lats[ice_lat_pos] + lat_mask = xr.DataArray( + lat_mask.repeat(ice_mask.longitude.size).reshape(ice_mask.shape), + dims=ice_mask.dims, + coords=ice_mask.coords, + ) + lat_mask["latitude"] = lats - dir_clev = np.arange(0, 380, 20) - f_clev = np.arange(1 / 40, 1 / 5, 0.01) - fvar = ["ice", "dir", "dp", "spr", "fp", "hs"] - fcmap = [ - plt.cm.Blues_r, - color_schemes.circle_medium_triple, - color_schemes.circle_medium_triple, - plt.cm.Blues, - plt.cm.Blues, - plt.cm.Blues, - ] - fpos = [0, 1, 2, 3, 4, 5] - clevs = [ - np.arange(0, 1, 0.2), - dir_clev, - dir_clev, - np.arange(0, 90, 10), - f_clev, - np.arange(0.5, 9, 0.5), - ] + # combine ice mask and new lat mask + ice_mask = ice_mask + lat_mask - font_for_print() + else: + ice_mask = np.isnan(G_beam.ice) + lats = ice_mask.latitude - F = M.figure_axis_xy(4, 3.5, view_scale=0.9, container=True) - plt.suptitle( - track_name_short + " | " + file_name_base[0:-1].replace("_", " "), y=1.3 - ) - lon, lat = G_beam.longitude, G_beam.latitude + # find closed latitude with with non-nan data + ice_lat_pos = ( + abs( + lats.where(ice_mask.sum("longitude") > 4, np.nan) + - np.array(lat_range).mean() + ) + .argmin() + .data + ) - gs = GridSpec(9, 6, wspace=0.1, hspace=0.4) + # redefine lat-range + lat_range = lats[ice_lat_pos].data - 2, lats[ice_lat_pos].data + 2 + lat_flag2 = (lat_range[0] < lats.data) & (lats.data < lat_range[1]) - for fv, fp, fc, cl in zip(fvar, fpos, fcmap, clevs): - ax1 = F.fig.add_subplot(gs[0:7, fp]) - if fp == 0: - ax1.spines["bottom"].set_visible(False) - ax1.spines["left"].set_visible(False) - ax1.tick_params(labelbottom=True, bottom=True) + lat_mask = xr.DataArray( + lat_flag2.repeat(ice_mask.longitude.size).reshape(ice_mask.shape), + dims=ice_mask.dims, + coords=ice_mask.coords, + ) + lat_mask["latitude"] = lats + + # plot 1st figure + def draw_range(lon_range, lat_range, *args, **kwargs): + plt.plot( + [ + lon_range[0], + lon_range[1], + lon_range[1], + lon_range[0], + lon_range[0], + ], + [ + lat_range[0], + lat_range[0], + lat_range[1], + lat_range[1], + lat_range[0], + ], + *args, + **kwargs, + ) - else: - ax1.axis("off") + dir_clev = np.arange(0, 380, 20) + f_clev = np.arange(1 / 40, 1 / 5, 0.01) + fvar = ["ice", "dir", "dp", "spr", "fp", "hs"] + fcmap = [ + plt.cm.Blues_r, + color_schemes.circle_medium_triple, + color_schemes.circle_medium_triple, + plt.cm.Blues, + plt.cm.Blues, + plt.cm.Blues, + ] + fpos = [0, 1, 2, 3, 4, 5] + clevs = [ + np.arange(0, 1, 0.2), + dir_clev, + dir_clev, + np.arange(0, 90, 10), + f_clev, + np.arange(0.5, 9, 0.5), + ] + + font_for_print() + + F = M.figure_axis_xy(4, 3.5, view_scale=0.9, container=True) + + file_name_base = "LOPS_WW3-GLOB-30M_" + plt.suptitle( + track_name_short + " | " + file_name_base[0:-1].replace("_", " "), y=1.3 + ) + lon, lat = G_beam.longitude, G_beam.latitude - plt.plot(G1["lons"], G1["lats"], ".r", markersize=5) - draw_range(lon_range, lat_range_prior, c="red", linewidth=1, zorder=12) - draw_range(lon_range, lat_range, c="blue", linewidth=0.7, zorder=10) + gs = GridSpec(9, 6, wspace=0.1, hspace=0.4) - if fv != "ice": - cm = plt.pcolor(lon, lat, G_beam[fv], vmin=cl[0], vmax=cl[-1], cmap=fc) - if G_beam.ice.shape[0] > 1: - plt.contour(lon, lat, G_beam.ice, colors="black", linewidths=0.6) - else: - cm = plt.pcolor(lon, lat, G_beam[fv], vmin=cl[0], vmax=cl[-1], cmap=fc) + for fv, fp, fc, cl in zip(fvar, fpos, fcmap, clevs): + ax1 = F.fig.add_subplot(gs[0:7, fp]) + if fp == 0: + ax1.spines["bottom"].set_visible(False) + ax1.spines["left"].set_visible(False) + ax1.tick_params(labelbottom=True, bottom=True) - plt.title(G_beam[fv].long_name.replace(" ", "\n") + "\n" + fv, loc="left") - ax1.axis("equal") + else: + ax1.axis("off") - ax2 = F.fig.add_subplot(gs[-1, fp]) - cbar = plt.colorbar(cm, cax=ax2, orientation="horizontal", aspect=1, fraction=1) - cl_ticks = np.linspace(cl[0], cl[-1], 3) + plt.plot(G1["lons"], G1["lats"], ".r", markersize=5) + draw_range(lon_range, lat_range_prior, c="red", linewidth=1, zorder=12) + draw_range(lon_range, lat_range, c="blue", linewidth=0.7, zorder=10) - cbar.set_ticks(np.round(cl_ticks, 3)) - cbar.set_ticklabels(np.round(cl_ticks, 2)) + if fv != "ice": + cm = plt.pcolor( + lon, lat, G_beam[fv], vmin=cl[0], vmax=cl[-1], cmap=fc + ) + if G_beam.ice.shape[0] > 1: + plt.contour( + lon, lat, G_beam.ice, colors="black", linewidths=0.6 + ) + else: + cm = plt.pcolor( + lon, lat, G_beam[fv], vmin=cl[0], vmax=cl[-1], cmap=fc + ) - F.save_pup(path=plot_path, name=plot_name + "_hindcast_data") + plt.title( + G_beam[fv].long_name.replace(" ", "\n") + "\n" + fv, loc="left" + ) + ax1.axis("equal") - G_beam_masked = G_beam.where(~ice_mask, np.nan) - ice_mask_prior = ice_mask.sel(latitude=G_prior.latitude) - G_prior_masked = G_prior.where(~ice_mask_prior, np.nan) + ax2 = F.fig.add_subplot(gs[-1, fp]) + cbar = plt.colorbar( + cm, cax=ax2, orientation="horizontal", aspect=1, fraction=1 + ) + cl_ticks = np.linspace(cl[0], cl[-1], 3) - def test_nan_frac(imask): - "test if False is less then 0.3" - return ((~imask).sum() / imask.size).data < 0.3 + cbar.set_ticks(np.round(cl_ticks, 3)) + cbar.set_ticklabels(np.round(cl_ticks, 2)) - while test_nan_frac(ice_mask_prior): - print(lat_range_prior) - lat_range_prior = lat_range_prior[0] + 0.5, lat_range_prior[1] + 0.5 - G_prior = sel_data(G_beam, lon_range, lat_range_prior) - ice_mask_prior = ice_mask.sel(latitude=G_prior.latitude) + F.save_pup(path=plot_path, name=plot_name + "_hindcast_data") - G_prior_masked = G_prior.where(~ice_mask_prior, np.nan) + ice_mask_prior = ice_mask.sel(latitude=G_prior.latitude) + G_prior_masked = G_prior.where(~ice_mask_prior, np.nan) - ### make pandas table with obs track end postitions + def test_nan_frac(imask): + "test if False is less then 0.3" + return ((~imask).sum() / imask.size).data < 0.3 - key_list = list(G_prior_masked.keys()) - # define directional and amplitude pairs - # pack as (amp, angle) - key_list_pairs = { - "mean": ("hs", "dir"), - "peak": ("hs", "dp"), - "partion0": ("phs0", "pdp0"), - "partion1": ("phs1", "pdp1"), - "partion2": ("phs2", "pdp2"), - "partion3": ("phs3", "pdp3"), - "partion4": ("phs4", "pdp4"), - } + while test_nan_frac(ice_mask_prior): + print(lat_range_prior) + lat_range_prior = lat_range_prior[0] + 0.5, lat_range_prior[1] + 0.5 + G_prior = sel_data(G_beam, lon_range, lat_range_prior) + ice_mask_prior = ice_mask.sel(latitude=G_prior.latitude) - key_list_pairs2 = list() - for k in key_list_pairs.values(): - key_list_pairs2.append(k[0]) - key_list_pairs2.append(k[1]) + G_prior_masked = G_prior.where(~ice_mask_prior, np.nan) - key_list_scaler = set(key_list) - set(key_list_pairs2) + ### make pandas table with obs track end positions - ### derive angle avearge - Tend = pd.DataFrame(index=key_list, columns=["mean", "std", "name"]) + key_list = list(G_prior_masked.keys()) + # define directional and amplitude pairs + # pack as (amp, angle) + key_list_pairs = { + "mean": ("hs", "dir"), + "peak": ("hs", "dp"), + "partion0": ("phs0", "pdp0"), + "partion1": ("phs1", "pdp1"), + "partion2": ("phs2", "pdp2"), + "partion3": ("phs3", "pdp3"), + "partion4": ("phs4", "pdp4"), + } - for k, pair in key_list_pairs.items(): - ave_amp, ave_deg, std_amp, std_deg = waves.get_ave_amp_angle( - G_prior_masked[pair[0]].data, G_prior_masked[pair[1]].data - ) - Tend.loc[pair[0]] = ave_amp, std_amp, G_prior_masked[pair[0]].long_name - Tend.loc[pair[1]] = ave_deg, std_deg, G_prior_masked[pair[1]].long_name - - for k in key_list_scaler: - Tend.loc[k] = ( - G_prior_masked[k].mean().data, - G_prior_masked[k].std().data, - G_prior_masked[k].long_name, - ) + # flatten key_list_pairs.values to a list + key_list_pairs2 = [ + item for pair in key_list_pairs.values() for item in pair + ] - Tend = Tend.T - Tend["lon"] = [ - ice_mask_prior.longitude.mean().data, - ice_mask_prior.longitude.std().data, - "lontigude", - ] - Tend["lat"] = [ - ice_mask_prior.latitude[ice_mask_prior.sum("longitude") == 0].mean().data, - ice_mask_prior.latitude[ice_mask_prior.sum("longitude") == 0].std().data, - "latitude", - ] - Tend = Tend.T + key_list_scaler = set(key_list) - set(key_list_pairs2) - Prior = dict() - Prior["incident_angle"] = { - "value": Tend["mean"]["dp"].astype("float"), - "name": Tend["name"]["dp"], - } - Prior["spread"] = { - "value": Tend["mean"]["spr"].astype("float"), - "name": Tend["name"]["spr"], - } - Prior["Hs"] = { - "value": Tend["mean"]["hs"].astype("float"), - "name": Tend["name"]["hs"], - } - Prior["peak_period"] = { - "value": 1 / Tend["mean"]["fp"].astype("float"), - "name": "1/" + Tend["name"]["fp"], - } + ### derive angle average + Tend = pd.DataFrame(index=key_list, columns=["mean", "std", "name"]) - Prior["center_lon"] = { - "value": Tend["mean"]["lon"].astype("float"), - "name": Tend["name"]["lon"], - } - Prior["center_lat"] = { - "value": Tend["mean"]["lat"].astype("float"), - "name": Tend["name"]["lat"], - } + for k, pair in key_list_pairs.items(): + ave_amp, ave_deg, std_amp, std_deg = waves.get_ave_amp_angle( + G_prior_masked[pair[0]].data, G_prior_masked[pair[1]].data + ) + Tend.loc[pair[0]] = ave_amp, std_amp, G_prior_masked[pair[0]].long_name + Tend.loc[pair[1]] = ave_deg, std_deg, G_prior_masked[pair[1]].long_name + + for k in key_list_scaler: + Tend.loc[k] = ( + G_prior_masked[k].mean().data, + G_prior_masked[k].std().data, + G_prior_masked[k].long_name, + ) - target_name = "A02_" + track_name + "_hindcast_success" - - MT.save_pandas_table({"priors_hindcast": Tend}, save_name, save_path) -except: - target_name = "A02_" + track_name + "_hindcast_fail" - - -def plot_prior(Prior, axx): - angle = Prior["incident_angle"][ - "value" - ] # incident direction in degrees from North clockwise (Meerological convention) - # use - angle_plot = -angle - 90 - axx.quiver( - Prior["center_lon"]["value"], - Prior["center_lat"]["value"], - -np.cos(angle_plot * np.pi / 180), - -np.sin(angle_plot * np.pi / 180), - scale=4.5, - zorder=12, - width=0.1, - headlength=4.5, - minshaft=2, - alpha=0.6, - color="black", - ) - axx.plot( - Prior["center_lon"]["value"], - Prior["center_lat"]["value"], - ".", - markersize=6, - zorder=12, - alpha=1, - color="black", - ) - tstring = ( - " " - + str(np.round(Prior["peak_period"]["value"], 1)) - + "sec \n " - + str(np.round(Prior["Hs"]["value"], 1)) - + "m\n " - + str(np.round(angle, 1)) - + "deg" - ) - plt.text(lon_range[1], Prior["center_lat"]["value"], tstring) + Tend = Tend.T + Tend["lon"] = [ + ice_mask_prior.longitude.mean().data, + ice_mask_prior.longitude.std().data, + "lontigude", # TODO: fix typo? + ] + Tend["lat"] = [ + ice_mask_prior.latitude[ice_mask_prior.sum("longitude") == 0] + .mean() + .data, + ice_mask_prior.latitude[ice_mask_prior.sum("longitude") == 0] + .std() + .data, + "latitude", + ] + Tend = Tend.T + + Prior = build_prior(Tend) + target_name = "A02_" + track_name + "_hindcast_success" + MT.save_pandas_table( + {"priors_hindcast": Tend}, save_name, str(save_path) + ) # TODO: refactor save_pandas_table to use Path objects + except Exception: + target_name = "A02_" + track_name + "_hindcast_fail" + + def plot_prior(Prior, axx): + angle = Prior["incident_angle"][ + "value" + ] # incident direction in degrees from North clockwise (Meterological convention) + # use + angle_plot = -angle - 90 + axx.quiver( + Prior["center_lon"]["value"], + Prior["center_lat"]["value"], + -np.cos(angle_plot * np.pi / 180), + -np.sin(angle_plot * np.pi / 180), + scale=4.5, + zorder=12, + width=0.1, + headlength=4.5, + minshaft=2, + alpha=0.6, + color="black", + ) + axx.plot( + Prior["center_lon"]["value"], + Prior["center_lat"]["value"], + ".", + markersize=6, + zorder=12, + alpha=1, + color="black", + ) + tstring = ( + " " + + str(np.round(Prior["peak_period"]["value"], 1)) + + "sec \n " + + str(np.round(Prior["Hs"]["value"], 1)) + + "m\n " + + str(np.round(angle, 1)) + + "deg" + ) + plt.text(lon_range[1], Prior["center_lat"]["value"], tstring) + try: + # plot 2nd figure -try: - # plot 2nd figure + font_for_print() + F = M.figure_axis_xy(2, 4.5, view_scale=0.9, container=False) - font_for_print() - F = M.figure_axis_xy(2, 4.5, view_scale=0.9, container=False) + ax1 = F.ax + lon, lat = G_beam.longitude, G_beam.latitude + ax1.spines["bottom"].set_visible(False) + ax1.spines["left"].set_visible(False) + ax1.tick_params(labelbottom=True, bottom=True) - ax1 = F.ax - lon, lat = G_beam.longitude, G_beam.latitude - ax1.spines["bottom"].set_visible(False) - ax1.spines["left"].set_visible(False) - ax1.tick_params(labelbottom=True, bottom=True) + plot_prior(Prior, ax1) + + # TODO: refactor as comprehension -- it might be less readable and more lines?. CP + str_list = list() + for i in np.arange(0, 6): + str_list.append( + " " + + str(np.round(Tend.loc["ptp" + str(i)]["mean"], 1)) + + "sec\n " + + str(np.round(Tend.loc["phs" + str(i)]["mean"], 1)) + + "m " + + str(np.round(Tend.loc["pdp" + str(i)]["mean"], 1)) + + "d" + ) - plot_prior(Prior, ax1) + plt.text(lon_range[1], lat_range[0], "\n ".join(str_list)) + + for vv in zip( + ["pdp0", "pdp1", "pdp2", "pdp3", "pdp4", "pdp5"], + ["phs0", "phs1", "phs3", "phs4", "phs5"], + ): + angle_plot = -Tend.loc[vv[0]]["mean"] - 90 + vsize = (1 / Tend.loc[vv[1]]["mean"]) ** (1 / 2) * 5 + ax1.quiver( + Prior["center_lon"]["value"], + Prior["center_lat"]["value"], + -np.cos(angle_plot * np.pi / 180), + -np.sin(angle_plot * np.pi / 180), + scale=vsize, + zorder=5, + width=0.1, + headlength=4.5, + minshaft=4, + alpha=0.6, + color="green", + ) - str_list = list() - for i in np.arange(0, 6): - str_list.append( - " " - + str(np.round(Tend.loc["ptp" + str(i)]["mean"], 1)) - + "sec\n " - + str(np.round(Tend.loc["phs" + str(i)]["mean"], 1)) - + "m " - + str(np.round(Tend.loc["pdp" + str(i)]["mean"], 1)) - + "d" - ) + plt.plot(G1["lons"], G1["lats"], ".r", markersize=5) + draw_range(lon_range, lat_range_prior, c="red", linewidth=1, zorder=11) + draw_range(lon_range, lat_range, c="blue", linewidth=0.7, zorder=10) - plt.text(lon_range[1], lat_range[0], "\n ".join(str_list)) - - for vv in zip( - ["pdp0", "pdp1", "pdp2", "pdp3", "pdp4", "pdp5"], - ["phs0", "phs1", "phs3", "phs4", "phs5"], - ): - angle_plot = -Tend.loc[vv[0]]["mean"] - 90 - vsize = (1 / Tend.loc[vv[1]]["mean"]) ** (1 / 2) * 5 - print(vsize) - ax1.quiver( - Prior["center_lon"]["value"], - Prior["center_lat"]["value"], - -np.cos(angle_plot * np.pi / 180), - -np.sin(angle_plot * np.pi / 180), - scale=vsize, - zorder=5, - width=0.1, - headlength=4.5, - minshaft=4, - alpha=0.6, - color="green", - ) + fv = "ice" + if fv != "ice": + plt.pcolor(lon, lat, G_beam[fv].where(~ice_mask, np.nan), cmap=fc) + plt.contour(lon, lat, G_beam.ice, colors="black", linewidths=0.6) + else: + plt.pcolor(lon, lat, G_beam[fv], cmap=fc) + + plt.title( + "Prior\n" + + file_name_base[0:-1].replace("_", " ") + + "\n" + + track_name_short + + "\nIncident angle", + loc="left", + ) + ax1.axis("equal") - plt.plot(G1["lons"], G1["lats"], ".r", markersize=5) - draw_range(lon_range, lat_range_prior, c="red", linewidth=1, zorder=11) - draw_range(lon_range, lat_range, c="blue", linewidth=0.7, zorder=10) + F.save_pup(path=plot_path, name=plot_name + "_hindcast_prior") + except Exception as e: + print(e) + echo("print 2nd figure failed", "red") - fv = "ice" - if fv != "ice": - plt.pcolor(lon, lat, G_beam[fv].where(~ice_mask, np.nan), cmap=fc) - plt.contour(lon, lat, G_beam.ice, colors="black", linewidths=0.6) - else: - plt.pcolor(lon, lat, G_beam[fv], cmap=fc) - - plt.title( - "Prior\n" - + file_name_base[0:-1].replace("_", " ") - + "\n" - + track_name_short - + "\nIncident angle", - loc="left", - ) - ax1.axis("equal") + MT.json_save( + target_name, + save_path, + str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M")), + ) - F.save_pup(path=plot_path, name=plot_name + "_hindcast_prior") -except Exception as e: - print(e) - print("print 2nd figure failed") + echo("done") -MT.json_save( - target_name, save_path, str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M")) -) -print("done") +make_iowaga_threads_prior_app = makeapp(run_A02c_IOWAGA_thredds_prior, name="threads-prior") + +if __name__ == "__main__": + make_iowaga_threads_prior_app() diff --git a/src/icesat2_tracks/analysis_db/B01_SL_load_single_file.py b/src/icesat2_tracks/analysis_db/B01_SL_load_single_file.py index 168f7956..756ca3a8 100644 --- a/src/icesat2_tracks/analysis_db/B01_SL_load_single_file.py +++ b/src/icesat2_tracks/analysis_db/B01_SL_load_single_file.py @@ -33,6 +33,7 @@ validate_batch_key, validate_output_dir, suppress_stdout, + report_input_parameters, update_paths_mconfig, echo, echoparam, @@ -90,18 +91,19 @@ def run_B01_SL_load_single_file( ID_flag: bool = True, plot_flag: bool = True, output_dir: str = typer.Option(..., callback=validate_output_dir), - verbose: bool = False + verbose: bool = False, ): """ Open an ICEsat2 tbeam_stats.pyrack, apply filters and corrections, and output smoothed photon heights on a regular grid in an .nc file. """ # report input parameters - echo("** Input parameters:") - echoparam("track_name", track_name) - echoparam("batch_key", batch_key) - echoparam("ID_flag", ID_flag) - echoparam("plot_flag", plot_flag) - echoparam("output_dir", output_dir) + kwargs = { + "track_name": track_name, + "batch_key": batch_key, + "ID_flag": ID_flag, + "output_dir": output_dir, + } + report_input_parameters(**kwargs) xr.set_options(display_style="text") matplotlib.use("Agg") # prevent plot windows from opening @@ -274,7 +276,7 @@ def run_B01_SL_load_single_file( echo("done") -step1app = makeapp(run_B01_SL_load_single_file, name="load-file") +load_file_app = makeapp(run_B01_SL_load_single_file, name="load-file") if __name__ == "__main__": - step1app() + load_file_app() diff --git a/src/icesat2_tracks/analysis_db/B02_make_spectra_gFT.py b/src/icesat2_tracks/analysis_db/B02_make_spectra_gFT.py index a9cb6b90..81baf24e 100644 --- a/src/icesat2_tracks/analysis_db/B02_make_spectra_gFT.py +++ b/src/icesat2_tracks/analysis_db/B02_make_spectra_gFT.py @@ -1,40 +1,53 @@ +#!/usr/bin/env python """ -This file open a ICEsat2 track applied filters and corections and returns smoothed photon heights on a regular grid in an .nc file. +This file open a ICEsat2 track applied filters and corrections and returns smoothed photon heights on a regular grid in an .nc file. This is python 3 """ -import sys - -import matplotlib.pyplot as plt -from icesat2_tracks.config.IceSAT2_startup import mconfig +import copy +import datetime +import h5py +from pathlib import Path +from functools import partial -from threadpoolctl import threadpool_info, threadpool_limits -from pprint import pprint import numpy as np import xarray as xr +from pprint import pprint +from scipy.ndimage import label +from threadpoolctl import threadpool_info, threadpool_limits +import matplotlib +from matplotlib import pyplot as plt +import typer -import h5py +import icesat2_tracks.ICEsat2_SI_tools.generalized_FT as gFT import icesat2_tracks.ICEsat2_SI_tools.iotools as io import icesat2_tracks.ICEsat2_SI_tools.spectral_estimates as spec - -import time -import imp -import copy -import icesat2_tracks.ICEsat2_SI_tools.spicke_remover as spicke_remover -import datetime -import icesat2_tracks.ICEsat2_SI_tools.generalized_FT as gFT -from scipy.ndimage import label +import icesat2_tracks.local_modules.m_general_ph3 as M +import icesat2_tracks.local_modules.m_spectrum_ph3 as spicke_remover import icesat2_tracks.local_modules.m_tools_ph3 as MT -from icesat2_tracks.local_modules import m_general_ph3 as M +from icesat2_tracks.config.IceSAT2_startup import mconfig + +from icesat2_tracks.clitools import ( + echo, + validate_batch_key, + validate_output_dir, + suppress_stdout, + update_paths_mconfig, + report_input_parameters, + validate_track_name_steps_gt_1, + makeapp, +) + +# import tracemalloc # removing this for now. CP -import tracemalloc +matplotlib.use("Agg") # prevent plot windows from opening def linear_gap_fill(F, key_lead, key_int): """ F pd.DataFrame - key_lead key in F that determined the independent coordindate + key_lead key in F that determined the independent coordinate key_int key in F that determined the dependent data """ y_g = np.array(F[key_int]) @@ -45,468 +58,521 @@ def linear_gap_fill(F, key_lead, key_int): return y_g -track_name, batch_key, test_flag = io.init_from_input( - sys.argv -) # loads standard experiment -hemis, batch = batch_key.split("_") -ATlevel = "ATL03" - -load_path = mconfig["paths"]["work"] + "/" + batch_key + "/B01_regrid/" -load_file = load_path + "processed_" + ATlevel + "_" + track_name + ".h5" - -save_path = mconfig["paths"]["work"] + "/" + batch_key + "/B02_spectra/" -save_name = "B02_" + track_name - -plot_path = ( - mconfig["paths"]["plot"] - + "/" - + hemis - + "/" - + batch_key - + "/" - + track_name - + "/B03_spectra/" -) -MT.mkdirs_r(plot_path) -MT.mkdirs_r(save_path) -bad_track_path = mconfig["paths"]["work"] + "bad_tracks/" + batch_key + "/" - -all_beams = mconfig["beams"]["all_beams"] -high_beams = mconfig["beams"]["high_beams"] -low_beams = mconfig["beams"]["low_beams"] - -N_process = 4 -print("N_process=", N_process) - -Gd = h5py.File(load_path + "/" + track_name + "_B01_binned.h5", "r") - - -# test amount of nans in the data -nan_fraction = list() -for k in all_beams: - heights_c_std = io.get_beam_var_hdf_store(Gd[k], "x") - nan_fraction.append(np.sum(np.isnan(heights_c_std)) / heights_c_std.shape[0]) - -del heights_c_std - -# test if beam pairs have bad ratio -bad_ratio_flag = False -for group in mconfig["beams"]["groups"]: - Ia = Gd[group[0]] - Ib = Gd[group[1]] - ratio = Ia["x"][:].size / Ib["x"][:].size - if (ratio > 10) | (ratio < 0.1): - print("bad data ratio ", ratio, 1 / ratio) - bad_ratio_flag = True +def run_B02_make_spectra_gFT( + track_name: str = typer.Option(..., callback=validate_track_name_steps_gt_1), + batch_key: str = typer.Option(..., callback=validate_batch_key), + ID_flag: bool = True, + output_dir: str = typer.Option(None, callback=validate_output_dir), + verbose: bool = False, +): + """ + TODO: add docstring + """ -if (np.array(nan_fraction).mean() > 0.95) | bad_ratio_flag: - print( - "nan fraction > 95%, or bad ratio of data, pass this track, add to bad tracks" - ) - MT.json_save( - track_name, - bad_track_path, - { - "nan_fraction": np.array(nan_fraction).mean(), - "date": str(datetime.date.today()), - }, + track_name, batch_key, _ = io.init_from_input( + [ + None, + track_name, + batch_key, + ID_flag, + ] # init_from_input expects sys.argv with 4 elements ) - print("exit.") - exit() - -# test LS with an even grid where missing values are set to 0 -imp.reload(spec) -print(Gd.keys()) -Gi = Gd[list(Gd.keys())[0]] # to select a test beam -dist = io.get_beam_var_hdf_store(Gd[list(Gd.keys())[0]], "x") - -# make dataframe form hdf5 -# derive spectal limits -# Longest deserved period: -T_max = 40 # sec -k_0 = (2 * np.pi / T_max) ** 2 / 9.81 -x = np.array(dist).squeeze() -dx = np.round(np.median(np.diff(x)), 1) -min_datapoint = 2 * np.pi / k_0 / dx - -Lpoints = int(np.round(min_datapoint) * 10) -Lmeters = Lpoints * dx + kargs = { + "track_name": track_name, + "batch_key": batch_key, + "ID_flag": ID_flag, + "output_dir": output_dir, + } + report_input_parameters(**kargs) -print("L number of gridpoint:", Lpoints) -print("L length in km:", Lmeters / 1e3) -print("approx number windows", 2 * dist.iloc[-1] / Lmeters - 1) + with suppress_stdout(verbose): -T_min = 6 -lambda_min = 9.81 * T_min**2 / (2 * np.pi) -flim = 1 / T_min + workdir, _ = update_paths_mconfig(output_dir, mconfig) + load_path = Path(workdir, batch_key, "B01_regrid") -oversample = 2 -dlambda = Lmeters * oversample -dk = 2 * np.pi / dlambda -kk = np.arange(0, 1 / lambda_min, 1 / dlambda) * 2 * np.pi -kk = kk[k_0 <= kk] + save_path = Path(workdir, batch_key, "B02_spectra") + save_name = f"B02_{track_name}" -print("2 M = ", kk.size * 2) + save_path.mkdir(parents=True, exist_ok=True) -print("define global xlims") -dist_list = np.array([np.nan, np.nan]) -for k in all_beams: - print(k) - x = Gd[k + "/x"][:] - print(x[0], x[-1]) - dist_list = np.vstack([dist_list, [x[0], x[-1]]]) + bad_track_path = Path(workdir, "bad_tracks", batch_key) -xlims = np.nanmin(dist_list[:, 0]) - dx, np.nanmin(dist_list[:, 1]) + all_beams = mconfig["beams"]["all_beams"] + N_process = 4 + print("N_process=", N_process) -for k in all_beams: - dist_i = io.get_beam_var_hdf_store(Gd[k], "x") - x_mask = (dist_i > xlims[0]) & (dist_i < xlims[1]) - print(k, sum(x_mask["x"]) / (xlims[1] - xlims[0])) + Gd = h5py.File(Path(load_path) / (track_name + "_B01_binned.h5"), "r") - -print("-reduced frequency resolution") -kk = kk[::2] -print("set xlims: ", xlims) -print( - "Loop start: ", - tracemalloc.get_traced_memory()[0] / 1e6, - tracemalloc.get_traced_memory()[1] / 1e6, -) - -G_gFT = dict() -G_gFT_x = dict() -G_rar_fft = dict() -Pars_optm = dict() - - -k = all_beams[1] -# sliderule version -hkey = "h_mean" -hkey_sigma = "h_sigma" -for k in all_beams: - tracemalloc.start() - # ------------------------------- use gridded data - Gi = io.get_beam_hdf_store(Gd[k]) - x_mask = (Gi["x"] > xlims[0]) & (Gi["x"] < xlims[1]) - if sum(x_mask) / (xlims[1] - xlims[0]) < 0.005: - print("------------------- not data in beam found; skip") - - Gd_cut = Gi[x_mask] - x = Gd_cut["x"] - del Gi - # cut data: - x_mask = (x >= xlims[0]) & (x <= xlims[1]) - x = x[x_mask] - dd = np.copy(Gd_cut[hkey]) - - dd_error = np.copy(Gd_cut[hkey_sigma]) - - dd_error[np.isnan(dd_error)] = 100 - - # compute slope spectra !! - dd = np.gradient(dd) - dd, _ = spicke_remover.spicke_remover(dd, spreed=10, verbose=False) - dd_nans = (np.isnan(dd)) + (Gd_cut["N_photos"] <= 5) - - # using gappy data - dd_no_nans = dd[~dd_nans] # windowing is applied here - x_no_nans = x[~dd_nans] - dd_error_no_nans = dd_error[~dd_nans] - - print("gFT") - with threadpool_limits(limits=N_process, user_api="blas"): - pprint(threadpool_info()) - - S = gFT.wavenumber_spectrogram_gFT( - np.array(x_no_nans), - np.array(dd_no_nans), - Lmeters, - dx, - kk, - data_error=dd_error_no_nans, - ov=None, - ) - GG, GG_x, Params = S.cal_spectrogram( - xlims=xlims, max_nfev=8000, plot_flag=False - ) - - print( - "after ", - k, - tracemalloc.get_traced_memory()[0] / 1e6, - tracemalloc.get_traced_memory()[1] / 1e6, - ) - - plot_data_model = False - if plot_data_model: - for i in np.arange(0, 16, 2): - c1 = "blue" - c2 = "red" - - GGi = GG.isel(x=i) - - xi_1 = GG_x.x[i] - xi_2 = GG_x.x[i + 1] - - F = M.figure_axis_xy(16, 2) - eta = GG_x.eta - - y_model = GG_x.y_model[:, i] - plt.plot(eta + xi_1, y_model, "-", c=c1, linewidth=0.8, alpha=1, zorder=12) - y_model = GG_x.y_model[:, i + 1] - plt.plot(eta + xi_2, y_model, "-", c=c2, linewidth=0.8, alpha=1, zorder=12) - - FT = gFT.generalized_Fourier(eta + xi_1, None, GG.k) - _ = FT.get_H() - FT.p_hat = np.concatenate([GGi.gFT_cos_coeff, GGi.gFT_sin_coeff]) - plt.plot( - eta + xi_1, - FT.model(), - "-", - c="orange", - linewidth=0.8, - alpha=1, - zorder=2, + # test amount of nans in the data TODO: rewrite as a comprehension. CP + nan_fraction = list() + for k in all_beams: + heights_c_std = io.get_beam_var_hdf_store(Gd[k], "x") + nan_fraction.append( + np.sum(np.isnan(heights_c_std)) / heights_c_std.shape[0] ) - FT = gFT.generalized_Fourier(eta + xi_2, None, GG.k) - _ = FT.get_H() - FT.p_hat = np.concatenate([GGi.gFT_cos_coeff, GGi.gFT_sin_coeff]) - plt.plot( - eta + xi_2, - FT.model(), - "-", - c="orange", - linewidth=0.8, - alpha=1, - zorder=2, + del heights_c_std + + # test if beam pairs have bad ratio + bad_ratio_flag = False + for group in mconfig["beams"]["groups"]: + Ia = Gd[group[0]] + Ib = Gd[group[1]] + ratio = Ia["x"][:].size / Ib["x"][:].size + if (ratio > 10) | (ratio < 0.1): + print("bad data ratio ", ratio, 1 / ratio) + bad_ratio_flag = True + + if (np.array(nan_fraction).mean() > 0.95) | bad_ratio_flag: + print( + "nan fraction > 95%, or bad ratio of data, pass this track, add to bad tracks" ) - - # oringial data - plt.plot(x, dd, "-", c="k", linewidth=2, alpha=0.6, zorder=11) - - F.ax.axvline(xi_1 + eta[0].data, linewidth=4, color=c1, alpha=0.5) - F.ax.axvline(xi_1 + eta[-1].data, linewidth=4, color=c1, alpha=0.5) - F.ax.axvline(xi_2 + eta[0].data, linewidth=4, color=c2, alpha=0.5) - F.ax.axvline(xi_2 + eta[-1].data, linewidth=4, color=c2, alpha=0.5) - - ylims = -np.nanstd(dd) * 2, np.nanstd(dd) * 2 - plt.text( - xi_1 + eta[0].data, - ylims[-1], - " N=" - + str(GG.sel(x=xi_1, method="nearest").N_per_stancil.data) - + " N/2M= " - + str( - GG.sel(x=xi_1, method="nearest").N_per_stancil.data / 2 / kk.size - ), + MT.json_save( + track_name, + bad_track_path, + { + "nan_fraction": np.array(nan_fraction).mean(), + "date": str(datetime.date.today()), + }, ) - plt.text( - xi_2 + eta[0].data, - ylims[-1], - " N=" - + str(GG.sel(x=xi_2, method="nearest").N_per_stancil.data) - + " N/2M= " - + str( - GG.sel(x=xi_2, method="nearest").N_per_stancil.data / 2 / kk.size - ), + print("exit.") + exit() + + # test LS with an even grid where missing values are set to 0 + print(Gd.keys()) + Gi = Gd[list(Gd.keys())[0]] # to select a test beam + dist = io.get_beam_var_hdf_store(Gd[list(Gd.keys())[0]], "x") + # make dataframe form hdf5 + # derive spectral limits + # Longest deserved period: + T_max = 40 # sec + k_0 = (2 * np.pi / T_max) ** 2 / 9.81 + x = np.array(dist).squeeze() + dx = np.round(np.median(np.diff(x)), 1) + min_datapoint = 2 * np.pi / k_0 / dx + + Lpoints = int(np.round(min_datapoint) * 10) + Lmeters = Lpoints * dx + + print("L number of gridpoint:", Lpoints) + print("L length in km:", Lmeters / 1e3) + print("approx number windows", 2 * dist.iloc[-1] / Lmeters - 1) + + T_min = 6 + lambda_min = 9.81 * T_min**2 / (2 * np.pi) + + oversample = 2 + dlambda = Lmeters * oversample + kk = np.arange(0, 1 / lambda_min, 1 / dlambda) * 2 * np.pi + kk = kk[k_0 <= kk] + + print("2 M = ", kk.size * 2) + + print("define global xlims") + dist_list = np.array([np.nan, np.nan]) + for k in all_beams: + print(k) + x = Gd[k + "/x"][:] + print(x[0], x[-1]) + dist_list = np.vstack([dist_list, [x[0], x[-1]]]) + + xlims = np.nanmin(dist_list[:, 0]) - dx, np.nanmin(dist_list[:, 1]) + + for k in all_beams: + dist_i = io.get_beam_var_hdf_store(Gd[k], "x") + x_mask = (dist_i > xlims[0]) & (dist_i < xlims[1]) + print(k, sum(x_mask["x"]) / (xlims[1] - xlims[0])) + + print("-reduced frequency resolution") + kk = kk[::2] + + print("set xlims: ", xlims) + + # Commented out for now. CP + # print( + # "Loop start: ", + # tracemalloc.get_traced_memory()[0] / 1e6, + # tracemalloc.get_traced_memory()[1] / 1e6, + # ) + + G_gFT = dict() + G_gFT_x = dict() + G_rar_fft = dict() + Pars_optm = dict() + + # sliderule version + hkey = "h_mean" + hkey_sigma = "h_sigma" + for k in all_beams: + # tracemalloc.start() + # ------------------------------- use gridded data + Gi = io.get_beam_hdf_store(Gd[k]) + x_mask = (Gi["x"] > xlims[0]) & (Gi["x"] < xlims[1]) + if sum(x_mask) / (xlims[1] - xlims[0]) < 0.005: + print("------------------- no data in beam found; skip") + + Gd_cut = Gi[x_mask] + x = Gd_cut["x"] + del Gi + # cut data: + x_mask = (x >= xlims[0]) & (x <= xlims[1]) + x = x[x_mask] + + dd = np.copy(Gd_cut[hkey]) + + dd_error = np.copy(Gd_cut[hkey_sigma]) + dd_error[np.isnan(dd_error)] = 100 + + # compute slope spectra !! + dd = np.gradient(dd) + dd, _ = spicke_remover.spicke_remover(dd, spreed=10, verbose=False) + dd_nans = (np.isnan(dd)) + (Gd_cut["N_photos"] <= 5) + + # using gappy data + dd_no_nans = dd[~dd_nans] # windowing is applied here + x_no_nans = x[~dd_nans] + dd_error_no_nans = dd_error[~dd_nans] + + print("gFT") + + with threadpool_limits(limits=N_process, user_api="blas"): + pprint(threadpool_info()) + + S = gFT.wavenumber_spectrogram_gFT( + np.array(x_no_nans), + np.array(dd_no_nans), + Lmeters, + dx, + kk, + data_error=dd_error_no_nans, + ov=None, + ) + GG, GG_x, Params = S.cal_spectrogram( + xlims=xlims, max_nfev=8000, plot_flag=False + ) + + # Commented out for now. CP + # print( + # "after ", + # k, + # tracemalloc.get_traced_memory()[0] / 1e6, + # tracemalloc.get_traced_memory()[1] / 1e6, + # ) + + plot_data_model = False + if plot_data_model: + for i in np.arange(0, 16, 2): + c1 = "blue" + c2 = "red" + + GGi = GG.isel(x=i) + + xi_1 = GG_x.x[i] + xi_2 = GG_x.x[i + 1] + + F = M.figure_axis_xy(16, 2) + eta = GG_x.eta + + y_model = GG_x.y_model[:, i] + plt.plot( + eta + xi_1, + y_model, + "-", + c=c1, + linewidth=0.8, + alpha=1, + zorder=12, + ) + y_model = GG_x.y_model[:, i + 1] + plt.plot( + eta + xi_2, + y_model, + "-", + c=c2, + linewidth=0.8, + alpha=1, + zorder=12, + ) + + FT = gFT.generalized_Fourier(eta + xi_1, None, GG.k) + _ = FT.get_H() + FT.p_hat = np.concatenate([GGi.gFT_cos_coeff, GGi.gFT_sin_coeff]) + plt.plot( + eta + xi_1, + FT.model(), + "-", + c="orange", + linewidth=0.8, + alpha=1, + zorder=2, + ) + + FT = gFT.generalized_Fourier(eta + xi_2, None, GG.k) + _ = FT.get_H() + FT.p_hat = np.concatenate([GGi.gFT_cos_coeff, GGi.gFT_sin_coeff]) + plt.plot( + eta + xi_2, + FT.model(), + "-", + c="orange", + linewidth=0.8, + alpha=1, + zorder=2, + ) + + # original data + plt.plot(x, dd, "-", c="k", linewidth=2, alpha=0.6, zorder=11) + + F.ax.axvline(xi_1 + eta[0].data, linewidth=4, color=c1, alpha=0.5) + F.ax.axvline(xi_1 + eta[-1].data, linewidth=4, color=c1, alpha=0.5) + F.ax.axvline(xi_2 + eta[0].data, linewidth=4, color=c2, alpha=0.5) + F.ax.axvline(xi_2 + eta[-1].data, linewidth=4, color=c2, alpha=0.5) + + ylims = -np.nanstd(dd) * 2, np.nanstd(dd) * 2 + plt.text( + xi_1 + eta[0].data, + ylims[-1], + " N=" + + str(GG.sel(x=xi_1, method="nearest").N_per_stancil.data) + + " N/2M= " + + str( + GG.sel(x=xi_1, method="nearest").N_per_stancil.data + / 2 + / kk.size + ), + ) + plt.text( + xi_2 + eta[0].data, + ylims[-1], + " N=" + + str(GG.sel(x=xi_2, method="nearest").N_per_stancil.data) + + " N/2M= " + + str( + GG.sel(x=xi_2, method="nearest").N_per_stancil.data + / 2 + / kk.size + ), + ) + plt.xlim(xi_1 + eta[0].data * 1.2, xi_2 + eta[-1].data * 1.2) + + plt.ylim(ylims[0], ylims[-1]) + plt.show() + + # add x-mean spectral error estimate to xarray + S.parceval(add_attrs=True, weight_data=False) + + # assign beam coordinate + GG.coords["beam"] = GG_x.coords["beam"] = str(k) + GG, GG_x = GG.expand_dims(dim="beam", axis=1), GG_x.expand_dims( + dim="beam", axis=1 + ) + # repack such that all coords are associated with beam + GG.coords["N_per_stancil"] = ( + ("x", "beam"), + np.expand_dims(GG["N_per_stancil"], 1), + ) + GG.coords["spec_adjust"] = ( + ("x", "beam"), + np.expand_dims(GG["spec_adjust"], 1), ) - plt.xlim(xi_1 + eta[0].data * 1.2, xi_2 + eta[-1].data * 1.2) - - plt.ylim(ylims[0], ylims[-1]) - plt.show() - - # add x-mean spectal error estimate to xarray - S.parceval(add_attrs=True, weight_data=False) - - # assign beam coordinate - GG.coords["beam"] = GG_x.coords["beam"] = str(k) - GG, GG_x = GG.expand_dims(dim="beam", axis=1), GG_x.expand_dims(dim="beam", axis=1) - # repack such that all coords are associated with beam - GG.coords["N_per_stancil"] = (("x", "beam"), np.expand_dims(GG["N_per_stancil"], 1)) - GG.coords["spec_adjust"] = (("x", "beam"), np.expand_dims(GG["spec_adjust"], 1)) - - # add more coodindates to the Dataset - x_coord_no_gaps = linear_gap_fill(Gd_cut, "x", "x") - y_coord_no_gaps = linear_gap_fill(Gd_cut, "x", "y") - mapped_coords = spec.sub_sample_coords( - Gd_cut["x"], x_coord_no_gaps, y_coord_no_gaps, S.stancil_iter, map_func=None - ) - - GG.coords["x_coord"] = GG_x.coords["x_coord"] = ( - ("x", "beam"), - np.expand_dims(mapped_coords[:, 1], 1), - ) - GG.coords["y_coord"] = GG_x.coords["y_coord"] = ( - ("x", "beam"), - np.expand_dims(mapped_coords[:, 2], 1), - ) - - # if data staarts with nans replace coords with nans again - if (GG.coords["N_per_stancil"] == 0).squeeze()[0].data: - nlabel = label((GG.coords["N_per_stancil"] == 0).squeeze())[0] - nan_mask = nlabel == nlabel[0] - GG.coords["x_coord"][nan_mask] = np.nan - GG.coords["y_coord"][nan_mask] = np.nan - - lons_no_gaps = linear_gap_fill(Gd_cut, "x", "lons") - lats_no_gaps = linear_gap_fill(Gd_cut, "x", "lats") - mapped_coords = spec.sub_sample_coords( - Gd_cut["x"], lons_no_gaps, lats_no_gaps, S.stancil_iter, map_func=None - ) - - GG.coords["lon"] = GG_x.coords["lon"] = ( - ("x", "beam"), - np.expand_dims(mapped_coords[:, 1], 1), - ) - GG.coords["lat"] = GG_x.coords["lat"] = ( - ("x", "beam"), - np.expand_dims(mapped_coords[:, 2], 1), - ) - - # calculate number data points - def get_stancil_nans(stancil): - x_mask = (stancil[0] < x) & (x <= stancil[-1]) - idata = Gd_cut["N_photos"][x_mask] - return stancil[1], idata.sum() - - photon_list = np.array( - list(dict(map(get_stancil_nans, copy.copy(S.stancil_iter))).values()) - ) - GG.coords["N_photons"] = (("x", "beam"), np.expand_dims(photon_list, 1)) - - # Save to dict - G_gFT[k] = GG - G_gFT_x[k] = GG_x - Pars_optm[k] = Params - - # plot - plt.subplot(2, 1, 2) - G_gFT_power = GG.gFT_PSD_data.squeeze() - plt.plot( - G_gFT_power.k, np.nanmean(G_gFT_power, 1), "gray", label="mean gFT power data " - ) - G_gFT_power = GG.gFT_PSD_model.squeeze() - plt.plot(GG.k, np.nanmean(S.G, 1), "k", label="mean gFT power model") - - # standard FFT - print("FFT") - dd[dd_nans] = 0 - - S = spec.wavenumber_spectrogram(x, dd, Lpoints) - G = S.cal_spectrogram() - S.mean_spectral_error() # add x-mean spectal error estimate to xarray - S.parceval(add_attrs=True) - - # assign beam coordinate - G.coords["beam"] = str(k) - G = G.expand_dims(dim="beam", axis=2) - G.coords["mean_El"] = (("k", "beam"), np.expand_dims(G["mean_El"], 1)) - G.coords["mean_Eu"] = (("k", "beam"), np.expand_dims(G["mean_Eu"], 1)) - G.coords["x"] = G.coords["x"] * dx - - stancil_iter = spec.create_chunk_boundaries(int(Lpoints), dd_nans.size) - - def get_stancil_nans(stancil): - idata = dd_nans[stancil[0] : stancil[-1]] - return stancil[1], idata.size - idata.sum() - - N_list = np.array(list(dict(map(get_stancil_nans, stancil_iter)).values())) - - # repack such that all coords are associated with beam - G.coords["N_per_stancil"] = (("x", "beam"), np.expand_dims(N_list, 1)) - - # save to dict and cut to the same size gFT - try: - G_rar_fft[k] = G.sel(x=slice(GG.x[0], GG.x[-1].data)) - except: - G_rar_fft[k] = G.isel(x=(GG.x[0].data < G.x.data) & (G.x.data < GG.x[-1].data)) - - # for plotting - try: - G_rar_fft_p = G.squeeze() - plt.plot( - G_rar_fft_p.k, - G_rar_fft_p[:, G_rar_fft_p["N_per_stancil"] > 10].mean("x"), - "darkblue", - label="mean FFT", - ) - plt.legend() - - except: - pass - time.sleep(3) - plt.close("all") - - -del Gd_cut -Gd.close() - -# save fitting parameters -MT.save_pandas_table(Pars_optm, save_name + "_params", save_path) - - -# repack data -def repack_attributes(DD): - attr_dim_list = list(DD.keys()) - for k in attr_dim_list: - for ka in list(DD[k].attrs.keys()): - I = DD[k] - I.coords[ka] = ("beam", np.expand_dims(I.attrs[ka], 0)) - return DD - - -beams_missing = set(all_beams) - set(G_gFT.keys()) - - -def make_dummy_beam(GG, beam): - dummy = GG.copy(deep=True) - for var in list(dummy.var()): - dummy[var] = dummy[var] * np.nan - dummy["beam"] = [beam] - return dummy - - -for beam in beams_missing: - GG = list(G_gFT.values())[0] - dummy = make_dummy_beam(GG, beam) - dummy["N_photons"] = dummy["N_photons"] * 0 - dummy["N_per_stancil"] = dummy["N_per_stancil"] * 0 - G_gFT[beam] = dummy - - GG = list(G_gFT_x.values())[0] - G_gFT_x[beam] = make_dummy_beam(GG, beam) - - GG = list(G_rar_fft.values())[0].copy(deep=True) - GG.data = GG.data * np.nan - GG["beam"] = [beam] - G_rar_fft[beam] = GG - -G_rar_fft.keys() - -G_gFT = repack_attributes(G_gFT) -G_gFT_x = repack_attributes(G_gFT_x) -G_rar_fft = repack_attributes(G_rar_fft) - -# save results -G_gFT_DS = xr.merge(G_gFT.values()) -G_gFT_DS["Z_hat_imag"] = G_gFT_DS.Z_hat.imag -G_gFT_DS["Z_hat_real"] = G_gFT_DS.Z_hat.real -G_gFT_DS = G_gFT_DS.drop("Z_hat") -G_gFT_DS.attrs["name"] = "gFT_estimates" -G_gFT_DS.to_netcdf(save_path + save_name + "_gFT_k.nc") + # add more coordinates to the Dataset + x_coord_no_gaps = linear_gap_fill(Gd_cut, "x", "x") + y_coord_no_gaps = linear_gap_fill(Gd_cut, "x", "y") + mapped_coords = spec.sub_sample_coords( + Gd_cut["x"], + x_coord_no_gaps, + y_coord_no_gaps, + S.stancil_iter, + map_func=None, + ) -G_gFT_x_DS = xr.merge(G_gFT_x.values()) -G_gFT_x_DS.attrs["name"] = "gFT_estimates_real_space" -G_gFT_x_DS.to_netcdf(save_path + save_name + "_gFT_x.nc") + GG.coords["x_coord"] = GG_x.coords["x_coord"] = ( + ("x", "beam"), + np.expand_dims(mapped_coords[:, 1], 1), + ) + GG.coords["y_coord"] = GG_x.coords["y_coord"] = ( + ("x", "beam"), + np.expand_dims(mapped_coords[:, 2], 1), + ) + # if data starts with nans replace coords with nans again + if (GG.coords["N_per_stancil"] == 0).squeeze()[0].data: + nlabel = label((GG.coords["N_per_stancil"] == 0).squeeze())[0] + nan_mask = nlabel == nlabel[0] + GG.coords["x_coord"][nan_mask] = np.nan + GG.coords["y_coord"][nan_mask] = np.nan + + lons_no_gaps = linear_gap_fill(Gd_cut, "x", "lons") + lats_no_gaps = linear_gap_fill(Gd_cut, "x", "lats") + mapped_coords = spec.sub_sample_coords( + Gd_cut["x"], + lons_no_gaps, + lats_no_gaps, + S.stancil_iter, + map_func=None, + ) -G_fft_DS = xr.merge(G_rar_fft.values()) -G_fft_DS.attrs["name"] = "FFT_power_spectra" -G_fft_DS.to_netcdf(save_path + save_name + "_FFT.nc") + GG.coords["lon"] = GG_x.coords["lon"] = ( + ("x", "beam"), + np.expand_dims(mapped_coords[:, 1], 1), + ) + GG.coords["lat"] = GG_x.coords["lat"] = ( + ("x", "beam"), + np.expand_dims(mapped_coords[:, 2], 1), + ) -print("saved and done") + # calculate number data points + def _get_stancil_nans(stancil, Gd_cut=Gd_cut): + x_mask = (stancil[0] < x) & (x <= stancil[-1]) + idata = Gd_cut["N_photos"][x_mask].sum() + return stancil[1], idata + + get_stancil_nans = partial(_get_stancil_nans, Gd_cut=Gd_cut) + photon_list = np.array( + list(dict(map(get_stancil_nans, copy.copy(S.stancil_iter))).values()) + ) # TODO: make more readable. CP + GG.coords["N_photons"] = (("x", "beam"), np.expand_dims(photon_list, 1)) + + # Save to dict + G_gFT[k] = GG + G_gFT_x[k] = GG_x + Pars_optm[k] = Params + + # plot + plt.subplot(2, 1, 2) + G_gFT_power = GG.gFT_PSD_data.squeeze() + plt.plot( + G_gFT_power.k, + np.nanmean(G_gFT_power, 1), + "gray", + label="mean gFT power data ", + ) + G_gFT_power = GG.gFT_PSD_model.squeeze() + plt.plot(GG.k, np.nanmean(S.G, 1), "k", label="mean gFT power model") + + # standard FFT + print("FFT") + dd[dd_nans] = 0 + + S = spec.wavenumber_spectrogram(x, dd, Lpoints) + G = S.cal_spectrogram() + S.mean_spectral_error() # add x-mean spectral error estimate to xarray + S.parceval(add_attrs=True) + + # assign beam coordinate + G.coords["beam"] = str(k) + G = G.expand_dims(dim="beam", axis=2) + G.coords["mean_El"] = (("k", "beam"), np.expand_dims(G["mean_El"], 1)) + G.coords["mean_Eu"] = (("k", "beam"), np.expand_dims(G["mean_Eu"], 1)) + G.coords["x"] = G.coords["x"] * dx + + stancil_iter = spec.create_chunk_boundaries(int(Lpoints), dd_nans.size) + + def get_stancil_nans(stancil): + idata = dd_nans[stancil[0] : stancil[-1]] + result = idata.size - idata.sum() + return stancil[1], result + + N_list = np.array( + list(dict(map(get_stancil_nans, stancil_iter)).values()) + ) # TODO: make more readable. CP + + # repack such that all coords are associated with beam + G.coords["N_per_stancil"] = (("x", "beam"), np.expand_dims(N_list, 1)) + + # save to dict and cut to the same size gFT + try: + G_rar_fft[k] = G.sel(x=slice(GG.x[0], GG.x[-1].data)) + except Exception: + G_rar_fft[k] = G.isel( + x=(GG.x[0].data < G.x.data) & (G.x.data < GG.x[-1].data) + ) + + # for plotting + try: + G_rar_fft_p = G.squeeze() + plt.plot( + G_rar_fft_p.k, + G_rar_fft_p[:, G_rar_fft_p["N_per_stancil"] > 10].mean("x"), + "darkblue", + label="mean FFT", + ) + plt.legend() + plt.show() + except Exception as e: + print(e, "An error occurred. Nothing to plot.") + + del Gd_cut + Gd.close() + + # save fitting parameters + MT.save_pandas_table(Pars_optm, save_name + "_params", str(save_path)) + + # repack data + def repack_attributes(DD): + attr_dim_list = list(DD.keys()) + for k in attr_dim_list: + for ka in list(DD[k].attrs.keys()): + I = DD[k] + I.coords[ka] = ("beam", np.expand_dims(I.attrs[ka], 0)) + return DD + + beams_missing = set(all_beams) - set(G_gFT.keys()) + + def make_dummy_beam(GG, beam): + dummy = GG.copy(deep=True) + for var in list(dummy.var()): + dummy[var] = dummy[var] * np.nan + dummy["beam"] = [beam] + return dummy + + for beam in beams_missing: + GG = list(G_gFT.values())[0] + dummy = make_dummy_beam(GG, beam) + dummy["N_photons"] = dummy["N_photons"] * 0 + dummy["N_per_stancil"] = dummy["N_per_stancil"] * 0 + G_gFT[beam] = dummy + + GG = list(G_gFT_x.values())[0] + G_gFT_x[beam] = make_dummy_beam(GG, beam) + + GG = list(G_rar_fft.values())[0].copy(deep=True) + GG.data = GG.data * np.nan + GG["beam"] = [beam] + G_rar_fft[beam] = GG + + G_rar_fft.keys() + + G_gFT = repack_attributes(G_gFT) + G_gFT_x = repack_attributes(G_gFT_x) + G_rar_fft = repack_attributes(G_rar_fft) + + # save results + G_gFT_DS = xr.merge(G_gFT.values()) + G_gFT_DS["Z_hat_imag"] = G_gFT_DS.Z_hat.imag + G_gFT_DS["Z_hat_real"] = G_gFT_DS.Z_hat.real + G_gFT_DS = G_gFT_DS.drop_vars("Z_hat") + G_gFT_DS.attrs["name"] = "gFT_estimates" + + savepathname = str(save_path / save_name) + G_gFT_DS.to_netcdf(savepathname + "_gFT_k.nc") + + G_gFT_x_DS = xr.merge(G_gFT_x.values()) + G_gFT_x_DS.attrs["name"] = "gFT_estimates_real_space" + G_gFT_x_DS.to_netcdf(savepathname + "_gFT_x.nc") + + G_fft_DS = xr.merge(G_rar_fft.values()) + G_fft_DS.attrs["name"] = "FFT_power_spectra" + G_fft_DS.to_netcdf(savepathname + "_FFT.nc") + + echo("saved and done") + + +make_spectra_app = makeapp(run_B02_make_spectra_gFT, name="makespectra") + +if __name__ == "__main__": + make_spectra_app() diff --git a/src/icesat2_tracks/analysis_db/B03_plot_spectra_ov.py b/src/icesat2_tracks/analysis_db/B03_plot_spectra_ov.py index bae0de4a..4dcd197f 100644 --- a/src/icesat2_tracks/analysis_db/B03_plot_spectra_ov.py +++ b/src/icesat2_tracks/analysis_db/B03_plot_spectra_ov.py @@ -1,11 +1,16 @@ +#!/usr/bin/env python3 """ -This file open a ICEsat2 track applied filters and corections and returns smoothed photon heights on a regular grid in an .nc file. +This file open a ICEsat2 track applied filters and corrections and returns smoothed photon heights on a regular grid in an .nc file. This is python 3 """ -import sys +from ast import comprehension +from pathlib import Path +import matplotlib import numpy as np import xarray as xr from matplotlib.gridspec import GridSpec +import typer + import icesat2_tracks.ICEsat2_SI_tools.iotools as io import icesat2_tracks.ICEsat2_SI_tools.generalized_FT as gFT import icesat2_tracks.local_modules.m_tools_ph3 as MT @@ -17,123 +22,16 @@ font_for_print, ) -track_name, batch_key, test_flag = io.init_from_input( - sys.argv # TODO: Handle via CLI -) # loads standard experiment -hemis, batch = batch_key.split("_") - -load_path = mconfig["paths"]["work"] + batch_key + "/B02_spectra/" -load_file = load_path + "B02_" + track_name -plot_path = ( - mconfig["paths"]["plot"] + "/" + hemis + "/" + batch_key + "/" + track_name + "/" -) # TODO: Update with pathlib -MT.mkdirs_r(plot_path) - -Gk = xr.open_dataset(load_file + "_gFT_k.nc") -Gx = xr.open_dataset(load_file + "_gFT_x.nc") - -Gfft = xr.open_dataset(load_file + "_FFT.nc") - -all_beams = mconfig["beams"]["all_beams"] -high_beams = mconfig["beams"]["high_beams"] -low_beams = mconfig["beams"]["low_beams"] -color_schemes.colormaps2(21) - -col_dict = color_schemes.rels -F = M.figure_axis_xy(9, 3, view_scale=0.5) - -plt.subplot(1, 3, 1) -plt.title(track_name, loc="left") -for k in all_beams: - I = Gk.sel(beam=k) - I2 = Gx.sel(beam=k) - plt.plot(I["lon"], I["lat"], ".", c=col_dict[k], markersize=0.7, linewidth=0.3) - plt.plot(I2["lon"], I2["lat"], "|", c=col_dict[k], markersize=0.7) - - -plt.xlabel("lon") -plt.ylabel("lat") - -plt.subplot(1, 3, 2) - -xscale = 1e3 -for k in all_beams: - I = Gk.sel(beam=k) - plt.plot( - I["x_coord"] / xscale, - I["y_coord"] / xscale, - ".", - c=col_dict[k], - linewidth=0.8, - markersize=0.8, - ) - -plt.xlabel("x_coord (km)") -plt.ylabel("y_coord (km)") - -plt.subplot(1, 3, 3) - -xscale = 1e3 -for k in all_beams: - I = Gk.sel(beam=k) - plt.plot( - I["x_coord"] / xscale, - (I["y_coord"] - I["y_coord"][0]), - ".", - c=col_dict[k], - linewidth=0.8, - markersize=0.8, - ) - -plt.xlabel("x_coord (km)") -plt.ylabel("y_coord deviation (m)") - - -F.save_light(path=plot_path, name="B03_specs_coord_check") - - -def dict_weighted_mean(Gdict, weight_key): - """ - returns the weighted meean of a dict of xarray, data_arrays - weight_key must be in the xr.DataArrays - """ - - akey = list(Gdict.keys())[0] - GSUM = Gdict[akey].copy() - GSUM.data = np.zeros(GSUM.shape) - N_per_stancil = GSUM.N_per_stancil * 0 - N_photons = np.zeros(GSUM.N_per_stancil.size) - - counter = 0 - for I in Gdict.items(): - I = I.squeeze() - if len(I.x) != 0: - GSUM += I.where(~np.isnan(I), 0) * I[weight_key] - N_per_stancil += I[weight_key] - if "N_photons" in GSUM.coords: - N_photons += I["N_photons"] - counter += 1 - - GSUM = GSUM / N_per_stancil - - if "N_photons" in GSUM.coords: - GSUM.coords["N_photons"] = (("x", "beam"), np.expand_dims(N_photons, 1)) - - GSUM["beam"] = ["weighted_mean"] - GSUM.name = "power_spec" - - return GSUM - - -G_gFT_wmean = ( - Gk["gFT_PSD_data"].where(~np.isnan(Gk["gFT_PSD_data"]), 0) * Gk["N_per_stancil"] -).sum("beam") / Gk["N_per_stancil"].sum("beam") -G_gFT_wmean["N_per_stancil"] = Gk["N_per_stancil"].sum("beam") - -G_fft_wmean = (Gfft.where(~np.isnan(Gfft), 0) * Gfft["N_per_stancil"]).sum( - "beam" -) / Gfft["N_per_stancil"].sum("beam") -G_fft_wmean["N_per_stancil"] = Gfft["N_per_stancil"].sum("beam") +from icesat2_tracks.clitools import ( + echo, + validate_batch_key, + validate_output_dir, + suppress_stdout, + update_paths_mconfig, + report_input_parameters, + validate_track_name_steps_gt_1, + makeapp, +) def plot_wavenumber_spectrogram(ax, Gi, clev, title=None, plot_photon_density=True): @@ -167,188 +65,6 @@ def plot_wavenumber_spectrogram(ax, Gi, clev, title=None, plot_photon_density=Tr plt.title(title, loc="left") -Gmean = G_gFT_wmean.rolling(k=5, center=True).mean() - -try: - k_max = Gmean.k[Gmean.isel(x=slice(0, 5)).mean("x").argmax().data].data -except: - k_max = Gmean.k[Gmean.isel(x=slice(0, 20)).mean("x").argmax().data].data - -k_max_range = (k_max * 0.75, k_max * 1, k_max * 1.25) -font_for_print() -F = M.figure_axis_xy(6.5, 5.6, container=True, view_scale=1) -Lmeters = Gk.L.data[0] - -plt.suptitle("gFT Slope Spectrograms\n" + track_name, y=0.98) -gs = GridSpec(3, 3, wspace=0.2, hspace=0.5) - -Gplot = ( - G_gFT_wmean.squeeze() - .rolling(k=10, min_periods=1, center=True) - .median() - .rolling(x=3, min_periods=1, center=True) - .median() -) -dd = 10 * np.log10(Gplot) -dd = dd.where(~np.isinf(dd), np.nan) -clev_log = M.clevels([dd.quantile(0.01).data, dd.quantile(0.98).data * 1.2], 31) * 1 - -xlims = Gmean.x[0] / 1e3, Gmean.x[-1] / 1e3 - -k = high_beams[0] -for pos, k, pflag in zip( - [gs[0, 0], gs[0, 1], gs[0, 2]], high_beams, [True, False, False] -): - ax0 = F.fig.add_subplot(pos) - Gplot = Gk.sel(beam=k).gFT_PSD_data.squeeze() - dd2 = 10 * np.log10(Gplot) - dd2 = dd2.where(~np.isinf(dd2), np.nan) - plot_wavenumber_spectrogram( - ax0, dd2, clev_log, title=k + " unsmoothed", plot_photon_density=True - ) - plt.xlim(xlims) - if pflag: - plt.ylabel("Wave length\n(meters)") - plt.legend() - -for pos, k, pflag in zip( - [gs[1, 0], gs[1, 1], gs[1, 2]], low_beams, [True, False, False] -): - ax0 = F.fig.add_subplot(pos) - Gplot = Gk.sel(beam=k).gFT_PSD_data.squeeze() - dd2 = 10 * np.log10(Gplot) - dd2 = dd2.where(~np.isinf(dd2), np.nan) - plot_wavenumber_spectrogram( - ax0, dd2, clev_log, title=k + " unsmoothed", plot_photon_density=True - ) - plt.xlim(xlims) - if pflag: - plt.ylabel("Wave length\n(meters)") - plt.legend() - -ax0 = F.fig.add_subplot(gs[2, 0]) - -plot_wavenumber_spectrogram( - ax0, - dd, - clev_log, - title="smoothed weighted mean \n10 $\log_{10}( (m/m)^2 m )$", - plot_photon_density=True, -) -plt.xlim(xlims) - -line_styles = ["--", "-", "--"] -for k_max, style in zip(k_max_range, line_styles): - ax0.axhline(2 * np.pi / k_max, color="red", linestyle=style, linewidth=0.5) - -if pflag: - plt.ylabel("Wave length\n(meters)") - plt.legend() - -pos = gs[2, 1] -ax0 = F.fig.add_subplot(pos) -plt.title("Photons density ($m^{-1}$)", loc="left") - -for k in all_beams: - I = Gk.sel(beam=k)["gFT_PSD_data"] - plt.plot(Gplot.x / 1e3, I.N_photons / I.L.data, label=k, linewidth=0.8) -plt.plot( - Gplot.x / 1e3, - G_gFT_wmean.N_per_stancil / 3 / I.L.data, - c="black", - label="ave Photons", - linewidth=0.8, -) -plt.xlim(xlims) -plt.xlabel("Distance from the Ice Edge (km)") - -pos = gs[2, 2] - -ax0 = F.fig.add_subplot(pos) -ax0.set_yscale("log") - -plt.title("Peak Spectal Power", loc="left") - -x0 = Gk.x[0].data -for k in all_beams: - I = Gk.sel(beam=k)["gFT_PSD_data"] - plt.scatter( - I.x.data / 1e3, - I.sel(k=slice(k_max_range[0], k_max_range[2])).integrate("k").data, - s=0.5, - marker=".", - color="red", - alpha=0.3, - ) - I = Gfft.sel(beam=k) - plt.scatter( - (x0 + I.x.data) / 1e3, - I.power_spec.sel(k=slice(k_max_range[0], k_max_range[2])).integrate("k"), - s=0.5, - marker=".", - c="blue", - alpha=0.3, - ) - - -Gplot = G_fft_wmean.squeeze() -Gplot = Gplot.power_spec[:, Gplot.N_per_stancil >= Gplot.N_per_stancil.max().data * 0.9] -plt.plot( - (x0 + Gplot.x) / 1e3, - Gplot.sel(k=slice(k_max_range[0], k_max_range[2])).integrate("k"), - ".", - markersize=1.5, - c="blue", - label="FFT", -) - -Gplot = G_gFT_wmean.squeeze() -plt.plot( - Gplot.x / 1e3, - Gplot.sel(k=slice(k_max_range[0], k_max_range[2])).integrate("k"), - ".", - markersize=1.5, - c="red", - label="gFT", -) - -plt.ylabel("1e-3 $(m)^2~m$") -plt.legend() - -F.save_light(path=plot_path, name="B03_specs_L" + str(Lmeters)) - -Gk.sel(beam=k).gFT_PSD_data.plot() - - -def plot_model_eta(D, ax, offset=0, xscale=1e3, **kargs): - eta = D.eta + D.x - y_data = D.y_model + offset - plt.plot(eta / xscale, y_data, **kargs) - - ax.axvline(eta[0].data / xscale, linewidth=2, color=kargs["color"], alpha=0.5) - ax.axvline(eta[-1].data / xscale, linewidth=2, color=kargs["color"], alpha=0.5) - - -def add_info(D, Dk, ylims): - eta = D.eta + D.x - N_per_stancil, ksize = Dk.N_per_stancil.data, Dk.k.size - plt.text( - eta[0].data, - ylims[-1], - " N=" - + numtostr(N_per_stancil) - + " N/2M= " - + fltostr(N_per_stancil / 2 / ksize, 1), - ) - - -def plot_data_eta(D, offset=0, xscale=1e3, **kargs): - eta_1 = D.eta + D.x - y_data = D.y_model + offset - plt.plot(eta_1 / xscale, y_data, **kargs) - return eta_1 - - def plot_data_eta(D, offset=0, **kargs): eta_1 = D.eta # + D.x y_data = D.y_model + offset @@ -365,204 +81,512 @@ def plot_model_eta(D, ax, offset=0, **kargs): ax.axvline(eta[-1].data, linewidth=0.1, color=kargs["color"], alpha=0.5) -if "y_data" in Gx.sel(beam="gt3r").keys(): - print("ydata is ", ("y_data" in Gx.sel(beam="gt3r").keys())) -else: - print("ydata is ", ("y_data" in Gx.sel(beam="gt3r").keys())) - MT.json_save("B03_fail", plot_path, {"reason": "no y_data"}) - print("failed, exit") - exit() +matplotlib.use("Agg") # prevent plot windows from opening -fltostr, numtostr = MT.float_to_str, MT.num_to_str -font_for_print() - -MT.mkdirs_r(plot_path + "B03_spectra/") - -x_pos_sel = np.arange(Gk.x.size)[~np.isnan(Gk.mean("beam").mean("k").gFT_PSD_data.data)] -x_pos_max = ( - Gk.mean("beam") - .mean("k") - .gFT_PSD_data[~np.isnan(Gk.mean("beam").mean("k").gFT_PSD_data)] - .argmax() - .data -) -xpp = x_pos_sel[[int(i) for i in np.round(np.linspace(0, x_pos_sel.size - 1, 4))]] -xpp = np.insert(xpp, 0, x_pos_max) - -for i in xpp: - F = M.figure_axis_xy(6, 8, container=True, view_scale=0.8) +def run_B03_plot_spectra_ov( + track_name: str = typer.Option(..., callback=validate_track_name_steps_gt_1), + batch_key: str = typer.Option(..., callback=validate_batch_key), + ID_flag: bool = True, + output_dir: str = typer.Option(None, callback=validate_output_dir), + verbose: bool = False, +): + """ + TODO: add docstring + """ - plt.suptitle( - "gFT Model and Spectrograms | x=" + str(Gk.x[i].data) + " \n" + track_name, - y=0.95, + track_name, batch_key, _ = io.init_from_input( + [ + None, + track_name, + batch_key, + ID_flag, + ] # init_from_input expects sys.argv with 4 elements ) - gs = GridSpec(5, 6, wspace=0.2, hspace=0.7) - - ax0 = F.fig.add_subplot(gs[0:2, :]) - col_d = color_schemes.__dict__["rels"] - - neven = True - offs = 0 - for k in all_beams: - Gx_1 = Gx.isel(x=i).sel(beam=k) - Gk_1 = Gk.isel(x=i).sel(beam=k) - plot_model_eta( - Gx_1, - ax0, - offset=offs, - linestyle="-", - color=col_d[k], - linewidth=0.4, - alpha=1, - zorder=12, + kargs = { + "track_name": track_name, + "batch_key": batch_key, + "ID_flag": ID_flag, + "output_dir": output_dir, + } + report_input_parameters(**kargs) + + with suppress_stdout(verbose): + hemis, _ = batch_key.split("_") + + workdir, plotsdir = update_paths_mconfig(output_dir, mconfig) + + load_path = Path(workdir, batch_key, "B02_spectra") + load_file = str(load_path / ("B02_" + track_name)) + plot_path = Path(plotsdir, hemis, batch_key, track_name) + plot_path.mkdir(parents=True, exist_ok=True) + + # TODO: use list comprehension to load all the files + Gk = xr.open_dataset(load_file + "_gFT_k.nc") + Gx = xr.open_dataset(load_file + "_gFT_x.nc") + Gfft = xr.open_dataset(load_file + "_FFT.nc") + + all_beams = mconfig["beams"]["all_beams"] + high_beams = mconfig["beams"]["high_beams"] + low_beams = mconfig["beams"]["low_beams"] + color_schemes.colormaps2(21) + + col_dict = color_schemes.rels + F = M.figure_axis_xy(9, 3, view_scale=0.5) + + plt.subplot(1, 3, 1) + plt.title(track_name, loc="left") + for k in all_beams: + I = Gk.sel(beam=k) + I2 = Gx.sel(beam=k) + plt.plot( + I["lon"], I["lat"], ".", c=col_dict[k], markersize=0.7, linewidth=0.3 + ) + plt.plot(I2["lon"], I2["lat"], "|", c=col_dict[k], markersize=0.7) + + plt.xlabel("lon") + plt.ylabel("lat") + + plt.subplot(1, 3, 2) + + xscale = 1e3 + for k in all_beams: + I = Gk.sel(beam=k) + plt.plot( + I["x_coord"] / xscale, + I["y_coord"] / xscale, + ".", + c=col_dict[k], + linewidth=0.8, + markersize=0.8, + ) + + plt.xlabel("x_coord (km)") + plt.ylabel("y_coord (km)") + + plt.subplot(1, 3, 3) + + xscale = 1e3 + for k in all_beams: + I = Gk.sel(beam=k) + plt.plot( + I["x_coord"] / xscale, + (I["y_coord"] - I["y_coord"][0]), + ".", + c=col_dict[k], + linewidth=0.8, + markersize=0.8, + ) + + plt.xlabel("x_coord (km)") + plt.ylabel("y_coord deviation (m)") + + F.save_light(path=plot_path, name="B03_specs_coord_check") + + # TODO: refactor to make more readable. CP + G_gFT_wmean = ( + Gk["gFT_PSD_data"].where(~np.isnan(Gk["gFT_PSD_data"]), 0) + * Gk["N_per_stancil"] + ).sum("beam") / Gk["N_per_stancil"].sum("beam") + G_gFT_wmean["N_per_stancil"] = Gk["N_per_stancil"].sum("beam") + + G_fft_wmean = (Gfft.where(~np.isnan(Gfft), 0) * Gfft["N_per_stancil"]).sum( + "beam" + ) / Gfft["N_per_stancil"].sum("beam") + G_fft_wmean["N_per_stancil"] = Gfft["N_per_stancil"].sum("beam") + Gmean = G_gFT_wmean.rolling(k=5, center=True).mean() + + # TODO: make function to compute k_max. CP + try: + k_max = Gmean.k[Gmean.isel(x=slice(0, 5)).mean("x").argmax().data].data + except Exception: + k_max = Gmean.k[Gmean.isel(x=slice(0, 20)).mean("x").argmax().data].data + + k_max_range = (k_max * 0.75, k_max, k_max * 1.25) + font_for_print() + F = M.figure_axis_xy(6.5, 5.6, container=True, view_scale=1) + Lmeters = Gk.L.data[0] + + plt.suptitle("gFT Slope Spectrograms\n" + track_name, y=0.98) + gs = GridSpec(3, 3, wspace=0.2, hspace=0.5) + + Gplot = ( + G_gFT_wmean.squeeze() + .rolling(k=10, min_periods=1, center=True) + .median() + .rolling(x=3, min_periods=1, center=True) + .median() ) - ylims = -np.nanstd(Gx_1.y_data) * 3, np.nanstd(Gx_1.y_data) * 3 - - # oringial data - eta_1 = plot_data_eta( - Gx_1, offset=offs, linestyle="-", c="k", linewidth=1, alpha=0.5, zorder=11 + dd = 10 * np.log10(Gplot) + dd = dd.where(~np.isinf(dd), np.nan) + clev_log = ( + M.clevels([dd.quantile(0.01).data, dd.quantile(0.98).data * 1.2], 31) * 1 ) - # reconstruct in gaps - FT = gFT.generalized_Fourier(Gx_1.eta + Gx_1.x, None, Gk_1.k) - _ = FT.get_H() - FT.p_hat = np.concatenate([Gk_1.gFT_cos_coeff, Gk_1.gFT_sin_coeff]) - plt.plot( - Gx_1.eta, - FT.model() + offs, - "-", - c="orange", - linewidth=0.3, - alpha=1, - zorder=2, + xlims = Gmean.x[0] / 1e3, Gmean.x[-1] / 1e3 + + k = high_beams[0] + for pos, k, pflag in zip( + [gs[0, 0], gs[0, 1], gs[0, 2]], high_beams, [True, False, False] + ): + ax0 = F.fig.add_subplot(pos) + Gplot = Gk.sel(beam=k).gFT_PSD_data.squeeze() + dd2 = 10 * np.log10(Gplot) + dd2 = dd2.where(~np.isinf(dd2), np.nan) + plot_wavenumber_spectrogram( + ax0, dd2, clev_log, title=k + " unsmoothed", plot_photon_density=True + ) + plt.xlim(xlims) + if pflag: + plt.ylabel("Wave length\n(meters)") + plt.legend() + + for pos, k, pflag in zip( + [gs[1, 0], gs[1, 1], gs[1, 2]], low_beams, [True, False, False] + ): + ax0 = F.fig.add_subplot(pos) + Gplot = Gk.sel(beam=k).gFT_PSD_data.squeeze() + dd2 = 10 * np.log10(Gplot) + dd2 = dd2.where(~np.isinf(dd2), np.nan) + plot_wavenumber_spectrogram( + ax0, dd2, clev_log, title=k + " unsmoothed", plot_photon_density=True + ) + plt.xlim(xlims) + if pflag: + plt.ylabel("Wave length\n(meters)") + plt.legend() + + ax0 = F.fig.add_subplot(gs[2, 0]) + + plot_wavenumber_spectrogram( + ax0, + dd, + clev_log, + title="smoothed weighted mean \n10 $\log_{10}( (m/m)^2 m )$", + plot_photon_density=True, ) + plt.xlim(xlims) - if neven: - neven = False - offs += 0.3 - else: - neven = True - offs += 0.6 - - dx = eta_1.diff("eta").mean().data + line_styles = ["--", "-", "--"] + for k_max, style in zip(k_max_range, line_styles): + ax0.axhline(2 * np.pi / k_max, color="red", linestyle=style, linewidth=0.5) - eta_ticks = np.linspace(Gx_1.eta.data[0], Gx_1.eta.data[-1], 11) + if pflag: + plt.ylabel("Wave length\n(meters)") + plt.legend() - ax0.set_xticks(eta_ticks) - ax0.set_xticklabels(eta_ticks / 1e3) - plt.xlim(eta_1[0].data - 40 * dx, eta_1[-1].data + 40 * dx) - plt.title("Model reconst.", loc="left") + pos = gs[2, 1] + ax0 = F.fig.add_subplot(pos) + plt.title("Photons density ($m^{-1}$)", loc="left") - plt.ylabel("relative slope (m/m)") - plt.xlabel( - "segment distance $\eta$ (km) @ x=" + fltostr(Gx_1.x.data / 1e3, 2) + "km" - ) + for k in all_beams: + I = Gk.sel(beam=k)["gFT_PSD_data"] + plt.plot(Gplot.x / 1e3, I.N_photons / I.L.data, label=k, linewidth=0.8) + plt.plot( + Gplot.x / 1e3, + G_gFT_wmean.N_per_stancil / 3 / I.L.data, + c="black", + label="ave Photons", + linewidth=0.8, + ) + plt.xlim(xlims) + plt.xlabel("Distance from the Ice Edge (km)") + + pos = gs[2, 2] + + ax0 = F.fig.add_subplot(pos) + ax0.set_yscale("log") + + plt.title("Peak Spectral Power", loc="left") + + x0 = Gk.x[0].data + for k in all_beams: + I = Gk.sel(beam=k)["gFT_PSD_data"] + plt.scatter( + I.x.data / 1e3, + I.sel(k=slice(k_max_range[0], k_max_range[2])).integrate("k").data, + s=0.5, + marker=".", + color="red", + alpha=0.3, + ) + I = Gfft.sel(beam=k) + plt.scatter( + (x0 + I.x.data) / 1e3, + I.power_spec.sel(k=slice(k_max_range[0], k_max_range[2])).integrate( + "k" + ), + s=0.5, + marker=".", + c="blue", + alpha=0.3, + ) + + Gplot = G_fft_wmean.squeeze() + Gplot = Gplot.power_spec[ + :, Gplot.N_per_stancil >= Gplot.N_per_stancil.max().data * 0.9 + ] + plt.plot( + (x0 + Gplot.x) / 1e3, + Gplot.sel(k=slice(k_max_range[0], k_max_range[2])).integrate("k"), + ".", + markersize=1.5, + c="blue", + label="FFT", + ) - # spectra - # define threshold - k_thresh = 0.085 - ax1_list = list() - dd_max = list() - for pos, kgroup, lflag in zip( - [gs[2, 0:2], gs[2, 2:4], gs[2, 4:]], - [["gt1l", "gt1r"], ["gt2l", "gt2r"], ["gt3l", "gt3r"]], - [True, False, False], - ): - ax11 = F.fig.add_subplot(pos) - ax11.tick_params(labelleft=lflag) - ax1_list.append(ax11) - for k in kgroup: - Gx_1 = Gx.isel(x=i).sel(beam=k) - Gk_1 = Gk.isel(x=i).sel(beam=k) - - klim = Gk_1.k[0], Gk_1.k[-1] - - if "l" in k: - dd = Gk_1.gFT_PSD_data - plt.plot(Gk_1.k, dd, color="gray", linewidth=0.5, alpha=0.5) - - dd = Gk_1.gFT_PSD_data.rolling(k=10, min_periods=1, center=True).mean() - plt.plot(Gk_1.k, dd, color=col_d[k], linewidth=0.8) - # handle the 'All-NaN slice encountered' warning - if np.all(np.isnan(dd.data)): - dd_max.append(np.nan) - else: - dd_max.append(np.nanmax(dd.data)) - - plt.xlim(klim) - if lflag: - plt.ylabel("$(m/m)^2/k$") - plt.title("Energy Spectra", loc="left") - - plt.xlabel("wavenumber k (2$\pi$ m$^{-1}$)") - - ax11.axvline(k_thresh, linewidth=1, color="gray", alpha=1) - ax11.axvspan(k_thresh, klim[-1], color="gray", alpha=0.5, zorder=12) - - if not np.all(np.isnan(dd_max)): - max_vale = np.nanmax(dd_max) - for ax in ax1_list: - ax.set_ylim(0, max_vale * 1.1) - - ax0 = F.fig.add_subplot(gs[-2:, :]) - - neven = True - offs = 0 - for k in all_beams: - Gx_1 = Gx.isel(x=i).sel(beam=k) - Gk_1 = Gk.isel(x=i).sel(beam=k) - - ylims = -np.nanstd(Gx_1.y_data) * 3, np.nanstd(Gx_1.y_data) * 3 - - # oringial data - eta_1 = plot_data_eta( - Gx_1, offset=offs, linestyle="-", c="k", linewidth=1.5, alpha=0.5, zorder=11 + Gplot = G_gFT_wmean.squeeze() + plt.plot( + Gplot.x / 1e3, + Gplot.sel(k=slice(k_max_range[0], k_max_range[2])).integrate("k"), + ".", + markersize=1.5, + c="red", + label="gFT", ) - # reconstruct in gaps - FT = gFT.generalized_Fourier(Gx_1.eta + Gx_1.x, None, Gk_1.k) - _ = FT.get_H() - FT.p_hat = np.concatenate([Gk_1.gFT_cos_coeff, Gk_1.gFT_sin_coeff]) + plt.ylabel("1e-3 $(m)^2~m$") + plt.legend() - p_hat_k = np.concatenate([Gk_1.k, Gk_1.k]) - k_mask = p_hat_k < k_thresh - FT.p_hat[~k_mask] = 0 + F.save_light(path=plot_path, name="B03_specs_L" + str(Lmeters)) - plt.plot( - Gx_1.eta, - FT.model() + offs, - "-", - c=col_d[k], - linewidth=0.8, - alpha=1, - zorder=12, - ) + Gk.sel(beam=k).gFT_PSD_data.plot() - if neven: - neven = False - offs += 0.3 + if "y_data" in Gx.sel(beam="gt3r").keys(): + print("ydata is ", ("y_data" in Gx.sel(beam="gt3r").keys())) else: - neven = True - offs += 0.6 + print("ydata is ", ("y_data" in Gx.sel(beam="gt3r").keys())) + MT.json_save("B03_fail", plot_path, {"reason": "no y_data"}) + echo("failed, exit", "red") + exit() + + fltostr, _ = MT.float_to_str, MT.num_to_str + + font_for_print() + + (plot_path / "B03_spectra").mkdir(parents=True, exist_ok=True) + + x_pos_sel = np.arange(Gk.x.size)[ + ~np.isnan(Gk.mean("beam").mean("k").gFT_PSD_data.data) + ] + x_pos_max = ( + Gk.mean("beam") + .mean("k") + .gFT_PSD_data[~np.isnan(Gk.mean("beam").mean("k").gFT_PSD_data)] + .argmax() + .data + ) + xpp = x_pos_sel[ + [int(i) for i in np.round(np.linspace(0, x_pos_sel.size - 1, 4))] + ] + xpp = np.insert(xpp, 0, x_pos_max) + + for i in xpp: + F = M.figure_axis_xy(6, 8, container=True, view_scale=0.8) + + plt.suptitle( + "gFT Model and Spectrograms | x=" + + str(Gk.x[i].data) + + " \n" + + track_name, + y=0.95, + ) + gs = GridSpec(5, 6, wspace=0.2, hspace=0.7) + + ax0 = F.fig.add_subplot(gs[0:2, :]) + col_d = color_schemes.__dict__["rels"] - dx = eta_1.diff("eta").mean().data + neven = True + offs = 0 + for k in all_beams: + Gx_1 = Gx.isel(x=i).sel(beam=k) + Gk_1 = Gk.isel(x=i).sel(beam=k) + + plot_model_eta( + Gx_1, + ax0, + offset=offs, + linestyle="-", + color=col_d[k], + linewidth=0.4, + alpha=1, + zorder=12, + ) + + # original data + eta_1 = plot_data_eta( + Gx_1, + offset=offs, + linestyle="-", + c="k", + linewidth=1, + alpha=0.5, + zorder=11, + ) + + # reconstruct in gaps + FT = gFT.generalized_Fourier(Gx_1.eta + Gx_1.x, None, Gk_1.k) + _ = FT.get_H() + FT.p_hat = np.concatenate([Gk_1.gFT_cos_coeff, Gk_1.gFT_sin_coeff]) + plt.plot( + Gx_1.eta, + FT.model() + offs, + "-", + c="orange", + linewidth=0.3, + alpha=1, + zorder=2, + ) + + if neven: + neven = False + offs += 0.3 + else: + neven = True + offs += 0.6 + + dx = eta_1.diff("eta").mean().data + + eta_ticks = np.linspace(Gx_1.eta.data[0], Gx_1.eta.data[-1], 11) + + ax0.set_xticks(eta_ticks) + ax0.set_xticklabels(eta_ticks / 1e3) + plt.xlim(eta_1[0].data - 40 * dx, eta_1[-1].data + 40 * dx) + plt.title("Model reconst.", loc="left") + + plt.ylabel("relative slope (m/m)") + # TODO: compute xlabel as fstring. CP + plt.xlabel( + "segment distance $\eta$ (km) @ x=" + + fltostr(Gx_1.x.data / 1e3, 2) + + "km" + ) + + # spectra + # define threshold + k_thresh = 0.085 + ax1_list = list() + dd_max = list() + for pos, kgroup, lflag in zip( + [gs[2, 0:2], gs[2, 2:4], gs[2, 4:]], + [["gt1l", "gt1r"], ["gt2l", "gt2r"], ["gt3l", "gt3r"]], + [True, False, False], + ): + ax11 = F.fig.add_subplot(pos) + ax11.tick_params(labelleft=lflag) + ax1_list.append(ax11) + for k in kgroup: + Gx_1 = Gx.isel(x=i).sel(beam=k) + Gk_1 = Gk.isel(x=i).sel(beam=k) + + klim = Gk_1.k[0], Gk_1.k[-1] + + if "l" in k: + dd = Gk_1.gFT_PSD_data + plt.plot(Gk_1.k, dd, color="gray", linewidth=0.5, alpha=0.5) + + dd = Gk_1.gFT_PSD_data.rolling( + k=10, min_periods=1, center=True + ).mean() + plt.plot(Gk_1.k, dd, color=col_d[k], linewidth=0.8) + # handle the 'All-NaN slice encountered' warning + if np.all(np.isnan(dd.data)): + dd_max.append(np.nan) + else: + dd_max.append(np.nanmax(dd.data)) + + plt.xlim(klim) + if lflag: + plt.ylabel("$(m/m)^2/k$") + plt.title("Energy Spectra", loc="left") + + plt.xlabel("wavenumber k (2$\pi$ m$^{-1}$)") + + ax11.axvline(k_thresh, linewidth=1, color="gray", alpha=1) + ax11.axvspan(k_thresh, klim[-1], color="gray", alpha=0.5, zorder=12) + + if not np.all(np.isnan(dd_max)): + max_vale = np.nanmax(dd_max) + for ax in ax1_list: + ax.set_ylim(0, max_vale * 1.1) + + ax0 = F.fig.add_subplot(gs[-2:, :]) - eta_ticks = np.linspace(Gx_1.eta.data[0], Gx_1.eta.data[-1], 11) + neven = True + offs = 0 + for k in all_beams: + Gx_1 = Gx.isel(x=i).sel(beam=k) + Gk_1 = Gk.isel(x=i).sel(beam=k) + + # original data + eta_1 = plot_data_eta( + Gx_1, + offset=offs, + linestyle="-", + c="k", + linewidth=1.5, + alpha=0.5, + zorder=11, + ) + + # reconstruct in gaps + FT = gFT.generalized_Fourier(Gx_1.eta + Gx_1.x, None, Gk_1.k) + _ = FT.get_H() + FT.p_hat = np.concatenate([Gk_1.gFT_cos_coeff, Gk_1.gFT_sin_coeff]) + + p_hat_k = np.concatenate([Gk_1.k, Gk_1.k]) + k_mask = p_hat_k < k_thresh + FT.p_hat[~k_mask] = 0 + + plt.plot( + Gx_1.eta, + FT.model() + offs, + "-", + c=col_d[k], + linewidth=0.8, + alpha=1, + zorder=12, + ) + + if neven: + neven = False + offs += 0.3 + else: + neven = True + offs += 0.6 + + dx = eta_1.diff("eta").mean().data + + eta_ticks = np.linspace(Gx_1.eta.data[0], Gx_1.eta.data[-1], 11) + + ax0.set_xticks(eta_ticks) + ax0.set_xticklabels(eta_ticks / 1e3) + plt.xlim(eta_1[1000].data - 40 * dx, eta_1[-1000].data + 40 * dx) + plt.title("Low-Wavenumber Model reconst.", loc="left") + + plt.ylabel("relative slope (m/m)") + # TODO: compute xlabel as fstring. CP + plt.xlabel( + "segment distance $\eta$ (km) @ x=" + + fltostr(Gx_1.x.data / 1e3, 2) + + "km" + ) + + F.save_pup( + path=str(plot_path / "B03_spectra"), name=f"B03_freq_reconst_x{i}" + ) + + MT.json_save( + "B03_success", + plot_path, + {"time": "time.asctime( time.localtime(time.time()) )"}, + ) - ax0.set_xticks(eta_ticks) - ax0.set_xticklabels(eta_ticks / 1e3) - plt.xlim(eta_1[1000].data - 40 * dx, eta_1[-1000].data + 40 * dx) - plt.title("Low-Wavenumber Model reconst.", loc="left") + echo("success", "green") - plt.ylabel("relative slope (m/m)") - plt.xlabel( - "segment distance $\eta$ (km) @ x=" + fltostr(Gx_1.x.data / 1e3, 2) + "km" - ) - F.save_pup(path=plot_path + "B03_spectra/", name="B03_freq_reconst_x" + str(i)) +plot_spectra = makeapp(run_B03_plot_spectra_ov, name="plotspectra") -MT.json_save( - "B03_success", plot_path, {"time": "time.asctime( time.localtime(time.time()) )"} -) +if __name__ == "__main__": + plot_spectra() diff --git a/src/icesat2_tracks/analysis_db/B04_angle.py b/src/icesat2_tracks/analysis_db/B04_angle.py index 0b1e269b..68f02b31 100644 --- a/src/icesat2_tracks/analysis_db/B04_angle.py +++ b/src/icesat2_tracks/analysis_db/B04_angle.py @@ -1,13 +1,14 @@ -import os, sys - +#!/usr/bin/env python """ -This file open a ICEsat2 track applied filters and corections and returns smoothed photon heights on a regular grid in an .nc file. +This file open a ICEsat2 track applied filters and corrections and returns smoothed photon heights on a regular grid in an .nc file. This is python 3 """ + +import itertools + from icesat2_tracks.config.IceSAT2_startup import ( mconfig, color_schemes, - plt, font_for_print, font_for_pres, ) @@ -17,11 +18,12 @@ import icesat2_tracks.ICEsat2_SI_tools.iotools as io import xarray as xr import numpy as np +from scipy.constants import g from matplotlib.gridspec import GridSpec - -from numba import jit +import matplotlib.pyplot as plt +from numba import jit # maybe for later optimizations? # noqa: F401 from icesat2_tracks.ICEsat2_SI_tools import angle_optimizer @@ -33,872 +35,986 @@ import time -from contextlib import contextmanager - -color_schemes.colormaps2(21) - - -col_dict = color_schemes.rels - -track_name, batch_key, test_flag = io.init_from_input( - sys.argv -) # loads standard experiment - -hemis, batch = batch_key.split("_") - -ATlevel = "ATL03" +from typer import Option -save_path = mconfig["paths"]["work"] + batch_key + "/B04_angle/" -save_name = "B04_" + track_name - -plot_path = ( - mconfig["paths"]["plot"] + "/" + hemis + "/" + batch_key + "/" + track_name + "/" -) -MT.mkdirs_r(plot_path) -MT.mkdirs_r(save_path) -bad_track_path = mconfig["paths"]["work"] + "bad_tracks/" + batch_key + "/" - -all_beams = mconfig["beams"]["all_beams"] -high_beams = mconfig["beams"]["high_beams"] -low_beams = mconfig["beams"]["low_beams"] -beam_groups = mconfig["beams"]["groups"] - -load_path = mconfig["paths"]["work"] + batch_key + "/B01_regrid/" -G_binned_store = h5py.File(load_path + "/" + track_name + "_B01_binned.h5", "r") -G_binned = dict() -for b in all_beams: - G_binned[b] = io.get_beam_hdf_store(G_binned_store[b]) -G_binned_store.close() - -load_path = mconfig["paths"]["work"] + batch_key + "/B02_spectra/" -Gx = xr.load_dataset(load_path + "/B02_" + track_name + "_gFT_x.nc") # -Gk = xr.load_dataset(load_path + "/B02_" + track_name + "_gFT_k.nc") # - - -# load prior information -load_path = mconfig["paths"]["work"] + batch_key + "/A02_prior/" - -try: - Prior = MT.load_pandas_table_dict("/A02_" + track_name, load_path)[ - "priors_hindcast" - ] -except: - print("Prior not found. exit") - MT.json_save( - "B04_fail", - plot_path, - { - "time": time.asctime(time.localtime(time.time())), - "reason": "Prior not found", - }, - ) - exit() - -if np.isnan(Prior["mean"]["dir"]): - print("Prior failed, entries are nan. exit.") - MT.json_save( - "B04_fail", - plot_path, - { - "time": time.asctime(time.localtime(time.time())), - "reason": "Prior not found", - }, - ) - exit() - -#### Define Prior -Pperiod = Prior.loc[["ptp0", "ptp1", "ptp2", "ptp3", "ptp4", "ptp5"]]["mean"] -Pdir = Prior.loc[["pdp0", "pdp1", "pdp2", "pdp3", "pdp4", "pdp5"]]["mean"].astype( - "float" +from icesat2_tracks.clitools import ( + validate_batch_key, + validate_output_dir, + suppress_stdout, + update_paths_mconfig, + report_input_parameters, + validate_track_name_steps_gt_1, + makeapp, ) -Pspread = Prior.loc[["pspr0", "pspr1", "pspr2", "pspr3", "pspr4", "pspr5"]]["mean"] -Pperiod = Pperiod[~np.isnan(list(Pspread))] -Pdir = Pdir[~np.isnan(list(Pspread))] -Pspread = Pspread[~np.isnan(list(Pspread))] +def run_B04_angle( + track_name: str = Option(..., callback=validate_track_name_steps_gt_1), + batch_key: str = Option(..., callback=validate_batch_key), + ID_flag: bool = True, + output_dir: str = Option(..., callback=validate_output_dir), + verbose: bool = False, +): + """ + TODO: add docstring + """ -# this is a hack since the current data does not have a spread -Pspread[Pspread == 0] = 20 + color_schemes.colormaps2(21) -# reset dirs: -Pdir[Pdir > 180] = Pdir[Pdir > 180] - 360 -Pdir[Pdir < -180] = Pdir[Pdir < -180] + 360 + col_dict = color_schemes.rels -# reorder dirs -dir_best = [0] -for dir in Pdir: - ip = np.argmin( + track_name, batch_key, test_flag = io.init_from_input( [ - abs(dir_best[-1] - dir), - abs(dir_best[-1] - (dir - 360)), - abs(dir_best[-1] - (dir + 360)), + None, + track_name, + batch_key, + ID_flag, ] ) - new_dir = np.array([dir, (dir - 360), (dir + 360)])[ip] - dir_best.append(new_dir) -dir_best = np.array(dir_best[1:]) - -if len(Pperiod) == 0: - print("constant peak wave number") - kk = Gk.k - Pwavenumber = kk * 0 + (2 * np.pi / (1 / Prior.loc["fp"]["mean"])) ** 2 / 9.81 - dir_best = kk * 0 + Prior.loc["dp"]["mean"] - dir_interp_smth = dir_interp = kk * 0 + Prior.loc["dp"]["mean"] - spread_smth = spread_interp = kk * 0 + Prior.loc["spr"]["mean"] - - -else: - Pwavenumber = (2 * np.pi / Pperiod) ** 2 / 9.81 - kk = Gk.k - dir_interp = np.interp( - kk, Pwavenumber[Pwavenumber.argsort()], dir_best[Pwavenumber.argsort()] - ) - dir_interp_smth = M.runningmean(dir_interp, 30, tailcopy=True) - dir_interp_smth[-1] = dir_interp_smth[-2] - spread_interp = np.interp( - kk, - Pwavenumber[Pwavenumber.argsort()], - Pspread[Pwavenumber.argsort()].astype("float"), - ) - spread_smth = M.runningmean(spread_interp, 30, tailcopy=True) - spread_smth[-1] = spread_smth[-2] - - -font_for_pres() - -F = M.figure_axis_xy(5, 4.5, view_scale=0.5) -plt.subplot(2, 1, 1) -plt.title("Prior angle smoothed\n" + track_name, loc="left") - - -plt.plot(Pwavenumber, dir_best, ".r", markersize=8) -plt.plot(kk, dir_interp, "-", color="red", linewidth=0.8, zorder=11) -plt.plot(kk, dir_interp_smth, color=color_schemes.green1) - -plt.fill_between( - kk, - dir_interp_smth - spread_smth, - dir_interp_smth + spread_smth, - zorder=1, - color=color_schemes.green1, - alpha=0.2, -) -plt.ylabel("Angle (deg)") - -ax2 = plt.subplot(2, 1, 2) -plt.title("Prior angle adjusted ", loc="left") - -# adjust angle def: -dir_interp_smth[dir_interp_smth > 180] = dir_interp_smth[dir_interp_smth > 180] - 360 -dir_interp_smth[dir_interp_smth < -180] = dir_interp_smth[dir_interp_smth < -180] + 360 - -plt.fill_between( - kk, - dir_interp_smth - spread_smth, - dir_interp_smth + spread_smth, - zorder=1, - color=color_schemes.green1, - alpha=0.2, -) -plt.plot(kk, dir_interp_smth, ".", markersize=1, color=color_schemes.green1) - -ax2.axhline(85, color="gray", linewidth=2) -ax2.axhline(-85, color="gray", linewidth=2) - -plt.ylabel("Angle (deg)") -plt.xlabel("wavenumber ($2 \pi/\lambda$)") - -F.save_light(path=plot_path, name="B04_prior_angle") - -# save -dir_interp_smth = xr.DataArray( - data=dir_interp_smth * np.pi / 180, - dims="k", - coords={"k": kk}, - name="Prior_direction", -) -spread_smth = xr.DataArray( - data=spread_smth * np.pi / 180, dims="k", coords={"k": kk}, name="Prior_spread" -) -Prior_smth = xr.merge([dir_interp_smth, spread_smth]) - -prior_angle = Prior_smth.Prior_direction * 180 / np.pi -if (abs(prior_angle) > 80).all(): - print("Prior angle is ", prior_angle.min().data, prior_angle.max().data, ". quit.") - dd_save = { - "time": time.asctime(time.localtime(time.time())), - "angle": list( - [ - float(prior_angle.min().data), - float(prior_angle.max().data), - float(prior_angle.median()), - ] - ), + kargs = { + "track_name": track_name, + "batch_key": batch_key, + "ID_flag": ID_flag, + "output_dir": output_dir, } - MT.json_save("B04_fail", plot_path, dd_save) - print("exit()") - exit() + report_input_parameters(**kargs) -# define paramater range -params_dict = { - "alpha": [-0.85 * np.pi / 2, 0.85 * np.pi / 2, 5], - "phase": [0, 2 * np.pi, 10], -} + with suppress_stdout(verbose): -alpha_dx = 0.02 -max_wavenumbers = 25 + hemis, batch = batch_key.split("_") -sample_flag = True -optimize_flag = False -brute_flag = False + workdir, plotsdir = update_paths_mconfig(output_dir, mconfig) -plot_flag = False + save_path = workdir / batch_key / "B04_angle" + plot_path = plotsdir / hemis / batch_key / track_name + save_path.mkdir(parents=True, exist_ok=True) + plot_path.mkdir(parents=True, exist_ok=True) -Nworkers = 6 -N_sample_chain = 300 -N_sample_chain_burn = 30 + all_beams = mconfig["beams"]["all_beams"] + beam_groups = mconfig["beams"]["groups"] -max_x_pos = 8 -x_pos_jump = 2 + load_path = workdir / batch_key / "B01_regrid" + G_binned_store = h5py.File(load_path / (track_name + "_B01_binned.h5"), "r") + G_binned = dict() + for b in all_beams: + G_binned[b] = io.get_beam_hdf_store(G_binned_store[b]) + G_binned_store.close() + load_path = workdir / batch_key / "B02_spectra" + xtrack, ktrack = [ + load_path / ("B02_" + track_name + f"_gFT_{var}.nc") for var in ["x", "k"] + ] -def make_fake_data(xi, group): - ki = Gk.k[0:2] + Gx = xr.load_dataset(xtrack) + Gk = xr.load_dataset(ktrack) - bins = np.arange( - params_dict["alpha"][0], params_dict["alpha"][1] + alpha_dx, alpha_dx - ) - bins_pos = bins[0:-1] + np.diff(bins) / 2 - marginal_stack = xr.DataArray( - np.nan * np.vstack([bins_pos, bins_pos]).T, - dims=("angle", "k"), - coords={"angle": bins_pos, "k": ki.data}, - ) + # load prior information + load_path = workdir / batch_key / "A02_prior" - group_name = str("group" + group[0].split("gt")[1].split("l")[0]) - marginal_stack.coords["beam_group"] = group_name - marginal_stack.coords["x"] = xi - marginal_stack.name = "marginals" - marginal_stack.expand_dims(dim="x", axis=2).expand_dims(dim="beam_group", axis=3) - return marginal_stack + try: + Prior = MT.load_pandas_table_dict("/A02_" + track_name, str(load_path))[ + "priors_hindcast" + ] + except FileNotFoundError: + print("Prior not found. exit") + MT.json_save( + "B04_fail", + plot_path, + { + "time": time.asctime(time.localtime(time.time())), + "reason": "Prior not found", + }, + ) + exit() + + if np.isnan(Prior["mean"]["dir"]): + print("Prior failed, entries are nan. exit.") + MT.json_save( + "B04_fail", + plot_path, + { + "time": time.asctime(time.localtime(time.time())), + "reason": "Prior not found", + }, + ) + exit() + + #### Define Prior + Pperiod = Prior.loc[["ptp0", "ptp1", "ptp2", "ptp3", "ptp4", "ptp5"]]["mean"] + Pdir = Prior.loc[["pdp0", "pdp1", "pdp2", "pdp3", "pdp4", "pdp5"]][ + "mean" + ].astype("float") + Pspread = Prior.loc[["pspr0", "pspr1", "pspr2", "pspr3", "pspr4", "pspr5"]][ + "mean" + ] + Pperiod = Pperiod[~np.isnan(list(Pspread))] + Pdir = Pdir[~np.isnan(list(Pspread))] + Pspread = Pspread[~np.isnan(list(Pspread))] + + # this is a hack since the current data does not have a spread + Pspread[Pspread == 0] = 20 + + # reset dirs: + Pdir[Pdir > 180] = Pdir[Pdir > 180] - 360 + Pdir[Pdir < -180] = Pdir[Pdir < -180] + 360 + + # reorder dirs + dir_best = [0] + for dir in Pdir: + ip = np.argmin( + [ + abs(dir_best[-1] - dir), + abs(dir_best[-1] - (dir - 360)), + abs(dir_best[-1] - (dir + 360)), + ] + ) + new_dir = np.array([dir, (dir - 360), (dir + 360)])[ip] + dir_best.append(new_dir) + dir_best = np.array(dir_best[1:]) + + if len(Pperiod) == 0: + print("constant peak wave number") + kk = Gk.k + Pwavenumber = kk * 0 + (2 * np.pi / (1 / Prior.loc["fp"]["mean"])) ** 2 / g + dir_best = kk * 0 + Prior.loc["dp"]["mean"] + dir_interp_smth = dir_interp = kk * 0 + Prior.loc["dp"]["mean"] + spread_smth = spread_interp = kk * 0 + Prior.loc["spr"]["mean"] -def define_wavenumber_weights_tot_var( - dd, m=3, variance_frac=0.33, k_upper_lim=None, verbose=False -): - """ - return peaks of a power spectrum dd that in the format such that they can be used as weights for the frequencies based fitting + else: + Pwavenumber = (2 * np.pi / Pperiod) ** 2 / g + kk = Gk.k + dir_interp = np.interp( + kk, Pwavenumber[Pwavenumber.argsort()], dir_best[Pwavenumber.argsort()] + ) + dir_interp_smth = M.runningmean(dir_interp, 30, tailcopy=True) + dir_interp_smth[-1] = dir_interp_smth[-2] - inputs: - dd xarray with PSD as data amd coordindate wavenumber k - m running mean half-width in gridpoints - variance_frac (0 to 1) How much variance should be explained by the returned peaks - verbose if true it plots some stuff + spread_interp = np.interp( + kk, + Pwavenumber[Pwavenumber.argsort()], + Pspread[Pwavenumber.argsort()].astype("float"), + ) + spread_smth = M.runningmean(spread_interp, 30, tailcopy=True) + spread_smth[-1] = spread_smth[-2] + + font_for_pres() + + F = M.figure_axis_xy(5, 4.5, view_scale=0.5) + plt.subplot(2, 1, 1) + plt.title("Prior angle smoothed\n" + track_name, loc="left") + + plt.plot(Pwavenumber, dir_best, ".r", markersize=8) + plt.plot(kk, dir_interp, "-", color="red", linewidth=0.8, zorder=11) + plt.plot(kk, dir_interp_smth, color=color_schemes.green1) + + plt.fill_between( + kk, + dir_interp_smth - spread_smth, + dir_interp_smth + spread_smth, + zorder=1, + color=color_schemes.green1, + alpha=0.2, + ) + plt.ylabel("Angle (deg)") + ax2 = plt.subplot(2, 1, 2) + plt.title("Prior angle adjusted ", loc="left") - return: - mask size of dd. where True the data is identified as having significant amplitude - k wanumbers where mask is true - dd_rm smoothed version of dd - positions postions where of significant data in array - """ + # adjust angle def: + dir_interp_smth[dir_interp_smth > 180] = ( + dir_interp_smth[dir_interp_smth > 180] - 360 + ) + dir_interp_smth[dir_interp_smth < -180] = ( + dir_interp_smth[dir_interp_smth < -180] + 360 + ) - if len(dd.shape) == 2: - dd_use = dd.mean("beam") + plt.fill_between( + kk, + dir_interp_smth - spread_smth, + dir_interp_smth + spread_smth, + zorder=1, + color=color_schemes.green1, + alpha=0.2, + ) + plt.plot(kk, dir_interp_smth, ".", markersize=1, color=color_schemes.green1) - if m is None: - dd_rm = dd_use.data - else: - dd_rm = M.runningmean(dd_use, m, tailcopy=True) + ax2.axhline(85, color="gray", linewidth=2) + ax2.axhline(-85, color="gray", linewidth=2) - k = dd_use.k[~np.isnan(dd_rm)].data - dd_rm = dd_rm[~np.isnan(dd_rm)] + plt.ylabel("Angle (deg)") + plt.xlabel("wavenumber ($2 \pi/\lambda$)") - orders = dd_rm.argsort()[::-1] - var_mask = dd_rm[orders].cumsum() / dd_rm.sum() < variance_frac - pos_cumsum = orders[var_mask] - mask = var_mask[orders.argsort()] - if k_upper_lim is not None: - mask = (k < k_upper_lim) & mask + F.save_light(path=plot_path, name="B04_prior_angle") - if verbose: - plt.plot( - dd.k, - dd, - "-", - color=col_dict[str(amp_data.beam[0].data)], - markersize=20, - alpha=0.6, + # save + dir_interp_smth = xr.DataArray( + data=dir_interp_smth * np.pi / 180, + dims="k", + coords={"k": kk}, + name="Prior_direction", + ) + spread_smth = xr.DataArray( + data=spread_smth * np.pi / 180, + dims="k", + coords={"k": kk}, + name="Prior_spread", ) - plt.plot(k, dd_rm, "-k", markersize=20) + Prior_smth = xr.merge([dir_interp_smth, spread_smth]) + + prior_angle = Prior_smth.Prior_direction * 180 / np.pi + if (abs(prior_angle) > 80).all(): + print( + "Prior angle is ", + prior_angle.min().data, + prior_angle.max().data, + ". quit.", + ) + dd_save = { + "time": time.asctime(time.localtime(time.time())), + "angle": list( + [ + float(prior_angle.min().data), + float(prior_angle.max().data), + float(prior_angle.median()), + ] + ), + } + MT.json_save("B04_fail", plot_path, dd_save) + print("exit()") + exit() + + # define parameter range + params_dict = { + "alpha": [-0.85 * np.pi / 2, 0.85 * np.pi / 2, 5], + "phase": [0, 2 * np.pi, 10], + } - plt.plot(k[mask], dd_rm[mask], ".r", markersize=10, zorder=12) - if k_upper_lim is not None: - plt.gca().axvline(k_upper_lim, color="black") + alpha_dx = 0.02 + max_wavenumbers = 25 - return mask, k, dd_rm, pos_cumsum + sample_flag = True + optimize_flag = False + brute_flag = False + plot_flag = False -def define_wavenumber_weights_threshold(dd, m=3, Nstd=2, verbose=False): - if m is None: - dd_rm = dd - else: - dd_rm = M.runningmean(dd, m, tailcopy=True) + # Nworkers = 6 for later parallelization? + N_sample_chain = 300 + N_sample_chain_burn = 30 - k = dd.k[~np.isnan(dd_rm)] - dd_rm = dd_rm[~np.isnan(dd_rm)] + max_x_pos = 8 + x_pos_jump = 2 - treshold = np.nanmean(dd_rm) + np.nanstd(dd_rm) * Nstd - mask = dd_rm > treshold + def make_fake_data(xi, group): + ki = Gk.k[0:2] - if verbose: - plt.plot(dd.k, dd, "-k", markersize=20) - plt.plot(k, dd_rm, "-b", markersize=20) + bins = np.arange( + params_dict["alpha"][0], params_dict["alpha"][1] + alpha_dx, alpha_dx + ) + bins_pos = bins[0:-1] + np.diff(bins) / 2 + marginal_stack = xr.DataArray( + np.nan * np.vstack([bins_pos, bins_pos]).T, + dims=("angle", "k"), + coords={"angle": bins_pos, "k": ki.data}, + ) - k_list = k[mask] - dd_list = dd_rm[mask] + group_name = str("group" + group[0].split("gt")[1].split("l")[0]) + marginal_stack.coords["beam_group"] = group_name + marginal_stack.coords["x"] = xi + marginal_stack.name = "marginals" + marginal_stack.expand_dims(dim="x", axis=2).expand_dims( + dim="beam_group", axis=3 + ) + return marginal_stack + + def define_wavenumber_weights_tot_var( + dd, m=3, variance_frac=0.33, k_upper_lim=None, verbose=False + ): + """ + return peaks of a power spectrum dd that in the format such that they can be used as weights for the frequencies based fitting + + inputs: + dd xarray with PSD as data amd coordinate wavenumber k + m running mean half-width in gridpoints + variance_frac (0 to 1) How much variance should be explained by the returned peaks + verbose if true it plots some stuff + + + return: + mask size of dd. where True the data is identified as having significant amplitude + k wanumbers where mask is true + dd_rm smoothed version of dd + positions positions where of significant data in array + """ + + if len(dd.shape) == 2: + dd_use = dd.mean("beam") + + if m is None: + dd_rm = dd_use.data + else: + dd_rm = M.runningmean(dd_use, m, tailcopy=True) + + k = dd_use.k[~np.isnan(dd_rm)].data + dd_rm = dd_rm[~np.isnan(dd_rm)] + + orders = dd_rm.argsort()[::-1] + var_mask = dd_rm[orders].cumsum() / dd_rm.sum() < variance_frac + pos_cumsum = orders[var_mask] + mask = var_mask[orders.argsort()] + if k_upper_lim is not None: + mask = (k < k_upper_lim) & mask + + if verbose: + plt.plot( + dd.k, + dd, + "-", + color=col_dict[str(amp_data.beam[0].data)], + markersize=20, + alpha=0.6, + ) + plt.plot(k, dd_rm, "-k", markersize=20) + + plt.plot(k[mask], dd_rm[mask], ".r", markersize=10, zorder=12) + if k_upper_lim is not None: + plt.gca().axvline(k_upper_lim, color="black") + + return mask, k, dd_rm, pos_cumsum + + def define_wavenumber_weights_threshold(dd, m=3, Nstd=2, verbose=False): + if m is None: + dd_rm = dd + else: + dd_rm = M.runningmean(dd, m, tailcopy=True) + + k = dd.k[~np.isnan(dd_rm)] + dd_rm = dd_rm[~np.isnan(dd_rm)] + + treshold = np.nanmean(dd_rm) + np.nanstd(dd_rm) * Nstd + mask = dd_rm > treshold + + if verbose: + plt.plot(dd.k, dd, "-k", markersize=20) + plt.plot(k, dd_rm, "-b", markersize=20) + + k_list = k[mask] + dd_list = dd_rm[mask] + + plt.plot(k_list, dd_list, ".r", markersize=10, zorder=12) + + return mask, k, dd_rm, np.arange(0, mask.size)[mask] + + def plot_instance( + z_model, + fargs, + key, + SM, + non_dim=False, + title_str=None, + brute=False, + optimze=False, + sample=False, + view_scale=0.3, + ): + x_concat, y_concat, z_concat = fargs + + F = M.figure_axis_xy(5, 6, view_scale=view_scale, container=True) + plt.suptitle(title_str) + gs = GridSpec(4, 3, wspace=0.4, hspace=1.2) + F.gs = gs + + col_list = itertools.cycle( + [ + color_schemes.cascade2, + color_schemes.rascade2, + color_schemes.cascade1, + color_schemes.rascade1, + color_schemes.cascade3, + color_schemes.rascade3, + ] + ) - plt.plot(k_list, dd_list, ".r", markersize=10, zorder=12) + beam_list = list(set(y_concat)) + for y_pos, pos in zip(beam_list, [gs[0, :], gs[1, :]]): + F.ax2 = F.fig.add_subplot(pos) - return mask, k, dd_rm, np.arange(0, mask.size)[mask] + plt.title(str(y_pos)) + plt.plot( + x_concat[y_concat == y_pos], + z_concat[y_concat == y_pos], + c=color_schemes.gray, + linewidth=1, + ) + plt.plot( + x_concat[y_concat == y_pos], + z_model[y_concat == y_pos], + "-", + c=next(col_list), + ) + plt.xlim( + x_concat[y_concat == y_pos][0], x_concat[y_concat == y_pos][-1] + ) -def plot_instance( - z_model, - fargs, - key, - SM, - non_dim=False, - title_str=None, - brute=False, - optimze=False, - sample=False, - view_scale=0.3, -): - x_concat, y_concat, z_concat = fargs + plt.xlabel("meter") + F.ax3 = F.fig.add_subplot(gs[2:, 0:-1]) + if brute is True: + plt.title("Brute-force costs", loc="left") + SM.plot_brute( + marker=".", color="blue", markersize=15, label="Brute", zorder=10 + ) - F = M.figure_axis_xy(5, 6, view_scale=view_scale, container=True) - plt.suptitle(title_str) - gs = GridSpec(4, 3, wspace=0.4, hspace=1.2) - F.gs = gs + if optimze is True: + SM.plot_optimze( + color="r", markersize=10, zorder=12, label="Dual Annealing" + ) - import itertools + if sample is True: + SM.plot_sample( + markersize=2, linewidth=0.8, alpha=0.2, color="black", zorder=8 + ) - col_list = itertools.cycle( - [ - color_schemes.cascade2, - color_schemes.rascade2, - color_schemes.cascade1, - color_schemes.rascade1, - color_schemes.cascade3, - color_schemes.rascade3, - ] - ) + F.ax4 = F.fig.add_subplot(gs[2:, -1]) + return F - beam_list = list(set(y_concat)) - for y_pos, pos in zip(beam_list, [gs[0, :], gs[1, :]]): - F.ax2 = F.fig.add_subplot(pos) + # isolate x positions with data + data_mask = Gk.gFT_PSD_data.mean("k") + data_mask.coords["beam_group"] = ( + "beam", + ["beam_group" + g_[2] for g_ in data_mask.beam.data], + ) + data_mask_group = data_mask.groupby("beam_group").mean(skipna=False) + # these stencils are actually used + data_sel_mask = data_mask_group.sum("beam_group") != 0 + + x_list = data_sel_mask.x[data_sel_mask] # iterate over these x positions + x_list_flag = ~np.isnan( + data_mask_group.sel(x=x_list) + ) # flag that is False if there is no data + + #### limit number of x coordinates + + x_list = x_list[::x_pos_jump] + if len(x_list) > max_x_pos: + x_list = x_list[0:max_x_pos] + x_list_flag = x_list_flag.sel(x=x_list) + + # plot + font_for_print() + F = M.figure_axis_xy(5.5, 3, view_scale=0.8) + plt.suptitle(track_name) + ax1 = plt.subplot(2, 1, 1) + plt.title("Data in Beam", loc="left") + plt.pcolormesh(data_mask.x / 1e3, data_mask.beam, data_mask, cmap=plt.cm.OrRd) + for i in np.arange(1.5, 6, 2): + ax1.axhline(i, color="black", linewidth=0.5) + plt.xlabel("Distance from Ice Edge") + + ax2 = plt.subplot(2, 1, 2) + plt.title("Data in Group", loc="left") + plt.pcolormesh( + data_mask.x / 1e3, + data_mask_group.beam_group, + data_mask_group, + cmap=plt.cm.OrRd, + ) - plt.title(str(y_pos)) + for i in np.arange(0.5, 3, 1): + ax2.axhline(i, color="black", linewidth=0.5) plt.plot( - x_concat[y_concat == y_pos], - z_concat[y_concat == y_pos], - c=color_schemes.gray, - linewidth=1, + x_list / 1e3, + x_list * 0 + 0, + ".", + markersize=2, + color=color_schemes.cascade1, ) plt.plot( - x_concat[y_concat == y_pos], - z_model[y_concat == y_pos], - "-", - c=next(col_list), + x_list / 1e3, + x_list * 0 + 1, + ".", + markersize=2, + color=color_schemes.cascade1, + ) + plt.plot( + x_list / 1e3, + x_list * 0 + 2, + ".", + markersize=2, + color=color_schemes.cascade1, ) - plt.xlim(x_concat[y_concat == y_pos][0], x_concat[y_concat == y_pos][-1]) - - plt.xlabel("meter") - F.ax3 = F.fig.add_subplot(gs[2:, 0:-1]) - if brute is True: - plt.title("Brute-force costs", loc="left") - SM.plot_brute(marker=".", color="blue", markersize=15, label="Brute", zorder=10) - - if optimze is True: - SM.plot_optimze(color="r", markersize=10, zorder=12, label="Dual Annealing") - - if sample is True: - SM.plot_sample(markersize=2, linewidth=0.8, alpha=0.2, color="black", zorder=8) - - F.ax4 = F.fig.add_subplot(gs[2:, -1]) - return F - - -# isolate x positions with data -data_mask = Gk.gFT_PSD_data.mean("k") -data_mask.coords["beam_group"] = ( - "beam", - ["beam_group" + g[2] for g in data_mask.beam.data], -) -data_mask_group = data_mask.groupby("beam_group").mean(skipna=False) -# these stancils are actually used -data_sel_mask = data_mask_group.sum("beam_group") != 0 - -x_list = data_sel_mask.x[data_sel_mask] # iterate over these x posistions -x_list_flag = ~np.isnan( - data_mask_group.sel(x=x_list) -) # flag that is False if there is no data - -#### limit number of x coordinates - -x_list = x_list[::x_pos_jump] -if len(x_list) > max_x_pos: - x_list = x_list[0:max_x_pos] -x_list_flag = x_list_flag.sel(x=x_list) - -# plot -font_for_print() -F = M.figure_axis_xy(5.5, 3, view_scale=0.8) -plt.suptitle(track_name) -ax1 = plt.subplot(2, 1, 1) -plt.title("Data in Beam", loc="left") -plt.pcolormesh(data_mask.x / 1e3, data_mask.beam, data_mask, cmap=plt.cm.OrRd) -for i in np.arange(1.5, 6, 2): - ax1.axhline(i, color="black", linewidth=0.5) -plt.xlabel("Distance from Ice Edge") - -ax2 = plt.subplot(2, 1, 2) -plt.title("Data in Group", loc="left") -plt.pcolormesh( - data_mask.x / 1e3, data_mask_group.beam_group, data_mask_group, cmap=plt.cm.OrRd -) - -for i in np.arange(0.5, 3, 1): - ax2.axhline(i, color="black", linewidth=0.5) - -plt.plot(x_list / 1e3, x_list * 0 + 0, ".", markersize=2, color=color_schemes.cascade1) -plt.plot(x_list / 1e3, x_list * 0 + 1, ".", markersize=2, color=color_schemes.cascade1) -plt.plot(x_list / 1e3, x_list * 0 + 2, ".", markersize=2, color=color_schemes.cascade1) - -plt.xlabel("Distance from Ice Edge") -F.save_pup(path=plot_path, name="B04_data_avail") + plt.xlabel("Distance from Ice Edge") -Marginals = dict() -L_collect = dict() + F.save_pup(path=plot_path, name="B04_data_avail") -group_number = np.arange(len(beam_groups)) -ggg, xxx = np.meshgrid(group_number, x_list.data) + Marginals = dict() + L_collect = dict() -for gi in zip(ggg.flatten(), xxx.flatten()): - print(gi) + group_number = np.arange(len(beam_groups)) + ggg, xxx = np.meshgrid(group_number, x_list.data) - group, xi = beam_groups[gi[0]], gi[1] + for gi in zip(ggg.flatten(), xxx.flatten()): + print(gi) - if bool(x_list_flag.sel(x=xi).isel(beam_group=gi[0]).data) is False: - print("no data, fill with dummy") - ikey = str(xi) + "_" + "_".join(group) - Marginals[ikey] = make_fake_data(xi, group) - continue + group, xi = beam_groups[gi[0]], gi[1] - GGx = Gx.sel(beam=group).sel(x=xi) - GGk = Gk.sel(beam=group).sel(x=xi) + if bool(x_list_flag.sel(x=xi).isel(beam_group=gi[0]).data) is False: + print("no data, fill with dummy") + ikey = str(xi) + "_" + "_".join(group) + Marginals[ikey] = make_fake_data(xi, group) + continue - ### define data - # normalize data - key = "y_data" - amp_Z = (GGx[key] - GGx[key].mean(["eta"])) / GGx[key].std(["eta"]) - N = amp_Z.shape[0] + GGx = Gx.sel(beam=group).sel(x=xi) + GGk = Gk.sel(beam=group).sel(x=xi) - # define x,y positions - eta_2d = GGx.eta + GGx.x_coord - GGx.x_coord.mean() - nu_2d = GGx.eta * 0 + GGx.y_coord - GGx.y_coord.mean() + ### define data + # normalize data + key = "y_data" + amp_Z = (GGx[key] - GGx[key].mean(["eta"])) / GGx[key].std(["eta"]) - # repack as np arrays - x_concat = eta_2d.data.T.flatten() - y_concat = nu_2d.data.T.flatten() - z_concat = amp_Z.data.flatten() + # define x,y positions + eta_2d = GGx.eta + GGx.x_coord - GGx.x_coord.mean() + nu_2d = GGx.eta * 0 + GGx.y_coord - GGx.y_coord.mean() - x_concat = x_concat[~np.isnan(z_concat)] - y_concat = y_concat[~np.isnan(z_concat)] - z_concat = z_concat[~np.isnan(z_concat)] - N_data = x_concat.size + # repack as np arrays + x_concat = eta_2d.data.T.flatten() + y_concat = nu_2d.data.T.flatten() + z_concat = amp_Z.data.flatten() - if np.isnan(z_concat).sum() != 0: - raise ValueError("There are still nans") + x_concat = x_concat[~np.isnan(z_concat)] + y_concat = y_concat[~np.isnan(z_concat)] + z_concat = z_concat[~np.isnan(z_concat)] + N_data = x_concat.size - mean_dist = (nu_2d.isel(beam=0) - nu_2d.isel(beam=1)).mean().data - k_upper_lim = 2 * np.pi / (mean_dist * 1) + if np.isnan(z_concat).sum() != 0: + raise ValueError("There are still nans") - print("k_upper_lim ", k_upper_lim) + mean_dist = (nu_2d.isel(beam=0) - nu_2d.isel(beam=1)).mean().data + k_upper_lim = 2 * np.pi / (mean_dist * 1) - # variance method - amp_data = np.sqrt(GGk.gFT_cos_coeff**2 + GGk.gFT_sin_coeff**2) - mask, k, weights, positions = define_wavenumber_weights_tot_var( - amp_data, m=1, k_upper_lim=k_upper_lim, variance_frac=0.20, verbose=False - ) + print("k_upper_lim ", k_upper_lim) - if len(k[mask]) == 0: - print("no good k found, fill with dummy") - ikey = str(xi) + "_" + "_".join(group) - Marginals[ikey] = make_fake_data(xi, group) - continue - - SM = angle_optimizer.sample_with_mcmc(params_dict) - SM.set_objective_func(angle_optimizer.objective_func) - nan_list = np.isnan(x_concat) | np.isnan(y_concat) | np.isnan(y_concat) - x_concat[nan_list] = [] - y_concat[nan_list] = [] - z_concat[nan_list] = [] - SM.fitting_args = fitting_args = (x_concat, y_concat, z_concat) - - # test: - k_prime_max = 0.02 - amp_Z = 1 - prior_sel = { - "alpha": ( - Prior_smth.sel(k=k_prime_max, method="nearest").Prior_direction.data, - Prior_smth.sel(k=k_prime_max, method="nearest").Prior_spread.data, - ) - } - SM.fitting_kargs = fitting_kargs = {"prior": prior_sel, "prior_weight": 3} - # test if it works - SM.params.add( - "K_prime", k_prime_max, vary=False, min=k_prime_max * 0.5, max=k_prime_max * 1.5 - ) - SM.params.add("K_amp", amp_Z, vary=False, min=amp_Z * 0.0, max=amp_Z * 5) - try: - SM.test_objective_func() - except: - raise ValueError("Objective function test fails") - - def get_instance(k_pair): - k_prime_max, Z_max = k_pair - - prior_sel = { - "alpha": ( - Prior_smth.sel(k=k_prime_max, method="nearest").Prior_direction.data, - Prior_smth.sel(k=k_prime_max, method="nearest").Prior_spread.data, + # variance method + amp_data = np.sqrt(GGk.gFT_cos_coeff**2 + GGk.gFT_sin_coeff**2) + mask, k, weights, positions = define_wavenumber_weights_tot_var( + amp_data, + m=1, + k_upper_lim=k_upper_lim, + variance_frac=0.20, + verbose=False, ) - } - SM.fitting_kargs = fitting_kargs = {"prior": prior_sel, "prior_weight": 2} - - amp_Z = 1 - SM.params.add( - "K_prime", - k_prime_max, - vary=False, - min=k_prime_max * 0.5, - max=k_prime_max * 1.5, - ) - SM.params.add("K_amp", amp_Z, vary=False, min=amp_Z * 0.0, max=amp_Z * 5) - - L_sample_i = None - L_optimize_i = None - L_brute_i = None - if sample_flag: - SM.sample(verbose=False, steps=N_sample_chain, progress=False, workers=None) - L_sample_i = list(SM.fitter.params.valuesdict().values()) # mcmc results - - elif optimize_flag: - SM.optimize(verbose=False) - L_optimize_i = list( - SM.fitter_optimize.params.valuesdict().values() - ) # mcmc results - - elif brute_flag: - SM.brute(verbose=False) - L_brute_i = list( - SM.fitter_brute.params.valuesdict().values() - ) # mcmc results - else: - raise ValueError( - "non of sample_flag,optimize_flag, or brute_flag are True" - ) - - y_hist, bins, bins_pos = SM.get_marginal_dist( - "alpha", alpha_dx, burn=N_sample_chain_burn, plot_flag=False - ) - fitter = SM.fitter # MCMC results - z_model = SM.objective_func(fitter.params, *fitting_args, test_flag=True) - cost = (fitter.residual**2).sum() / (z_concat**2).sum() - - if plot_flag: - F = plot_instance( - z_model, - fitting_args, - "y_data_normed", - SM, - brute=brute_flag, - optimze=optimize_flag, - sample=sample_flag, - title_str="k=" + str(np.round(k_prime_max, 4)), - view_scale=0.6, + if len(k[mask]) == 0: + print("no good k found, fill with dummy") + ikey = str(xi) + "_" + "_".join(group) + Marginals[ikey] = make_fake_data(xi, group) + continue + + SM = angle_optimizer.sample_with_mcmc(params_dict) + SM.set_objective_func(angle_optimizer.objective_func) + nan_list = np.isnan(x_concat) | np.isnan(y_concat) | np.isnan(y_concat) + x_concat[nan_list] = [] + y_concat[nan_list] = [] + z_concat[nan_list] = [] + SM.fitting_args = fitting_args = (x_concat, y_concat, z_concat) + + # test: + k_prime_max = 0.02 + amp_Z = 1 + prior_sel = { + "alpha": ( + Prior_smth.sel( + k=k_prime_max, method="nearest" + ).Prior_direction.data, + Prior_smth.sel(k=k_prime_max, method="nearest").Prior_spread.data, + ) + } + SM.fitting_kargs = { + "prior": prior_sel, + "prior_weight": 3, + } # not sure this might be used somewhere. CP + + # test if it works + SM.params.add( + "K_prime", + k_prime_max, + vary=False, + min=k_prime_max * 0.5, + max=k_prime_max * 1.5, ) + SM.params.add("K_amp", amp_Z, vary=False, min=amp_Z * 0.0, max=amp_Z * 5) + try: + SM.test_objective_func() + except ValueError: + raise ValueError("Objective function test fails") + + def get_instance(k_pair): + k_prime_max, Z_max = k_pair + + prior_sel = { + "alpha": ( + Prior_smth.sel( + k=k_prime_max, method="nearest" + ).Prior_direction.data, + Prior_smth.sel( + k=k_prime_max, method="nearest" + ).Prior_spread.data, + ) + } + + SM.fitting_kargs = { + "prior": prior_sel, + "prior_weight": 2, + } # not sure this might be used somewhere. CP + + amp_Z = 1 + SM.params.add( + "K_prime", + k_prime_max, + vary=False, + min=k_prime_max * 0.5, + max=k_prime_max * 1.5, + ) + SM.params.add( + "K_amp", amp_Z, vary=False, min=amp_Z * 0.0, max=amp_Z * 5 + ) - if fitting_kargs["prior"] is not None: - F.ax3.axhline( - prior_sel["alpha"][0], color="green", linewidth=2, label="Prior" + L_sample_i = None + L_optimize_i = None + L_brute_i = None + if sample_flag: + SM.sample( + verbose=False, + steps=N_sample_chain, + progress=False, + workers=None, + ) + L_sample_i = list( + SM.fitter.params.valuesdict().values() + ) # mcmc results + + elif optimize_flag: + SM.optimize(verbose=False) + L_optimize_i = list( + SM.fitter_optimize.params.valuesdict().values() + ) # mcmc results + + elif brute_flag: + SM.brute(verbose=False) + L_brute_i = list( + SM.fitter_brute.params.valuesdict().values() + ) # mcmc results + else: + raise ValueError( + "non of sample_flag,optimize_flag, or brute_flag are True" + ) + + y_hist, bins, bins_pos = SM.get_marginal_dist( + "alpha", alpha_dx, burn=N_sample_chain_burn, plot_flag=False ) - F.ax3.axhline( - prior_sel["alpha"][0] - prior_sel["alpha"][1], - color="green", - linewidth=0.7, + fitter = SM.fitter # MCMC results + z_model = SM.objective_func( + fitter.params, *fitting_args, test_flag=True ) - F.ax3.axhline( - prior_sel["alpha"][0] + prior_sel["alpha"][1], - color="green", - linewidth=0.7, + cost = (fitter.residual**2).sum() / (z_concat**2).sum() + + if plot_flag: + F = plot_instance( + z_model, + fitting_args, + "y_data_normed", + SM, + brute=brute_flag, + optimze=optimize_flag, + sample=sample_flag, + title_str="k=" + str(np.round(k_prime_max, 4)), + view_scale=0.6, + ) + + if not prior_sel: # check if prior is empty + F.ax3.axhline( + prior_sel["alpha"][0], + color="green", + linewidth=2, + label="Prior", + ) + F.ax3.axhline( + prior_sel["alpha"][0] - prior_sel["alpha"][1], + color="green", + linewidth=0.7, + ) + F.ax3.axhline( + prior_sel["alpha"][0] + prior_sel["alpha"][1], + color="green", + linewidth=0.7, + ) + + F.ax3.axhline( + fitter.params["alpha"].min, color="gray", linewidth=2, alpha=0.6 + ) + F.ax3.axhline( + fitter.params["alpha"].max, color="gray", linewidth=2, alpha=0.6 + ) + + plt.sca(F.ax3) + plt.legend() + plt.xlabel("Phase") + plt.ylabel("Angle") + plt.xlim(0, np.pi * 2) + + plt.sca(F.ax4) + plt.xlabel("Density") + plt.stairs(y_hist, bins, orientation="horizontal", color="k") + + F.ax4.axhline( + fitter.params["alpha"].min, color="gray", linewidth=2, alpha=0.6 + ) + F.ax4.axhline( + fitter.params["alpha"].max, color="gray", linewidth=2, alpha=0.6 + ) + + F.ax3.set_ylim( + min(-np.pi / 2, prior_sel["alpha"][0] - 0.2), + max(np.pi / 2, prior_sel["alpha"][0] + 0.2), + ) + F.ax4.set_ylim( + min(-np.pi / 2, prior_sel["alpha"][0] - 0.2), + max(np.pi / 2, prior_sel["alpha"][0] + 0.2), + ) + + plt.show() + F.save_light( + path=plot_path, name=track_name + "_fit_k" + str(k_prime_max) + ) + + marginal_stack_i = xr.DataArray( + y_hist, dims=("angle"), coords={"angle": bins_pos} ) - - F.ax3.axhline( - fitter.params["alpha"].min, color="gray", linewidth=2, alpha=0.6 - ) - F.ax3.axhline( - fitter.params["alpha"].max, color="gray", linewidth=2, alpha=0.6 + marginal_stack_i.coords["k"] = np.array(k_prime_max) + + rdict = { + "marginal_stack_i": marginal_stack_i, + "L_sample_i": L_sample_i, + "L_optimize_i": L_optimize_i, + "L_brute_i": L_brute_i, + "cost": cost, + } + return k_prime_max, rdict + + k_list, weight_list = k[mask], weights[mask] + print("# of wavenumber: ", len(k_list)) + if len(k_list) > max_wavenumbers: + print("cut wavenumber list to 20") + k_list = k_list[0:max_wavenumbers] + weight_list = weight_list[0:max_wavenumbers] + + # # parallel version tends to fail... + # Let's keep this in case we decided to work on parallelize it + # with futures.ProcessPoolExecutor(max_workers=Nworkers) as executor: + # A = dict( executor.map(get_instance, zip(k_list, weight_list) )) + + A = dict() + for k_pair in zip(k_list, weight_list): + kk, I = get_instance(k_pair) + A[kk] = I + + cost_stack = dict() + marginal_stack = dict() + L_sample = pd.DataFrame(index=["alpha", "group_phase", "K_prime", "K_amp"]) + L_optimize = pd.DataFrame( + index=["alpha", "group_phase", "K_prime", "K_amp"] ) + L_brute = pd.DataFrame(index=["alpha", "group_phase", "K_prime", "K_amp"]) - plt.sca(F.ax3) - plt.legend() - plt.xlabel("Phase") - plt.ylabel("Angle") - plt.xlim(0, np.pi * 2) + for kk, I in A.items(): + L_sample[kk] = I["L_sample_i"] + L_optimize[kk] = I["L_optimize_i"] + L_brute[kk] = I["L_brute_i"] - plt.sca(F.ax4) - plt.xlabel("Density") - plt.stairs(y_hist, bins, orientation="horizontal", color="k") + marginal_stack[kk] = I["marginal_stack_i"] + cost_stack[kk] = I["cost"] - F.ax4.axhline( - fitter.params["alpha"].min, color="gray", linewidth=2, alpha=0.6 - ) - F.ax4.axhline( - fitter.params["alpha"].max, color="gray", linewidth=2, alpha=0.6 - ) + # ## add beam_group dimension + marginal_stack = xr.concat(marginal_stack.values(), dim="k").sortby("k") + L_sample = L_sample.T.sort_values("K_prime") + L_optimize = L_optimize.T.sort_values("K_prime") + L_brute = L_brute.T.sort_values("K_prime") - F.ax3.set_ylim( - min(-np.pi / 2, prior_sel["alpha"][0] - 0.2), - max(np.pi / 2, prior_sel["alpha"][0] + 0.2), - ) - F.ax4.set_ylim( - min(-np.pi / 2, prior_sel["alpha"][0] - 0.2), - max(np.pi / 2, prior_sel["alpha"][0] + 0.2), - ) + print("done with ", group, xi / 1e3) - plt.show() - F.save_light(path=plot_path, name=track_name + "_fit_k" + str(k_prime_max)) + # collect + ikey = str(xi) + "_" + "_".join(group) - marginal_stack_i = xr.DataArray( - y_hist, dims=("angle"), coords={"angle": bins_pos} - ) - marginal_stack_i.coords["k"] = np.array(k_prime_max) - - rdict = { - "marginal_stack_i": marginal_stack_i, - "L_sample_i": L_sample_i, - "L_optimize_i": L_optimize_i, - "L_brute_i": L_brute_i, - "cost": cost, - } - return k_prime_max, rdict - - k_list, weight_list = k[mask], weights[mask] - print("# of wavenumber: ", len(k_list)) - if len(k_list) > max_wavenumbers: - print("cut wavenumber list to 20") - k_list = k_list[0:max_wavenumbers] - weight_list = weight_list[0:max_wavenumbers] - - # # parallel version tends to fail... - # Let's keep this in case we decided to work on parallellize it - # with futures.ProcessPoolExecutor(max_workers=Nworkers) as executor: - # A = dict( executor.map(get_instance, zip(k_list, weight_list) )) - - A = dict() - for k_pair in zip(k_list, weight_list): - kk, I = get_instance(k_pair) - A[kk] = I - - cost_stack = dict() - marginal_stack = dict() - L_sample = pd.DataFrame(index=["alpha", "group_phase", "K_prime", "K_amp"]) - L_optimize = pd.DataFrame(index=["alpha", "group_phase", "K_prime", "K_amp"]) - L_brute = pd.DataFrame(index=["alpha", "group_phase", "K_prime", "K_amp"]) - - for kk, I in A.items(): - L_sample[kk] = I["L_sample_i"] - L_optimize[kk] = I["L_optimize_i"] - L_brute[kk] = I["L_brute_i"] - - marginal_stack[kk] = I["marginal_stack_i"] - cost_stack[kk] = I["cost"] - - # ## add beam_group dimension - marginal_stack = xr.concat(marginal_stack.values(), dim="k").sortby("k") - L_sample = L_sample.T.sort_values("K_prime") - L_optimize = L_optimize.T.sort_values("K_prime") - L_brute = L_brute.T.sort_values("K_prime") - - print("done with ", group, xi / 1e3) - - # collect - ikey = str(xi) + "_" + "_".join(group) - - marginal_stack.name = "marginals" - marginal_stack = marginal_stack.to_dataset() - marginal_stack["cost"] = (("k"), list(cost_stack.values())) - marginal_stack["weight"] = (("k"), weight_list) - - group_name = str("group" + group[0].split("gt")[1].split("l")[0]) - marginal_stack.coords["beam_group"] = group_name - marginal_stack.coords["x"] = xi - - Marginals[ikey] = marginal_stack.expand_dims(dim="x", axis=0).expand_dims( - dim="beam_group", axis=1 - ) - Marginals[ikey].coords["N_data"] = ( - ("x", "beam_group"), - np.expand_dims(np.expand_dims(N_data, 0), 1), - ) - - L_sample["cost"] = cost_stack - L_sample["weight"] = weight_list - L_collect[group_name, str(int(xi))] = L_sample - - -MM = xr.merge(Marginals.values()) -MM = xr.merge([MM, Prior_smth]) -MM.to_netcdf(save_path + save_name + "_marginals.nc") - -try: - LL = pd.concat(L_collect) - MT.save_pandas_table({"L_sample": LL}, save_name + "_res_table", save_path) -except Exception as e: - print(f"This is a warning: {e}") - -# plot -font_for_print() -F = M.figure_axis_xy(6, 5.5, view_scale=0.7, container=True) - -gs = GridSpec(4, 6, wspace=0.2, hspace=0.8) - -ax0 = F.fig.add_subplot(gs[0:2, -1]) -ax0.tick_params(labelleft=False) - -klims = 0, LL["K_prime"].max() * 1.2 - - -for g in MM.beam_group: - MMi = MM.sel(beam_group=g) - plt.plot( - MMi.weight.T, - MMi.k, - ".", - color=col_dict[str(g.data)], - markersize=3, - linewidth=0.8, - ) - -plt.xlabel("Power") -plt.ylim(klims) - -ax1 = F.fig.add_subplot(gs[0:2, 0:-1]) - -for g in MM.beam_group: - Li = LL.loc[str(g.data)] - - angle_list = np.array(Li["alpha"]) * 180 / np.pi - kk_list = np.array(Li["K_prime"]) - weight_list_i = np.array(Li["weight"]) - - plt.scatter( - angle_list, - kk_list, - s=(weight_list_i * 8e1) ** 2, - c=col_dict[str(g.data)], - label="mode " + str(g.data), - ) + marginal_stack.name = "marginals" + marginal_stack = marginal_stack.to_dataset() + marginal_stack["cost"] = (("k"), list(cost_stack.values())) + marginal_stack["weight"] = (("k"), weight_list) + group_name = str("group" + group[0].split("gt")[1].split("l")[0]) + marginal_stack.coords["beam_group"] = group_name + marginal_stack.coords["x"] = xi -dir_best[dir_best > 180] = dir_best[dir_best > 180] - 360 -plt.plot(dir_best, Pwavenumber, ".r", markersize=6) + Marginals[ikey] = marginal_stack.expand_dims(dim="x", axis=0).expand_dims( + dim="beam_group", axis=1 + ) + Marginals[ikey].coords["N_data"] = ( + ("x", "beam_group"), + np.expand_dims(np.expand_dims(N_data, 0), 1), + ) -dir_interp[dir_interp > 180] = dir_interp[dir_interp > 180] - 360 -plt.plot(dir_interp, Gk.k, "-", color="red", linewidth=0.3, zorder=11) + L_sample["cost"] = cost_stack + L_sample["weight"] = weight_list + L_collect[group_name, str(int(xi))] = L_sample + MM = xr.merge(Marginals.values()) + MM = xr.merge([MM, Prior_smth]) -plt.fill_betweenx( - Gk.k, - (dir_interp_smth - spread_smth) * 180 / np.pi, - (dir_interp_smth + spread_smth) * 180 / np.pi, - zorder=1, - color=color_schemes.green1, - alpha=0.2, -) -plt.plot( - dir_interp_smth * 180 / np.pi, Gk.k, ".", markersize=1, color=color_schemes.green1 -) + save_name = "B04_" + track_name + MM.to_netcdf(save_path / (save_name + "_marginals.nc")) -ax1.axvline(85, color="gray", linewidth=2) -ax1.axvline(-85, color="gray", linewidth=2) + try: + LL = pd.concat(L_collect) + MT.save_pandas_table( + {"L_sample": LL}, save_name + "_res_table", str(save_path) + ) # TODO: clean up save_pandas_table to use pathlib + except Exception as e: + print(f"This is a warning: {e}") + else: + # plotting with LL + font_for_print() + F = M.figure_axis_xy(6, 5.5, view_scale=0.7, container=True) + + gs = GridSpec(4, 6, wspace=0.2, hspace=0.8) + + ax0 = F.fig.add_subplot(gs[0:2, -1]) + ax0.tick_params(labelleft=False) + + klims = 0, LL["K_prime"].max() * 1.2 + + for g_ in MM.beam_group: + MMi = MM.sel(beam_group=g_) + plt.plot( + MMi.weight.T, + MMi.k, + ".", + color=col_dict[str(g_.data)], + markersize=3, + linewidth=0.8, + ) + plt.xlabel("Power") + plt.ylim(klims) -plt.legend() -plt.ylabel("wavenumber (deg)") -plt.xlabel("Angle (deg)") + ax1 = F.fig.add_subplot(gs[0:2, 0:-1]) + for g_ in MM.beam_group: + Li = LL.loc[str(g_.data)] -plt.ylim(klims) + angle_list = np.array(Li["alpha"]) * 180 / np.pi + kk_list = np.array(Li["K_prime"]) + weight_list_i = np.array(Li["weight"]) -prior_angle_str = str(np.round((prior_sel["alpha"][0]) * 180 / np.pi)) -plt.title(track_name + "\nprior=" + prior_angle_str + "deg", loc="left") + plt.scatter( + angle_list, + kk_list, + s=(weight_list_i * 8e1) ** 2, + c=col_dict[str(g_.data)], + label="mode " + str(g_.data), + ) -plt.xlim(min([-90, np.nanmin(dir_best)]), max([np.nanmax(dir_best), 90])) + dir_best[dir_best > 180] = dir_best[dir_best > 180] - 360 + plt.plot(dir_best, Pwavenumber, ".r", markersize=6) + dir_interp[dir_interp > 180] = dir_interp[dir_interp > 180] - 360 + plt.plot(dir_interp, Gk.k, "-", color="red", linewidth=0.3, zorder=11) -ax3 = F.fig.add_subplot(gs[2, 0:-1]) + plt.fill_betweenx( + Gk.k, + (dir_interp_smth - spread_smth) * 180 / np.pi, + (dir_interp_smth + spread_smth) * 180 / np.pi, + zorder=1, + color=color_schemes.green1, + alpha=0.2, + ) + plt.plot( + dir_interp_smth * 180 / np.pi, + Gk.k, + ".", + markersize=1, + color=color_schemes.green1, + ) -for g in MM.beam_group: - MMi = MM.sel(beam_group=g) - wegihted_margins = (MMi.marginals * MMi.weight).sum(["x", "k"]) / MMi.weight.sum( - ["x", "k"] - ) - plt.plot( - MMi.angle * 180 / np.pi, - wegihted_margins, - ".", - color=col_dict[str(g.data)], - markersize=2, - linewidth=0.8, - ) + ax1.axvline(85, color="gray", linewidth=2) + ax1.axvline(-85, color="gray", linewidth=2) -plt.ylabel("Density") -plt.title("weight margins", loc="left") + plt.legend() + plt.ylabel("wavenumber (deg)") + plt.xlabel("Angle (deg)") + + plt.ylim(klims) + + prior_angle_str = str(np.round((prior_sel["alpha"][0]) * 180 / np.pi)) + plt.title(track_name + "\nprior=" + prior_angle_str + "deg", loc="left") + + plt.xlim(min([-90, np.nanmin(dir_best)]), max([np.nanmax(dir_best), 90])) + + ax3 = F.fig.add_subplot(gs[2, 0:-1]) # can the assignment be removed? CP + + for g_ in MM.beam_group: + MMi = MM.sel(beam_group=g_) + weighted_margins = (MMi.marginals * MMi.weight).sum( + ["x", "k"] + ) / MMi.weight.sum(["x", "k"]) + plt.plot( + MMi.angle * 180 / np.pi, + weighted_margins, + ".", + color=col_dict[str(g_.data)], + markersize=2, + linewidth=0.8, + ) -plt.xlim(-90, 90) + plt.ylabel("Density") + plt.title("weight margins", loc="left") + + plt.xlim(-90, 90) + + ax3 = F.fig.add_subplot( + gs[-1, 0:-1] + ) # can the assignment be removed? Not used later. CP + + for g_ in MM.beam_group: + MMi = MM.sel(beam_group=g_) + weighted_margins = MMi.marginals.mean(["x", "k"]) + plt.plot( + MMi.angle * 180 / np.pi, + weighted_margins, + ".", + color=col_dict[str(g_.data)], + markersize=2, + linewidth=0.8, + ) + plt.ylabel("Density") + plt.xlabel("Angle (deg)") + plt.title("unweighted margins", loc="left") -ax3 = F.fig.add_subplot(gs[-1, 0:-1]) + plt.xlim(-90, 90) -for g in MM.beam_group: - MMi = MM.sel(beam_group=g) - wegihted_margins = MMi.marginals.mean(["x", "k"]) - plt.plot( - MMi.angle * 180 / np.pi, - wegihted_margins, - ".", - color=col_dict[str(g.data)], - markersize=2, - linewidth=0.8, - ) + F.save_pup(path=plot_path, name="B04_marginal_distributions") -plt.ylabel("Density") -plt.xlabel("Angle (deg)") -plt.title("unweighted margins", loc="left") + MT.json_save( + "B04_success", + plot_path, + {"time": "time.asctime( time.localtime(time.time()) )"}, + ) -plt.xlim(-90, 90) -F.save_pup(path=plot_path, name="B04_marginal_distributions") +make_b04_angle_app = makeapp(run_B04_angle, name="B04_angle") -MT.json_save( - "B04_success", plot_path, {"time": "time.asctime( time.localtime(time.time()) )"} -) +if __name__ == "__main__": + make_b04_angle_app() diff --git a/src/icesat2_tracks/analysis_db/B05_define_angle.py b/src/icesat2_tracks/analysis_db/B05_define_angle.py index f201c207..1ca15f87 100644 --- a/src/icesat2_tracks/analysis_db/B05_define_angle.py +++ b/src/icesat2_tracks/analysis_db/B05_define_angle.py @@ -1,18 +1,21 @@ +#!/usr/bin/env python """ -This file open a ICEsat2 track applied filters and corections and returns smoothed photon heights on a regular grid in an .nc file. +This file open a ICEsat2 track applied filters and corrections and returns smoothed photon heights on a regular grid in an .nc file. This is python 3 """ -import sys +from pathlib import Path +import matplotlib +from matplotlib import pyplot as plt +from typer import Option from icesat2_tracks.config.IceSAT2_startup import ( mconfig, color_schemes, - plt, - font_for_print + font_for_print, ) -import icesat2_tracks.ICEsat2_SI_tools.iotools as io +from icesat2_tracks.ICEsat2_SI_tools.iotools import init_from_input, ID_to_str import icesat2_tracks.ICEsat2_SI_tools.spectral_estimates as spec import xarray as xr @@ -21,48 +24,27 @@ import icesat2_tracks.ICEsat2_SI_tools.lanczos as lanczos import icesat2_tracks.local_modules.m_tools_ph3 as MT import icesat2_tracks.local_modules.m_general_ph3 as M - + from matplotlib.gridspec import GridSpec from scipy.ndimage import label -color_schemes.colormaps2(21) -col_dict = color_schemes.rels - -track_name, batch_key, test_flag = io.init_from_input(sys.argv) -hemis, batch = batch_key.split("_") - -ATlevel = "ATL03" -plot_path = ( - mconfig["paths"]["plot"] - + "/" - + hemis - + "/" - + batch_key - + "/" - + track_name - + "/B05_angle/" +from icesat2_tracks.clitools import ( + validate_batch_key, + validate_output_dir, + suppress_stdout, + report_input_parameters, + validate_track_name_steps_gt_1, + makeapp, ) -MT.mkdirs_r(plot_path) - -all_beams = mconfig["beams"]["all_beams"] -high_beams = mconfig["beams"]["high_beams"] -low_beams = mconfig["beams"]["low_beams"] -beam_groups = mconfig["beams"]["groups"] -group_names = mconfig["beams"]["group_names"] - -load_path = mconfig["paths"]["work"] + batch_key + "/B02_spectra/" -Gk = xr.load_dataset(load_path + "/B02_" + track_name + "_gFT_k.nc") # - -load_path = mconfig["paths"]["work"] + batch_key + "/B04_angle/" -Marginals = xr.load_dataset(load_path + "/B04_" + track_name + "_marginals.nc") # - -load_path = mconfig["paths"]["work"] + batch_key + "/A02_prior/" -Prior = MT.load_pandas_table_dict("/A02_" + track_name, load_path)["priors_hindcast"] -save_path = mconfig["paths"]["work"] + batch_key + "/B04_angle/" +matplotlib.use("Agg") +color_schemes.colormaps2(21) def derive_weights(weights): + """ + Normalize weights to have a minimum of 0 + """ weights = (weights - weights.mean()) / weights.std() weights = weights - weights.min() return weights @@ -83,7 +65,9 @@ def weighted_means(data, weights, x_angle, color="k"): k = wi.k.data data_k = data.sel(k=k).squeeze() data_weight = data_k * wi - plt.stairs(data_weight.sum("k") / weight_norm, x_angle, linewidth=1, color="k") + plt.stairs( + data_weight.sum("k") / weight_norm, x_angle, linewidth=1, color=color + ) if data_k.k.size > 1: for k in data_k.k.data: plt.stairs( @@ -96,191 +80,35 @@ def weighted_means(data, weights, x_angle, color="k"): return data_weighted_mean -# cut out data at the boundary and redistibute variance -angle_mask = Marginals.angle * 0 == 0 -angle_mask[0], angle_mask[-1] = False, False -corrected_marginals = ( - Marginals.marginals.isel(angle=angle_mask) - + Marginals.marginals.isel(angle=~angle_mask).sum("angle") / sum(angle_mask).data -) - -# get groupweights -# ----------------- thius does not work jet.ckeck with data on server how to get number of data points per stancil -# Gx['x'] = Gx.x - Gx.x[0] - -# makde dummy variables -M_final = xr.full_like( - corrected_marginals.isel(k=0, beam_group=0).drop_vars("beam_group").drop_vars("k"), np.nan -) -M_final_smth = xr.full_like( - corrected_marginals.isel(k=0, beam_group=0).drop_vars("beam_group").drop_vars("k"), np.nan -) -if M_final.shape[0] > M_final.shape[1]: - M_final = M_final.T - M_final_smth = M_final_smth.T - corrected_marginals = corrected_marginals.T - -Gweights = corrected_marginals.N_data -Gweights = Gweights / Gweights.max() - -k_mask = corrected_marginals.mean("beam_group").mean("angle") - -xticks_2pi = np.arange(-np.pi, np.pi + np.pi / 4, np.pi / 4) -xtick_labels_2pi = [ - "-$\pi$", - "-$3\pi/4$", - "-$\pi/2$", - "-$\pi/4$", - "0", - "$\pi/4$", - "$\pi/2$", - "$3\pi/4$", - "$\pi$", -] - -xticks_pi = np.arange(-np.pi / 2, np.pi / 2 + np.pi / 4, np.pi / 4) -xtick_labels_pi = [ - "-$\pi/2$", - "-$\pi/4$", - "0", - "$\pi/4$", - "$\pi/2$", -] - - -font_for_print() -x_list = corrected_marginals.x -for xi in range(x_list.size): - F = M.figure_axis_xy(7, 3.5, view_scale=0.8, container=True) - gs = GridSpec(3, 2, wspace=0.1, hspace=0.8) - x_str = str(int(x_list[xi] / 1e3)) - - plt.suptitle( - "Weighted marginal PDFs\nx=" + x_str + "\n" + io.ID_to_str(track_name), - y=1.05, - x=0.125, - horizontalalignment="left", - ) - group_weight = Gweights.isel(x=xi) - - ax_list = dict() - ax_sum = F.fig.add_subplot(gs[1, 1]) - - ax_list["sum"] = ax_sum - - data_collect = dict() - for group, gpos in zip(Marginals.beam_group.data, [gs[0, 0], gs[0, 1], gs[1, 0]]): - ax0 = F.fig.add_subplot(gpos) - ax0.tick_params(labelbottom=False) - ax_list[group] = ax0 - - data = corrected_marginals.isel(x=xi).sel(beam_group=group) - weights = derive_weights(Marginals.weight.isel(x=xi).sel(beam_group=group)) - weights = weights**2 - - # derive angle axis - x_angle = data.angle.data - d_angle = np.diff(x_angle)[0] - x_angle = np.insert(x_angle, x_angle.size, x_angle[-1].data + d_angle) - - if ((~np.isnan(data)).sum().data == 0) | ((~np.isnan(weights)).sum().data == 0): - data_wmean = data.mean("k") - else: - data_wmean = weighted_means(data, weights, x_angle, color=col_dict[group]) - plt.stairs(data_wmean, x_angle, color=col_dict[group], alpha=1) - - plt.title("Marginal PDF " + group, loc="left") - plt.sca(ax_sum) - - data_collect[group] = data_wmean - - data_collect = xr.concat(data_collect.values(), dim="beam_group") - final_data = (group_weight * data_collect).sum("beam_group") / group_weight.sum( - "beam_group" - ).data - - plt.sca(ax_sum) - plt.stairs(final_data, x_angle, color="k", alpha=1, linewidth=0.8) - ax_sum.set_xlabel("Angle (rad)") - plt.title("Weighted mean over group & wavenumber", loc="left") - - # get relevant priors - for axx in ax_list.values(): - axx.set_ylim(0, final_data.max() * 1.5) - axx.set_xticks(xticks_pi) - axx.set_xticklabels(xtick_labels_pi) - - try: - ax_list["group3"].set_ylabel("PDF") - ax_list["group1"].set_ylabel("PDF") - ax_list["group3"].tick_params(labelbottom=True) - ax_list["group3"].set_xlabel("Angle (rad)") - except: - pass +def convert_to_kilo_string(x): + return str(int(x / 1e3)) - ax_final = F.fig.add_subplot(gs[-1, :]) - plt.title("Final angle PDF", loc="left") - priors_k = Marginals.Prior_direction[~np.isnan(k_mask.isel(x=xi))] - for pk in priors_k: - ax_final.axvline(pk, color=color_schemes.cascade2, linewidth=1, alpha=0.7) - - plt.stairs(final_data, x_angle, color="k", alpha=0.5, linewidth=0.8) - - final_data_smth = lanczos.lanczos_filter_1d(x_angle, final_data, 0.1) - - plt.plot(x_angle[0:-1], final_data_smth, color="black", linewidth=0.8) - - ax_final.axvline( - x_angle[0:-1][final_data_smth.argmax()], - color=color_schemes.orange, - linewidth=1.5, - alpha=1, - zorder=1, - ) - ax_final.axvline( - x_angle[0:-1][final_data_smth.argmax()], - color=color_schemes.black, - linewidth=3.2, - alpha=1, - zorder=0, - ) +def update_axes(ax, x_ticks, x_tick_labels, xlims): + ax.set_xticks(x_ticks) + ax.set_xticklabels(x_tick_labels) + ax.set_xlim(xlims) - plt.xlabel("Angle (rad)") - plt.xlim(-np.pi * 0.8, np.pi * 0.8) - ax_final.set_xticks(xticks_pi) - ax_final.set_xticklabels(xtick_labels_pi) +def get_first_and_last_nonzero_data(dir_data): + nonzero_data = dir_data.k[(dir_data.sum("angle") != 0)] + return (nonzero_data[0].data, nonzero_data[-1].data) - M_final[xi, :] = final_data - M_final_smth[xi, :] = final_data_smth - F.save_pup(path=plot_path, name="B05_weigthed_margnials_x" + x_str) - - -M_final.name = "weighted_angle_PDF" -M_final_smth.name = "weighted_angle_PDF_smth" -Gpdf = xr.merge([M_final, M_final_smth]) - -if len(Gpdf.x) < 2: - print("not enough x data, exit") - MT.json_save( - "B05_fail", - plot_path + "../", - { - "time": time.asctime(time.localtime(time.time())), - "reason": "not enough x segments", - }, - ) - print("exit()") - exit() +def build_plot_data(dir_data, i_spec, lims): + plot_data = dir_data * i_spec.mean("x") + plot_data = plot_data.rolling(angle=5, k=10).median() + plot_data = plot_data.sel(k=slice(lims[0], lims[-1])) + return plot_data -class plot_polarspectra(object): - def __init__(self, k, thetas, data, data_type="fraction", lims=None, verbose=False): +class PlotPolarSpectra: + def __init__( + self, k, thetas, data, dir_data, data_type="fraction", lims=None, verbose=False + ): """ data_type either 'fraction' or 'energy', default (fraction) - lims (None) limts of k. if None set by the limits of the vector k + lims (None) limits of k. if None set by the limits of the vector k """ self.k = k self.data = data @@ -306,15 +134,15 @@ def __init__(self, k, thetas, data, data_type="fraction", lims=None, verbose=Fal self.clevs = np.linspace(self.min + self.min * 0.05, self.max * 0.60, 21) def linear(self, radial_axis="period", ax=None, cbar_flag=True): - """ """ + """ + TODO: add docstring + """ if ax is None: ax = plt.subplot(111, polar=True) - else: - ax = ax + ax.set_theta_direction(-1) ax.set_theta_zero_location("W") - - grid = ax.grid(color="k", alpha=0.5, linestyle="-", linewidth=0.5) + ax.grid(color="k", alpha=0.5, linestyle="-", linewidth=0.5) if self.data_type == "fraction": cm = plt.cm.RdYlBu_r @@ -327,7 +155,7 @@ def linear(self, radial_axis="period", ax=None, cbar_flag=True): cm.set_bad = "w" colorax = ax.contourf( self.thetas, self.k, self.data, self.clevs, cmap=cm, zorder=1 - ) # , vmin=self.ctrs_min) + ) if cbar_flag: cbar = plt.colorbar( @@ -336,7 +164,7 @@ def linear(self, radial_axis="period", ax=None, cbar_flag=True): cbar.ax.get_yaxis().labelpad = 30 cbar.outline.set_visible(False) clev_tick_names, clev_ticks = MT.tick_formatter( - FP.clevs, expt_flag=False, shift=0, rounder=4, interval=1 + self.clevs, expt_flag=False, shift=0, rounder=4, interval=1 ) cbar.set_ticks(clev_ticks[::5]) cbar.set_ticklabels(clev_tick_names[::5]) @@ -349,7 +177,7 @@ def linear(self, radial_axis="period", ax=None, cbar_flag=True): xx_tick_names, xx_ticks = MT.tick_formatter( radial_ticks, expt_flag=False, shift=1, rounder=0, interval=1 ) - xx_tick_names = [" " + str(d) + "m" for d in xx_tick_names] + xx_tick_names = [f" {d}m" for d in xx_tick_names] ax.set_yticks(xx_ticks[::1]) ax.set_yticklabels(xx_tick_names[::1]) @@ -361,7 +189,9 @@ def linear(self, radial_axis="period", ax=None, cbar_flag=True): degrange_label[degrange_label > 180] - 360 ) - degrange_label = [str(d) + "$^{\circ}$" for d in degrange_label] + degrange_label = [ + str(d) + "$^{\circ}$" for d in degrange_label + ] # TODO: maybe replace the latex with "°"? lines, labels = plt.thetagrids(degrange, labels=degrange_label) @@ -374,143 +204,372 @@ def linear(self, radial_axis="period", ax=None, cbar_flag=True): self.ax = ax -font_for_print() -F = M.figure_axis_xy(6, 5.5, view_scale=0.7, container=True) -gs = GridSpec(8, 6, wspace=0.1, hspace=3.1) -color_schemes.colormaps2(21) +def define_angle( + track_name: str = Option(..., callback=validate_track_name_steps_gt_1), + batch_key: str = Option(..., callback=validate_batch_key), + ID_flag: bool = True, + output_dir: str = Option(..., callback=validate_output_dir), + verbose: bool = False, +): + """ + TODO: add docstring + """ -cmap_spec = plt.cm.ocean_r -clev_spec = np.linspace(-8, -1, 21) * 10 + track_name, batch_key, _ = init_from_input( + [ + None, + track_name, + batch_key, + ID_flag, + ] + ) -cmap_angle = color_schemes.cascade_r -clev_angle = np.linspace(0, 4, 21) + hemis, _ = batch_key.split("_") + plotsdir = Path(output_dir, mconfig["paths"]["plot"]) + workdir = Path(output_dir, mconfig["paths"]["work"]) -ax1 = F.fig.add_subplot(gs[0:3, :]) -ax1.tick_params(labelbottom=False) + kwargs = { + "track_name": track_name, + "batch_key": batch_key, + "ID_flag": ID_flag, + "output_dir": output_dir, + } + report_input_parameters(**kwargs) -weighted_spec = (Gk.gFT_PSD_data * Gk.N_per_stancil).sum("beam") / Gk.N_per_stancil.sum( - "beam" -) -x_spec = weighted_spec.x / 1e3 -k = weighted_spec.k - -xlims = x_spec[0], x_spec[-1] -clev_spec = np.linspace(-80, (10 * np.log(weighted_spec)).max() * 0.9, 21) - -plt.pcolor( - x_spec, - k, - 10 * np.log(weighted_spec), - vmin=clev_spec[0], - vmax=clev_spec[-1], - cmap=cmap_spec, -) + with suppress_stdout(verbose): + plot_path = plotsdir / hemis / batch_key / track_name / "B05_angle" + plot_path.mkdir(parents=True, exist_ok=True) + # Load Gk + load_path = workdir / batch_key / "B02_spectra" + Gk = xr.load_dataset(load_path / ("B02_" + track_name + "_gFT_k.nc")) -plt.title(track_name + "\nPower Spectra (m/m)$^2$ k$^{-1}$", loc="left") + # Load Marginals + load_path = workdir / batch_key / "B04_angle" + Marginals = xr.load_dataset(load_path / ("B04_" + track_name + "_marginals.nc")) -cbar = plt.colorbar(fraction=0.018, pad=0.01, orientation="vertical", label="Power") -cbar.outline.set_visible(False) -clev_ticks = np.round(clev_spec[::3], 0) -cbar.set_ticks(clev_ticks) -cbar.set_ticklabels(clev_ticks) + save_path = workdir / batch_key / "B04_angle" -plt.ylabel("wavenumber $k$") + # cut out data at the boundary and redistribute variance + angle_mask = Marginals.angle * 0 == 0 + angle_mask[0], angle_mask[-1] = False, False + corrected_marginals = ( + Marginals.marginals.isel(angle=angle_mask) + + Marginals.marginals.isel(angle=~angle_mask).sum("angle") + / sum(angle_mask).data + ) -ax2 = F.fig.add_subplot(gs[3:5, :]) -ax2.tick_params(labelleft=True) + # get group weights + # ----------------- this does not work yet. Check with data on server how to get number of data points per stancil + # Gx['x'] = Gx.x - Gx.x[0] -dir_data = Gpdf.interp(x=weighted_spec.x).weighted_angle_PDF_smth.T + # make dummy variables + M_final = xr.full_like( + corrected_marginals.isel(k=0, beam_group=0) + .drop_vars("beam_group") + .drop_vars("k"), + np.nan, + ) + M_final_smth = xr.full_like( + corrected_marginals.isel(k=0, beam_group=0) + .drop_vars("beam_group") + .drop_vars("k"), + np.nan, + ) + if M_final.shape[0] > M_final.shape[1]: + M_final = M_final.T + M_final_smth = M_final_smth.T + corrected_marginals = corrected_marginals.T + + Gweights = corrected_marginals.N_data + Gweights = Gweights / Gweights.max() + + k_mask = corrected_marginals.mean("beam_group").mean("angle") + + xticks_pi = np.arange(-np.pi / 2, np.pi / 2 + np.pi / 4, np.pi / 4) + xtick_labels_pi = [ + "-$\pi/2$", + "-$\pi/4$", + "0", + "$\pi/4$", + "$\pi/2$", + ] + + font_for_print() + col_dict = color_schemes.rels + x_list = corrected_marginals.x + for xi, xval in enumerate(x_list): + F = M.figure_axis_xy(7, 3.5, view_scale=0.8, container=True) + gs = GridSpec(3, 2, wspace=0.1, hspace=0.8) + x_str = convert_to_kilo_string(xval) + + plt.suptitle( + f"Weighted marginal PDFs\nx={x_str}\n{ID_to_str(track_name)}", + y=1.05, + x=0.125, + horizontalalignment="left", + ) + group_weight = Gweights.isel(x=xi) -x = Gpdf.x / 1e3 -angle = Gpdf.angle -plt.pcolor( - x_spec, angle, dir_data, vmin=clev_angle[0], vmax=clev_angle[-1], cmap=cmap_angle -) + ax_list = dict() + ax_sum = F.fig.add_subplot(gs[1, 1]) -cbar = plt.colorbar(fraction=0.01, pad=0.01, orientation="vertical", label="Density") -plt.title("Direction PDF", loc="left") + ax_list["sum"] = ax_sum -plt.xlabel("x (km)") -plt.ylabel("angle") + data_collect = dict() + for group, gpos in zip( + Marginals.beam_group.data, [gs[0, 0], gs[0, 1], gs[1, 0]] + ): + ax0 = F.fig.add_subplot(gpos) + ax0.tick_params(labelbottom=False) + ax_list[group] = ax0 -ax2.set_yticks(xticks_pi) -ax2.set_yticklabels(xtick_labels_pi) + data = corrected_marginals.isel(x=xi).sel(beam_group=group) + weights = derive_weights( + Marginals.weight.isel(x=xi).sel(beam_group=group) + ) + weights = weights**2 + + # derive angle axis + x_angle = data.angle.data + d_angle = np.diff(x_angle)[0] + x_angle = np.insert(x_angle, x_angle.size, x_angle[-1].data + d_angle) + + if any(np.all(np.isnan(x)) for x in [data, weights]): + data_wmean = data.mean("k") + else: + data_wmean = weighted_means( + data, weights, x_angle, color=col_dict[group] + ) + plt.stairs(data_wmean, x_angle, color=col_dict[group], alpha=1) + + plt.title(f"Marginal PDF {group}", loc="left") + plt.sca(ax_sum) + + data_collect[group] = data_wmean + + data_collect = xr.concat(data_collect.values(), dim="beam_group") + final_data = (group_weight * data_collect).sum( + "beam_group" + ) / group_weight.sum("beam_group").data + + plt.sca(ax_sum) + plt.stairs(final_data, x_angle, color="k", alpha=1, linewidth=0.8) + ax_sum.set_xlabel("Angle (rad)") + plt.title("Weighted mean over group & wavenumber", loc="left") + + # get relevant priors + for axx in ax_list.values(): + axx.set_ylim(0, final_data.max() * 1.5) + axx.set_xticks(xticks_pi) + axx.set_xticklabels(xtick_labels_pi) + + for key in ["group3", "group1"]: + if key in ax_list: + ax_list[key].set_ylabel("PDF") + if key == "group3": + ax_list[key].tick_params(labelbottom=True) + ax_list[key].set_xlabel("Angle (rad)") + else: + print(f"Key {key} not found in ax_list") + + ax_final = F.fig.add_subplot(gs[-1, :]) + plt.title("Final angle PDF", loc="left") + + priors_k = Marginals.Prior_direction[~np.isnan(k_mask.isel(x=xi))] + for pk in priors_k: + ax_final.axvline( + pk, color=color_schemes.cascade2, linewidth=1, alpha=0.7 + ) + plt.stairs(final_data, x_angle, color="k", alpha=0.5, linewidth=0.8) -x_ticks = np.arange(0, xlims[-1].data, 50) -x_tick_labels, x_ticks = MT.tick_formatter( - x_ticks, expt_flag=False, shift=0, rounder=1, interval=2 -) + final_data_smth = lanczos.lanczos_filter_1d(x_angle, final_data, 0.1) -ax1.set_xticks(x_ticks) -ax2.set_xticks(x_ticks) -ax1.set_xticklabels(x_tick_labels) -ax2.set_xticklabels(x_tick_labels) -ax1.set_xlim(xlims) -ax2.set_xlim(xlims) + plt.plot(x_angle[0:-1], final_data_smth, color="black", linewidth=0.8) + ax_final.axvline( + x_angle[0:-1][final_data_smth.argmax()], + color=color_schemes.orange, + linewidth=1.5, + alpha=1, + zorder=1, + ) + ax_final.axvline( + x_angle[0:-1][final_data_smth.argmax()], + color=color_schemes.black, + linewidth=3.2, + alpha=1, + zorder=0, + ) -xx_list = np.insert(corrected_marginals.x.data, 0, 0) -x_chunks = spec.create_chunk_boundaries( - int(xx_list.size / 3), xx_list.size, iter_flag=False -) -x_chunks = x_chunks[:, ::2] -x_chunks[-1, -1] = xx_list.size - 1 + plt.xlabel("Angle (rad)") + plt.xlim(-np.pi * 0.8, np.pi * 0.8) + ax_final.set_xticks(xticks_pi) + ax_final.set_xticklabels(xtick_labels_pi) -for x_pos, gs in zip(x_chunks.T, [gs[-3:, 0:2], gs[-3:, 2:4], gs[-3:, 4:]]): - x_range = xx_list[[x_pos[0], x_pos[-1]]] + M_final[xi, :] = final_data + M_final_smth[xi, :] = final_data_smth - ax1.axvline(x_range[0] / 1e3, linestyle=":", color="white", alpha=0.5) - ax1.axvline(x_range[-1] / 1e3, color="gray", alpha=0.5) + F.save_pup(path=plot_path, name="B05_weighted_marginals_x" + x_str) - ax2.axvline(x_range[0] / 1e3, linestyle=":", color="white", alpha=0.5) - ax2.axvline(x_range[-1] / 1e3, color="gray", alpha=0.5) + M_final.name = "weighted_angle_PDF" + M_final_smth.name = "weighted_angle_PDF_smth" + Gpdf = xr.merge([M_final, M_final_smth]) - i_spec = weighted_spec.sel(x=slice(x_range[0], x_range[-1])) - i_dir = corrected_marginals.sel(x=slice(x_range[0], x_range[-1])) + if len(Gpdf.x) < 2: + print("not enough x data, exit") + MT.json_save( + "B05_fail", + plot_path.parent, + { + "time": time.asctime(time.localtime(time.time())), + "reason": "not enough x segments", + }, + ) + print("exit()") + exit() + + font_for_print() + F = M.figure_axis_xy(6, 5.5, view_scale=0.7, container=True) + gs = GridSpec(8, 6, wspace=0.1, hspace=3.1) + color_schemes.colormaps2(21) + + cmap_spec = plt.cm.ocean_r + clev_spec = np.linspace(-8, -1, 21) * 10 + + cmap_angle = color_schemes.cascade_r + clev_angle = np.linspace(0, 4, 21) + + ax1 = F.fig.add_subplot(gs[0:3, :]) + ax1.tick_params(labelbottom=False) + + weighted_spec = (Gk.gFT_PSD_data * Gk.N_per_stancil).sum( + "beam" + ) / Gk.N_per_stancil.sum("beam") + x_spec = weighted_spec.x / 1e3 + k = weighted_spec.k + + xlims = x_spec[0], x_spec[-1] + clev_spec = np.linspace(-80, (10 * np.log(weighted_spec)).max() * 0.9, 21) + + plt.pcolor( + x_spec, + k, + 10 * np.log(weighted_spec), + vmin=clev_spec[0], + vmax=clev_spec[-1], + cmap=cmap_spec, + ) - dir_data = (i_dir * i_dir.N_data).sum(["beam_group", "x"]) / i_dir.N_data.sum( - ["beam_group", "x"] - ) - lims = ( - dir_data.k[(dir_data.sum("angle") != 0)][0].data, - dir_data.k[(dir_data.sum("angle") != 0)][-1].data, - ) + _title = f"{track_name}\nPower Spectra (m/m)$^2$ k$^{{-1}}$" + plt.title(_title, loc="left") + + cbar = plt.colorbar( + fraction=0.018, pad=0.01, orientation="vertical", label="Power" + ) + cbar.outline.set_visible(False) + clev_ticks = np.round(clev_spec[::3], 0) + cbar.set_ticks(clev_ticks) + cbar.set_ticklabels(clev_ticks) + + plt.ylabel("wavenumber $k$") + + ax2 = F.fig.add_subplot(gs[3:5, :]) + ax2.tick_params(labelleft=True) + + dir_data = Gpdf.interp(x=weighted_spec.x).weighted_angle_PDF_smth.T + + angle = Gpdf.angle + plt.pcolor( + x_spec, + angle, + dir_data, + vmin=clev_angle[0], + vmax=clev_angle[-1], + cmap=cmap_angle, + ) + + cbar = plt.colorbar( + fraction=0.01, pad=0.01, orientation="vertical", label="Density" + ) + plt.title("Direction PDF", loc="left") - N_angle = i_dir.angle.size - dir_data2 = dir_data + plt.xlabel("x (km)") + plt.ylabel("angle") - plot_data = dir_data2 * i_spec.mean("x") - plot_data = plot_data.rolling(angle=5, k=10).median() + ax2.set_yticks(xticks_pi) + ax2.set_yticklabels(xtick_labels_pi) - plot_data = plot_data.sel(k=slice(lims[0], lims[-1])) - xx = 2 * np.pi / plot_data.k - - if np.nanmax(plot_data.data) != np.nanmin(plot_data.data): - ax3 = F.fig.add_subplot(gs, polar=True) - FP = plot_polarspectra( - xx, - plot_data.angle, - plot_data, - lims=None, - verbose=False, - data_type="fraction", + x_ticks = np.arange(0, xlims[-1].data, 50) + x_tick_labels, x_ticks = MT.tick_formatter( + x_ticks, expt_flag=False, shift=0, rounder=1, interval=2 ) - FP.clevs = np.linspace( - np.nanpercentile(plot_data.data, 1), np.round(plot_data.max(), 4), 21 + + for ax in [ax1, ax2]: + update_axes(ax, x_ticks, x_tick_labels, xlims) + + xx_list = np.insert(corrected_marginals.x.data, 0, 0) + x_chunks = spec.create_chunk_boundaries( + int(xx_list.size / 3), xx_list.size, iter_flag=False ) - FP.linear(ax=ax3, cbar_flag=False) + x_chunks = x_chunks[:, ::2] + x_chunks[-1, -1] = xx_list.size - 1 -F.save_pup(path=plot_path + "../", name="B05_dir_ov") + for x_pos, gs in zip(x_chunks.T, [gs[-3:, 0:2], gs[-3:, 2:4], gs[-3:, 4:]]): + x_range = xx_list[[x_pos[0], x_pos[-1]]] -# save data -Gpdf.to_netcdf(save_path + "/B05_" + track_name + "_angle_pdf.nc") + ax1.axvline(x_range[0] / 1e3, linestyle=":", color="white", alpha=0.5) + ax1.axvline(x_range[-1] / 1e3, color="gray", alpha=0.5) -MT.json_save( - "B05_success", - plot_path + "../", - {"time": time.asctime(time.localtime(time.time()))}, -) + ax2.axvline(x_range[0] / 1e3, linestyle=":", color="white", alpha=0.5) + ax2.axvline(x_range[-1] / 1e3, color="gray", alpha=0.5) + + i_spec = weighted_spec.sel(x=slice(x_range[0], x_range[-1])) + i_dir = corrected_marginals.sel(x=slice(x_range[0], x_range[-1])) + + dir_data = (i_dir * i_dir.N_data).sum( + ["beam_group", "x"] + ) / i_dir.N_data.sum(["beam_group", "x"]) + lims = get_first_and_last_nonzero_data(dir_data) + + plot_data = build_plot_data(dir_data, i_spec, lims) + + xx = 2 * np.pi / plot_data.k + + if np.nanmax(plot_data.data) != np.nanmin(plot_data.data): + ax3 = F.fig.add_subplot(gs, polar=True) + FP = PlotPolarSpectra( + xx, + plot_data.angle, + plot_data, + dir_data, + lims=None, + verbose=False, + data_type="fraction", + ) + FP.clevs = np.linspace( + np.nanpercentile(plot_data.data, 1), + np.round(plot_data.max(), 4), + 21, + ) + FP.linear(ax=ax3, cbar_flag=False) + + F.save_pup(path=plot_path.parent, name="B05_dir_ov") + + # save data + Gpdf.to_netcdf(save_path / ("B05_" + track_name + "_angle_pdf.nc")) + + MT.json_save( + "B05_success", + plot_path.parent, + {"time": time.asctime(time.localtime(time.time()))}, + ) + + +define_angle_app = makeapp(define_angle, name="B04_angle") + +if __name__ == "__main__": + define_angle_app() diff --git a/src/icesat2_tracks/analysis_db/B06_correct_separate_var.py b/src/icesat2_tracks/analysis_db/B06_correct_separate_var.py index 14d485df..710a8ba5 100644 --- a/src/icesat2_tracks/analysis_db/B06_correct_separate_var.py +++ b/src/icesat2_tracks/analysis_db/B06_correct_separate_var.py @@ -2,18 +2,12 @@ This file open a ICEsat2 track applied filters and corections and returns smoothed photon heights on a regular grid in an .nc file. This is python 3 """ -import os, sys -from icesat2_tracks.config.IceSAT2_startup import ( - mconfig, - color_schemes, - font_for_pres, - font_for_print, - lstrings, - fig_sizes -) +import os + import h5py +from pathlib import Path import icesat2_tracks.ICEsat2_SI_tools.iotools as io import icesat2_tracks.local_modules.m_tools_ph3 as MT from icesat2_tracks.local_modules import m_general_ph3 as M @@ -25,268 +19,76 @@ import numpy as np import matplotlib.pyplot as plt from matplotlib.gridspec import GridSpec +import piecewise_regression +import typer -xr.set_options(display_style="text") -ID_name, batch_key, test_flag = io.init_from_input(sys.argv) -hemis, batch = batch_key.split("_") - -all_beams = mconfig["beams"]["all_beams"] -high_beams = mconfig["beams"]["high_beams"] -low_beams = mconfig["beams"]["low_beams"] - -load_path_work = mconfig["paths"]["work"] + "/" + batch_key + "/" -B3_hdf5 = h5py.File( - load_path_work + "B01_regrid" + "/" + ID_name + "_B01_binned.h5", "r" -) - - -load_path_angle = mconfig["paths"]["work"] + "/" + batch_key + "/B04_angle/" - -B3 = dict() -for b in all_beams: - B3[b] = io.get_beam_hdf_store(B3_hdf5[b]) - -B3_hdf5.close() - -load_file = load_path_work + "/B02_spectra/" + "B02_" + ID_name # + '.nc' -Gk = xr.open_dataset(load_file + "_gFT_k.nc") -Gx = xr.open_dataset(load_file + "_gFT_x.nc") -Gfft = xr.open_dataset(load_file + "_FFT.nc") - -plot_path = ( - mconfig["paths"]["plot"] - + "/" - + hemis - + "/" - + batch_key - + "/" - + ID_name - + "/B06_correction/" +from icesat2_tracks.config.IceSAT2_startup import ( + mconfig, + color_schemes, + font_for_pres, + font_for_print, + lstrings, + fig_sizes, ) -MT.mkdirs_r(plot_path) - -save_path = mconfig["paths"]["work"] + batch_key + "/B06_corrected_separated/" -MT.mkdirs_r(save_path) - - -color_schemes.colormaps2(31, gamma=1) -col_dict = color_schemes.rels - - -def dict_weighted_mean(Gdict, weight_key): - """ - returns the weighted meean of a dict of xarray, data_arrays - weight_key must be in the xr.DataArrays - """ - - akey = list(Gdict.keys())[0] - GSUM = Gdict[akey].copy() - GSUM.data = np.zeros(GSUM.shape) - N_per_stancil = GSUM.N_per_stancil * 0 - N_photons = np.zeros(GSUM.N_per_stancil.size) - - counter = 0 - for _,I in Gdict.items(): - I = I.squeeze() - print(len(I.x)) - if len(I.x) != 0: - GSUM += I.where(~np.isnan(I), 0) * I[weight_key] - N_per_stancil += I[weight_key] - if "N_photons" in GSUM.coords: - N_photons += I["N_photons"] - counter += 1 - - GSUM = GSUM / N_per_stancil - - if "N_photons" in GSUM.coords: - GSUM.coords["N_photons"] = (("x", "beam"), np.expand_dims(N_photons, 1)) - - GSUM["beam"] = ["weighted_mean"] - GSUM.name = "power_spec" - - return GSUM - - -G_gFT_wmean = (Gk.where(~np.isnan(Gk["gFT_PSD_data"]), 0) * Gk["N_per_stancil"]).sum( - "beam" -) / Gk["N_per_stancil"].sum("beam") -G_gFT_wmean["N_photons"] = Gk["N_photons"].sum("beam") - -G_fft_wmean = (Gfft.where(~np.isnan(Gfft), 0) * Gfft["N_per_stancil"]).sum( - "beam" -) / Gfft["N_per_stancil"].sum("beam") -G_fft_wmean["N_per_stancil"] = Gfft["N_per_stancil"].sum("beam") - - -# plot -# derive spectral errors: -Lpoints = Gk.Lpoints.mean("beam").data -N_per_stancil = Gk.N_per_stancil.mean("beam").data # [0:-2] - -G_error_model = dict() -G_error_data = dict() - -for bb in Gk.beam.data: - I = Gk.sel(beam=bb) - b_bat_error = np.concatenate([I.model_error_k_cos.data, I.model_error_k_sin.data]) - Z_error = gFT.complex_represenation(b_bat_error, Gk.k.size, Lpoints) - PSD_error_data, PSD_error_model = gFT.Z_to_power_gFT( - Z_error, np.diff(Gk.k)[0], N_per_stancil, Lpoints - ) - - G_error_model[bb] = xr.DataArray( - data=PSD_error_model, - coords=I.drop("N_per_stancil").coords, - name="gFT_PSD_data_error", - ).expand_dims("beam") - G_error_data[bb] = xr.DataArray( - data=PSD_error_data, - coords=I.drop("N_per_stancil").coords, - name="gFT_PSD_data_error", - ).expand_dims("beam") - -gFT_PSD_data_error_mean = xr.concat(G_error_model.values(), dim="beam") -gFT_PSD_data_error_mean = xr.concat(G_error_data.values(), dim="beam") - -gFT_PSD_data_error_mean = ( - gFT_PSD_data_error_mean.where(~np.isnan(gFT_PSD_data_error_mean), 0) - * Gk["N_per_stancil"] -).sum("beam") / Gk["N_per_stancil"].sum("beam") -gFT_PSD_data_error_mean = ( - gFT_PSD_data_error_mean.where(~np.isnan(gFT_PSD_data_error_mean), 0) - * Gk["N_per_stancil"] -).sum("beam") / Gk["N_per_stancil"].sum("beam") - -G_gFT_wmean["gFT_PSD_data_err"] = gFT_PSD_data_error_mean -G_gFT_wmean["gFT_PSD_data_err"] = gFT_PSD_data_error_mean -Gk["gFT_PSD_data_err"] = xr.concat(G_error_model.values(), dim="beam") -Gk["gFT_PSD_data_err"] = xr.concat(G_error_data.values(), dim="beam") - - -# - -G_gFT_smth = ( - G_gFT_wmean["gFT_PSD_data"].rolling(k=30, center=True, min_periods=1).mean() +from icesat2_tracks.clitools import ( + echo, + validate_batch_key, + validate_output_dir, + suppress_stdout, + update_paths_mconfig, + report_input_parameters, + validate_track_name_steps_gt_1, + makeapp, ) -G_gFT_smth["N_photons"] = G_gFT_wmean.N_photons -G_gFT_smth["N_per_stancil_fraction"] = Gk["N_per_stancil"].T.mean( - "beam" -) / Gk.Lpoints.mean("beam") - -k = G_gFT_smth.k - -F = M.figure_axis_xy() - -plt.loglog(k, G_gFT_smth / k) - -plt.title("displacement power Spectra", loc="left") -def define_noise_wavenumber_tresh_simple( - data_xr, k_peak, k_end_lim=None, plot_flag=False -): - """ - returns noise wavenumber on the high end of a spectral peak. This method fits a straight line in loglog speace using robust regression. - The noise level is defined as the wavenumber at which the residual error of a linear fit to the data is minimal. - - inputs: - data_xr xarray.Dataarray with the power spectra with k as dimension - k_peak wavenumber above which the searh should start - dk the intervall over which the regrssion is repeated - - returns: - k_end the wavenumber at which the spectrum flattens - m slope of the fitted line - b intersect of the fitted line - """ - - if k_end_lim is None: - k_end_lim = data_xr.k[-1] - - k_lead_peak_margin = k_peak * 1.05 - try: - data_log = ( - np.log(data_xr) - .isel(k=(data_xr.k > k_lead_peak_margin)) - .rolling(k=10, center=True, min_periods=1) - .mean() - ) - - except: - data_log = ( - np.log(data_xr) - .isel(k=(data_xr.k > k_lead_peak_margin / 2)) - .rolling(k=10, center=True, min_periods=1) - .mean() - ) - - k_log = np.log(data_log.k) - try: - d_grad = ( - data_log.differentiate("k").rolling(k=40, center=True, min_periods=4).mean() - ) - except: - d_grad = ( - data_log.differentiate("k").rolling(k=20, center=True, min_periods=2).mean() - ) - ll = label(d_grad >= -5) - - if ll[0][0] != 0: - print("no decay, set to peak") - return k_peak +def get_correct_breakpoint(pw_results): + br_points = [i for i in pw_results.keys() if "breakpoint" in i] - if sum(ll[0]) == 0: - k_end = d_grad.k[-1] - else: - k_end = d_grad.k[(ll[0] == 1)][0].data + br_points_df = pw_results[br_points] + br_points_sorted = br_points_df.sort_values() - if plot_flag: - plt.plot(np.log(data_xr.k), np.log(data_xr)) - plt.plot(k_log, data_log) - plt.plot([np.log(k_end), np.log(k_end)], [-6, -5]) - return k_end + alphas_sorted = [] + betas_sorted = [] + for point in br_points_sorted.index: + alphas_sorted.append(point.replace("breakpoint", "alpha")) + betas_sorted.append(point.replace("breakpoint", "beta")) + alphas_sorted.append(f"alpha{len(alphas_sorted) + 1}") -# new version -def get_correct_breakpoint(pw_results): - br_points = list() - for i in pw_results.keys(): - [br_points.append(i) if "breakpoint" in i else None] - br_points_df = pw_results[br_points] - br_points_sorted = br_points_df.sort_values() + ## TODO: Camilo decided to leave this piece of code here in case the output data is not + ## the right one - alphas_sorted = [ - i.replace("breakpoint", "alpha") for i in br_points_df.sort_values().index - ] - alphas_sorted.append("alpha" + str(len(alphas_sorted) + 1)) + # alphas_sorted = [ + # point.replace("breakpoint", "alpha") for point in br_points_df.sort_values().index + # ] + # alphas_sorted.append(f"alpha{len(alphas_sorted) + 1}") - betas_sorted = [ - i.replace("breakpoint", "beta") for i in br_points_df.sort_values().index - ] + # betas_sorted = [ + # point.replace("breakpoint", "beta") for point in br_points_df.sort_values().index + # ] # betas_sorted alphas_v2 = list() alpha_i = pw_results["alpha1"] - for i in [0] + list(pw_results[betas_sorted]): - alpha_i += i - alphas_v2.append(alpha_i) + alphas_v2 = [alpha_i := alpha_i + i for i in [0] + list(pw_results[betas_sorted])] alphas_v2_sorted = pd.Series(index=alphas_sorted, data=alphas_v2) - br_points_sorted["breakpoint" + str(br_points_sorted.size + 1)] = "end" + br_points_sorted[f"breakpoint{br_points_sorted.size + 1}"] = "end" - print("all alphas") - print(alphas_v2_sorted) + echo("all alphas") + echo(alphas_v2_sorted) slope_mask = alphas_v2_sorted < 0 if sum(slope_mask) == 0: - print("no negative slope found, set to lowest") + echo("no negative slope found, set to lowest") breakpoint = "start" else: # take steepest slope alpah_v2_sub = alphas_v2_sorted[slope_mask] - print(alpah_v2_sub) - print(alpah_v2_sub.argmin()) + echo(alpah_v2_sub) + echo(alpah_v2_sub.argmin()) break_point_name = alpah_v2_sub.index[alpah_v2_sub.argmin()].replace( "alpha", "breakpoint" ) @@ -298,13 +100,11 @@ def get_correct_breakpoint(pw_results): def get_breakingpoints(xx, dd): - import piecewise_regression - x2, y2 = xx, dd convergence_flag = True n_breakpoints = 3 while convergence_flag: - pw_fit = piecewise_regression.Fit(x2, y2, n_breakpoints=n_breakpoints) + pw_fit = piecewise_regression.Fit(xx, dd, n_breakpoints=n_breakpoints) print("n_breakpoints", n_breakpoints, pw_fit.get_results()["converged"]) convergence_flag = not pw_fit.get_results()["converged"] n_breakpoints += 1 @@ -325,7 +125,7 @@ def get_breakingpoints(xx, dd): def define_noise_wavenumber_piecewise(data_xr, plot_flag=False): - data_log = data_xr + data_log = np.log(data_xr) k = data_log.k.data @@ -333,11 +133,11 @@ def define_noise_wavenumber_piecewise(data_xr, plot_flag=False): pw_fit, breakpoint_log = get_breakingpoints(k_log, data_log.data) - if breakpoint_log is "start": - print("no decay, set to lowerst wavenumber") + if breakpoint_log == "start": + echo("no decay, set to lowerst wavenumber") breakpoint_log = k_log[0] - if (breakpoint_log is "end") | (breakpoint_log is False): - print("higest wavenumner") + if (breakpoint_log == "end") | (breakpoint_log is False): + echo("higest wavenumner") breakpoint_log = k_log[-1] breakpoint_pos = abs(k_log - breakpoint_log).argmin() @@ -350,419 +150,633 @@ def define_noise_wavenumber_piecewise(data_xr, plot_flag=False): return breakpoint_k, pw_fit -k_lim_list = list() -k_end_previous = np.nan -x = G_gFT_smth.x.data[0] -k = G_gFT_smth.k.data +def weighted_mean(data, weights, additional_data=None): + # Where data is not NaN, replace NaN with 0 and multiply by weights + weighted_data = data.where(~np.isnan(data), 0) * weights + + # Calculate the sum of weighted data along the specified dimension + mean_data = weighted_data.sum("beam") / weights.sum("beam") -for x in G_gFT_smth.x.data: - print(x) - # use displacement power spectrum - k_end, pw_fit = define_noise_wavenumber_piecewise( + # Optionally, add additional data to the resulting array + if additional_data is not None: + mean_data[additional_data] = data[additional_data].sum("beam") + + return mean_data + + +def calculate_k_end(x, k, k_end_previous, G_gFT_smth): + echo(x) + k_end, _ = define_noise_wavenumber_piecewise( G_gFT_smth.sel(x=x) / k, plot_flag=False ) - k_save = k_end_previous if k_end == k[0] else k_end - k_end_previous = k_save - k_lim_list.append(k_save) - print("--------------------------") - -font_for_pres() -G_gFT_smth.coords["k_lim"] = ("x", k_lim_list) -G_gFT_smth.k_lim.plot() -k_lim_smth = G_gFT_smth.k_lim.rolling(x=3, center=True, min_periods=1).mean() -k_lim_smth.plot(c="r") - -plt.title("k_c filter", loc="left") -F.save_light(path=plot_path, name=str(ID_name) + "_B06_atten_ov") - -G_gFT_smth["k_lim"] = k_lim_smth -G_gFT_wmean.coords["k_lim"] = k_lim_smth - -font_for_print() - -fn = copy.copy(lstrings) -F = M.figure_axis_xy( - fig_sizes["two_column"][0], - fig_sizes["two_column"][0] * 0.9, - container=True, - view_scale=1, -) + echo("--------------------------") + return k_save -plt.suptitle( - "Cut-off Frequency for Displacement Spectral\n" + io.ID_to_str(ID_name), y=0.97 -) -gs = GridSpec(8, 3, wspace=0.1, hspace=1.5) - -k_lims = G_gFT_wmean.k_lim -xlims = G_gFT_wmean.k[0], G_gFT_wmean.k[-1] -# -k = high_beams[0] -for pos, k, pflag in zip( - [gs[0:2, 0], gs[0:2, 1], gs[0:2, 2]], high_beams, [True, False, False] -): - ax0 = F.fig.add_subplot(pos) - Gplot = ( - Gk.sel(beam=k) - .isel(x=slice(0, -1)) - .gFT_PSD_data.squeeze() - .rolling(k=20, x=2, min_periods=1, center=True) - .mean() - ) - Gplot = Gplot.where(Gplot["N_per_stancil"] / Gplot["Lpoints"] >= 0.1) - alpha_range = iter(np.linspace(1, 0, Gplot.x.data.size)) - for x in Gplot.x.data: - ialpha = next(alpha_range) - plt.loglog( - Gplot.k, - Gplot.sel(x=x) / Gplot.k, - linewidth=0.5, - color=color_schemes.rels[k], - alpha=ialpha, - ) - ax0.axvline( - k_lims.sel(x=x), linewidth=0.4, color="black", zorder=0, alpha=ialpha - ) +def tanh_fitler(x, x_cutoff, sigma_g=0.01): + """ + zdgfsg + """ + + decay = 0.5 - np.tanh((x - x_cutoff) / sigma_g) / 2 + return decay - plt.title(next(fn) + k, color=col_dict[k], loc="left") - plt.xlim(xlims) - # - if pflag: - ax0.tick_params(labelbottom=False, bottom=True) - plt.ylabel("Power (m$^2$/k')") - plt.legend() - else: - ax0.tick_params(labelbottom=False, bottom=True, labelleft=False) -for pos, k, pflag in zip( - [gs[2:4, 0], gs[2:4, 1], gs[2:4, 2]], low_beams, [True, False, False] +def get_k_x_corrected(Gk, theta=0, theta_flag=False): + + if not theta_flag: + return np.nan, np.nan + + lam_p = 2 * np.pi / Gk.k + lam = lam_p * np.cos(theta) + k_corrected = 2 * np.pi / lam + x_corrected = Gk.x * np.cos(theta) + + return k_corrected, x_corrected + + +### TODO: Fix the variables in this function. +## +# dx = Gx.eta.diff("eta").mean().dat +### +# def reconstruct_displacement(Gx_1, Gk_1, T3, k_thresh): +# """ +# reconstructs photon displacement heights for each stancil given the model parameters in Gk_1 +# A low-pass frequeny filter can be applied using k-thresh + +# inputs: +# Gk_1 model data per stencil from _gFT_k file with sin and cos coefficients +# Gx_1 real data per stencil from _gFT_x file with mean photon heights and coordindate systems +# T3 +# k_thresh (None) threshold for low-pass filter + +# returns: +# height_model reconstucted displements heights of the stancil +# poly_offset fitted staight line to the residual between observations and model to account for low-pass variability +# nan_mask mask where is observed data in +# """ + +# dist_stencil = Gx_1.eta + Gx_1.x + +# gFT_cos_coeff_sel = np.copy(Gk_1.gFT_cos_coeff) +# gFT_sin_coeff_sel = np.copy(Gk_1.gFT_sin_coeff) + +# gFT_cos_coeff_sel = gFT_cos_coeff_sel * tanh_fitler( +# Gk_1.k, k_thresh, sigma_g=0.003 +# ) +# gFT_sin_coeff_sel = gFT_sin_coeff_sel * tanh_fitler( +# Gk_1.k, k_thresh, sigma_g=0.003 +# ) + +# FT_int = gFT.generalized_Fourier(Gx_1.eta + Gx_1.x, None, Gk_1.k) +# _ = FT_int.get_H() +# FT_int.p_hat = np.concatenate( +# [-gFT_sin_coeff_sel / Gk_1.k, gFT_cos_coeff_sel / Gk_1.k] +# ) + +# dx = Gx.eta.diff("eta").mean().data +# height_model = FT_int.model() / dx +# dist_nanmask = np.isnan(Gx_1.y_data) +# height_data = np.interp( +# dist_stencil, T3_sel["dist"], T3_sel["heights_c_weighted_mean"] +# ) +# return height_model, np.nan, dist_nanmask + + +def save_table(data, tablename, save_path): + try: + io.save_pandas_table(data, tablename, save_path) + except Exception as e: + tabletoremove = save_path + tablename + ".h5" + echo(e, f"Failed to save table. Removing {tabletoremove} and re-trying..") + os.remove(tabletoremove) + io.save_pandas_table(data, tablename, save_path) + + +def buil_G_error(Gk_sel, PSD_list, list_name): + + return xr.DataArray( + data=PSD_list, coords=Gk_sel.drop("N_per_stancil").coords, name=list_name + ).expand_dims("beam") + + +def run_B06_correct_separate_var( + track_name: str = typer.Option(..., callback=validate_track_name_steps_gt_1), + batch_key: str = typer.Option(..., callback=validate_batch_key), + ID_flag: bool = True, + output_dir: str = typer.Option(..., callback=validate_output_dir), + verbose: bool = False, ): - ax0 = F.fig.add_subplot(pos) - Gplot = ( - Gk.sel(beam=k) - .isel(x=slice(0, -1)) - .gFT_PSD_data.squeeze() - .rolling(k=20, x=2, min_periods=1, center=True) - .mean() + + color_schemes.colormaps2(31, gamma=1) + col_dict = color_schemes.rels + + track_name, batch_key, test_flag = io.init_from_input( + [ + None, + track_name, + batch_key, + ID_flag, + ] # init_from_input expects sys.argv with 4 elements ) - Gplot = Gplot.where(Gplot["N_per_stancil"] / Gplot["Lpoints"] >= 0.1) - - alpha_range = iter(np.linspace(1, 0, Gplot.x.data.size)) - for x in Gplot.x.data: - ialpha = next(alpha_range) - plt.loglog( - Gplot.k, - Gplot.sel(x=x) / Gplot.k, - linewidth=0.5, - color=color_schemes.rels[k], - alpha=ialpha, - ) - ax0.axvline( - k_lims.sel(x=x), linewidth=0.4, color="black", zorder=0, alpha=ialpha - ) + kwargs = { + "track_name": track_name, + "batch_key": batch_key, + "ID_flag": ID_flag, + "output_dir": output_dir, + } - plt.title(next(fn) + k, color=col_dict[k], loc="left") - plt.xlim(xlims) - plt.xlabel("observed wavenumber k' ") + report_input_parameters(**kwargs) + with suppress_stdout(verbose): - if pflag: - ax0.tick_params(bottom=True) - plt.ylabel("Power (m$^2$/k')") - plt.legend() - else: - ax0.tick_params(bottom=True, labelleft=False) + hemis, _ = batch_key.split("_") -F.save_light(path=plot_path, name=str(ID_name) + "_B06_atten_ov_simple") -F.save_pup(path=plot_path, name=str(ID_name) + "_B06_atten_ov_simple") + workdir, plotsdir = update_paths_mconfig(output_dir, mconfig) -pos = gs[5:, 0:2] -ax0 = F.fig.add_subplot(pos) + xr.set_options(display_style="text") -lat_str = ( - str(np.round(Gx.isel(x=0).lat.mean().data, 2)) - + " to " - + str(np.round(Gx.isel(x=-1).lat.mean().data, 2)) -) -plt.title(next(fn) + "Mean Displacement Spectra\n(lat=" + lat_str + ")", loc="left") + all_beams = mconfig["beams"]["all_beams"] + high_beams = mconfig["beams"]["high_beams"] + low_beams = mconfig["beams"]["low_beams"] -dd = 10 * np.log((G_gFT_smth / G_gFT_smth.k).isel(x=slice(0, -1))) -dd = dd.where(~np.isinf(dd), np.nan) + load_path_work = workdir / batch_key -## filter out segments with less then 10% of data points -dd = dd.where(G_gFT_smth["N_per_stancil_fraction"] >= 0.1) + h5_file_path = load_path_work / "B01_regrid" / (track_name + "_B01_binned.h5") + with h5py.File(h5_file_path, "r") as B3_hdf5: + load_path_angle = load_path_work / "B04_angle" + B3 = {b: io.get_beam_hdf_store(B3_hdf5[b]) for b in all_beams} -dd_lims = np.round(dd.quantile(0.01).data * 0.95, 0), np.round( - dd.quantile(0.95).data * 1.05, 0 -) -plt.pcolor( - dd.x / 1e3, - dd.k, - dd, - vmin=dd_lims[0], - vmax=dd_lims[-1], - cmap=color_schemes.white_base_blgror, -) -cb = plt.colorbar(orientation="vertical") - -cb.set_label("Power (m$^2$/k)") -plt.plot( - G_gFT_smth.isel(x=slice(0, -1)).x / 1e3, - G_gFT_smth.isel(x=slice(0, -1)).k_lim, - color=color_schemes.black, - linewidth=1, -) -plt.ylabel("wavenumber k") -plt.xlabel("X (km)") - -pos = gs[6:, -1] -ax9 = F.fig.add_subplot(pos) - -plt.title("Data Coverage (%)", loc="left") -plt.plot( - G_gFT_smth.x / 1e3, - G_gFT_smth["N_per_stancil_fraction"] * 100, - linewidth=0.8, - color="black", -) -ax9.spines["left"].set_visible(False) -ax9.spines["right"].set_visible(True) -ax9.tick_params(labelright=True, right=True, labelleft=False, left=False) -ax9.axhline(10, linewidth=0.8, linestyle="--", color="black") -plt.xlabel("X (km)") + file_suffixes = ["_gFT_k.nc", "_gFT_x.nc"] + Gk, Gx = [ + xr.open_dataset( + load_path_work / "B02_spectra" / f"B02_{track_name}{suffix}" + ) + for suffix in file_suffixes + ] + plot_path = Path(plotsdir, hemis, batch_key, track_name, "B06_correction") + plot_path.mkdir(parents=True, exist_ok=True) -F.save_light(path=plot_path, name=str(ID_name) + "_B06_atten_ov") -F.save_pup(path=plot_path, name=str(ID_name) + "_B06_atten_ov") + save_path = Path(workdir, batch_key, "B06_corrected_separated") + save_path.mkdir(parents=True, exist_ok=True) + G_gFT_wmean = ( + Gk.where(~np.isnan(Gk["gFT_PSD_data"]), 0) * Gk["N_per_stancil"] + ).sum("beam") / Gk["N_per_stancil"].sum("beam") + G_gFT_wmean["N_photons"] = Gk["N_photons"].sum("beam") -# reconstruct slope displacement data -def fit_offset(x, data, model, nan_mask, deg): - p_offset = np.polyfit(x[~nan_mask], data[~nan_mask] - model[~nan_mask], deg) - p_offset[-1] = 0 - poly_offset = np.polyval(p_offset, x) - return poly_offset + # plot + # derive spectral errors: + Lpoints = Gk.Lpoints.mean("beam").data + N_per_stancil = Gk.N_per_stancil.mean("beam").data # [0:-2] + G_error_data = dict() -def tanh_fitler(x, x_cutoff, sigma_g=0.01): - """ - zdgfsg - """ + for bb in Gk.beam.data: + I = Gk.sel(beam=bb) + b_bat_error = np.concatenate( + [I.model_error_k_cos.data, I.model_error_k_sin.data] + ) + Z_error = gFT.complex_represenation(b_bat_error, Gk.k.size, Lpoints) + PSD_error_data, PSD_error_model = gFT.Z_to_power_gFT( + Z_error, np.diff(Gk.k)[0], N_per_stancil, Lpoints + ) - decay = 0.5 - np.tanh((x - x_cutoff) / sigma_g) / 2 - return decay + G_error_data[bb] = xr.DataArray( + data=PSD_error_data, + coords=I.drop("N_per_stancil").coords, + name="gFT_PSD_data_error", + ).expand_dims("beam") + gFT_PSD_data_error_mean = xr.concat(G_error_data.values(), dim="beam") -def reconstruct_displacement(Gx_1, Gk_1, T3, k_thresh): - """ - reconstructs photon displacement heights for each stancil given the model parameters in Gk_1 - A low-pass frequeny filter can be applied using k-thresh - - inputs: - Gk_1 model data per stencil from _gFT_k file with sin and cos coefficients - Gx_1 real data per stencil from _gFT_x file with mean photon heights and coordindate systems - T3 - k_thresh (None) threshold for low-pass filter - - returns: - height_model reconstucted displements heights of the stancil - poly_offset fitted staight line to the residual between observations and model to account for low-pass variability - nan_mask mask where is observed data in - """ + gFT_PSD_data_error_mean = ( + gFT_PSD_data_error_mean.where(~np.isnan(gFT_PSD_data_error_mean), 0) + * Gk["N_per_stancil"] + ).sum("beam") / Gk["N_per_stancil"].sum("beam") - dist_stencil = Gx_1.eta + Gx_1.x - dist_stencil_lims = dist_stencil[0].data, dist_stencil[-1].data + G_gFT_wmean["gFT_PSD_data_err"] = gFT_PSD_data_error_mean - gFT_cos_coeff_sel = np.copy(Gk_1.gFT_cos_coeff) - gFT_sin_coeff_sel = np.copy(Gk_1.gFT_sin_coeff) + Gk["gFT_PSD_data_err"] = xr.concat(G_error_data.values(), dim="beam") - gFT_cos_coeff_sel = gFT_cos_coeff_sel * tanh_fitler(Gk_1.k, k_thresh, sigma_g=0.003) - gFT_sin_coeff_sel = gFT_sin_coeff_sel * tanh_fitler(Gk_1.k, k_thresh, sigma_g=0.003) + G_gFT_smth = ( + G_gFT_wmean["gFT_PSD_data"].rolling(k=30, center=True, min_periods=1).mean() + ) + G_gFT_smth["N_photons"] = G_gFT_wmean.N_photons + G_gFT_smth["N_per_stancil_fraction"] = Gk["N_per_stancil"].T.mean( + "beam" + ) / Gk.Lpoints.mean("beam") - FT_int = gFT.generalized_Fourier(Gx_1.eta + Gx_1.x, None, Gk_1.k) - _ = FT_int.get_H() - FT_int.p_hat = np.concatenate( - [-gFT_sin_coeff_sel / Gk_1.k, gFT_cos_coeff_sel / Gk_1.k] - ) + k = G_gFT_smth.k - dx = Gx.eta.diff("eta").mean().data - height_model = FT_int.model() / dx - dist_nanmask = np.isnan(Gx_1.y_data) - height_data = np.interp( - dist_stencil, T3_sel["dist"], T3_sel["heights_c_weighted_mean"] - ) - return height_model, np.nan, dist_nanmask - - -# cutting Table data -G_height_model = dict() -k = "gt2l" -for bb in Gx.beam.data: - G_height_model_temp = dict() - for i in np.arange(Gx.x.size): - Gx_1 = Gx.isel(x=i).sel(beam=bb) - Gk_1 = Gk.isel(x=i).sel(beam=bb) - k_thresh = G_gFT_smth.k_lim.isel(x=0).data - - dist_stencil = Gx_1.eta + Gx_1.x - dist_stencil_lims = dist_stencil[0].data, dist_stencil[-1].data - dist_stencil_lims_plot = dist_stencil_lims - dist_stencil_lims_plot = Gx_1.eta[0] * 1 + Gx_1.x, Gx_1.eta[-1] * 1 + Gx_1.x - - T3_sel = B3[k].loc[ - ( - (B3[k]["dist"] >= dist_stencil_lims[0]) - & (B3[k]["dist"] <= dist_stencil_lims[1]) - ) + F = M.figure_axis_xy() + + plt.loglog(k, G_gFT_smth / k) + + plt.title("displacement power Spectra", loc="left") + + # new version + k_lim_list = list() + k_end_previous = np.nan + x = G_gFT_smth.x.data[0] + k = G_gFT_smth.k.data + k_end_previous = np.nan + + k_lim_list = [ + calculate_k_end(x, k, k_end_previous, G_gFT_smth) for x in G_gFT_smth.x.data ] - if T3_sel.shape[0] != 0: - height_model, poly_offset, dist_nanmask = reconstruct_displacement( - Gx_1, Gk_1, T3_sel, k_thresh=k_thresh + font_for_pres() + G_gFT_smth.coords["k_lim"] = ("x", k_lim_list) + G_gFT_smth.k_lim.plot() + k_lim_smth = G_gFT_smth.k_lim.rolling(x=3, center=True, min_periods=1).mean() + k_lim_smth.plot(c="r") + + plt.title("k_c filter", loc="left") + F.save_light(path=plot_path, name=str(track_name) + "_B06_atten_ov") + + G_gFT_smth["k_lim"] = k_lim_smth + G_gFT_wmean.coords["k_lim"] = k_lim_smth + + font_for_print() + + fn = copy.copy(lstrings) + F = M.figure_axis_xy( + fig_sizes["two_column"][0], + fig_sizes["two_column"][0] * 0.9, + container=True, + view_scale=1, + ) + + plt.suptitle( + "Cut-off Frequency for Displacement Spectral\n" + io.ID_to_str(track_name), + y=0.97, + ) + gs = GridSpec(8, 3, wspace=0.1, hspace=1.5) + + k_lims = G_gFT_wmean.k_lim + xlims = G_gFT_wmean.k[0], G_gFT_wmean.k[-1] + + k = high_beams[0] + + for pos, k, pflag in zip( + [gs[0:2, 0], gs[0:2, 1], gs[0:2, 2]], high_beams, [True, False, False] + ): + ax0 = F.fig.add_subplot(pos) + Gplot = ( + Gk.sel(beam=k) + .isel(x=slice(0, -1)) + .gFT_PSD_data.squeeze() + .rolling(k=20, x=2, min_periods=1, center=True) + .mean() ) - poly_offset = poly_offset * 0 - G_height_model_temp[str(i) + bb] = xr.DataArray( - height_model, coords=Gx_1.coords, dims=Gx_1.dims, name="height_model" + Gplot = Gplot.where(Gplot["N_per_stancil"] / Gplot["Lpoints"] >= 0.1) + + alpha_range = np.linspace(1, 0, Gplot.x.data.size) + + for x, ialpha in zip(Gplot.x.data, alpha_range): + plt.loglog( + Gplot.k, + Gplot.sel(x=x) / Gplot.k, + linewidth=0.5, + color=color_schemes.rels[k], + alpha=ialpha, + ) + ax0.axvline( + k_lims.sel(x=x), + linewidth=0.4, + color="black", + zorder=0, + alpha=ialpha, + ) + + plt.title(next(fn) + k, color=col_dict[k], loc="left") + plt.xlim(xlims) + + if pflag: + ax0.tick_params(labelbottom=False, bottom=True) + plt.ylabel("Power (m$^2$/k')") + plt.legend() + else: + ax0.tick_params(labelbottom=False, bottom=True, labelleft=False) + + for pos, k, pflag in zip( + [gs[2:4, 0], gs[2:4, 1], gs[2:4, 2]], low_beams, [True, False, False] + ): + ax0 = F.fig.add_subplot(pos) + Gplot = ( + Gk.sel(beam=k) + .isel(x=slice(0, -1)) + .gFT_PSD_data.squeeze() + .rolling(k=20, x=2, min_periods=1, center=True) + .mean() ) - else: - G_height_model_temp[str(i) + bb] = xr.DataArray( - Gx_1.y_model.data, - coords=Gx_1.coords, - dims=Gx_1.dims, - name="height_model", + + Gplot = Gplot.where(Gplot["N_per_stancil"] / Gplot["Lpoints"] >= 0.1) + + alpha_range = np.linspace(1, 0, Gplot.x.data.size) + for x, ialpha in zip(Gplot.x.data, alpha_range): + plt.loglog( + Gplot.k, + Gplot.sel(x=x) / Gplot.k, + linewidth=0.5, + color=color_schemes.rels[k], + alpha=ialpha, + ) + ax0.axvline( + k_lims.sel(x=x), + linewidth=0.4, + color="black", + zorder=0, + alpha=ialpha, + ) + + plt.title(next(fn) + k, color=col_dict[k], loc="left") + plt.xlim(xlims) + plt.xlabel("observed wavenumber k' ") + + if pflag: + ax0.tick_params(bottom=True) + plt.ylabel("Power (m$^2$/k')") + plt.legend() + else: + ax0.tick_params(bottom=True, labelleft=False) + + F.save_light(path=plot_path, name=str(track_name) + "_B06_atten_ov_simple") + F.save_pup(path=plot_path, name=str(track_name) + "_B06_atten_ov_simple") + + pos = gs[5:, 0:2] + ax0 = F.fig.add_subplot(pos) + + lat_str = ( + str(np.round(Gx.isel(x=0).lat.mean().data, 2)) + + " to " + + str(np.round(Gx.isel(x=-1).lat.mean().data, 2)) + ) + plt.title( + next(fn) + "Mean Displacement Spectra\n(lat=" + lat_str + ")", loc="left" + ) + + dd = 10 * np.log((G_gFT_smth / G_gFT_smth.k).isel(x=slice(0, -1))) + dd = dd.where(~np.isinf(dd), np.nan) + + ## filter out segments with less then 10% of data points + dd = dd.where(G_gFT_smth["N_per_stancil_fraction"] >= 0.1) + + dd_lims = np.round(dd.quantile(0.01).data * 0.95, 0), np.round( + dd.quantile(0.95).data * 1.05, 0 + ) + plt.pcolor( + dd.x / 1e3, + dd.k, + dd, + vmin=dd_lims[0], + vmax=dd_lims[-1], + cmap=color_schemes.white_base_blgror, + ) + cb = plt.colorbar(orientation="vertical") + + cb.set_label("Power (m$^2$/k)") + plt.plot( + G_gFT_smth.isel(x=slice(0, -1)).x / 1e3, + G_gFT_smth.isel(x=slice(0, -1)).k_lim, + color=color_schemes.black, + linewidth=1, + ) + plt.ylabel("wavenumber k") + plt.xlabel("X (km)") + + pos = gs[6:, -1] + ax9 = F.fig.add_subplot(pos) + + plt.title("Data Coverage (%)", loc="left") + plt.plot( + G_gFT_smth.x / 1e3, + G_gFT_smth["N_per_stancil_fraction"] * 100, + linewidth=0.8, + color="black", + ) + ax9.spines["left"].set_visible(False) + ax9.spines["right"].set_visible(True) + ax9.tick_params(labelright=True, right=True, labelleft=False, left=False) + ax9.axhline(10, linewidth=0.8, linestyle="--", color="black") + plt.xlabel("X (km)") + + F.save_light(path=plot_path, name=str(track_name) + "_B06_atten_ov") + F.save_pup(path=plot_path, name=str(track_name) + "_B06_atten_ov") + + def reconstruct_displacement(Gx_1, Gk_1, T3, k_thresh): + """ + reconstructs photon displacement heights for each stancil given the model parameters in Gk_1 + A low-pass frequeny filter can be applied using k-thresh + + inputs: + Gk_1 model data per stencil from _gFT_k file with sin and cos coefficients + Gx_1 real data per stencil from _gFT_x file with mean photon heights and coordindate systems + T3 + k_thresh (None) threshold for low-pass filter + + returns: + height_model reconstucted displements heights of the stancil + poly_offset fitted staight line to the residual between observations and model to account for low-pass variability + nan_mask mask where is observed data in + """ + + dist_stencil = Gx_1.eta + Gx_1.x + + gFT_cos_coeff_sel = np.copy(Gk_1.gFT_cos_coeff) + gFT_sin_coeff_sel = np.copy(Gk_1.gFT_sin_coeff) + + gFT_cos_coeff_sel = gFT_cos_coeff_sel * tanh_fitler( + Gk_1.k, k_thresh, sigma_g=0.003 + ) + gFT_sin_coeff_sel = gFT_sin_coeff_sel * tanh_fitler( + Gk_1.k, k_thresh, sigma_g=0.003 ) - G_height_model[bb] = xr.concat(G_height_model_temp.values(), dim="x").T + FT_int = gFT.generalized_Fourier(Gx_1.eta + Gx_1.x, None, Gk_1.k) + _ = FT_int.get_H() + FT_int.p_hat = np.concatenate( + [-gFT_sin_coeff_sel / Gk_1.k, gFT_cos_coeff_sel / Gk_1.k] + ) -Gx["height_model"] = xr.concat(G_height_model.values(), dim="beam").transpose( - "eta", "beam", "x" -) + dx = Gx.eta.diff("eta").mean().data + height_model = FT_int.model() / dx + dist_nanmask = np.isnan(Gx_1.y_data) + + return height_model, np.nan, dist_nanmask + + # cutting Table data + G_height_model = dict() + k = "gt2l" + for bb in Gx.beam.data: + G_height_model_temp = dict() + for i in np.arange(Gx.x.size): + Gx_1 = Gx.isel(x=i).sel(beam=bb) + Gk_1 = Gk.isel(x=i).sel(beam=bb) + k_thresh = G_gFT_smth.k_lim.isel(x=0).data + + dist_stencil = Gx_1.eta + Gx_1.x + dist_stencil_lims = dist_stencil[0].data, dist_stencil[-1].data + + T3_sel = B3[k].loc[ + ( + (B3[k]["dist"] >= dist_stencil_lims[0]) + & (B3[k]["dist"] <= dist_stencil_lims[1]) + ) + ] + + if T3_sel.shape[0] != 0: + height_model, poly_offset, _ = reconstruct_displacement( + Gx_1, Gk_1, T3_sel, k_thresh=k_thresh + ) + poly_offset = poly_offset * 0 + G_height_model_temp[str(i) + bb] = xr.DataArray( + height_model, + coords=Gx_1.coords, + dims=Gx_1.dims, + name="height_model", + ) + else: + G_height_model_temp[str(i) + bb] = xr.DataArray( + Gx_1.y_model.data, + coords=Gx_1.coords, + dims=Gx_1.dims, + name="height_model", + ) + + G_height_model[bb] = xr.concat(G_height_model_temp.values(), dim="x").T + + Gx["height_model"] = xr.concat(G_height_model.values(), dim="beam").transpose( + "eta", "beam", "x" + ) -Gx_v2, B2_v2, B3_v2 = dict(), dict(), dict() -for bb in Gx.beam.data: - print(bb) - Gx_k = Gx.sel(beam=bb) - Gh = Gx["height_model"].sel(beam=bb).T - Gh_err = Gx_k["model_error_x"].T - Gnans = np.isnan(Gx_k.y_model) - - concented_heights = Gh.data.reshape(Gh.data.size) - concented_err = Gh_err.data.reshape(Gh.data.size) - concented_nans = Gnans.data.reshape(Gnans.data.size) - concented_x = (Gh.x + Gh.eta).data.reshape(Gh.data.size) - - dx = Gh.eta.diff("eta")[0].data - continous_x_grid = np.arange(concented_x.min(), concented_x.max(), dx) - continous_height_model = np.interp(continous_x_grid, concented_x, concented_heights) - concented_err = np.interp(continous_x_grid, concented_x, concented_err) - continous_nans = np.interp(continous_x_grid, concented_x, concented_nans) == 1 - - T3 = B3[bb] - T3 = T3.sort_values("x") - T3 = T3.sort_values("dist") - - T3["heights_c_model"] = np.interp( - T3["dist"], continous_x_grid, continous_height_model - ) - T3["heights_c_model_err"] = np.interp(T3["dist"], continous_x_grid, concented_err) - T3["heights_c_residual"] = T3["heights_c_weighted_mean"] - T3["heights_c_model"] + Gx_v2, B2_v2, B3_v2 = dict(), dict(), dict() - B3_v2[bb] = T3 - Gx_v2[bb] = Gx_k + for bb in Gx.beam.data: + echo(bb) + Gx_k = Gx.sel(beam=bb) + Gh = Gx["height_model"].sel(beam=bb).T + Gh_err = Gx_k["model_error_x"].T + Gnans = np.isnan(Gx_k.y_model) -try: - G_angle = xr.open_dataset(load_path_angle + "/B05_" + ID_name + "_angle_pdf.nc") + concented_heights = Gh.data.reshape(Gh.data.size) + concented_err = Gh_err.data.reshape(Gh.data.size) + concented_x = (Gh.x + Gh.eta).data.reshape(Gh.data.size) - font_for_pres() + dx = Gh.eta.diff("eta")[0].data + continous_x_grid = np.arange(concented_x.min(), concented_x.max(), dx) + continous_height_model = np.interp( + continous_x_grid, concented_x, concented_heights + ) + concented_err = np.interp(continous_x_grid, concented_x, concented_err) - Ga_abs = ( - G_angle.weighted_angle_PDF_smth.isel(angle=G_angle.angle > 0).data - + G_angle.weighted_angle_PDF_smth.isel(angle=G_angle.angle < 0).data[:, ::-1] - ) / 2 - Ga_abs = xr.DataArray( - data=Ga_abs.T, - dims=G_angle.dims, - coords=G_angle.isel(angle=G_angle.angle > 0).coords, - ) + T3 = B3[bb] + T3 = T3.sort_values("x") + T3 = T3.sort_values("dist") - Ga_abs_front = Ga_abs.isel(x=slice(0, 3)) - Ga_best = (Ga_abs_front * Ga_abs_front.N_data).sum("x") / Ga_abs_front.N_data.sum( - "x" - ) + T3["heights_c_model"] = np.interp( + T3["dist"], continous_x_grid, continous_height_model + ) + T3["heights_c_model_err"] = np.interp( + T3["dist"], continous_x_grid, concented_err + ) + T3["heights_c_residual"] = ( + T3["heights_c_weighted_mean"] - T3["heights_c_model"] + ) - theta = Ga_best.angle[Ga_best.argmax()].data - theta_flag = True + B3_v2[bb] = T3 + Gx_v2[bb] = Gx_k - font_for_print() - F = M.figure_axis_xy(3, 5, view_scale=0.7) + try: + G_angle = xr.open_dataset( + load_path_angle / (f"B05_{track_name}_angle_pdf.nc") + ) - plt.subplot(2, 1, 1) - plt.pcolor(Ga_abs) - plt.xlabel("abs angle") - plt.ylabel("x") + except ValueError as e: + echo(f"{e} no angle data found, skip angle corretion") + theta = 0 + theta_flag = False + else: + font_for_pres() + Ga_abs = ( + G_angle.weighted_angle_PDF_smth.isel(angle=G_angle.angle > 0).data + + G_angle.weighted_angle_PDF_smth.isel(angle=G_angle.angle < 0).data[ + :, ::-1 + ] + ) / 2 + Ga_abs = xr.DataArray( + data=Ga_abs.T, + dims=G_angle.dims, + coords=G_angle.isel(angle=G_angle.angle > 0).coords, + ) - ax = plt.subplot(2, 1, 2) - Ga_best.plot() - plt.title("angle front " + str(theta * 180 / np.pi), loc="left") - ax.axvline(theta, color="red") - F.save_light(path=plot_path, name="B06_angle_def") -except: - print("no angle data found, skip angle corretion") - theta = 0 - theta_flag = False + Ga_abs_front = Ga_abs.isel(x=slice(0, 3)) + Ga_best = (Ga_abs_front * Ga_abs_front.N_data).sum( + "x" + ) / Ga_abs_front.N_data.sum("x") -# %% -lam_p = 2 * np.pi / Gk.k -lam = lam_p * np.cos(theta) + theta = Ga_best.angle[Ga_best.argmax()].data + theta_flag = True -if theta_flag: - k_corrected = 2 * np.pi / lam - x_corrected = Gk.x * np.cos(theta) -else: - k_corrected = 2 * np.pi / lam * np.nan - x_corrected = Gk.x * np.cos(theta) * np.nan + font_for_print() + F = M.figure_axis_xy(3, 5, view_scale=0.7) -# spectral save -G5 = G_gFT_wmean.expand_dims(dim="beam", axis=1) -G5.coords["beam"] = ["weighted_mean"] -G5 = G5.assign_coords(N_photons=G5.N_photons) -G5["N_photons"] = G5["N_photons"].expand_dims("beam") -G5["N_per_stancil_fraction"] = G5["N_per_stancil_fraction"].expand_dims("beam") + plt.subplot(2, 1, 1) + plt.pcolor(Ga_abs) + plt.xlabel("abs angle") + plt.ylabel("x") -Gk_v2 = xr.merge([Gk, G5]) + ax = plt.subplot(2, 1, 2) + Ga_best.plot() + plt.title("angle front " + str(theta * 180 / np.pi), loc="left") + ax.axvline(theta, color="red") + F.save_light(path=plot_path, name="B06_angle_def") -Gk_v2 = Gk_v2.assign_coords(x_corrected=("x", x_corrected.data)).assign_coords( - k_corrected=("k", k_corrected.data) -) + k_corrected, x_corrected = get_k_x_corrected(Gk, theta, theta_flag) -Gk_v2.attrs["best_guess_incident_angle"] = theta + # spectral save + G5 = G_gFT_wmean.expand_dims(dim="beam", axis=1) + G5.coords["beam"] = ["weighted_mean"] + G5 = G5.assign_coords(N_photons=G5.N_photons) + G5["N_photons"] = G5["N_photons"].expand_dims("beam") + G5["N_per_stancil_fraction"] = G5["N_per_stancil_fraction"].expand_dims("beam") -# save collected spectral data -Gk_v2.to_netcdf(save_path + "/B06_" + ID_name + "_gFT_k_corrected.nc") + Gk_v2 = xr.merge([Gk, G5]) -# save real space data -Gx.to_netcdf(save_path + "/B06_" + ID_name + "_gFT_x_corrected.nc") + Gk_v2 = Gk_v2.assign_coords(x_corrected=("x", x_corrected.data)).assign_coords( + k_corrected=("k", k_corrected.data) + ) + ## TODO: abstract to a function Github Iusse #117 + Gk_v2.attrs["best_guess_incident_angle"] = theta -def save_table(data, tablename, save_path): - try: - io.save_pandas_table(data, tablename, save_path) - except Exception as e: - tabletoremove = save_path + tablename + ".h5" - print(e, f"Failed to save table. Removing {tabletoremove} and re-trying..") - os.remove(tabletoremove) - io.save_pandas_table(data, tablename, save_path) + # save collected spectral data + Gk_v2.to_netcdf(save_path / ("B06_" + track_name + "_gFT_k_corrected.nc")) -B06_ID_name = "B06_" + ID_name -table_names = [B06_ID_name + suffix for suffix in ["_B06_corrected_resid", "_binned_resid"]] -data = [B2_v2, B3_v2] -for tablename, data in zip(table_names, data): - save_table(data, tablename, save_path) + # save real space data + Gx.to_netcdf(save_path / ("B06_" + track_name + "_gFT_x_corrected.nc")) -MT.json_save( - "B06_success", - plot_path + "../", - {"time": time.asctime(time.localtime(time.time()))}, -) -print("done. saved target at " + plot_path + "../B06_success") + B06_ID_name = "B06_" + track_name + table_names = [ + B06_ID_name + suffix for suffix in ["_B06_corrected_resid", "_binned_resid"] + ] + data = [B2_v2, B3_v2] + for tablename, data in zip(table_names, data): + save_table(data, tablename, save_path) + + MT.json_save( + "B06_success", + (plot_path / "../"), + {"time": time.asctime(time.localtime(time.time()))}, + ) + echo("done. saved target at " + str(plot_path) + "../B06_success") + echo("Done B06_correct_separate_var") + + +correct_separate_app = makeapp(run_B06_correct_separate_var, name="correct-separate") + +if __name__ == "__main__": + correct_separate_app() diff --git a/src/icesat2_tracks/app.py b/src/icesat2_tracks/app.py index f2e350be..654a09b4 100644 --- a/src/icesat2_tracks/app.py +++ b/src/icesat2_tracks/app.py @@ -7,6 +7,26 @@ run_B01_SL_load_single_file as _loadfile, ) +from icesat2_tracks.analysis_db.B02_make_spectra_gFT import ( + run_B02_make_spectra_gFT as _makespectra, +) + +from icesat2_tracks.analysis_db.B03_plot_spectra_ov import ( + run_B03_plot_spectra_ov as _plotspectra, +) + +from icesat2_tracks.analysis_db.A02c_IOWAGA_thredds_prior import ( + run_A02c_IOWAGA_thredds_prior as _threddsprior, +) + + +from icesat2_tracks.analysis_db.B04_angle import run_B04_angle as _run_B04_angle + +from icesat2_tracks.analysis_db.B05_define_angle import define_angle as _define_angle + +from icesat2_tracks.analysis_db.B06_correct_separate_var import run_B06_correct_separate_var as _run_correct_separate_var + + from icesat2_tracks.clitools import ( validate_track_name, validate_batch_key, @@ -14,11 +34,10 @@ validate_track_name_steps_gt_1, ) - app = Typer(add_completion=False) validate_track_name_gt_1_opt = Option(..., callback=validate_track_name_steps_gt_1) validate_batch_key_opt = Option(..., callback=validate_batch_key) -validate_output_dir_opt = Option(None, callback=validate_output_dir) +validate_output_dir_opt = Option(..., callback=validate_output_dir) def run_job( @@ -27,24 +46,94 @@ def run_job( batch_key: str, ID_flag: bool = True, output_dir: str = validate_output_dir_opt, + verbose: bool = False, ): - analysis_func( - track_name, - batch_key, - ID_flag, - output_dir, - ) + analysis_func(track_name, batch_key, ID_flag, output_dir, verbose) @app.command(help=_loadfile.__doc__) -def loadfile( +def load_file( track_name: str = Option(..., callback=validate_track_name), batch_key: str = validate_batch_key_opt, ID_flag: bool = True, output_dir: str = validate_output_dir_opt, + verbose: bool = False, +): + run_job(_loadfile, track_name, batch_key, ID_flag, output_dir, verbose) + + +@app.command(help=_makespectra.__doc__) +def make_spectra( + track_name: str = validate_track_name_gt_1_opt, + batch_key: str = validate_batch_key_opt, + ID_flag: bool = True, + output_dir: str = validate_output_dir_opt, + verbose: bool = False, +): + run_job(_makespectra, track_name, batch_key, ID_flag, output_dir, verbose) + + +@app.command(help=_plotspectra.__doc__) +def plot_spectra( + track_name: str = validate_track_name_gt_1_opt, + batch_key: str = validate_batch_key_opt, + ID_flag: bool = True, + output_dir: str = validate_output_dir_opt, + verbose: bool = False, +): + run_job(_plotspectra, track_name, batch_key, ID_flag, output_dir, verbose) + + +@app.command(help=_plotspectra.__doc__) +def separate_var( + track_name: str = validate_track_name_gt_1_opt, + batch_key: str = validate_batch_key_opt, + ID_flag: bool = True, + output_dir: str = validate_output_dir_opt, ): - run_job(_loadfile, track_name, batch_key, ID_flag, output_dir) + run_job(_plotspectra, track_name, batch_key, ID_flag, output_dir) +@app.command(help=_threddsprior.__doc__) +def make_iowaga_threads_prior( # TODO: revise naming @mochell + track_name: str = validate_track_name_gt_1_opt, + batch_key: str = validate_batch_key_opt, + ID_flag: bool = True, + output_dir: str = validate_output_dir_opt, + verbose: bool = False, +): + run_job(_threddsprior, track_name, batch_key, ID_flag, output_dir, verbose) + + +@app.command(help=_run_B04_angle.__doc__) +def make_b04_angle( # TODO: revise naming @mochell + track_name: str = validate_track_name_gt_1_opt, + batch_key: str = validate_batch_key_opt, + ID_flag: bool = True, + output_dir: str = validate_output_dir_opt, + verbose: bool = False, +): + run_job(_run_B04_angle, track_name, batch_key, ID_flag, output_dir, verbose) + +@app.command(help=_define_angle.__doc__) +def define_angle( + track_name: str = validate_track_name_gt_1_opt, + batch_key: str = validate_batch_key_opt, + ID_flag: bool = True, + output_dir: str = validate_output_dir_opt, + verbose: bool = False, +): + run_job(_define_angle, track_name, batch_key, ID_flag, output_dir, verbose) + + +@app.command(help=_run_correct_separate_var.__doc__) +def correct_separate( # TODO: rename with a verb or something + track_name: str = validate_track_name_gt_1_opt, + batch_key: str = validate_batch_key_opt, + ID_flag: bool = True, + output_dir: str = validate_output_dir_opt, + verbose: bool = False, +): + run_job(_run_correct_separate_var, track_name, batch_key, ID_flag, output_dir) if __name__ == "__main__": - app() \ No newline at end of file + app() diff --git a/src/icesat2_tracks/clitools.py b/src/icesat2_tracks/clitools.py index dd01a182..dd896027 100644 --- a/src/icesat2_tracks/clitools.py +++ b/src/icesat2_tracks/clitools.py @@ -1,7 +1,6 @@ import os import re -import sys -from contextlib import contextmanager +from contextlib import contextmanager, redirect_stdout from pathlib import Path import typer @@ -14,12 +13,8 @@ def suppress_stdout(verbose=False): yield else: with open(os.devnull, "w") as devnull: - old_stdout = sys.stdout - sys.stdout = devnull - try: + with redirect_stdout(devnull): yield - finally: - sys.stdout = old_stdout # Callbacks for typer @@ -35,7 +30,9 @@ def validate_pattern_wrapper( return value -def validate_track_name(ctx: typer.Context, param: typer.CallbackParam, value: str) -> str: +def validate_track_name( + ctx: typer.Context, param: typer.CallbackParam, value: str +) -> str: """ Validate the track name `value` based on a specific pattern (see below). @@ -65,7 +62,7 @@ def validate_track_name(ctx: typer.Context, param: typer.CallbackParam, value: s '20221231115959_87654321_321_21' >>> validate_track_name(None, None, '20220228235959_00000000_000_00') '20220228235959_00000000_000_00' - + Doctest: >>> validate_track_name(None, None, 'invalid_track_name') Traceback (most recent call last): @@ -83,7 +80,9 @@ def validate_track_name(ctx: typer.Context, param: typer.CallbackParam, value: s ) -def validate_batch_key(ctx: typer.Context, param: typer.CallbackParam, value: str) -> str: +def validate_batch_key( + ctx: typer.Context, param: typer.CallbackParam, value: str +) -> str: """ Validate a batch key based on a specific pattern (see below).