Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add infrastructure and initial set of tests for ADQ component #215

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/extra/components/adq.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(self, data, channel, digitizer=None, pulses=None,
raise ValueError('channel expected to be 2 or 3 characters, '
'e.g. 1A or 1_A')

self._channel_number = ord(self._channel_letter) - ord('A')
self._channel_number = ord(self._channel_letter) - ord('A') + 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we confident no-one is already using channel numbers as they are now, counting from zero?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the numbers and letters are enumated at initialization time by the AdqDigitizer device and cannot be configured. They're not even (always) in the physical order in the crate.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, perhaps I'm misunderstanding. At the moment, if I create adq = AdqDigitizer(run, '1C'), then adq.number will give me 2 (A=0, B=1, C=2), yes? With this change, it becomes 3. Are you confident that no-one has started using those numbers as they are?

(We can still change it even if someone might be using it - I just want to think about if we need to announce the change somehow)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies, I understand now. So the earlier behaviour was a bug, adq.number was always meant to represent the board number in the sense it was used in CONTROL parameters (yes, someone funny used letters for fast data and numbers for slow data). As a matter of fact, .channel_parameters would've given you the wrong values as of now.

I highly doubt this parameter was used already directly, but if you feel strongly about it I can add it as a bugfix.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, thanks. Let's mention it in the changelog, though since we install master every day it doesn't make much difference.

self._channel_name = f'{self._channel_board}_{self._channel_letter}'

key = f'digitizers.channel_{self._channel_name}.raw.samples'
Expand Down Expand Up @@ -232,6 +232,10 @@ def _correct_cm_by_train(signal, out, period, baseline, baselevel=None):
if baselevel is not None:
baseline = baseline - baselevel

# Make sure the dtypes match, otherwise baseline is likely
# going to be float64 and not castable via `safe`.
baseline = baseline.astype(out.dtype, copy=False)

for offset in range(period):
sel = np.s_[offset::period]
np.subtract(
Expand Down Expand Up @@ -490,7 +494,7 @@ def name(self):
def channel_key(self, suffix):
"""Instrument KeyData object of this channel."""
return self._instrument_src[f'digitizers.channel_{self._channel_name}'
f'_{suffix}']
f'.{suffix}']

@property
def raw_samples_key(self):
Expand Down Expand Up @@ -707,7 +711,8 @@ def correct_common_mode(self, data, cm_period, baseline, baselevel=None):
baseline to 0.

Returns:
out (numpy.ndarray): Corrected input data.
out (numpy.ndarray): Corrected input data, same dtype as
input data if floating otherwise `float32`.
"""

if cm_period < 1:
Expand All @@ -716,7 +721,8 @@ def correct_common_mode(self, data, cm_period, baseline, baselevel=None):
if not isinstance(data, np.ndarray):
data = np.asarray(data)

out = np.zeros_like(data, dtype=np.float32)
out = np.zeros_like(data, dtype=data.dtype \
if np.issubdtype(data.dtype, np.floating) else np.float32)
self._correct_cm_by_train(data, out, cm_period, baseline, baselevel)

return out
Expand Down
11 changes: 8 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from extra_data.tests.mockdata import write_file

from extra_data.tests.mockdata.motor import Motor
from .mockdata.detector_motors import (
DetectorMotorDataSelector, get_motor_sources, write_motor_positions)
from .mockdata.adq import AdqDigitizer
from .mockdata.detector_motors import (DetectorMotorDataSelector,
get_motor_sources,
write_motor_positions)
from .mockdata.dld import ReconstructedDld
from .mockdata.timeserver import PulsePatternDecoder, Timeserver
from .mockdata.xgm import XGM, XGMD, XGMReduced
Expand Down Expand Up @@ -98,7 +100,10 @@ def mock_sqs_remi_directory():
XGM('SA3_XTD10_XGM/XGM/DOOCS'),
ReconstructedDld('SQS_REMI_DLD6/DET/TOP'),
ReconstructedDld('SQS_REMI_DLD6/DET/BOTTOM'),
Motor('SQS_ILH_LAS/MOTOR/DELAY_AX_800')]
Motor('SQS_ILH_LAS/MOTOR/DELAY_AX_800'),
AdqDigitizer('SQS_DIGITIZER_UTC1/ADC/1', channels_per_board=2 * [4]),
AdqDigitizer('SQS_DIGITIZER_UTC2/ADC/1', channels_per_board=4 * [4],
data_channels={(0, 0), (2, 1)})]

with TemporaryDirectory() as td:
write_file(Path(td) / 'RAW-R0001-DA01-S00000.h5', sources, 100)
Expand Down
82 changes: 82 additions & 0 deletions tests/mockdata/adq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@

import numpy as np

from extra_data.tests.mockdata.base import DeviceBase

from extra.utils import gaussian


class AdqDigitizer(DeviceBase):
# still needed?
control_keys = [
('fakeScalar', 'u4', ()),
]
takluyver marked this conversation as resolved.
Show resolved Hide resolved

output_channels = ('network/digitizers',)

instrument_keys = [
('channel_{board}_{ch_letter}/raw/length', 'u4', ()),
('channel_{board}_{ch_letter}/raw/position', 'u4', ()),
('channel_{board}_{ch_letter}/raw/triggerId', 'u8', ()),
('channel_{board}_{ch_letter}/raw/samples', 'i2', (50000,))
]

extra_run_values = [
('classId', None, 'AdqDigitizer'),
('board{board}/enable', None, True),
('board{board}/interleavedMode', None, False),
('board{board}/enable_raw', None, True),
('board{board}/channel_{ch_number}/offset', None, 0),
('board{board}/channel_{ch_number}/enable', None, True),
]

def __init__(self, *args, channels_per_board, data_channels={}, **kwargs):
self.channel_labels = []
self.data_channels = []

# These are dicts for now to have no duplicate keys, their
# values are turned into lists afterwards.
instrument_keys = {}
extra_run_values = {}

cls = self.__class__
format_fields = {}
for board_idx, num_channels in enumerate(channels_per_board):
format_fields['board'] = board_idx + 1

for ch_idx in range(num_channels):
format_fields['ch_number'] = ch_idx + 1
format_fields['ch_letter'] = chr(ord('A') + ch_idx)
self.channel_labels.append(
'{board}_{ch_letter}'.format(**format_fields))

for key, dtype, shape in cls.instrument_keys:
full_key = key.format(**format_fields)
instrument_keys[full_key] = (full_key, dtype, shape)

for key, dtype, value in cls.extra_run_values:
full_key = key.format(**format_fields)
extra_run_values[full_key] = (full_key, dtype, value)

if (board_idx, ch_idx) in data_channels:
self.data_channels.append(self.channel_labels[-1])

self.instrument_keys = list(instrument_keys.values())
self.extra_run_values = list(extra_run_values.values())

super().__init__(*args, **kwargs)

def write_instrument(self, f):
super().write_instrument(f)

root_grp = f[f'INSTRUMENT/{self.device_id}:network/digitizers']
x = np.arange(50000)

for i, ch_label in enumerate(self.channel_labels):
if ch_label not in self.data_channels:
continue

# Add a channel-dependent baseline shift and gaussian to
# each channel.
root_grp[f'channel_{ch_label}/raw/samples'][:] += -10 * (i+1) \
+ gaussian(x, 0, i*80, i*1000, 50).astype(np.int16)
176 changes: 176 additions & 0 deletions tests/test_components_adq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@

from itertools import product

import pytest
import numpy as np

from extra.components import AdqRawChannel, XrayPulses

from .mockdata import assert_equal_sourcedata, assert_equal_keydata


@pytest.mark.parametrize('channel', ['1_C', '1C'])
def test_adq_init(mock_sqs_remi_run, channel):
ctrl_sd = mock_sqs_remi_run['SQS_DIGITIZER_UTC2/ADC/1']
instr_sd = mock_sqs_remi_run['SQS_DIGITIZER_UTC2/ADC/1:network']

# Run contains UTC1 and UTC2, first test regular operation with
# specific digitizer.
ch = AdqRawChannel(
mock_sqs_remi_run, channel, digitizer='SQS_DIGITIZER_UTC2')
assert_equal_sourcedata(ch.control_source, ctrl_sd)
assert_equal_sourcedata(ch.instrument_source, instr_sd)
assert_equal_keydata(ch.raw_samples_key,
instr_sd['digitizers.channel_1_C.raw.samples'])
assert_equal_keydata(ch.channel_key('raw.samples'),
instr_sd['digitizers.channel_1_C.raw.samples'])
assert ch._pulses.sase == 3

# Full auto-detection should fail on the entire run with two
# digitizers.
with pytest.raises(ValueError):
ch = AdqRawChannel(mock_sqs_remi_run, channel)

# Full auto-detection should work after selection.
subrun = mock_sqs_remi_run.deselect('SQS_DIGITIZER_UTC1*')
ch = AdqRawChannel(subrun, channel)
assert_equal_sourcedata(
ch.instrument_source, subrun['SQS_DIGITIZER_UTC2/ADC/1:network'])

# Can also pass explicit (instrument!) source name.
ch = AdqRawChannel(mock_sqs_remi_run, channel,
digitizer='SQS_DIGITIZER_UTC2/ADC/1:network')
assert_equal_sourcedata(
ch.instrument_source, subrun['SQS_DIGITIZER_UTC2/ADC/1:network'])

# Remove control data.
ctrlless_run = subrun.deselect('SQS_DIGITIZER_UTC2/ADC/1')

# Will fail creation without further keywords.
with pytest.raises(ValueError):
AdqRawChannel(ctrlless_run, channel)

# But works with passing interleaving flag explicitly.
ch = AdqRawChannel(ctrlless_run, channel, interleaved=False)
assert_equal_sourcedata(
ch.instrument_source, ctrlless_run['SQS_DIGITIZER_UTC2/ADC/1:network'])

# Remove timeserver information
timeless_run = subrun.deselect('SQS_RR_UTC/*')

# Simply creating a channel should fail now.
with pytest.raises(ValueError):
AdqRawChannel(timeless_run, channel)

# But it still works with pulse information disabled.
ch = AdqRawChannel(timeless_run, channel, pulses=False)
assert_equal_sourcedata(
ch.instrument_source, timeless_run['SQS_DIGITIZER_UTC2/ADC/1:network'])


def test_adq_properties(mock_sqs_remi_run):
ch = AdqRawChannel(mock_sqs_remi_run, '1C', digitizer='SQS_DIGITIZER_UTC1')
assert ch.board == 1
assert ch.letter == 'C'
assert ch.number == 3
assert ch.name == '1_C'

assert not ch.interleaved
assert ch.clock_ratio == 440
assert np.isclose(ch.sampling_rate, 2.0e9, rtol=1e-2)
assert np.isclose(ch.sampling_period, 0.5e-9)
assert ch.trace_shape == 50000
assert np.isclose(ch.trace_duration, 25e-6, rtol=1e-2)

assert ch.board_parameters == {
'enable': True, 'enable_raw': True, 'interleavedMode': False}
assert ch.channel_parameters == {
'enable': True, 'offset': 0, 'enable_raw': True,
'interleavedMode': False}


# Overwrite interleaving.
ch = AdqRawChannel(mock_sqs_remi_run, '1C', digitizer='SQS_DIGITIZER_UTC1',
interleaved=True)
assert ch.clock_ratio == 880
assert np.isclose(ch.sampling_rate, 4.0e9, rtol=1e-2)
assert np.isclose(ch.sampling_period, 0.25e-9)
assert np.isclose(ch.trace_duration, 12.5e-6, rtol=1e-2)

# One of the special 3G boards.
ch = AdqRawChannel(mock_sqs_remi_run, '1C', digitizer='SQS_DIGITIZER_UTC2')
assert ch.clock_ratio == 392
assert np.isclose(ch.sampling_rate, 1.76e9, rtol=1e-2)
assert np.isclose(ch.sampling_period, 0.57e-9)
assert np.isclose(ch.trace_duration, 28e-6, rtol=1e-2)


def test_adq_samples_per_pulse(mock_sqs_remi_run):
# Skip early trains with no or too many pulses.
run = mock_sqs_remi_run.select_trains(np.s_[50:])

# Use SAS for pulse information, since SA3 has only a single one.
ch = AdqRawChannel(run, '1C', digitizer='SQS_DIGITIZER_UTC1',
pulses=XrayPulses(run, sase=1))

assert ch.samples_per_pulse() == 5280
assert ch.samples_per_pulse(pulse_period=14) == 6160
assert ch.samples_per_pulse(pulse_duration=3.1e-6) == 6160
assert ch.samples_per_pulse(repetition_rate=320e3) == 6160
assert ch.samples_per_pulse(pulse_ids=np.array([1000, 1014, 1028])) == 6160

# These can give different results with fractional enabled.
assert np.isclose(
ch.samples_per_pulse(pulse_duration=3.1e-6, fractional=True), 6156.94)
assert np.isclose(
ch.samples_per_pulse(repetition_rate=320e3, fractional=True), 6206.6)


@pytest.mark.parametrize('in_dtype', [np.int16, np.float32, np.float64])
def test_adq_correct_common_mode(mock_sqs_remi_run, in_dtype):
ch = AdqRawChannel(mock_sqs_remi_run, '1C', digitizer='SQS_DIGITIZER_UTC1')
expected_dtype = in_dtype if np.issubdtype(in_dtype, np.floating) \
else np.float32

# Construct a trace with extreme common mode and tile it.
traces = np.zeros((2, 3, 500), dtype=in_dtype)

for offset in range(5):
traces[:, :, offset::5] = offset

# Default baselevel (i.e. 0).
out = ch.correct_common_mode(traces, 5, np.s_[:50])
assert np.allclose(out, 0.0)
assert out.shape == traces.shape
assert out.dtype == expected_dtype

# Custom baselevel.
out = ch.correct_common_mode(traces, 5, np.s_[:50], 17.14)
assert np.allclose(out, 17.14)
assert out.shape == traces.shape
assert out.dtype == expected_dtype

# Also test the raveled array.
out = ch.correct_common_mode(traces.ravel(), 5, np.s_[:50])
assert np.allclose(out, 0.0)
assert out.shape == (traces.size,)
assert out.dtype == expected_dtype


@pytest.mark.parametrize('in_dtype', [np.int16, np.float32, np.float64])
def test_pull_baseline(mock_sqs_remi_run, in_dtype):
ch = AdqRawChannel(mock_sqs_remi_run, '1C', digitizer='SQS_DIGITIZER_UTC1')
expected_dtype = in_dtype if np.issubdtype(in_dtype, np.floating) \
else np.float32

# Test a single trace first.
single_trace = np.arange(100, dtype=in_dtype)
out = ch.pull_baseline(single_trace, np.s_[:50], 0)
np.testing.assert_allclose(out, single_trace - 24.5)

# Now tile it multiple times.
traces = np.tile(single_trace, [2, 3, 1])
out = ch.pull_baseline(traces, np.s_[:50], 0)

for i, j in product(range(2), range(3)):
np.testing.assert_allclose(out[i, j], single_trace - 24.5)