Skip to content

Commit

Permalink
plots.xy_1d: adjustable num_samples_per_point
Browse files Browse the repository at this point in the history
Added QSpinBox to context menu so that user can get more accurate error
bars when averaging.
  • Loading branch information
pmldrmota committed May 8, 2024
1 parent 9b15fde commit 3760f42
Showing 1 changed file with 37 additions and 8 deletions.
45 changes: 37 additions & 8 deletions ndscan/plots/xy_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict
from typing import NamedTuple

from .._qt import QtCore
from .._qt import QtCore, QtWidgets
from .annotation_items import ComputedCurveItem, CurveItem, VLineItem
from .cursor import CrosshairAxisLabel, LabeledCrosshairCursor
from .model import ScanModel
Expand Down Expand Up @@ -69,12 +69,14 @@ def __init__(self, view_box, data_name, data_item, error_bar_name, error_bar_ite

#: Whether to average points with the same x coordinate.
self.averaging_enabled = False
#: Assumed number of samples per point for calculating the combined uncertainty.
self.num_samples_per_point = 1

#: Keeps track of source points for each x coordinate for faster updates while
#: averaging is enabled.
self.source_points_by_x = defaultdict[float, list[SourcePoint]](list)

def update(self, x_data, data, averaging_enabled):
def update(self, x_data, data, averaging_enabled, num_samples_per_point):
def channel(name):
return np.array(data.get("channel_" + name, []))

Expand All @@ -90,13 +92,14 @@ def channel(name):

# If nothing has changed, skip the update.
if (num_to_show == self.num_current_points
and averaging_enabled == self.averaging_enabled):
and averaging_enabled == self.averaging_enabled
and num_samples_per_point == self.num_samples_per_point):
return

# Combine points with same coordinates if enabled.
if averaging_enabled:
x_data, y_data, y_err, source_idxs = self._average_add_points(
num_to_show, x_data, y_data, y_err)
num_to_show, x_data, y_data, y_err, num_samples_per_point)
else:
x_data = x_data[:num_to_show]
y_data = y_data[:num_to_show]
Expand All @@ -121,9 +124,11 @@ def channel(name):
self.view_box.addItem(self.error_bar_item)

self.averaging_enabled = averaging_enabled
self.num_samples_per_point = num_samples_per_point
self.num_current_points = num_to_show

def _average_add_points(self, num_to_show, x_data, y_data, y_err):
def _average_add_points(self, num_to_show, x_data, y_data, y_err,
num_samples_per_point):
# Append new data to collection.
start_idx = sum(len(v) for v in self.source_points_by_x.values())
for i in range(start_idx, num_to_show):
Expand All @@ -140,8 +145,10 @@ def _average_add_points(self, num_to_show, x_data, y_data, y_err):
# points -- see ``combined_uncertainty()``.
y_data = np.array(
[np.nanmean([p.y for p in self.source_points_by_x[x]]) for x in x_data])
y_err = np.array(
[combined_uncertainty(self.source_points_by_x[x]) for x in x_data])
y_err = np.array([
combined_uncertainty(self.source_points_by_x[x], num_samples_per_point)
for x in x_data
])

# We can only ascribe a single source index to the data if there wasn't any
# actual averaging.
Expand Down Expand Up @@ -180,6 +187,7 @@ def __init__(self, model: ScanModel, get_alternate_plot_names):
self.unique_x_data = set()
self.found_duplicate_x_data = False
self.averaging_enabled = False
self.num_samples_per_point = 1

self.x_schema = self.model.axes[0]
self.x_param_spec = self.x_schema["param"]["spec"]
Expand Down Expand Up @@ -303,7 +311,7 @@ def _update_points(self, points):
if self.x_schema["param"]["type"] == "enum":
x_data = enum_to_numeric(self.x_param_spec["members"].keys(), x_data)
for s in self.series:
s.update(x_data, points, self.averaging_enabled)
s.update(x_data, points, self.averaging_enabled, self.num_samples_per_point)

def _clear_annotations(self):
for item in self.annotation_items:
Expand Down Expand Up @@ -412,6 +420,23 @@ def build_context_menu(self, pane_idx, builder):
action.setChecked(self.averaging_enabled)
action.triggered.connect(
lambda *a: self.enable_averaging(not self.averaging_enabled))

if self.averaging_enabled:
num_samples_box = QtWidgets.QSpinBox()
num_samples_box.setMinimum(1)
num_samples_box.setMaximum(2**16)
num_samples_box.setValue(self.num_samples_per_point)
num_samples_box.valueChanged.connect(self.change_num_samples_per_point)
container = QtWidgets.QWidget()
layout = QtWidgets.QHBoxLayout()
container.setLayout(layout)
label = QtWidgets.QLabel("Samples per point:")
layout.addWidget(label)
layout.addWidget(num_samples_box)
layout.insertStretch(0)
action = builder.append_widget_action()
action.setDefaultWidget(container)

builder.ensure_separator()

if len(self.data_names) > 1:
Expand All @@ -426,6 +451,10 @@ def enable_averaging(self, enabled: bool):
self.averaging_enabled = enabled
self._update_points(self.model.get_point_data())

def change_num_samples_per_point(self, num_samples_per_point: int):
self.num_samples_per_point = num_samples_per_point
self._update_points(self.model.get_point_data())

def _set_dataset_from_crosshair_x(self, pane_idx, dataset_key):
if not self.crosshairs:
logger.warning("Plot not initialised yet, ignoring set dataset request")
Expand Down

0 comments on commit 3760f42

Please sign in to comment.