diff --git a/sksurgerynditracker/nditracker.py b/sksurgerynditracker/nditracker.py index c7f3eae..8f9f490 100644 --- a/sksurgerynditracker/nditracker.py +++ b/sksurgerynditracker/nditracker.py @@ -116,6 +116,8 @@ def __init__(self, configuration): ports to probe: + use quaternions: + :raises Exception: IOError, KeyError, OSError """ self._device = None diff --git a/tests/polaris_mocks.py b/tests/polaris_mocks.py index 4c007f6..a92a04f 100644 --- a/tests/polaris_mocks.py +++ b/tests/polaris_mocks.py @@ -2,6 +2,7 @@ """scikit-surgerynditracker mocks for polaris""" +from numpy import array, concatenate import ndicapy SETTINGS_POLARIS = { @@ -12,6 +13,15 @@ "data/8700339.rom"] } +SETTINGS_POLARIS_QUAT = { + "tracker type": "polaris", + "ports to probe": 20, + "romfiles" : [ + "data/something_else.rom", + "data/8700339.rom"], + "use quaternions": "true" + } + class MockPort: """A fake serial port for ndi""" device = 'bad port' @@ -44,14 +54,6 @@ def mockComports(): #pylint:disable=invalid-name mock_ports[5].device = 'good port' return mock_ports -def mockndiGetPHSRNumberOfHandles(_device): #pylint:disable=invalid-name - """Mock of ndiGetPHSRNumberOfHandles""" - return 4 - -def mockndiGetPHRQHandle(_device): #pylint:disable=invalid-name - """Mock of ndiGetPHRQHandle""" - return int(0) - def mockndiGetPHSRHandle(_device, index): #pylint:disable=invalid-name """Mock of ndiGetPHSRHandle""" return int(index) @@ -60,15 +62,84 @@ def mockndiVER(_device, _other_arg): #pylint:disable=invalid-name """Mock of ndiVER""" return 'Mock for Testing' -def mockndiGetBXFrame(_device, _port_handle): #pylint:disable=invalid-name - """Mock of ndiGetBXFrame""" - bx_frame_count = 0 - return bx_frame_count - -def mockndiGetBXTransform(_device, _port_handle): #pylint:disable=invalid-name - """Mock of ndiGetBXTransform""" - return [0,0,0,0,0,0,0,0] - -def mockndiGetBXTransformMissing(_device, _port_handle): #pylint:disable=invalid-name - """Mock of ndiGetBXTransform""" - return "MISSING" +class MockNDIDevice(): + """ + A mock NDI device, enables us to keep track of how many tools we've added + """ + def __init__(self): + self.attached_tools = 0 + + def mockndiCommand(self, _device, command): #pylint:disable=invalid-name + """Mock a general command, strings over serial""" + if command == "PHRQ:*********1****": + self.attached_tools += 1 + + def mockndiGetPHRQHandle(self, _device): #pylint:disable=invalid-name + """Mock of ndiGetPHRQHandle""" + return int(self.attached_tools - 1) + + def mockndiGetPHSRNumberOfHandles(self, _device): #pylint:disable=invalid-name + """Mock of ndiGetPHSRNumberOfHandles""" + return self.attached_tools + + +class MockBXFrameSource(): + """ + A class to handle mocking of calls to get + frame. Enables us to increment the frame number and return + changing values + """ + def __init__(self): + self.bx_frame_count = 0 + self.bx_call_count = 0 + self.rotation = array([1., 0., 0., 0.]) + self.position = array([0., 0., 0.]) + self.velocity = array([10., -20., 5.]) + self.quality = array([1.]) + self.tracked_tools = 0 + + def setdevice(self, ndidevice): + """ + We set an ndidevice so we know how many tracked objects to + return + """ + assert isinstance (ndidevice, MockNDIDevice) + self.tracked_tools = ndidevice.attached_tools + + def mockndiGetBXFrame(self, _device, port_handle): #pylint:disable=invalid-name + """Mock of ndiGetBXFrame""" + self.bx_call_count += 1 + ph_int = int.from_bytes(port_handle, byteorder = 'little') + if ph_int == 0: + self.bx_frame_count += 1 + + assert ph_int < self.tracked_tools + if ph_int == self.tracked_tools - 1: + assert self.bx_call_count == self.bx_frame_count * \ + self.tracked_tools + + return self.bx_frame_count + + def mockndiGetBXTransform(self, _device, port_handle): #pylint:disable=invalid-name + """ + Mock of ndiGetBXTransform. To enable a simple test of tracking + smoothing translate the mock object between frames. Full testing of + the averaging code is in the base class + sksurgerycore.tests.algorithms + The base ndicapi library uses Py_BuildValue to return the transform + as a tuple of double float values, so we also + return a tuple + """ + assert self.bx_frame_count > 0 + ph_int = int.from_bytes(port_handle, byteorder = 'little') + if ph_int == 0: + self.position = self.velocity * self.bx_frame_count + return tuple(concatenate((self.rotation, self.position, + self.quality))) + + return tuple(concatenate((self.rotation, array([0, 0, 0]), + self.quality))) + + def mockndiGetBXTransformMissing(self, _device, _port_handle): #pylint:disable=invalid-name + """Mock of ndiGetBXTransform""" + return "MISSING" diff --git a/tests/test_sksurgerynditracker_mockndi_getframe.py b/tests/test_sksurgerynditracker_mockndi_getframe.py index 950c22f..3814057 100644 --- a/tests/test_sksurgerynditracker_mockndi_getframe.py +++ b/tests/test_sksurgerynditracker_mockndi_getframe.py @@ -1,13 +1,15 @@ # coding=utf-8 """scikit-surgerynditracker tests using a mocked ndicapy""" + +import numpy as np from sksurgerynditracker.nditracker import NDITracker -from tests.polaris_mocks import SETTINGS_POLARIS, mockndiProbe, \ +from tests.polaris_mocks import SETTINGS_POLARIS, SETTINGS_POLARIS_QUAT, \ + mockndiProbe, \ mockndiOpen, mockndiGetError, mockComports, \ - mockndiGetPHSRNumberOfHandles, mockndiGetPHRQHandle, \ - mockndiGetPHSRHandle, mockndiVER, mockndiGetBXFrame, \ - mockndiGetBXTransform, mockndiGetBXTransformMissing + mockndiGetPHSRHandle, mockndiVER, \ + MockNDIDevice, MockBXFrameSource def test_getframe_polaris_mock(mocker): """ @@ -15,23 +17,109 @@ def test_getframe_polaris_mock(mocker): reqs: 03, 04 """ tracker = None + bxsource = MockBXFrameSource() + ndidevice = MockNDIDevice() mocker.patch('serial.tools.list_ports.comports', mockComports) mocker.patch('ndicapy.ndiProbe', mockndiProbe) mocker.patch('ndicapy.ndiOpen', mockndiOpen) - mocker.patch('ndicapy.ndiCommand') + mocker.patch('ndicapy.ndiCommand', ndidevice.mockndiCommand) mocker.patch('ndicapy.ndiGetError', mockndiGetError) mocker.patch('ndicapy.ndiClose') mocker.patch('ndicapy.ndiGetPHSRNumberOfHandles', - mockndiGetPHSRNumberOfHandles) - mocker.patch('ndicapy.ndiGetPHRQHandle', mockndiGetPHRQHandle) + ndidevice.mockndiGetPHSRNumberOfHandles) + mocker.patch('ndicapy.ndiGetPHRQHandle', ndidevice.mockndiGetPHRQHandle) mocker.patch('ndicapy.ndiPVWRFromFile') mocker.patch('ndicapy.ndiGetPHSRHandle', mockndiGetPHSRHandle) mocker.patch('ndicapy.ndiVER', mockndiVER) - mocker.patch('ndicapy.ndiGetBXFrame', mockndiGetBXFrame) - mocker.patch('ndicapy.ndiGetBXTransform', mockndiGetBXTransform) + mocker.patch('ndicapy.ndiGetBXFrame', bxsource.mockndiGetBXFrame) + mocker.patch('ndicapy.ndiGetBXTransform', bxsource.mockndiGetBXTransform) tracker = NDITracker(SETTINGS_POLARIS) - tracker.get_frame() + + bxsource.setdevice(ndidevice) + + (port_handles, time_stamps, frame_numbers, tracking, + tracking_quality ) = tracker.get_frame() + + assert len(port_handles) == 2 + assert len(time_stamps) == 2 + assert frame_numbers.count(1) == 2 + expected_tracking_0 = np.array([[1.,0.,0.,10.], + [0.,1.,0.,-20.], + [0.,0.,1.,5.], + [0.,0.,0.,1.]]) + assert np.array_equal(expected_tracking_0, tracking[0]) + expected_tracking_1 = np.array([[1.,0.,0.,0.], + [0.,1.,0.,0.], + [0.,0.,1.,0.], + [0.,0.,0.,1.]]) + assert np.array_equal(expected_tracking_1, tracking[1]) + assert tracking_quality.count(1.) == 2 + + (port_handles, time_stamps, frame_numbers, tracking, + tracking_quality ) = tracker.get_frame() + + assert len(port_handles) == 2 + assert len(time_stamps) == 2 + assert frame_numbers.count(2) == 2 + expected_tracking_0 = np.array([[1.,0.,0.,20.], + [0.,1.,0.,-40.], + [0.,0.,1.,10.], + [0.,0.,0.,1.]]) + assert np.array_equal(expected_tracking_0, tracking[0]) + assert np.array_equal(expected_tracking_1, tracking[1]) + assert tracking_quality.count(1.) == 2 + + del tracker + +def test_getframe_polaris_mock_quat(mocker): + """ + Checks that get frame works with quaternions + """ + tracker = None + bxsource = MockBXFrameSource() + ndidevice = MockNDIDevice() + mocker.patch('serial.tools.list_ports.comports', mockComports) + mocker.patch('ndicapy.ndiProbe', mockndiProbe) + mocker.patch('ndicapy.ndiOpen', mockndiOpen) + mocker.patch('ndicapy.ndiCommand', ndidevice.mockndiCommand) + mocker.patch('ndicapy.ndiGetError', mockndiGetError) + mocker.patch('ndicapy.ndiClose') + mocker.patch('ndicapy.ndiGetPHSRNumberOfHandles', + ndidevice.mockndiGetPHSRNumberOfHandles) + mocker.patch('ndicapy.ndiGetPHRQHandle', ndidevice.mockndiGetPHRQHandle) + mocker.patch('ndicapy.ndiPVWRFromFile') + mocker.patch('ndicapy.ndiGetPHSRHandle', mockndiGetPHSRHandle) + mocker.patch('ndicapy.ndiVER', mockndiVER) + mocker.patch('ndicapy.ndiGetBXFrame', bxsource.mockndiGetBXFrame) + mocker.patch('ndicapy.ndiGetBXTransform', bxsource.mockndiGetBXTransform) + + tracker = NDITracker(SETTINGS_POLARIS_QUAT) + + bxsource.setdevice(ndidevice) + + (port_handles, time_stamps, frame_numbers, tracking, + tracking_quality ) = tracker.get_frame() + + assert len(port_handles) == 2 + assert len(time_stamps) == 2 + assert frame_numbers.count(1) == 2 + expected_tracking_0 = np.array([[1.,0.,0.,0.,10.,-20,5.]]) + assert np.array_equal(expected_tracking_0, tracking[0]) + expected_tracking_1 = np.array([[1.,0.,0.,0.,0.,0.,0.]]) + assert np.array_equal(expected_tracking_1, tracking[1]) + assert tracking_quality.count(1.) == 2 + + (port_handles, time_stamps, frame_numbers, tracking, + tracking_quality ) = tracker.get_frame() + + assert len(port_handles) == 2 + assert len(time_stamps) == 2 + assert frame_numbers.count(2) == 2 + expected_tracking_0 = np.array([[1.,0.,0.,0.,20.,-40,10.]]) + assert np.array_equal(expected_tracking_0, tracking[0]) + assert np.array_equal(expected_tracking_1, tracking[1]) + assert tracking_quality.count(1.) == 2 del tracker @@ -41,22 +129,34 @@ def test_getframe_missing(mocker): reqs: 03, 04 """ tracker = None + bxsource = MockBXFrameSource() + ndidevice = MockNDIDevice() mocker.patch('serial.tools.list_ports.comports', mockComports) mocker.patch('ndicapy.ndiProbe', mockndiProbe) mocker.patch('ndicapy.ndiOpen', mockndiOpen) - mocker.patch('ndicapy.ndiCommand') + mocker.patch('ndicapy.ndiCommand', ndidevice.mockndiCommand) mocker.patch('ndicapy.ndiGetError', mockndiGetError) mocker.patch('ndicapy.ndiClose') mocker.patch('ndicapy.ndiGetPHSRNumberOfHandles', - mockndiGetPHSRNumberOfHandles) - mocker.patch('ndicapy.ndiGetPHRQHandle', mockndiGetPHRQHandle) + ndidevice.mockndiGetPHSRNumberOfHandles) + mocker.patch('ndicapy.ndiGetPHRQHandle', ndidevice.mockndiGetPHRQHandle) mocker.patch('ndicapy.ndiPVWRFromFile') mocker.patch('ndicapy.ndiGetPHSRHandle', mockndiGetPHSRHandle) mocker.patch('ndicapy.ndiVER', mockndiVER) - mocker.patch('ndicapy.ndiGetBXFrame', mockndiGetBXFrame) - mocker.patch('ndicapy.ndiGetBXTransform', mockndiGetBXTransformMissing) + mocker.patch('ndicapy.ndiGetBXFrame', bxsource.mockndiGetBXFrame) + mocker.patch('ndicapy.ndiGetBXTransform', + bxsource.mockndiGetBXTransformMissing) tracker = NDITracker(SETTINGS_POLARIS) - tracker.get_frame() + + bxsource.setdevice(ndidevice) + (port_handles, time_stamps, frame_numbers, tracking, + tracking_quality ) = tracker.get_frame() + + assert len(port_handles) == 2 + assert len(time_stamps) == 2 + assert frame_numbers.count(1) == 2 + assert np.all(np.isnan(tracking)) + assert np.all(np.isnan(tracking_quality)) del tracker diff --git a/tests/test_sksurgerynditracker_mockndi_polaris.py b/tests/test_sksurgerynditracker_mockndi_polaris.py index df8ab34..2a90be0 100644 --- a/tests/test_sksurgerynditracker_mockndi_polaris.py +++ b/tests/test_sksurgerynditracker_mockndi_polaris.py @@ -8,8 +8,7 @@ from tests.polaris_mocks import SETTINGS_POLARIS, mockndiProbe, \ mockndiOpen, mockndiOpen_fail, mockndiGetError, mockComports, \ - mockndiGetPHSRNumberOfHandles, mockndiGetPHRQHandle, \ - mockndiGetPHSRHandle, mockndiVER + mockndiGetPHSRHandle, mockndiVER, MockNDIDevice def test_connect_polaris_mock(mocker): """ @@ -17,40 +16,41 @@ def test_connect_polaris_mock(mocker): reqs: 03, 04 """ tracker = None + ndidevice = MockNDIDevice() + #add a couple of extra tools to increase test coverage + ndidevice.attached_tools = 2 mocker.patch('serial.tools.list_ports.comports', mockComports) mocker.patch('ndicapy.ndiProbe', mockndiProbe) mocker.patch('ndicapy.ndiOpen', mockndiOpen) - mocker.patch('ndicapy.ndiCommand') + mocker.patch('ndicapy.ndiCommand', ndidevice.mockndiCommand) mocker.patch('ndicapy.ndiGetError', mockndiGetError) mocker.patch('ndicapy.ndiClose') mocker.patch('ndicapy.ndiGetPHSRNumberOfHandles', - mockndiGetPHSRNumberOfHandles) - mocker.patch('ndicapy.ndiGetPHRQHandle', mockndiGetPHRQHandle) + ndidevice.mockndiGetPHSRNumberOfHandles) + mocker.patch('ndicapy.ndiGetPHRQHandle', ndidevice.mockndiGetPHRQHandle) mocker.patch('ndicapy.ndiPVWRFromFile') mocker.patch('ndicapy.ndiGetPHSRHandle', mockndiGetPHSRHandle) mocker.patch('ndicapy.ndiVER', mockndiVER) spy = mocker.spy(ndicapy, 'ndiCommand') tracker = NDITracker(SETTINGS_POLARIS) - assert spy.call_count == 18 + assert spy.call_count == 16 assert spy.call_args_list[0] == call(True, 'INIT:') assert spy.call_args_list[1] == call(True, 'COMM:50000') assert spy.call_args_list[2] == call(True, 'PHSR:01') assert spy.call_args_list[3] == call(True, 'PHF:00') assert spy.call_args_list[4] == call(True, 'PHF:01') - assert spy.call_args_list[5] == call(True, 'PHF:02') - assert spy.call_args_list[6] == call(True, 'PHF:03') - assert spy.call_args_list[7] == call(True, 'PHRQ:*********1****') - assert spy.call_args_list[8] == call(True, 'PHRQ:*********1****') - assert spy.call_args_list[9] == call(True, 'PHSR:01') - assert spy.call_args_list[10] == call(True, 'PHSR:02') - assert spy.call_args_list[11] == call(True, 'PINIT:00') - assert spy.call_args_list[12] == call(True, 'PINIT:00') - assert spy.call_args_list[13] == call(True, 'PHSR:03') - assert spy.call_args_list[14] == call(True, 'PENA:00D') - assert spy.call_args_list[15] == call(True, 'PENA:01D') - assert spy.call_args_list[16] == call(True, 'PENA:02D') - assert spy.call_args_list[17] == call(True, 'PENA:03D') + assert spy.call_args_list[5] == call(True, 'PHRQ:*********1****') + assert spy.call_args_list[6] == call(True, 'PHRQ:*********1****') + assert spy.call_args_list[7] == call(True, 'PHSR:01') + assert spy.call_args_list[8] == call(True, 'PHSR:02') + assert spy.call_args_list[9] == call(True, 'PINIT:02') + assert spy.call_args_list[10] == call(True, 'PINIT:03') + assert spy.call_args_list[11] == call(True, 'PHSR:03') + assert spy.call_args_list[12] == call(True, 'PENA:00D') + assert spy.call_args_list[13] == call(True, 'PENA:01D') + assert spy.call_args_list[14] == call(True, 'PENA:02D') + assert spy.call_args_list[15] == call(True, 'PENA:03D') del tracker def test_connect_polaris_mk_fserial(mocker):