diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..c3fd080 --- /dev/null +++ b/.flake8 @@ -0,0 +1,17 @@ +[flake8] +ignore = + # Whitespace before ':' + E203, + # Module level import not at top of file + E402, + # Line break occurred before a binary operator + W503, + # Line break occurred after a binary operator + W504 + # line break before binary operator + E203 + # line too long + E501 + # No lambdas — too strict + E731 +max-line-length = 120 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8473d92 --- /dev/null +++ b/.gitignore @@ -0,0 +1,20 @@ +# python +*.pyc +**/__pycache__/ +.pytest_cache/* +.pydevproject + +# IDE +.vscode/* + +# Pip +*.egg-info + +# Log files +*.out +*.err +*.gz + +# Custom +*.ckpt +*.zip \ No newline at end of file diff --git a/README.md b/README.md index 43732f5..7ef85f3 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,17 @@
-
No histogram results, please add more experiments or adjust the search filter.
' + margin: ClassVar[List[int]] = [5, 5, 5, 30] + css_classes: ClassVar[List[str]] = ['histogram-default-div'] + + @classmethod + def get_config(cls) -> Dict[str, Any]: + """Get configs as a dict.""" + return {'text': cls.text, 'margin': cls.margin, 'css_classes': cls.css_classes} + + +@dataclass(frozen=True) +class HistogramTabPlotConfig: + """Config for the histogram tab plot column tag.""" + + css_classes: ClassVar[List[str]] = ['histogram-plots'] + name: ClassVar[str] = 'histogram_plots' + default_width: ClassVar[int] = 800 + + @classmethod + def get_config(cls) -> Dict[str, Any]: + """Get configs as a dict.""" + return {'name': cls.name, 'css_classes': cls.css_classes} + + +@dataclass(frozen=True) +class HistogramTabModalQueryButtonConfig: + """Config for the histogram tab modal query button tag.""" + + name: ClassVar[str] = 'histogram_modal_query_btn' + label: ClassVar[str] = 'Search Results' + css_classes: ClassVar[List[str]] = ['btn', 'btn-primary', 'modal-btn', 'histogram-tab-modal-query-btn'] + + @classmethod + def get_config(cls) -> Dict[str, Any]: + """Get configs as a dict.""" + return {'name': cls.name, 'label': cls.label, 'css_classes': cls.css_classes} diff --git a/sledge/sledgeboard/tabs/config/overview_tab_config.py b/sledge/sledgeboard/tabs/config/overview_tab_config.py new file mode 100644 index 0000000..ae33dbc --- /dev/null +++ b/sledge/sledgeboard/tabs/config/overview_tab_config.py @@ -0,0 +1,98 @@ +from dataclasses import dataclass +from typing import Any, ClassVar, Dict, List, Optional + +OVERVIEW_PLANNER_CHECKBOX_GROUP_NAME = 'overview_planner_checkbox_group' + + +@dataclass +class OverviewAggregatorData: + """Aggregator metric data in the overview tab.""" + + aggregator_file_name: str # Aggregator output file name + aggregator_type: str # Aggregator type + planner_name: str # Planner name + scenario_type: str # Scenario type + num_scenarios: int # Number of scenarios in the type + score: float # The aggregator scores for the scenario type + + +@dataclass(frozen=True) +class OverviewTabDefaultDataSourceDictConfig: + """Config for the overview tab default data source tag.""" + + experiment: ClassVar[List[str]] = ['-'] + scenario_type: ClassVar[List[str]] = ['-'] + planner: ClassVar[List[str]] = [ + 'No metric aggregator results, please add more experiments ' 'or adjust the search filter' + ] + + @classmethod + def get_config(cls) -> Dict[str, Any]: + """Get configs as a dict.""" + return {'experiment': cls.experiment, 'scenario_type': cls.scenario_type, 'planner': cls.planner} + + +@dataclass(frozen=True) +class OverviewTabExperimentTableColumnConfig: + """Config for the overview tab experiment table column tag.""" + + field: ClassVar[str] = 'experiment' + title: ClassVar[str] = 'Experiment' + width: ClassVar[int] = 150 + sortable: ClassVar[bool] = False + + @classmethod + def get_config(cls) -> Dict[str, Any]: + """Get configs as a dict.""" + return {'field': cls.field, 'title': cls.title, 'width': cls.width, 'sortable': cls.sortable} + + +@dataclass(frozen=True) +class OverviewTabScenarioTypeTableColumnConfig: + """Config for the overview tab scenario type table column tag.""" + + field: ClassVar[str] = 'scenario_type' + title: ClassVar[str] = 'Scenario Type (Number of Scenarios)' + width: ClassVar[int] = 200 + sortable: ClassVar[bool] = False + + @classmethod + def get_config(cls) -> Dict[str, Any]: + """Get configs as a dict.""" + return {'field': cls.field, 'title': cls.title, 'width': cls.width, 'sortable': cls.sortable} + + +@dataclass(frozen=True) +class OverviewTabPlannerTableColumnConfig: + """Config for the overview tab planner table column tag.""" + + field: ClassVar[str] = 'planner' + title: ClassVar[str] = 'Evaluation Score' + sortable: ClassVar[bool] = False + + @classmethod + def get_config(cls) -> Dict[str, Any]: + """Get configs as a dict.""" + return {'field': cls.field, 'title': cls.title, 'sortable': cls.sortable} + + +@dataclass(frozen=True) +class OverviewTabDataTableConfig: + """Config for the overview tab planner data table tag.""" + + selectable: ClassVar[bool] = True + row_height: ClassVar[int] = 80 + index_position: ClassVar[Optional[int]] = None + name: ClassVar[str] = 'overview_table' + css_classes: ClassVar[List[str]] = ['overview-table'] + + @classmethod + def get_config(cls) -> Dict[str, Any]: + """Get configs as a dict.""" + return { + 'selectable': cls.selectable, + 'row_height': cls.row_height, + 'index_position': cls.index_position, + 'name': cls.name, + 'css_classes': cls.css_classes, + } diff --git a/sledge/sledgeboard/tabs/config/scenario_tab_config.py b/sledge/sledgeboard/tabs/config/scenario_tab_config.py new file mode 100644 index 0000000..2cbc4cb --- /dev/null +++ b/sledge/sledgeboard/tabs/config/scenario_tab_config.py @@ -0,0 +1,72 @@ +from dataclasses import dataclass, field +from typing import Any, ClassVar, Dict, List, Tuple + + +@dataclass(frozen=True) +class ScenarioTabTitleDivConfig: + """Config for the scenario tab title div tag.""" + + text: ClassVar[str] = "-" + name: ClassVar[str] = 'scenario_title_div' + css_classes: ClassVar[List[str]] = ['scenario-tab-title-div'] + + @classmethod + def get_config(cls) -> Dict[str, Any]: + """Get configs as a dict.""" + return {'text': cls.text, 'name': cls.name, 'css_classes': cls.css_classes} + + +@dataclass(frozen=True) +class ScenarioTabScenarioTokenMultiChoiceConfig: + """Config for scenario tab scenario token multi choice tag.""" + + max_items: ClassVar[int] = 1 + option_limit: ClassVar[int] = 10 + height: ClassVar[int] = 40 + placeholder: ClassVar[str] = "Scenario token" + name: ClassVar[str] = 'scenario_token_multi_choice' + css_classes: ClassVar[List[str]] = ['scenario-token-multi-choice'] + + @classmethod + def get_config(cls) -> Dict[str, Any]: + """Get configs as a dict.""" + return { + 'max_items': cls.max_items, + 'option_limit': cls.option_limit, + 'height': cls.height, + 'placeholder': cls.placeholder, + 'name': cls.name, + 'css_classes': cls.css_classes, + } + + +@dataclass(frozen=True) +class ScenarioTabModalQueryButtonConfig: + """Config for scenario tab modal query button tag.""" + + name: ClassVar[str] = 'scenario_modal_query_btn' + label: ClassVar[str] = 'Query Scenario' + css_classes: ClassVar[List[str]] = ['btn', 'btn-primary', 'modal-btn', 'scenario-tab-modal-query-btn'] + + @classmethod + def get_config(cls) -> Dict[str, Any]: + """Get configs as a dict.""" + return {'name': cls.name, 'label': cls.label, 'css_classes': cls.css_classes} + + +@dataclass(frozen=True) +class ScenarioTabFrameButtonConfig: + """Config for scenario tab's frame control buttons.""" + + label: str + margin: Tuple[int, int, int, int] = field(default_factory=lambda: (5, 19, 5, 35)) # Top, right, bottom, left + css_classes: List[str] = field(default_factory=lambda: ["frame-control-button"]) + width: int = field(default_factory=lambda: 56) + + +# Global config instances +first_button_config = ScenarioTabFrameButtonConfig(label="first") +prev_button_config = ScenarioTabFrameButtonConfig(label="prev") +play_button_config = ScenarioTabFrameButtonConfig(label="play") +next_button_config = ScenarioTabFrameButtonConfig(label="next") +last_button_config = ScenarioTabFrameButtonConfig(label="last") diff --git a/sledge/sledgeboard/tabs/configuration_tab.py b/sledge/sledgeboard/tabs/configuration_tab.py new file mode 100644 index 0000000..f755db3 --- /dev/null +++ b/sledge/sledgeboard/tabs/configuration_tab.py @@ -0,0 +1,124 @@ +import base64 +import io +import logging +import pathlib +import pickle +from pathlib import Path +from typing import Any, List + +from bokeh.document.document import Document +from bokeh.models import CheckboxGroup, FileInput + +from sledge.sledgeboard.base.base_tab import BaseTab +from sledge.sledgeboard.base.data_class import SledgeBoardFile +from sledge.sledgeboard.base.experiment_file_data import ExperimentFileData +from sledge.sledgeboard.style import configuration_tab_style + +logger = logging.getLogger(__name__) + + +class ConfigurationTab: + """Configuration tab for sledgeboard.""" + + def __init__(self, doc: Document, experiment_file_data: ExperimentFileData, tabs: List[BaseTab]): + """ + Configuration tab about configuring sledgeboard. + :param experiment_file_data: Experiment file data. + :param tabs: A list of tabs to be updated when configuration is changed. + """ + self._doc = doc + self._tabs = tabs + self.experiment_file_data = experiment_file_data + + self._file_path_input = FileInput( + accept=SledgeBoardFile.extension(), + css_classes=["file-path-input"], + margin=configuration_tab_style["file_path_input_margin"], + name="file_path_input", + ) + self._file_path_input.on_change("value", self._add_experiment_file) + self._experiment_file_path_checkbox_group = CheckboxGroup( + labels=self.experiment_file_path_stems, + active=[index for index in range(len(self.experiment_file_data.file_paths))], + name="experiment_file_path_checkbox_group", + css_classes=["experiment-file-path-checkbox-group"], + ) + self._experiment_file_path_checkbox_group.on_click(self._click_experiment_file_path_checkbox) + if self.experiment_file_data.file_paths: + self._file_paths_on_change() + + @property + def experiment_file_path_stems(self) -> List[str]: + """Return a list of file path stems.""" + experiment_paths = [] + for file_path in self.experiment_file_data.file_paths: + metric_path = file_path.current_path / file_path.metric_folder + if metric_path.exists(): + experiment_file_path_stem = file_path.current_path + else: + experiment_file_path_stem = file_path.metric_main_path + + if isinstance(experiment_file_path_stem, str): + experiment_file_path_stem = pathlib.Path(experiment_file_path_stem) + + experiment_file_path_stem = "/".join( + [experiment_file_path_stem.parts[-2], experiment_file_path_stem.parts[-1]] + ) + experiment_paths.append(experiment_file_path_stem) + return experiment_paths + + @property + def file_path_input(self) -> FileInput: + """Return the file path input widget.""" + return self._file_path_input + + @property + def experiment_file_path_checkbox_group(self) -> CheckboxGroup: + """Return experiment file path checkboxgroup.""" + return self._experiment_file_path_checkbox_group + + def _click_experiment_file_path_checkbox(self, attr: Any) -> None: + """ + Click event handler for experiment_file_path_checkbox_group. + :param attr: Clicked attributes. + """ + self._file_paths_on_change() + + def add_sledgeboard_file_to_experiments(self, sledgeboard_file: SledgeBoardFile) -> None: + """ + Add sledgeboard files to experiments. + :param sledgeboard_file: Added sledgeboard file. + """ + sledgeboard_file.current_path = Path(sledgeboard_file.metric_main_path) + if sledgeboard_file not in self.experiment_file_data.file_paths: + self.experiment_file_data.update_data(file_paths=[sledgeboard_file]) + self._experiment_file_path_checkbox_group.labels = self.experiment_file_path_stems + self._experiment_file_path_checkbox_group.active += [len(self.experiment_file_path_stems) - 1] + self._file_paths_on_change() + + def _add_experiment_file(self, attr: str, old: bytes, new: bytes) -> None: + """ + Event responds to file change. + :param attr: Attribute name. + :param old: Old value. + :param new: New value. + """ + if not new: + return + try: + decoded_string = base64.b64decode(new) + file_stream = io.BytesIO(decoded_string) + data = pickle.load(file_stream) + sledgeboard_file = SledgeBoardFile.deserialize(data=data) + self.add_sledgeboard_file_to_experiments(sledgeboard_file=sledgeboard_file) + file_stream.close() + except (OSError, IOError) as e: + logger.info(f"Error loading experiment file. {str(e)}.") + + def _file_paths_on_change(self) -> None: + """Function to call when we change file paths.""" + for tab in self._tabs: + tab.file_paths_on_change( + experiment_file_data=self.experiment_file_data, + experiment_file_active_index=self._experiment_file_path_checkbox_group.active, + ) diff --git a/sledge/sledgeboard/tabs/histogram_tab.py b/sledge/sledgeboard/tabs/histogram_tab.py new file mode 100644 index 0000000..6e3ca30 --- /dev/null +++ b/sledge/sledgeboard/tabs/histogram_tab.py @@ -0,0 +1,746 @@ +import logging +from collections import defaultdict +from copy import deepcopy +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import numpy.typing as npt +from bokeh.document.document import Document +from bokeh.layouts import column, gridplot, layout +from bokeh.models import Button, ColumnDataSource, Div, FactorRange, HoverTool, MultiChoice, Spinner, glyph +from bokeh.plotting import figure + +from sledge.sledgeboard.base.base_tab import BaseTab +from sledge.sledgeboard.base.experiment_file_data import ExperimentFileData +from sledge.sledgeboard.tabs.config.histogram_tab_config import ( + HistogramConstantConfig, + HistogramFigureData, + HistogramTabBinSpinnerConfig, + HistogramTabDefaultDivConfig, + HistogramTabFigureGridPlotStyleConfig, + HistogramTabFigureStyleConfig, + HistogramTabFigureTitleDivStyleConfig, + HistogramTabHistogramBarStyleConfig, + HistogramTabMetricNameMultiChoiceConfig, + HistogramTabModalQueryButtonConfig, + HistogramTabPlotConfig, + HistogramTabScenarioTypeMultiChoiceConfig, +) +from sledge.sledgeboard.tabs.js_code.histogram_tab_js_code import ( + HistogramTabLoadingEndJSCode, + HistogramTabLoadingJSCode, + HistogramTabUpdateWindowsSizeJSCode, +) +from sledge.sledgeboard.utils.sledgeboard_histogram_utils import ( + aggregate_metric_aggregator_dataframe_histogram_data, + aggregate_metric_statistics_dataframe_histogram_data, + compute_histogram_edges, + get_histogram_plot_x_range, +) + +logger = logging.getLogger(__name__) + + +class HistogramTab(BaseTab): + """Histogram tab in SledgeBoard.""" + + def __init__( + self, + doc: Document, + experiment_file_data: ExperimentFileData, + bins: int = HistogramTabBinSpinnerConfig.default_bins, + max_scenario_names: int = 20, + ): + """ + Histogram for metric results about simulation. + :param doc: Bokeh html document. + :param experiment_file_data: Experiment file data. + :param bins: Default number of bins in histograms. + :param max_scenario_names: Show the maximum list of scenario names in each bin, 0 or None to disable + """ + super().__init__(doc=doc, experiment_file_data=experiment_file_data) + self._bins = bins + self._max_scenario_names = max_scenario_names + + # UI. + # Planner selection + self.planner_checkbox_group.name = HistogramConstantConfig.PLANNER_CHECKBOX_GROUP_NAME + self.planner_checkbox_group.js_on_change("active", HistogramTabLoadingJSCode.get_js_code()) + + # Scenario type multi choices + self._scenario_type_multi_choice = MultiChoice(**HistogramTabScenarioTypeMultiChoiceConfig.get_config()) + self._scenario_type_multi_choice.on_change("value", self._scenario_type_multi_choice_on_change) + self._scenario_type_multi_choice.js_on_change("value", HistogramTabUpdateWindowsSizeJSCode.get_js_code()) + + # Metric name multi choices + self._metric_name_multi_choice = MultiChoice(**HistogramTabMetricNameMultiChoiceConfig.get_config()) + self._metric_name_multi_choice.on_change("value", self._metric_name_multi_choice_on_change) + self._metric_name_multi_choice.js_on_change("value", HistogramTabUpdateWindowsSizeJSCode.get_js_code()) + + self._bin_spinner = Spinner(**HistogramTabBinSpinnerConfig.get_config()) + self._histogram_modal_query_btn = Button(**HistogramTabModalQueryButtonConfig.get_config()) + self._histogram_modal_query_btn.js_on_click(HistogramTabLoadingJSCode.get_js_code()) + self._histogram_modal_query_btn.on_click(self._setting_modal_query_button_on_click) + + self._default_div = Div(**HistogramTabDefaultDivConfig.get_config()) + # Histogram plot frame. + self._histogram_plots = column(self._default_div, **HistogramTabPlotConfig.get_config()) + self._histogram_plots.js_on_change("children", HistogramTabLoadingEndJSCode.get_js_code()) + self._histogram_figures: Optional[column] = None + self._aggregated_data: Optional[HistogramConstantConfig.HistogramDataType] = None + self._histogram_edges: Optional[HistogramConstantConfig.HistogramEdgesDataType] = None + self._plot_data: Dict[str, List[glyph]] = defaultdict(list) + self._init_selection() + + @property + def bin_spinner(self) -> Spinner: + """Return a bin spinner.""" + return self._bin_spinner + + @property + def scenario_type_multi_choice(self) -> MultiChoice: + """Return scenario_type_multi_choice.""" + return self._scenario_type_multi_choice + + @property + def metric_name_multi_choice(self) -> MultiChoice: + """Return metric_name_multi_choice.""" + return self._metric_name_multi_choice + + @property + def histogram_plots(self) -> column: + """Return histogram_plots.""" + return self._histogram_plots + + @property + def histogram_modal_query_btn(self) -> Button: + """Return histogram modal query button.""" + return self._histogram_modal_query_btn + + def _click_planner_checkbox_group(self, attr: Any) -> None: + """ + Click event handler for planner_checkbox_group. + :param attr: Clicked attributes. + """ + if not self._aggregated_data and not self._histogram_edges: + return + + # Render histograms. + self._histogram_figures = self._render_histograms() + + # Make sure the histogram upgrades at the last + self._doc.add_next_tick_callback(self._update_histogram_layouts) + + def file_paths_on_change( + self, experiment_file_data: ExperimentFileData, experiment_file_active_index: List[int] + ) -> None: + """ + Interface to update layout when file_paths is changed. + :param experiment_file_data: Experiment file data. + :param experiment_file_active_index: Active indexes for experiment files. + """ + self._experiment_file_data = experiment_file_data + self._experiment_file_active_index = experiment_file_active_index + + self._init_selection() + self._update_histograms() + + def _update_histogram_layouts(self) -> None: + """Update histogram layouts.""" + self._histogram_plots.children[0] = layout(self._histogram_figures) + + def _update_histograms(self) -> None: + """Update histograms.""" + # Aggregate data + self._aggregated_data = self._aggregate_statistics() + + # Aggregate scenario type scores + aggregated_scenario_type_score_data = self._aggregate_scenario_type_score_histogram() + + self._aggregated_data.update(aggregated_scenario_type_score_data) + + # Compute histogram edges + self._histogram_edges = compute_histogram_edges(aggregated_data=self._aggregated_data, bins=self._bins) + + # Render histograms. + self._histogram_figures = self._render_histograms() + + # Make sure the histogram upgrades at the last + self._doc.add_next_tick_callback(self._update_histogram_layouts) + + def _setting_modal_query_button_on_click(self) -> None: + """Setting modal query button on click helper function.""" + if self._metric_name_multi_choice.tags: + self.window_width = self._metric_name_multi_choice.tags[0] + self.window_height = self._metric_name_multi_choice.tags[1] + + if self._bin_spinner.value: + self._bins = self._bin_spinner.value + self._update_histograms() + + def _metric_name_multi_choice_on_change(self, attr: str, old: str, new: str) -> None: + """ + Helper function to change event in histogram metric name. + :param attr: Attribute. + :param old: Old value. + :param new: New value. + """ + # Set up window width and height + if self._metric_name_multi_choice.tags: + self.window_width = self._metric_name_multi_choice.tags[0] + self.window_height = self._metric_name_multi_choice.tags[1] + + def _scenario_type_multi_choice_on_change(self, attr: str, old: str, new: str) -> None: + """ + Helper function to change event in histogram scenario type. + :param attr: Attribute. + :param old: Old value. + :param new: New value. + """ + # Set up window width and height + if self._scenario_type_multi_choice.tags: + self.window_width = self._scenario_type_multi_choice.tags[0] + self.window_height = self.scenario_type_multi_choice.tags[1] + + def _adjust_plot_width_size(self, n_bins: int) -> int: + """ + Adjust plot width size based on number of bins. + :param n_bins: Number of bins. + :return Width size of a histogram plot. + """ + base_plot_width: int = self.plot_sizes[0] + if n_bins < 20: + return base_plot_width + # Increase the width of 50 for every number of bins 20 + width_multiplier_factor: int = (n_bins // 20) * 100 + width_size: int = min( + base_plot_width + width_multiplier_factor, HistogramTabFigureStyleConfig.maximum_plot_width + ) + return width_size + + def _init_selection(self) -> None: + """Init histogram and scalar selection options.""" + # For planner checkbox + planner_name_list: List[str] = [] + # Clean up + self.planner_checkbox_group.labels = [] + self.planner_checkbox_group.active = [] + for index, metric_statistics_dataframes in enumerate(self.experiment_file_data.metric_statistics_dataframes): + if index not in self._experiment_file_active_index: + continue + for metric_statistics_dataframe in metric_statistics_dataframes: + planner_names = metric_statistics_dataframe.planner_names + planner_name_list += planner_names + + sorted_planner_name_list = sorted(list(set(planner_name_list))) + self.planner_checkbox_group.labels = sorted_planner_name_list + self.planner_checkbox_group.active = [index for index in range(len(sorted_planner_name_list))] + + self._init_multi_search_criteria_selection( + scenario_type_multi_choice=self._scenario_type_multi_choice, + metric_name_multi_choice=self._metric_name_multi_choice, + ) + + def plot_vbar( + self, + histogram_figure_data: HistogramFigureData, + counts: npt.NDArray[np.int64], + category: List[str], + planner_name: str, + legend_label: str, + color: str, + scenario_names: List[str], + x_values: List[str], + width: float = 0.4, + histogram_file_name: Optional[str] = None, + ) -> None: + """ + Plot a vertical bar plot. + :param histogram_figure_data: Figure class. + :param counts: An array of counts for each category. + :param category: A list of category (x-axis label). + :param planner_name: Planner name. + :param legend_label: Legend label. + :param color: Legend color. + :param scenario_names: A list of scenario names. + :param x_values: X-axis values. + :param width: Bar width. + :param histogram_file_name: Histogram file name for the histogram data. + """ + y_values = deepcopy(counts) + bottom: npt.NDArray[np.int64] = ( + np.zeros_like(counts) + if histogram_figure_data.frequency_array is None + else histogram_figure_data.frequency_array + ) + count_position = counts > 0 + bottom_arrays: npt.NDArray[np.int64] = bottom * count_position + top = counts + bottom_arrays + histogram_file_names = [histogram_file_name] * len(top) + data_source = ColumnDataSource( + dict( + x=category, + top=top, + bottom=bottom_arrays, + y_values=y_values, + x_values=x_values, + scenario_names=scenario_names, + histogram_file_name=histogram_file_names, + ) + ) + figure_plot = histogram_figure_data.figure_plot + vbar = figure_plot.vbar( + x="x", + top="top", + bottom="bottom", + fill_color=color, + legend_label=legend_label, + width=width, + source=data_source, + **HistogramTabHistogramBarStyleConfig.get_config(), + ) + self._plot_data[planner_name].append(vbar) + HistogramTabHistogramBarStyleConfig.update_histogram_bar_figure_style(histogram_figure=figure_plot) + + def plot_histogram( + self, + histogram_figure_data: HistogramFigureData, + hist: npt.NDArray[np.float64], + edges: npt.NDArray[np.float64], + planner_name: str, + legend_label: str, + color: str, + scenario_names: List[str], + x_values: List[str], + histogram_file_name: Optional[str] = None, + ) -> None: + """ + Plot a histogram. + Reference from https://docs.bokeh.org/en/latest/docs/gallery/histogram.html. + :param histogram_figure_data: Histogram figure data. + :param hist: Histogram data. + :param edges: Histogram bin data. + :param planner_name: Planner name. + :param legend_label: Legend label. + :param color: Legend color. + :param scenario_names: A list of scenario names. + :param x_values: A list of x value names. + :param histogram_file_name: Histogram file name for the histogram data. + """ + bottom: npt.NDArray[np.int64] = ( + np.zeros_like(hist) + if histogram_figure_data.frequency_array is None + else histogram_figure_data.frequency_array + ) + hist_position = hist > 0 + bottom_arrays: npt.NDArray[np.int64] = bottom * hist_position + top = hist + bottom_arrays + histogram_file_names = [histogram_file_name] * len(top) + data_source = ColumnDataSource( + dict( + top=top, + bottom=bottom_arrays, + left=edges[:-1], + right=edges[1:], + y_values=hist, + x_values=x_values, + scenario_names=scenario_names, + histogram_file_name=histogram_file_names, + ) + ) + figure_plot = histogram_figure_data.figure_plot + quad = figure_plot.quad( + top="top", + bottom="bottom", + left="left", + right="right", + fill_color=color, + legend_label=legend_label, + **HistogramTabHistogramBarStyleConfig.get_config(), + source=data_source, + ) + + self._plot_data[planner_name].append(quad) + HistogramTabHistogramBarStyleConfig.update_histogram_bar_figure_style(histogram_figure=figure_plot) + + def _render_histogram_plot( + self, + title: str, + x_axis_label: str, + x_range: Optional[Union[List[str], FactorRange]] = None, + histogram_file_name: Optional[str] = None, + ) -> HistogramFigureData: + """ + Render a histogram plot. + :param title: Title. + :param x_axis_label: x-axis label. + :param x_range: A list of category data if specified. + :param histogram_file_name: Histogram file name for the histogram plot. + :return a figure. + """ + if x_range is None: + len_plot_width = 1 + elif isinstance(x_range, list): + len_plot_width = len(x_range) + else: + len_plot_width = len(x_range.factors) + + plot_width = self._adjust_plot_width_size(n_bins=len_plot_width) + tooltips = [("Frequency", "@y_values"), ("Values", "@x_values{safe}"), ("Scenarios", "@scenario_names{safe}")] + if histogram_file_name: + tooltips.append(("File", "@histogram_file_name")) + + hover_tool = HoverTool(tooltips=tooltips, point_policy="follow_mouse") + statistic_figure = figure( + **HistogramTabFigureStyleConfig.get_config( + title=title, x_axis_label=x_axis_label, width=plot_width, height=self.plot_sizes[1], x_range=x_range + ), + tools=["pan", "wheel_zoom", "save", "reset", hover_tool], + ) + HistogramTabFigureStyleConfig.update_histogram_figure_style(histogram_figure=statistic_figure) + return HistogramFigureData(figure_plot=statistic_figure) + + def _render_histogram_layout(self, histograms: HistogramConstantConfig.HistogramFigureDataType) -> List[column]: + """ + Render histogram layout. + :param histograms: A dictionary of histogram names and their histograms. + :return: A list of lists of figures (a list per row). + """ + layouts = [] + ncols = self.get_plot_cols( + plot_width=self.plot_sizes[0], default_ncols=HistogramConstantConfig.HISTOGRAM_TAB_DEFAULT_NUMBER_COLS + ) + for metric_statistics_name, statistics_data in histograms.items(): + title_div = Div(**HistogramTabFigureTitleDivStyleConfig.get_config(title=metric_statistics_name)) + figures = [histogram_figure.figure_plot for statistic_name, histogram_figure in statistics_data.items()] + grid_plot = gridplot( + figures, + **HistogramTabFigureGridPlotStyleConfig.get_config(ncols=ncols, height=self.plot_sizes[1]), + ) + grid_layout = column(title_div, grid_plot) + layouts.append(grid_layout) + + return layouts + + def _aggregate_scenario_type_score_histogram(self) -> HistogramConstantConfig.HistogramDataType: + """ + Aggregate metric aggregator data. + :return: A dictionary of metric aggregator names and their metric scores. + """ + data: HistogramConstantConfig.HistogramDataType = defaultdict(list) + selected_scenario_types = self._scenario_type_multi_choice.value + + # Loop through all metric aggregators + for index, metric_aggregator_dataframes in enumerate(self.experiment_file_data.metric_aggregator_dataframes): + if index not in self._experiment_file_active_index: + continue + for metric_aggregator_filename, metric_aggregator_dataframe in metric_aggregator_dataframes.items(): + # Aggregate a list of histogram data list + histogram_data_list = aggregate_metric_aggregator_dataframe_histogram_data( + metric_aggregator_dataframe_index=index, + metric_aggregator_dataframe=metric_aggregator_dataframe, + scenario_types=selected_scenario_types, + dataframe_file_name=metric_aggregator_filename, + ) + if histogram_data_list: + data[HistogramConstantConfig.SCENARIO_TYPE_SCORE_HISTOGRAM_NAME] += histogram_data_list + + return data + + def _aggregate_statistics(self) -> HistogramConstantConfig.HistogramDataType: + """ + Aggregate statistics data. + :return A dictionary of metric names and their aggregated data. + """ + data: HistogramConstantConfig.HistogramDataType = defaultdict(list) + scenario_types = self._scenario_type_multi_choice.value + metric_choices = self._metric_name_multi_choice.value + if not len(scenario_types) and not len(metric_choices): + return data + + if 'all' in scenario_types: + scenario_types = None + else: + scenario_types = tuple(scenario_types) + + for index, metric_statistics_dataframes in enumerate(self.experiment_file_data.metric_statistics_dataframes): + if index not in self._experiment_file_active_index: + continue + + for metric_statistics_dataframe in metric_statistics_dataframes: + histogram_data_list = aggregate_metric_statistics_dataframe_histogram_data( + metric_statistics_dataframe=metric_statistics_dataframe, + metric_statistics_dataframe_index=index, + scenario_types=scenario_types, + metric_choices=metric_choices, + ) + + if histogram_data_list: + data[metric_statistics_dataframe.metric_statistic_name] += histogram_data_list + return data + + def _plot_bool_histogram( + self, + histogram_figure_data: HistogramFigureData, + values: npt.NDArray[np.float64], + scenarios: List[str], + planner_name: str, + legend_name: str, + color: str, + histogram_file_name: Optional[str] = None, + ) -> None: + """ + Plot boolean type of histograms. + :param histogram_figure_data: Histogram figure data. + :param values: An array of values. + :param scenarios: A list of scenario names. + :param planner_name: Planner name. + :param legend_name: Legend name. + :param color: Plot color. + :param histogram_file_name: Histogram file name for the histogram data. + """ + # False and True + num_true = np.nansum(values) + num_false = len(values[values == 0]) + scenario_names: List[List[str]] = [[] for _ in range(2)] # False and True bins only + # Get scenario names + for index, scenario in enumerate(scenarios): + scenario_name_index = 1 if values[index] else 0 + if not self._max_scenario_names or len(scenario_names[scenario_name_index]) < self._max_scenario_names: + scenario_names[scenario_name_index].append(scenario) + + scenario_names_flatten = ["No time series results, please add more experiments or + adjust the search filter.
""", + css_classes=['scenario-default-div'], + margin=default_div_style['margin'], + width=default_div_style['width'], + ) + self._time_series_layout = column( + self._default_time_series_div, + css_classes=["scenario-time-series-layout"], + name="time_series_layout", + ) + self._default_ego_expert_states_div = Div( + text="""No expert and ego states, please add more experiments or + adjust the search filter.
""", + css_classes=['scenario-default-div'], + margin=default_div_style['margin'], + width=default_div_style['width'], + ) + self._ego_expert_states_layout = column( + self._default_ego_expert_states_div, + css_classes=["scenario-ego-expert-states-layout"], + name="ego_expert_states_layout", + ) + self._default_simulation_div = Div( + text="""No simulation data, please add more experiments or + adjust the search filter.
""", + css_classes=['scenario-default-div'], + margin=default_div_style['margin'], + width=default_div_style['width'], + ) + self._simulation_tile_layout = column( + self._default_simulation_div, + css_classes=["scenario-simulation-layout"], + name="simulation_tile_layout", + ) + + self._simulation_tile_layout.js_on_change("children", ScenarioTabLoadingEndJSCode.get_js_code()) + self.simulation_tile = SimulationTile( + map_factory=self._scenario_builder.get_map_factory(), + doc=self._doc, + vehicle_parameters=vehicle_parameters, + experiment_file_data=experiment_file_data, + async_rendering=async_rendering, + frame_rate_cap_hz=frame_rate_cap_hz, + ) + + self._default_scenario_score_div = Div( + text="""No scenario score results, please add more experiments or + adjust the search filter.
""", + css_classes=['scenario-default-div'], + margin=default_div_style['margin'], + width=default_div_style['width'], + ) + self._scenario_score_layout = column( + self._default_scenario_score_div, + css_classes=["scenario-score-layout"], + name="scenario_score_layout", + ) + + self._scenario_metric_score_data_figure_sizes = scenario_tab_style['scenario_metric_score_figure_sizes'] + self._scenario_metric_score_data: scenario_metric_score_dict_type = {} + self._time_series_data: Dict[str, List[ScenarioTimeSeriesData]] = {} + self._simulation_figure_data: List[SimulationData] = [] + self._available_scenario_names: List[str] = [] + self._simulation_plots: Optional[column] = None + + object_types = ['Ego', 'Vehicle', 'Pedestrian', 'Bicycle', 'Generic', 'Traffic Cone', 'Barrier', 'Czone Sign'] + self._object_checkbox_group = CheckboxGroup( + labels=object_types, + active=list(range(len(object_types))), + css_classes=["scenario-object-checkbox-group"], + name='scenario_object_checkbox_group', + ) + self._object_checkbox_group.on_change('active', self._object_checkbox_group_active_on_change) + trajectories = ['Expert Trajectory', 'Ego Trajectory', 'Goal', 'Traffic Light', 'RoadBlock'] + self._traj_checkbox_group = CheckboxGroup( + labels=trajectories, + active=list(range(len(trajectories))), + css_classes=["scenario-traj-checkbox-group"], + name='scenario_traj_checkbox_group', + ) + self._traj_checkbox_group.on_change('active', self._traj_checkbox_group_active_on_change) + map_objects = [ + 'Lane', + 'Intersection', + 'Stop Line', + 'Crosswalk', + 'Walkway', + 'Carpark', + 'Lane Connector', + 'Lane Line', + ] + self._map_checkbox_group = CheckboxGroup( + labels=map_objects, + active=list(range(len(map_objects))), + css_classes=["scenario-map-checkbox-group"], + name='scenario_map_checkbox_group', + ) + self._map_checkbox_group.on_change('active', self._map_checkbox_group_active_on_change) + self.plot_state_keys = [ + 'x [m]', + 'y [m]', + 'heading [rad]', + 'velocity_x [m/s]', + 'velocity_y [m/s]', + 'speed [m/s]', + 'acceleration_x [m/s^2]', + 'acceleration_y [m/s^2]', + 'acceleration [m/s^2]', + 'steering_angle [rad]', + 'yaw_rate [rad/s]', + ] + self.expert_planner_key = 'Expert' + self._init_selection() + + @property + def scenario_title_div(self) -> Div: + """Return scenario title div.""" + return self._scenario_title_div + + @property + def scalar_scenario_type_select(self) -> Select: + """Return scalar_scenario_type_select.""" + return self._scalar_scenario_type_select + + @property + def scalar_log_name_select(self) -> Select: + """Return scalar_log_name_select.""" + return self._scalar_log_name_select + + @property + def scalar_scenario_name_select(self) -> Select: + """Return scalar_scenario_name_select.""" + return self._scalar_scenario_name_select + + @property + def scenario_token_multi_choice(self) -> MultiChoice: + """Return scenario_token multi choice.""" + return self._scenario_token_multi_choice + + @property + def scenario_modal_query_btn(self) -> Button: + """Return scenario_modal_query_button.""" + return self._scenario_modal_query_btn + + @property + def object_checkbox_group(self) -> CheckboxGroup: + """Return object checkbox group.""" + return self._object_checkbox_group + + @property + def traj_checkbox_group(self) -> CheckboxGroup: + """Return traj checkbox group.""" + return self._traj_checkbox_group + + @property + def map_checkbox_group(self) -> CheckboxGroup: + """Return map checkbox group.""" + return self._map_checkbox_group + + @property + def time_series_layout(self) -> column: + """Return time_series_layout.""" + return self._time_series_layout + + @property + def scenario_score_layout(self) -> column: + """Return scenario_score_layout.""" + return self._scenario_score_layout + + @property + def simulation_tile_layout(self) -> column: + """Return simulation_tile_layout.""" + return self._simulation_tile_layout + + @property + def ego_expert_states_layout(self) -> column: + """Return time_series_state_layout.""" + return self._ego_expert_states_layout + + def _update_glyph_checkbox_group(self, glyph_names: List[str]) -> None: + """ + Update visibility of glyphs according to checkbox group. + :param glyph_names: A list of updated glyph names. + """ + for simulation_figure in self.simulation_tile.figures: + simulation_figure.update_glyphs_visibility(glyph_names=glyph_names) + + def _traj_checkbox_group_active_on_change(self, attr: str, old: List[int], new: List[int]) -> None: + """ + Helper function for traj checkbox group when the list of actives changes. + :param attr: Attribute name. + :param old: Old active index. + :param new: New active index. + """ + active_indices = list(set(old) - set(new)) + list(set(new) - set(old)) + active_labels = [self._traj_checkbox_group.labels[index] for index in active_indices] + self._update_glyph_checkbox_group(glyph_names=active_labels) + + def _map_checkbox_group_active_on_change(self, attr: str, old: List[int], new: List[int]) -> None: + """ + Helper function for map checkbox group when the list of actives changes. + :param attr: Attribute name. + :param old: Old active index. + :param new: New active index. + """ + active_indices = list(set(old) - set(new)) + list(set(new) - set(old)) + active_labels = [self._map_checkbox_group.labels[index] for index in active_indices] + self._update_glyph_checkbox_group(glyph_names=active_labels) + + def _object_checkbox_group_active_on_change(self, attr: str, old: List[int], new: List[int]) -> None: + """ + Helper function for object checkbox group when the list of actives changes. + :param attr: Attribute name. + :param old: Old active index. + :param new: New active index. + """ + active_indices = list(set(old) - set(new)) + list(set(new) - set(old)) + active_labels = [self._object_checkbox_group.labels[index] for index in active_indices] + self._update_glyph_checkbox_group(glyph_names=active_labels) + + def file_paths_on_change( + self, experiment_file_data: ExperimentFileData, experiment_file_active_index: List[int] + ) -> None: + """ + Interface to update layout when file_paths is changed. + :param experiment_file_data: Experiment file data. + :param experiment_file_active_index: Active indexes for experiment files. + """ + self._experiment_file_data = experiment_file_data + self._experiment_file_active_index = experiment_file_active_index + + self.simulation_tile.init_simulations(figure_sizes=self.simulation_figure_sizes) + self._init_selection() + self._scenario_metric_score_data = self._update_aggregation_metric() + self._update_scenario_plot() + + def _click_planner_checkbox_group(self, attr: Any) -> None: + """ + Click event handler for planner_checkbox_group. + :param attr: Clicked attributes. + """ + # Render scenario metric figures + scenario_metric_score_figure_data = self._render_scenario_metric_score() + + # Render scenario metric score layout + scenario_metric_score_layout = self._render_scenario_metric_layout( + figure_data=scenario_metric_score_figure_data, + default_div=self._default_scenario_score_div, + plot_width=self._scenario_metric_score_data_figure_sizes[0], + legend=False, + ) + self._scenario_score_layout.children[0] = layout(scenario_metric_score_layout) + + # Filter time series data + filtered_time_series_data: Dict[str, List[ScenarioTimeSeriesData]] = defaultdict(list) + for key, time_series_data in self._time_series_data.items(): + for data in time_series_data: + if data.planner_name not in self.enable_planner_names: + continue + filtered_time_series_data[key].append(data) + + # Render time series figure data + time_series_figure_data = self._render_time_series(aggregated_time_series_data=filtered_time_series_data) + + # Render time series layout + time_series_figures = self._render_scenario_metric_layout( + figure_data=time_series_figure_data, + default_div=self._default_time_series_div, + plot_width=self.plot_sizes[0], + legend=True, + ) + self._time_series_layout.children[0] = layout(time_series_figures) + + # Render simulation + filtered_simulation_figures = [ + data for data in self._simulation_figure_data if data.planner_name in self.enable_planner_names + ] + if not filtered_simulation_figures: + simulation_layouts = column(self._default_simulation_div) + ego_expert_state_layouts = column(self._default_ego_expert_states_div) + else: + simulation_layouts = gridplot( + [simulation_figure.plot for simulation_figure in filtered_simulation_figures], + ncols=self.get_plot_cols( + plot_width=self.simulation_figure_sizes[0], offset_width=scenario_tab_style['col_offset_width'] + ), + toolbar_location=None, + ) + ego_expert_state_layouts = self._render_ego_expert_states( + simulation_figure_data=filtered_simulation_figures + ) + self._simulation_tile_layout.children[0] = layout(simulation_layouts) + self._ego_expert_states_layout.children[0] = layout(ego_expert_state_layouts) + + def _update_simulation_layouts(self) -> None: + """Update simulation layouts.""" + self._simulation_tile_layout.children[0] = layout(self._simulation_plots) + + def _update_scenario_plot(self) -> None: + """Update scenario plots when selection is made.""" + start_time = time.perf_counter() + self._simulation_figure_data = [] + + # Render scenario metric score figure data + scenario_metric_score_figure_data = self._render_scenario_metric_score() + + # Render scenario metric score layout + scenario_metric_score_layout = self._render_scenario_metric_layout( + figure_data=scenario_metric_score_figure_data, + default_div=self._default_scenario_score_div, + plot_width=self._scenario_metric_score_data_figure_sizes[0], + legend=False, + ) + self._scenario_score_layout.children[0] = layout(scenario_metric_score_layout) + + # Aggregate time series data + self._time_series_data = self._aggregate_time_series_data() + # Render time series figure data + time_series_figure_data = self._render_time_series(aggregated_time_series_data=self._time_series_data) + # Render time series layout + time_series_figures = self._render_scenario_metric_layout( + figure_data=time_series_figure_data, + default_div=self._default_time_series_div, + plot_width=self.plot_sizes[0], + legend=True, + ) + self._time_series_layout.children[0] = layout(time_series_figures) + + # Render simulations. + self._simulation_plots = self._render_simulations() + + # render ego and expert states, call after rendering simulation + ego_expert_state_layout = self._render_ego_expert_states(simulation_figure_data=self._simulation_figure_data) + self._ego_expert_states_layout.children[0] = layout(ego_expert_state_layout) + + # Make sure the simulation plot upgrades at the last + self._doc.add_next_tick_callback(self._update_simulation_layouts) + end_time = time.perf_counter() + elapsed_time = end_time - start_time + logger.info(f"Rending scenario plot takes {elapsed_time:.4f} seconds.") + + def _update_planner_names(self) -> None: + """Update planner name options in the checkbox widget.""" + self.planner_checkbox_group.labels = [] + self.planner_checkbox_group.active = [] + selected_keys = [ + key + for key in self.experiment_file_data.simulation_scenario_keys + if key.scenario_type == self._scalar_scenario_type_select.value + and key.scenario_name == self._scalar_scenario_name_select.value + ] + sorted_planner_names = sorted(list({key.planner_name for key in selected_keys})) + self.planner_checkbox_group.labels = sorted_planner_names + self.planner_checkbox_group.active = [index for index in range(len(sorted_planner_names))] + + def _scalar_scenario_type_select_on_change(self, attr: str, old: str, new: str) -> None: + """ + Helper function to change event in scalar scenario type. + :param attr: Attribute. + :param old: Old value. + :param new: New value. + """ + if new == "": + return + + available_log_names = self.load_log_name(scenario_type=self._scalar_scenario_type_select.value) + self._scalar_log_name_select.options = [""] + available_log_names + self._scalar_log_name_select.value = "" + self._scalar_scenario_name_select.options = [""] + self._scalar_scenario_name_select.value = "" + + def _scalar_log_name_select_on_change(self, attr: str, old: str, new: str) -> None: + """ + Helper function to change event in scalar log name. + :param attr: Attribute. + :param old: Old value. + :param new: New value. + """ + if new == "": + return + + available_scenario_names = self.load_scenario_names( + scenario_type=self._scalar_scenario_type_select.value, log_name=self._scalar_log_name_select.value + ) + self._scalar_scenario_name_select.options = [""] + available_scenario_names + self._scalar_scenario_name_select.value = "" + + def _scalar_scenario_name_select_on_change(self, attr: str, old: str, new: str) -> None: + """ + Helper function to change event in scalar scenario name. + :param attr: Attribute. + :param old: Old value. + :param new: New value. + """ + if self._scalar_scenario_name_select.tags: + self.window_width = self._scalar_scenario_name_select.tags[0] + self.window_height = self._scalar_scenario_name_select.tags[1] + + def _scenario_token_multi_choice_on_change(self, attr: str, old: List[str], new: List[str]) -> None: + """ + Helper function to change event in scenario token multi choice. + :param attr: Attribute. + :param old: List of old values. + :param new: List of new values. + """ + available_scenario_tokens = self._experiment_file_data.available_scenario_tokens + if not available_scenario_tokens or not new: + return + scenario_token_info = available_scenario_tokens.get(new[0]) + if self._scalar_scenario_type_select.value != scenario_token_info.scenario_type: + self._scalar_scenario_type_select.value = scenario_token_info.scenario_type + if self._scalar_log_name_select.value != scenario_token_info.log_name: + self._scalar_log_name_select.value = scenario_token_info.log_name + if self._scalar_scenario_name_select.value != scenario_token_info.scenario_name: + self.scalar_scenario_name_select.value = scenario_token_info.scenario_name + + def _scenario_modal_query_button_on_click(self) -> None: + """Helper function when click the modal query button.""" + if self._scalar_scenario_name_select.tags: + self.window_width = self._scalar_scenario_name_select.tags[0] + self.window_height = self._scalar_scenario_name_select.tags[1] + self._update_planner_names() + self._update_scenario_plot() + + def _init_selection(self) -> None: + """Init histogram and scalar selection options.""" + self._scalar_scenario_type_select.value = "" + self._scalar_scenario_type_select.options = [] + self._scalar_log_name_select.value = "" + self._scalar_log_name_select.options = [] + self._scalar_scenario_name_select.value = "" + self._scalar_scenario_name_select.options = [] + self._available_scenario_names = [] + self._simulation_figure_data = [] + + if len(self._scalar_scenario_type_select.options) == 0: + self._scalar_scenario_type_select.options = [""] + self.experiment_file_data.available_scenario_types + + if len(self._scalar_scenario_type_select.options) > 0: + self._scalar_scenario_type_select.value = self._scalar_scenario_type_select.options[0] + + available_scenario_tokens = list(self._experiment_file_data.available_scenario_tokens.keys()) + self._scenario_token_multi_choice.options = available_scenario_tokens + self._update_planner_names() + + @staticmethod + def _render_scalar_figure( + title: str, + y_axis_label: str, + hover: HoverTool, + sizes: List[int], + x_axis_label: Optional[str] = None, + x_range: Optional[List[str]] = None, + y_range: Optional[List[str]] = None, + ) -> Figure: + """ + Render a scalar figure. + :param title: Plot title. + :param y_axis_label: Y axis label. + :param hover: Hover tool for the plot. + :param sizes: Width and height in pixels. + :param x_axis_label: Label in x axis. + :param x_range: Labels in x major axis. + :param y_range: Labels in y major axis. + :return A time series plot. + """ + scenario_scalar_figure = Figure( + background_fill_color=PLOT_PALETTE["background_white"], + title=title, + css_classes=["time-series-figure"], + margin=scenario_tab_style["time_series_figure_margins"], + width=sizes[0], + height=sizes[1], + active_scroll="wheel_zoom", + output_backend="webgl", + x_range=x_range, + y_range=y_range, + ) + scenario_scalar_figure.add_tools(hover) + + scenario_scalar_figure.title.text_font_size = scenario_tab_style["time_series_figure_title_text_font_size"] + scenario_scalar_figure.xaxis.axis_label_text_font_size = scenario_tab_style[ + "time_series_figure_xaxis_axis_label_text_font_size" + ] + scenario_scalar_figure.xaxis.major_label_text_font_size = scenario_tab_style[ + "time_series_figure_xaxis_major_label_text_font_size" + ] + scenario_scalar_figure.yaxis.axis_label_text_font_size = scenario_tab_style[ + "time_series_figure_yaxis_axis_label_text_font_size" + ] + scenario_scalar_figure.yaxis.major_label_text_font_size = scenario_tab_style[ + "time_series_figure_yaxis_major_label_text_font_size" + ] + scenario_scalar_figure.toolbar.logo = None + + # Rotate the x_axis label with 45 (180/4) degrees. + scenario_scalar_figure.xaxis.major_label_orientation = np.pi / 4 + + scenario_scalar_figure.yaxis.axis_label = y_axis_label + scenario_scalar_figure.xaxis.axis_label = x_axis_label + + return scenario_scalar_figure + + def _update_aggregation_metric(self) -> scenario_metric_score_dict_type: + """ + Update metric score for each scenario. + :return A dict of log name: {scenario names and their metric scores}. + """ + data: scenario_metric_score_dict_type = defaultdict(lambda: defaultdict(list)) + # Loop through all metric aggregators + for index, metric_aggregator_dataframes in enumerate(self.experiment_file_data.metric_aggregator_dataframes): + if index not in self._experiment_file_active_index: + continue + for file_index, (metric_aggregator_filename, metric_aggregator_dataframe) in enumerate( + metric_aggregator_dataframes.items() + ): + # Get columns + columns = set(list(metric_aggregator_dataframe.columns)) + # List of non-metric columns to be excluded + non_metric_columns = { + 'scenario', + 'log_name', + 'scenario_type', + 'num_scenarios', + 'planner_name', + 'aggregator_type', + } + metric_columns = sorted(list(columns - non_metric_columns)) + # Iterate through rows + for _, row_data in metric_aggregator_dataframe.iterrows(): + num_scenarios = row_data["num_scenarios"] + if not np.isnan(num_scenarios): + continue + + planner_name = row_data["planner_name"] + scenario_name = row_data["scenario"] + log_name = row_data["log_name"] + for metric_column in metric_columns: + score = row_data[metric_column] + # Add scenario metric score data + if score is not None: + data[log_name][scenario_name].append( + ScenarioMetricScoreData( + experiment_index=index, + metric_aggregator_file_name=metric_aggregator_filename, + metric_aggregator_file_index=file_index, + planner_name=planner_name, + metric_statistic_name=metric_column, + score=np.round(score, 4), + ) + ) + + return data + + def _aggregate_time_series_data(self) -> Dict[str, List[ScenarioTimeSeriesData]]: + """ + Aggregate time series data. + :return A dict of metric statistic names and their data. + """ + aggregated_time_series_data: Dict[str, List[ScenarioTimeSeriesData]] = {} + scenario_types = ( + tuple([self._scalar_scenario_type_select.value]) if self._scalar_scenario_type_select.value else None + ) + log_names = tuple([self._scalar_log_name_select.value]) if self._scalar_log_name_select.value else None + if not len(self._scalar_scenario_name_select.value): + return aggregated_time_series_data + for index, metric_statistics_dataframes in enumerate(self.experiment_file_data.metric_statistics_dataframes): + if index not in self._experiment_file_active_index: + continue + + for metric_statistics_dataframe in metric_statistics_dataframes: + planner_names = metric_statistics_dataframe.planner_names + if metric_statistics_dataframe.metric_statistic_name not in aggregated_time_series_data: + aggregated_time_series_data[metric_statistics_dataframe.metric_statistic_name] = [] + for planner_name in planner_names: + data_frame = metric_statistics_dataframe.query_scenarios( + scenario_names=tuple([str(self._scalar_scenario_name_select.value)]), + scenario_types=scenario_types, + planner_names=tuple([planner_name]), + log_names=log_names, + ) + if not len(data_frame): + continue + + time_series_headers = metric_statistics_dataframe.time_series_headers + time_series: pandas.DataFrame = data_frame[time_series_headers] + if time_series[time_series_headers[0]].iloc[0] is None: + continue + + time_series_values: npt.NDArray[np.float64] = np.round( + np.asarray( + list( + chain.from_iterable(time_series[metric_statistics_dataframe.time_series_values_column]) + ) + ), + 4, + ) + + time_series_timestamps = list( + chain.from_iterable(time_series[metric_statistics_dataframe.time_series_timestamp_column]) + ) + time_series_unit = time_series[metric_statistics_dataframe.time_series_unit_column].iloc[0] + time_series_selected_frames = metric_statistics_dataframe.get_time_series_selected_frames + + scenario_time_series_data = ScenarioTimeSeriesData( + experiment_index=index, + planner_name=planner_name, + time_series_values=time_series_values, + time_series_timestamps=time_series_timestamps, + time_series_unit=time_series_unit, + time_series_selected_frames=time_series_selected_frames, + ) + + aggregated_time_series_data[metric_statistics_dataframe.metric_statistic_name].append( + scenario_time_series_data + ) + + return aggregated_time_series_data + + def _render_time_series( + self, aggregated_time_series_data: Dict[str, List[ScenarioTimeSeriesData]] + ) -> Dict[str, Figure]: + """ + Render time series plots. + :param aggregated_time_series_data: Aggregated scenario time series data. + :return A dict of figure name and figures. + """ + time_series_figures: Dict[str, Figure] = {} + for metric_statistic_name, scenario_time_series_data in aggregated_time_series_data.items(): + for data in scenario_time_series_data: + if not len(data.time_series_values): + continue + + if metric_statistic_name not in time_series_figures: + time_series_figures[metric_statistic_name] = self._render_scalar_figure( + title=metric_statistic_name, + y_axis_label=data.time_series_unit, + x_axis_label='frame', + hover=HoverTool( + tooltips=[ + ("Frame", "@x"), + ("Value", "@y{0.0000}"), + ("Time_us", "@time_us"), + ("Planner", "$name"), + ] + ), + sizes=self.plot_sizes, + ) + planner_name = data.planner_name + f" ({self.get_file_path_last_name(data.experiment_index)})" + color = self.experiment_file_data.file_path_colors[data.experiment_index][data.planner_name] + time_series_figure = time_series_figures[metric_statistic_name] + # Get frame numbers based on timestamps + timestamp_frames = ( + data.time_series_selected_frames + if data.time_series_selected_frames is not None + else list(range(len(data.time_series_timestamps))) + ) + data_source = ColumnDataSource( + dict( + x=timestamp_frames, + y=data.time_series_values, + time_us=data.time_series_timestamps, + ) + ) + if data.time_series_selected_frames is not None: + time_series_figure.scatter( + x="x", y="y", name=planner_name, color=color, legend_label=planner_name, source=data_source + ) + else: + time_series_figure.line( + x="x", y="y", name=planner_name, color=color, legend_label=planner_name, source=data_source + ) + + return time_series_figures + + def _render_scenario_metric_score_scatter( + self, scatter_figure: Figure, scenario_metric_score_data: Dict[str, List[ScenarioMetricScoreData]] + ) -> None: + """ + Render scatter plot with scenario metric score data. + :param scatter_figure: A scatter figure. + :param scenario_metric_score_data: Metric score data for a scenario. + """ + # Aggregate data sources + data_sources: Dict[str, ScenarioMetricScoreDataSource] = {} + for metric_name, metric_score_data in scenario_metric_score_data.items(): + for index, score_data in enumerate(metric_score_data): + experiment_name = self.get_file_path_last_name(score_data.experiment_index) + legend_label = f"{score_data.planner_name} ({experiment_name})" + data_source_index = legend_label + f" - {score_data.metric_aggregator_file_index})" + if data_source_index not in data_sources: + data_sources[data_source_index] = ScenarioMetricScoreDataSource( + xs=[], + ys=[], + planners=[], + aggregators=[], + experiments=[], + fill_colors=[], + marker=self.get_scatter_sign(score_data.metric_aggregator_file_index), + legend_label=legend_label, + ) + fill_color = self.experiment_file_data.file_path_colors[score_data.experiment_index][ + score_data.planner_name + ] + data_sources[data_source_index].xs.append(score_data.metric_statistic_name) + data_sources[data_source_index].ys.append(score_data.score) + data_sources[data_source_index].planners.append(score_data.planner_name) + data_sources[data_source_index].aggregators.append(score_data.metric_aggregator_file_name) + data_sources[data_source_index].experiments.append( + self.get_file_path_last_name(score_data.experiment_index) + ) + data_sources[data_source_index].fill_colors.append(fill_color) + + # Plot scatter + for legend_label, data_source in data_sources.items(): + sources = ColumnDataSource( + dict( + xs=data_source.xs, + ys=data_source.ys, + planners=data_source.planners, + experiments=data_source.experiments, + aggregators=data_source.aggregators, + fill_colors=data_source.fill_colors, + line_colors=data_source.fill_colors, + ) + ) + glyph_renderer = self.get_scatter_render_func( + scatter_sign=data_source.marker, scatter_figure=scatter_figure + ) + glyph_renderer(x="xs", y="ys", size=10, fill_color="fill_colors", line_color="fill_colors", source=sources) + + def _render_scenario_metric_score(self) -> Dict[str, Figure]: + """ + Render scenario metric score plot. + :return A dict of figure names and figures. + """ + if ( + not self._scalar_log_name_select.value + or not self._scalar_scenario_name_select.value + or not self._scenario_metric_score_data + ): + return {} + selected_scenario_metric_score: List[ScenarioMetricScoreData] = self._scenario_metric_score_data[ + self._scalar_log_name_select.value + ][self._scalar_scenario_name_select.value] + # Rearranged to {metric_statistic_namae: List[scenario_metric_score_data]} + data: Dict[str, List[ScenarioMetricScoreData]] = defaultdict(list) + for scenario_metric_score_data in selected_scenario_metric_score: + if scenario_metric_score_data.planner_name not in self.enable_planner_names: + continue + + # Rename final score from score to scenario_score + metric_statistic_name = scenario_metric_score_data.metric_statistic_name + data[metric_statistic_name].append(scenario_metric_score_data) + metric_statistic_names = sorted(list(set(data.keys()))) + # Make sure the final score of a scenario is the last element + if 'score' in metric_statistic_names: + metric_statistic_names.remove('score') + metric_statistic_names.append('score') + hover = HoverTool( + tooltips=[ + ("Metric", "@xs"), + ("Score", "@ys"), + ("Planner", "@planners"), + ("Experiment", "@experiments"), + ("Aggregator", "@aggregators"), + ] + ) + number_of_figures = ceil(len(metric_statistic_names) / self._number_metrics_per_figure) + + # Create figures based on the number of metrics per figure + scenario_metric_score_figures: Dict[str, Figure] = defaultdict() + for index in range(number_of_figures): + starting_index = index * self._number_metrics_per_figure + ending_index = starting_index + self._number_metrics_per_figure + selected_metric_names = metric_statistic_names[starting_index:ending_index] + scenario_metric_score_figure = self._render_scalar_figure( + title="", + y_axis_label="score", + hover=hover, + x_range=selected_metric_names, + sizes=self._scenario_metric_score_data_figure_sizes, + ) + + # Plot scatter on the figure + metric_score_data = {metric_name: data[metric_name] for metric_name in selected_metric_names} + self._render_scenario_metric_score_scatter( + scatter_figure=scenario_metric_score_figure, scenario_metric_score_data=metric_score_data + ) + scenario_metric_score_figures[str(index)] = scenario_metric_score_figure + return scenario_metric_score_figures + + def _render_grid_plot(self, figures: Dict[str, Figure], plot_width: int, legend: bool = True) -> LayoutDOM: + """ + Render a grid plot. + :param figures: A dict of figure names and figures. + :param plot_width: Width of each plot. + :param legend: If figures have legends. + :return A grid plot. + """ + figure_plot_list: List[Figure] = [] + for figure_name, figure_plot in figures.items(): + if legend: + figure_plot.legend.label_text_font_size = scenario_tab_style["plot_legend_label_text_font_size"] + figure_plot.legend.background_fill_alpha = 0.0 + figure_plot.legend.click_policy = "hide" + figure_plot_list.append(figure_plot) + + grid_plot = gridplot( + figure_plot_list, + ncols=self.get_plot_cols(plot_width=plot_width), + toolbar_location="left", + ) + return grid_plot + + def _render_scenario_metric_layout( + self, figure_data: Dict[str, Figure], default_div: Div, plot_width: int, legend: bool = True + ) -> column: + """ + Render a layout for scenario metric. + :param figure_data: A dict of figure_data. + :param default_div: Default message when there is no result. + :param plot_width: Figure width. + :param legend: If figures have legends. + :return A bokeh column layout. + """ + if not figure_data: + return column(default_div) + + grid_plot = self._render_grid_plot(figures=figure_data, plot_width=plot_width, legend=legend) + scenario_metric_layout = column(grid_plot) + return scenario_metric_layout + + def _render_simulations(self) -> column: + """ + Render simulation plot. + :return: A list of Bokeh columns or rows. + """ + selected_keys = [ + key + for key in self.experiment_file_data.simulation_scenario_keys + if key.scenario_type == self._scalar_scenario_type_select.value + and key.log_name == self._scalar_log_name_select.value + and key.scenario_name == self._scalar_scenario_name_select.value + and key.sledgeboard_file_index in self._experiment_file_active_index + ] + if not selected_keys: + self._scenario_title_div.text = "-" + simulation_layouts = column(self._default_simulation_div) + else: + hidden_glyph_names = [ + label + for checkbox_group in [self._object_checkbox_group, self._traj_checkbox_group, self._map_checkbox_group] + for index, label in enumerate(checkbox_group.labels) + if index not in checkbox_group.active + ] + self._simulation_figure_data = self.simulation_tile.render_simulation_tiles( + selected_scenario_keys=selected_keys, + figure_sizes=self.simulation_figure_sizes, + hidden_glyph_names=hidden_glyph_names, + ) + simulation_figures = [data.plot for data in self._simulation_figure_data] + simulation_layouts = gridplot( + simulation_figures, + ncols=self.get_plot_cols( + plot_width=self.simulation_figure_sizes[0], offset_width=scenario_tab_style['col_offset_width'] + ), + toolbar_location=None, + ) + self._scenario_title_div.text = ( + f"{self._scalar_scenario_type_select.value} - " + f"{self._scalar_log_name_select.value} - " + f"{self._scalar_scenario_name_select.value}" + ) + + return simulation_layouts + + @staticmethod + def _get_ego_expert_states(state_key: str, ego_state: EgoState) -> float: + """ + Get states based on the state key. + :param state_key: Ego state key. + :param ego_state: Ego state. + :return ego state based on the key. + """ + if state_key == 'x [m]': + return cast(float, ego_state.car_footprint.center.x) + elif state_key == 'y [m]': + return cast(float, ego_state.car_footprint.center.y) + elif state_key == 'velocity_x [m/s]': + return cast(float, ego_state.dynamic_car_state.rear_axle_velocity_2d.x) + elif state_key == 'velocity_y [m/s]': + return cast(float, ego_state.dynamic_car_state.rear_axle_velocity_2d.y) + elif state_key == 'speed [m/s]': + return cast(float, ego_state.dynamic_car_state.speed) + elif state_key == 'acceleration_x [m/s^2]': + return cast(float, ego_state.dynamic_car_state.rear_axle_acceleration_2d.x) + elif state_key == 'acceleration_y [m/s^2]': + return cast(float, ego_state.dynamic_car_state.rear_axle_acceleration_2d.y) + elif state_key == 'acceleration [m/s^2]': + return cast(float, ego_state.dynamic_car_state.acceleration) + elif state_key == 'heading [rad]': + return cast(float, ego_state.car_footprint.center.heading) + elif state_key == 'steering_angle [rad]': + return cast(float, ego_state.dynamic_car_state.tire_steering_rate) + elif state_key == 'yaw_rate [rad/s]': + return cast(float, ego_state.dynamic_car_state.angular_velocity) + else: + raise ValueError(f"{state_key} not available!") + + def _render_ego_expert_state_glyph( + self, + ego_expert_plot_aggregated_states: scenario_ego_expert_state_figure_type, + ego_expert_plot_colors: Dict[str, str], + ) -> column: + """ + Render line and circle glyphs on ego_expert_state figures and get a grid plot. + :param ego_expert_plot_aggregated_states: Aggregated ego and expert states over frames. + :param ego_expert_plot_colors: Colors for different planners. + :return Column layout for ego and expert states. + """ + # Render figures with the state keys + ego_expert_state_figures: Dict[str, Figure] = defaultdict() + for plot_state_key in self.plot_state_keys: + hover = HoverTool( + tooltips=[ + ("Frame", "@x"), + ("Value", "@y{0.0000}"), + ("Planner", "$name"), + ] + ) + ego_expert_state_figure = self._render_scalar_figure( + title='', + y_axis_label=plot_state_key, + x_axis_label='frame', + hover=hover, + sizes=scenario_tab_style["ego_expert_state_figure_sizes"], + ) + # Disable scientific notation + ego_expert_state_figure.yaxis.formatter = BasicTickFormatter(use_scientific=False) + ego_expert_state_figures[plot_state_key] = ego_expert_state_figure + for planner_name, plot_states in ego_expert_plot_aggregated_states.items(): + color = ego_expert_plot_colors.get(planner_name, None) + if not color: + color = None + for plot_state_key, plot_state_values in plot_states.items(): + ego_expert_state_figure = ego_expert_state_figures[plot_state_key] + data_source = ColumnDataSource( + dict( + x=list(range(len(plot_state_values))), + y=np.round(plot_state_values, 2), + ) + ) + if self.expert_planner_key in planner_name: + ego_expert_state_figure.circle( + x="x", + y="y", + name=planner_name, + color=color, + legend_label=planner_name, + source=data_source, + size=2, + ) + else: + ego_expert_state_figure.line( + x="x", + y="y", + name=planner_name, + color=color, + legend_label=planner_name, + source=data_source, + line_width=1, + ) + + # Make layout horizontally + ego_expert_states_layout = self._render_grid_plot( + figures=ego_expert_state_figures, + plot_width=scenario_tab_style["ego_expert_state_figure_sizes"][0], + legend=True, + ) + return ego_expert_states_layout + + def _get_ego_expert_plot_color(self, planner_name: str, file_path_index: int, figure_planer_name: str) -> str: + """ + Get color for ego expert plot states based on the planner name. + :param planner_name: Plot planner name. + :param file_path_index: File path index for the plot. + :param figure_planer_name: Figure original planner name. + """ + return cast( + str, + self.experiment_file_data.expert_color_palettes[file_path_index] + if self.expert_planner_key in planner_name + else self.experiment_file_data.file_path_colors[file_path_index][figure_planer_name], + ) + + def _render_ego_expert_states(self, simulation_figure_data: List[SimulationData]) -> column: + """ + Render expert and ego time series states. Make sure it is called after _render_simulation. + :param simulation_figure_data: Simulation figure data after rendering simulation. + :return Column layout for ego and expert states. + """ + if not simulation_figure_data: + return column(self._default_ego_expert_states_div) + + # Aggregate data, {planner_name: {state_key: A list of values for the state}} + ego_expert_plot_aggregated_states: scenario_ego_expert_state_figure_type = defaultdict( + lambda: defaultdict(list) + ) + ego_expert_plot_colors: Dict[str, str] = defaultdict() + for figure_data in simulation_figure_data: + experiment_file_index = figure_data.simulation_figure.file_path_index + experiment_name = self.get_file_path_last_name(experiment_file_index) + expert_planner_name = f'{self.expert_planner_key} - ({experiment_name})' + ego_planner_name = f'{figure_data.planner_name} - ({experiment_name})' + ego_expert_states = { + expert_planner_name: figure_data.simulation_figure.scenario.get_expert_ego_trajectory(), + ego_planner_name: figure_data.simulation_figure.simulation_history.extract_ego_state, + } + for planner_name, planner_states in ego_expert_states.items(): + # Get expert color + ego_expert_plot_colors[planner_name] = self._get_ego_expert_plot_color( + planner_name=planner_name, + figure_planer_name=figure_data.planner_name, + file_path_index=figure_data.simulation_figure.file_path_index, + ) + if planner_name in ego_expert_plot_aggregated_states: + continue + for planner_state in planner_states: + for plot_state_key in self.plot_state_keys: + state_key_value = self._get_ego_expert_states(state_key=plot_state_key, ego_state=planner_state) + ego_expert_plot_aggregated_states[planner_name][plot_state_key].append(state_key_value) + + ego_expert_states_layout = self._render_ego_expert_state_glyph( + ego_expert_plot_aggregated_states=ego_expert_plot_aggregated_states, + ego_expert_plot_colors=ego_expert_plot_colors, + ) + return ego_expert_states_layout diff --git a/sledge/sledgeboard/templates/index.html b/sledge/sledgeboard/templates/index.html new file mode 100644 index 0000000..95a1677 --- /dev/null +++ b/sledge/sledgeboard/templates/index.html @@ -0,0 +1,79 @@ + + + {% extends base %} {% block head %} + + + {% block inner_head %} + + + + + + + + + + + + +