Skip to content

Commit

Permalink
Merge pull request #259 from nyx-space/quickfix/residuals-plot
Browse files Browse the repository at this point in the history
Quickfix/residuals plot
  • Loading branch information
ChristopherRabotin authored Dec 1, 2023
2 parents a081425 + 8e01cf4 commit d5ab39d
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 32 deletions.
3 changes: 2 additions & 1 deletion python/nyx_space/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

from .gauss_markov import plot_gauss_markov
from .od import plot_covar, plot_estimates, plot_measurements
from .od import plot_covar, plot_estimates, plot_measurements, overlay_measurements
from .traj import plot_traj, plot_ground_track, plot_traj_errors

__all__ = [
Expand All @@ -28,4 +28,5 @@
"plot_traj_errors",
"plot_ground_track",
"plot_measurements",
"overlay_measurements",
]
122 changes: 97 additions & 25 deletions python/nyx_space/plots/od.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"""

import plotly.graph_objects as go
import plotly.express as px
from datetime import datetime

from .utils import plot_with_error, plot_line, finalize_plot, colors

Expand All @@ -28,6 +30,7 @@
import numpy as np
from scipy.stats import norm

from nyx_space.time import Epoch

def plot_estimates(
dfs,
Expand Down Expand Up @@ -93,8 +96,8 @@ def plot_estimates(
epoch = epoch.replace("UTC", "").strip()
if "." not in epoch:
epoch += ".0"
pd_ok_epochs += [epoch]
time_col = pd.to_datetime(pd_ok_epochs)
pd_ok_epochs += [datetime.fromisoformat(str(epoch).replace("UTC", "").strip())]
time_col = pd.Series(pd_ok_epochs)
x_title = "Epoch {}".format(time_col_name[-3:])

# Check that the requested covariance frame exists
Expand Down Expand Up @@ -247,12 +250,12 @@ def plot_estimates(

if msr_df is not None:
# Plot the measurements on both plots
pos_fig = plot_measurements(
msr_df, title, time_col_name, fig=pos_fig, show=False
pos_fig = overlay_measurements(
pos_fig, msr_df, title, time_col_name, show=False
)

vel_fig = plot_measurements(
msr_df, title, time_col_name, fig=vel_fig, show=False
vel_fig = overlay_measurements(
vel_fig, msr_df, title, time_col_name, show=False
)

if html_out:
Expand Down Expand Up @@ -333,8 +336,8 @@ def plot_covar(
epoch = epoch.replace("UTC", "").strip()
if "." not in epoch:
epoch += ".0"
pd_ok_epochs += [epoch]
time_col = pd.to_datetime(pd_ok_epochs)
pd_ok_epochs += [datetime.fromisoformat(str(epoch).replace("UTC", "").strip())]
time_col = pd.Series(pd_ok_epochs)
x_title = "Epoch {}".format(time_col_name[-3:])

# Check that the requested covariance frame exists
Expand Down Expand Up @@ -454,12 +457,12 @@ def plot_covar(

if msr_df is not None:
# Plot the measurements on both plots
pos_fig = plot_measurements(
msr_df, title, time_col_name, fig=pos_fig, show=False
pos_fig = overlay_measurements(
pos_fig, msr_df, title, time_col_name, show=False
)

vel_fig = plot_measurements(
msr_df, title, time_col_name, fig=vel_fig, show=False
vel_fig = overlay_measurements(
vel_fig, msr_df, title, time_col_name, show=False
)

if html_out:
Expand All @@ -481,21 +484,22 @@ def plot_covar(
return pos_fig, vel_fig


def plot_measurements(
def overlay_measurements(
fig,
dfs,
title,
time_col_name="Epoch:Gregorian UTC",
html_out=None,
copyright=None,
fig=None,
show=True,
):
"""
Given a plotly figure, overlay the measurements as shaded regions on top of the existing plot.
For a plot of measurements only, use `plot_measurements`.
"""
if not isinstance(dfs, list):
dfs = [dfs]

if fig is None:
fig = go.Figure()

color_values = list(colors.values())

station_colors = {}
Expand All @@ -518,8 +522,8 @@ def plot_measurements(
epoch = epoch.replace("UTC", "").strip()
if "." not in epoch:
epoch += ".0"
pd_ok_epochs += [epoch]
time_col = pd.to_datetime(pd_ok_epochs)
pd_ok_epochs += [datetime.fromisoformat(str(epoch).replace("UTC", "").strip())]
time_col = pd.Series(pd_ok_epochs)
x_title = "Epoch {}".format(time_col_name[-3:])

# Diff the epochs of the measurements to find when there is a start and end.
Expand Down Expand Up @@ -571,7 +575,7 @@ def plot_measurements(
line_width=0,
)

finalize_plot(fig, title, x_title, copyright, show)
finalize_plot(fig, title, x_title, None, copyright)

if html_out:
with open(html_out, "w") as f:
Expand All @@ -595,7 +599,7 @@ def plot_residuals(
show=True,
):
"""
Plot of residuals, with 3-σ lines
Plot of residuals, with 3-σ lines. Returns a tuple of the plots if show=False.
"""

try:
Expand All @@ -615,12 +619,14 @@ def plot_residuals(
epoch = epoch.replace("UTC", "").strip()
if "." not in epoch:
epoch += ".0"
pd_ok_epochs += [epoch]
time_col = pd.to_datetime(pd_ok_epochs)
pd_ok_epochs += [datetime.fromisoformat(str(epoch).replace("UTC", "").strip())]
time_col = pd.Series(pd_ok_epochs)
x_title = "Epoch {}".format(time_col_name[-3:])

plt_any = False

rtn_plots = []

for col in df.columns:
if col.startswith(kind):
fig = go.Figure()
Expand Down Expand Up @@ -671,8 +677,8 @@ def plot_residuals(

if msr_df is not None:
# Plot the measurements on both plots
fig = plot_measurements(
msr_df, title, time_col_name, fig=fig, show=False
fig = overlay_measurements(
fig, msr_df, title, time_col_name, show=False
)

finalize_plot(
Expand All @@ -689,10 +695,15 @@ def plot_residuals(

if show:
fig.show()
else:
rtn_plots += [fig]

if not plt_any:
raise ValueError(f"No columns ending with {kind} found -- nothing plotted")

if not show:
return rtn_plots


def plot_residual_histogram(
df, title, kind="Prefit", copyright=None, html_out=None, show=True
Expand Down Expand Up @@ -737,3 +748,64 @@ def plot_residual_histogram(

if show:
fig.show()

def plot_measurements(
df,
msr_type=None,
title=None,
time_col_name="Epoch:Gregorian UTC",
html_out=None,
copyright=None,
show=True,
):
"""
Plot the provided measurement type, fuzzy matching of the column name, or plot all as a strip
"""

if title is None:
# Build a title
station_names = ", ".join([name for name in df["Tracking device"].unique()])
start = Epoch(df["Epoch:Gregorian UTC"].iloc[0])
end = Epoch(df["Epoch:Gregorian UTC"].iloc[-1])
arc_duration = end.timedelta(start)
title = f"Measurements from {station_names} spanning {start} to {end} ({arc_duration})"

try:
orig_tim_col = df[time_col_name]
except KeyError:
# Find the time column
try:
col_name = [x for x in df.columns if x.startswith("Epoch")][0]
except IndexError:
raise KeyError("Could not find any Epoch column")
print(f"Could not find time column {time_col_name}, using `{col_name}`")
orig_tim_col = df[col_name]

# Build a Python datetime column
pd_ok_epochs = []
for epoch in orig_tim_col:
epoch = epoch.replace("UTC", "").strip()
if "." not in epoch:
epoch += ".0"
pd_ok_epochs += [datetime.fromisoformat(str(epoch).replace("UTC", "").strip())]
df["time_col"] = pd.Series(pd_ok_epochs)
x_title = "Epoch {}".format(time_col_name[-3:])

if msr_type is None:
fig = px.strip(df, x="time_col", y="Tracking device", color="Tracking device")
finalize_plot(fig, title, x_title, "All tracking data", copyright)
else:
msr_col_name = [col for col in df.columns if msr_type in col.lower()]

fig = px.scatter(df, x="time_col", y=msr_col_name, color="Tracking device")
finalize_plot(fig, title, x_title, msr_col_name[0], copyright)

if html_out:
with open(html_out, "w") as f:
f.write(fig.to_html())
print(f"Saved HTML to {html_out}")

if show:
fig.show()
else:
return fig
9 changes: 5 additions & 4 deletions python/nyx_space/plots/traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
from datetime import datetime

from .utils import (
radii,
Expand Down Expand Up @@ -219,8 +220,8 @@ def plot_orbit_elements(
epoch = epoch.replace("UTC", "").strip()
if "." not in epoch:
epoch += ".0"
pd_ok_epochs += [epoch]
df["Epoch"] = pd.to_datetime(pd_ok_epochs)
pd_ok_epochs += [datetime.fromisoformat(str(epoch).replace("UTC", "").strip())]
df["Epoch"] = pd.Series(pd_ok_epochs)

if not isinstance(names, list):
names = [names]
Expand Down Expand Up @@ -317,8 +318,8 @@ def plot_traj_errors(
epoch = epoch.replace("UTC", "").strip()
if "." not in epoch:
epoch += ".0"
pd_ok_epochs += [epoch]
df["Epoch"] = pd.to_datetime(pd_ok_epochs)
pd_ok_epochs += [datetime.fromisoformat(str(epoch).replace("UTC", "").strip())]
df["Epoch"] = pd.Series(pd_ok_epochs)

if not isinstance(names, list):
names = [names]
Expand Down
4 changes: 2 additions & 2 deletions python/nyx_space/plots/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _add_watermark(who):
nyx_tpl.layout.annotations = [
dict(
name="watermark",
text=f"Nyx Space 🄯 AGPLv3 {year}",
text=f"Powered by Nyx Space © {year}",
opacity=0.75,
font=dict(color="#3d84e8", size=12),
xref="paper",
Expand Down Expand Up @@ -201,7 +201,7 @@ def finalize_plot(fig, title, xtitle=None, ytitle=None, copyright=None):
"""

annotations = [dict(templateitemname="watermark")]
if copyright:
if copyright is not None:
annotations += [
dict(
templateitemname="watermark",
Expand Down

0 comments on commit d5ab39d

Please sign in to comment.