diff --git a/examples/example_16/ob03235_2_full.yaml b/examples/example_16/ob03235_2_full.yaml index 1af097fa..4fd95e88 100644 --- a/examples/example_16/ob03235_2_full.yaml +++ b/examples/example_16/ob03235_2_full.yaml @@ -120,6 +120,7 @@ plots: best model: # You can skip the line below - the light curve will be plotted on screen. file: ob03235_2_model.png + interactive : ob03235_2_model.html time range: 2452820 2452855 magnitude range: 19.3 16.9 rcParams: diff --git a/examples/example_16/ob08092-o4_minimal_interactive_plot.yaml b/examples/example_16/ob08092-o4_minimal_interactive_plot.yaml new file mode 100644 index 00000000..cacb34a6 --- /dev/null +++ b/examples/example_16/ob08092-o4_minimal_interactive_plot.yaml @@ -0,0 +1,9 @@ +photometry_files: data/OB08092/phot_ob08092_O4.dat +# Define model to be plotted: +model: + parameters: t_0 u_0 t_E + values: 2455379.5716130617 0.5236953977433282 17.926270401551196 +# Now set the file where you want the light curve to be plotted: +plots: + best model: + interactive: ob08092-o4_minimal_plot.html diff --git a/examples/example_16/ob08092-o4_minimal_plot.yaml b/examples/example_16/ob08092-o4_minimal_plot.yaml index 8b7978e1..08e31d22 100644 --- a/examples/example_16/ob08092-o4_minimal_plot.yaml +++ b/examples/example_16/ob08092-o4_minimal_plot.yaml @@ -19,4 +19,4 @@ fitting_parameters: # Now set the file where you want the light curve to be plotted: plots: best model: - file: ob08092-o4_minimal_plot.png + file: ob08092-o4_minimal_plot.png diff --git a/examples/example_16/requirements.txt b/examples/example_16/requirements.txt index a56b342a..be35a5ef 100644 --- a/examples/example_16/requirements.txt +++ b/examples/example_16/requirements.txt @@ -1,6 +1,7 @@ numpy scipy>=0.18.0 matplotlib +plotly yaml emcee corner diff --git a/examples/example_16/ulens_model_fit.py b/examples/example_16/ulens_model_fit.py index 935d39d9..85af3521 100644 --- a/examples/example_16/ulens_model_fit.py +++ b/examples/example_16/ulens_model_fit.py @@ -37,14 +37,17 @@ import ultranest except Exception: import_failed.add("ultranest") - +try: + import plotly.graph_objects as go +except Exception: + import_failed.add("plotly") try: import MulensModel as mm except Exception: raise ImportError('\nYou have to install MulensModel first!\n') -__version__ = '0.40.1' +__version__ = '0.41.0' class UlensModelFit(object): @@ -330,8 +333,8 @@ class UlensModelFit(object): ``trajectory``, and ``best model``. The values are also dicts and currently accepted keys are: 1) for ``best model``: - ``'file'``, ``'time range'``, ``'magnitude range'``, ``'legend'``, - and ``'rcParams'``, + ``'file'``,``'interactive' ``'time range'``, ``'magnitude range'``, + ``'legend'``,`and ``'rcParams'``, 2) for ``triangle`` and ``trace``: ``'file'`` and ``'shift t_0'`` (*bool*, *True* is default) 3) for ``trajectory``: @@ -352,6 +355,7 @@ class UlensModelFit(object): 'time range': 2456050. 2456300. 'best model': 'file': 'my_fit_best.png' + 'interactive' : 'my_fit_best.html' 'time range': 2456000. 2456300. 'magnitude range': 15.123 13.012 'legend': @@ -690,6 +694,9 @@ def _check_imports(self): if self._plots is not None and 'triangle' in self._plots: required_packages.add('corner') + if self._plots['best model'] and 'interactive' in self._plots['best model']: + required_packages.add('plotly') + failed = import_failed.intersection(required_packages) if len(failed) > 0: @@ -701,6 +708,8 @@ def _check_imports(self): "\nFor corner package it's enough that you run:\nwget " + "https://raw.githubusercontent.com/dfm/corner.py/" + "v2.0.0/corner/corner.py") + if "plotly" in failed: + message += ("\nThe plotly package is required for creating interactive best model plots.") raise ImportError(message) @@ -792,7 +801,8 @@ def _check_plots_parameters(self): if 'trace' in self._plots: self._check_plots_parameters_trace() - names = {key: value['file'] for (key, value) in self._plots.items()} + names = {key: value.get('file', None) + for (key, value) in self._plots.items()} done = {} for (plot_type, name) in names.items(): if name is None: @@ -808,7 +818,7 @@ def _check_plots_parameters_best_model(self): Check if parameters of best model make sense """ allowed = set(['file', 'time range', 'magnitude range', 'legend', - 'rcParams', 'second Y scale']) + 'rcParams', 'second Y scale', 'interactive']) unknown = set(self._plots['best model'].keys()) - allowed if len(unknown) > 0: raise ValueError( @@ -839,6 +849,9 @@ def _check_plots_parameters_best_model(self): args = [key, type(self._plots['best model'][key])] raise TypeError(msg.format(*args)) + if 'interactive' in self._plots['best model']: + self._check_plots_parameters_best_model_interactive() + if 'second Y scale' in self._plots['best model']: self._check_plots_parameters_best_model_Y_scale() @@ -862,6 +875,14 @@ def _set_time_range_for_plot(self, plot_type): "plot:\n" + text[0] + " " + text[1]) self._plots[plot_type]['time range'] = [t_0, t_1] + def _check_plots_parameters_best_model_interactive(self): + """ + Check if there is no problem with interactive best plot + """ + if "second Y scale" in self._plots['best model']: + msg = "Interactive plot will not have a second Y scale. This feature is not yet implemented." + raise NotImplementedError(msg) + def _check_plots_parameters_best_model_Y_scale(self): """ Check if parameters for second Y scale make sense. @@ -1721,7 +1742,7 @@ def _parse_fit_constraints_soft_blending(self, key, value): sigma = float(value.split()[0]) sets = list(map(self._get_no_of_dataset, - shlex.split(value, posix=False)[1:])) + shlex.split(value, posix=False)[1:])) if len(sets) > len(self._datasets): raise ValueError( "dataset number specified in" + @@ -3353,6 +3374,8 @@ def _make_plots(self): self._trace_plot() if 'best model' in self._plots: self._best_model_plot() + if 'interactive' in self._plots['best model']: + self._make_interactive_plot() if 'trajectory' in self._plots: self._make_trajectory_plot() @@ -3645,15 +3668,12 @@ def _plot_models_for_best_model_plot(self, fluxes, kwargs_model): """ for dataset in self._datasets: if dataset.ephemerides_file is None: - self._model.plot_lc( - source_flux=fluxes[0], blend_flux=fluxes[1], - **kwargs_model) + self._model.plot_lc(source_flux=fluxes[0], blend_flux=fluxes[1], **kwargs_model) break for model in self._models_satellite: model.parameters.parameters = {**self._model.parameters.parameters} - model.plot_lc(source_flux=fluxes[0], blend_flux=fluxes[1], - **kwargs_model) + model.plot_lc(source_flux=fluxes[0], blend_flux=fluxes[1], **kwargs_model) def _plot_legend_for_best_model_plot(self): """ @@ -3821,6 +3841,323 @@ def _make_trajectory_plot(self): self._save_figure(self._plots['trajectory'].get('file'), dpi=dpi) + def _make_interactive_plot(self): + """ + plot best model and residuals interactively + """ + scale = 0.5 # original size=(1920:1440) + + self._ln_like(self._best_model_theta) # Sets all parameters to the best model. + + self._reset_rcParams() + if 'rcParams' in self._plots['best model']: + for (key, value) in self._plots['best model']['rcParams'].items(): + rcParams[key] = value + + kwargs_all = self._get_kwargs_for_best_model_plot() + (ylim, ylim_residuals) = self._get_ylim_for_best_model_plot(*kwargs_all[4:6]) + (layout, kwargs_model, kwargs_interactive, kwargs) = \ + self._prepare_interactive_layout(scale, kwargs_all, ylim, ylim_residuals) + + (t_data_start, t_data_stop) = self._get_time_span_data() + kwargs_model['t_start'] = t_data_start + kwargs_model['t_stop'] = t_data_stop + data_ref = self._event.data_ref + (f_source_0, f_blend_0) = self._event.get_flux_for_dataset(data_ref) + traces_lc = self._make_interactive_lc_traces(f_source_0, f_blend_0, **kwargs_model, **kwargs_interactive) + self._interactive_fig = go.Figure(data=traces_lc, layout=layout) + + self._add_interactive_zero_trace(**kwargs_model, **kwargs_interactive) + self._add_interactive_data_traces(kwargs_interactive, **kwargs) + self._add_interactive_residuals_traces(kwargs_interactive, **kwargs_model) + + self._save_interactive_fig() + + def _prepare_interactive_layout(self, scale, kwargs_all, ylim, ylim_residuals): + """Prepares the layout for the interactive plot.""" + kwargs_grid, kwargs_model, kwargs, xlim, t_1, t_2 = kwargs_all[:6] + kwargs_axes_1, kwargs_axes_2 = kwargs_all[6:] + kwargs_interactive = self._get_kwargs_for_plotly_plot(scale) + + layout = self._make_interactive_layout( + ylim, ylim_residuals, + **kwargs_grid, + **kwargs_model, + **kwargs_interactive + ) + return layout, kwargs_model, kwargs_interactive, kwargs + + def _get_kwargs_for_plotly_plot(self, scale): + """_ + setting kwargs for interactive plot + """ + sizes = np.array([ + 10., # markers data points + 4., # model line + 4., # residuals error thickens + 4., # residuals error width + 56., # font label + 4., # zero-line width in residuals + 4., # axes and legend border width + 15., # ticks len + 30., # font lagend + ]) + sizes = sizes*scale + colors = ['black', 'black', '#b9b9b9'] # This are: axes, font, and legend border. + + kwargs_interactive = dict(sizes=sizes, colors=colors, opacity=0.7, width=1920*scale, + height=1440*scale, font='Old Standard TT, serif', paper_bgcolor='white') + return kwargs_interactive + + def _make_interactive_layout(self, ylim, ylim_residuals, height_ratios, hspace, sizes, colors, opacity, width, + height, font, paper_bgcolor, t_start, t_stop, + subtract_2450000=None, subtract_2460000=None, **kwargs): + """ + Creates plotly.graph_objects.Layout object analogues to best model plot + """ + hsplit = height_ratios[1] / height_ratios[0] + subtract = mm.utils.PlotUtils.find_subtract(subtract_2450000, subtract_2460000) + + xtitle = 'Time' + if subtract > 0.: + xtitle = xtitle+' - {:d}'.format(int(subtract)) + + t_start = t_start - subtract + t_stop = t_stop - subtract + + font_base = dict(family=font, size=sizes[4], color=colors[1]) + font_legend = dict(family=font, size=sizes[8]) + kwargs_ = dict(showgrid=False, ticks='inside', showline=True, ticklen=sizes[7], + tickwidth=sizes[6], linewidth=sizes[6], linecolor=colors[0], tickfont=font_base) + kwargs_y = {'mirror': 'all', **kwargs_} + kwargs_x = {'range': [t_start, t_stop], **kwargs_} + layout = go.Layout( + autosize=True, width=width, height=height, showlegend=True, + legend=dict( + x=1.02, y=.98, bgcolor=paper_bgcolor, bordercolor=colors[2], borderwidth=sizes[6], font=font_legend), + paper_bgcolor=paper_bgcolor, plot_bgcolor=paper_bgcolor, font=font_base, + yaxis=dict(title_text='Magnitude', domain=[hsplit+(hspace/2), 1], range=ylim, **kwargs_y), + yaxis2=dict(title_text='Residuals', domain=[0, hsplit-(hspace/2)], anchor="x", range=ylim_residuals, + **kwargs_y), + xaxis=dict(anchor="y", mirror='ticks', side='top', scaleanchor='x2', matches='x2', showticklabels=False, + **kwargs_x), + xaxis2=dict(title_text=xtitle, anchor="y2", mirror='all', scaleanchor='x', matches='x', **kwargs_x) + ) + return layout + + def _get_time_span_data(self): + """ + Returning time span of datasets + """ + t_min = np.zeros(len(self._datasets)) + t_max = np.zeros(len(self._datasets)) + for (i, data) in enumerate(self._datasets): + t_min[i] = min(data.time) + t_max[i] = max(data.time) + + return (min(t_min), max(t_max)) + + def _make_interactive_lc_traces(self, f_source_0, f_blend_0, sizes, colors, opacity, width, height, font, + paper_bgcolor, t_start, t_stop, name=None, dash='solid', subtract_2450000=None, + subtract_2460000=None, gamma=None, bandpass=None, **kwargs): + """ + Creates plotly.graph_objects.Scatter objects with model light curve + """ + traces_lc = [] + subtract = mm.utils.PlotUtils.find_subtract(subtract_2450000, subtract_2460000) + times = np.linspace(t_start, t_stop, num=5000) - subtract + + if isinstance(name, type(None)): + showlegend = False + else: + showlegend = True + + for dataset in self._datasets: + if dataset.ephemerides_file is None: + lc = self._model.get_lc( + times=times, source_flux=f_source_0, blend_flux=f_blend_0, gamma=gamma, bandpass=bandpass) + traces_lc.append(self._make_interactive_scatter_lc( + times, lc, name, showlegend, colors[1], sizes[1], dash)) + break + + traces_lc.extend(self._make_interactive_scatter_lc_satellite( + traces_lc, times, f_source_0, f_blend_0, gamma, bandpass, colors, sizes, dash, subtract, showlegend)) + return traces_lc + + def _make_interactive_scatter_lc_satellite( + self, traces, times, f_source_0, f_blend_0, gamma, + bandpass, colors, sizes, dash, subtract, showlegend): + """Generates Plotly Scatter traces for the light-curve satellite models.""" + + for (i, model) in enumerate(self._models_satellite): + name = self._event.datasets[i].plot_properties['label'] + model.parameters.parameters = {**self._model.parameters.parameters} + lc = self._model.get_lc(times=times, source_flux=f_source_0, blend_flux=f_blend_0, + gamma=gamma, bandpass=bandpass) + times = times - subtract + trace = self._make_interactive_scatter_lc( + times, lc, name, showlegend, colors[1], sizes[1], dash) + traces.append(trace) + return traces + + def _make_interactive_scatter_lc( + self, times, lc, name, + showlegend, color, size, dash): + """Creates a Plotly Scatter trace for the light curve.""" + + return go.Scatter(x=times, y=lc, name=name, showlegend=showlegend, mode='lines', + line=dict(color=color, width=size, dash=dash), + xaxis="x", yaxis="y") + + def _add_interactive_zero_trace(self, t_start, t_stop, colors, sizes, + subtract_2450000=False, subtract_2460000=False, **kwargs): + """ + Creates plotly.graph_objects.Scatter object for line y=0 in + residuals plot + """ + subtract = mm.utils.PlotUtils.find_subtract(subtract_2450000, subtract_2460000) + times = np.linspace(t_start, t_stop, num=2000) + line = np.zeros(len(times)) + trace_0 = go.Scatter(x=times-subtract, y=line, mode='lines', + line=dict(color=colors[0], width=sizes[5], dash='dash'), + xaxis="x2", yaxis="y2", showlegend=False,) + self._interactive_fig.add_trace(trace_0) + + def _add_interactive_data_traces(self, kwargs_interactive, phot_fmt='mag', data_ref=None, show_errorbars=True, + show_bad=None, subtract_2450000=False, subtract_2460000=False, **kwargs): + """ + Creates plotly.graph_objects.Scatter object for observation points + per each data set. + """ + self._event._set_default_colors() + if self._event.fits is None: + self._event.get_chi2() + + if data_ref is None: + data_ref = self._event.data_ref + + subtract = mm.utils.PlotUtils.find_subtract(subtract_2450000, subtract_2460000) + (f_source_0, f_blend_0) = self._event.get_flux_for_dataset(data_ref) + + traces_data = [] + for (dataset_index, data) in enumerate(self._datasets): + # Scale the data flux + (flux, err_flux) = self._event.fits[dataset_index].scale_fluxes(f_source_0, f_blend_0) + (y_value, y_err) = mm.utils.PlotUtils.get_y_value_y_err(phot_fmt, flux, err_flux) + times = data.time - subtract + trace_data = self._make_one_interactive_data_trace( + dataset_index, times, y_value, y_err, xaxis='x', yaxis='y', showlegend=True, + show_errorbars=show_errorbars, show_bad=show_bad, **kwargs_interactive) + traces_data.extend(trace_data) + + for trace in traces_data: + self._interactive_fig.add_trace(trace) + + def _make_one_interactive_data_trace(self, dataset_index, times, y_value, y_err, xaxis, yaxis, showlegend, + colors, sizes, opacity, show_errorbars=None, show_bad=None, **kwargs): + """ + Creates plotly.graph_objects.Scatter object with data points form a given data set. + """ + trace_data = [] + dataset, show_errorbars, show_bad = self._get_interactive_dataset(dataset_index) + + trace_data_good = self._make_interactive_good_data_trace( + dataset, times, y_value, y_err, opacity, sizes, xaxis, yaxis, showlegend, show_errorbars) + trace_data.append(trace_data_good) + + if show_bad: + trace_data_bad = self._make_interactive_bad_data_trace( + dataset, times, y_value, y_err, opacity, sizes, xaxis, yaxis, showlegend) + trace_data.append(trace_data_bad) + + return trace_data + + def _get_interactive_dataset(self, dataset_index): + """Get dataset properties for interactive plot settings.""" + dataset = self._event.datasets[dataset_index] + show_errorbars = dataset.plot_properties.get('show_errorbars', True) + show_bad = dataset.plot_properties.get('show_bad', False) + return (dataset, show_errorbars, show_bad) + + def _make_interactive_good_data_trace( + self, dataset, times, y_value, y_err, opacity, + sizes, xaxis, yaxis, showlegend, show_errorbars): + """Creates a single plotly.graph_objects.Scatter object for the good data points.""" + times_good = times[dataset.good] + y_good = y_value[dataset.good] + y_err_good = y_err[dataset.good] + return self._make_interactive_data_trace( + times_good, y_good, y_err_good, dataset, + opacity, sizes, xaxis, yaxis, showlegend, show_errorbars) + + def _make_interactive_data_trace(self, x, y, y_err, dataset, opacity, sizes, xaxis, yaxis, + showlegend, show_errorbars, color_override=None, error_visible=True): + """Creates single plotly.graph_objects.Scatter object for good or bad data.""" + color = color_override if color_override else dataset.plot_properties['color'] + error_y = dict(type='data', array=y_err, visible=error_visible, thickness=sizes[2], width=sizes[3]) + marker = dict(color=color, size=sizes[0], line=dict(color=color, width=1)) + return go.Scatter(x=x, y=y, opacity=opacity, name=dataset.plot_properties['label'], mode='markers', + showlegend=showlegend, error_y=error_y, marker=marker, xaxis=xaxis, yaxis=yaxis) + + def _make_interactive_bad_data_trace(self, dataset, times, y_value, y_err, opacity, sizes, + xaxis, yaxis, showlegend): + """Creates a single plotly.graph_objects.Scatter object for the bad data points.""" + times_bad = times[dataset.bad] + y_bad = y_value[dataset.bad] + y_err_bad = y_err[dataset.bad] + return self._make_interactive_data_trace( + times_bad, y_bad, y_err_bad, dataset, + opacity, sizes, xaxis, yaxis, showlegend, + show_errorbars=False, + color_override='black', + error_visible=False + ) + + def _add_interactive_residuals_traces(self, kwargs_interactive, phot_fmt='mag', data_ref=None, show_errorbars=True, + show_bad=None, subtract_2450000=False, subtract_2460000=False, **kwargs,): + """ + Creates plotly.graph_objects.Scatter object for residuals points + per each data set. + """ + traces_residuals = [] + self._event._set_default_colors() # For each dataset + if self._event.fits is None: + self._event.get_chi2() + + if data_ref is None: + data_ref = self._event.data_ref + + subtract = mm.utils.PlotUtils.find_subtract(subtract_2450000, subtract_2460000) + + # Get fluxes for the reference dataset + (f_source_0, f_blend_0) = self._event.get_flux_for_dataset(data_ref) + kwargs_residuals = {'phot_fmt': 'scaled', 'bad': False, + 'source_flux': f_source_0, 'blend_flux': f_blend_0} + if show_bad: + kwargs_residuals['bad'] = True + + for (dataset_index, data) in enumerate(self._datasets): + (y_value, y_err) = self._event.fits[dataset_index].get_residuals(**kwargs_residuals) + times = data.time-subtract + trace_residuals = self._make_one_interactive_data_trace( + dataset_index, times, y_value, y_err, xaxis='x2', yaxis='y2', + showlegend=False, show_errorbars=show_errorbars, show_bad=show_bad, **kwargs_interactive) + traces_residuals.extend(trace_residuals) + + for trace in traces_residuals: + self._interactive_fig.add_trace(trace) + + def _save_interactive_fig(self): + """ + Saving interactive figure + """ + file_ = self._plots['best model']['interactive'] + if path.exists(file_): + if path.isfile(file_): + msg = "Existing file " + file_ + " will be overwritten" + warnings.warn(msg) + self._interactive_fig.write_html(file_, full_html=True) + if __name__ == '__main__': if len(sys.argv) != 2: