Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

User/wmerynda/processing kf2dataset #554

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions rec_to_nwb/processing/builder/nwb_file_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class NWBFileBuilder:
process_dio (boolean): flag if dio data should be processed
process_mda (boolean): flag if mda data should be processed
process_analog (boolean): flag if analog data should be processed
process_video (boolean): flag if video data should be processed
process_camera_sample_frame_count(boolean): flag if camera sample frame count should be processed
video_path (string): path to directory with video files associated to nwb file
output_file (string): path and name specifying where .nwb file gonna be written

Expand All @@ -88,6 +90,8 @@ def __init__(
process_mda: bool = True,
process_analog: bool = True,
process_pos_timestamps: bool = True,
process_video: bool = False,
process_camera_sample_frame_count: bool = False,
video_path: str = '',
output_file: str = 'output.nwb',
reconfig_header: str = ''
Expand Down Expand Up @@ -126,7 +130,9 @@ def __init__(
self.process_dio = process_dio
self.process_mda = process_mda
self.process_analog = process_analog
self.process_video = process_video
self.process_pos_timestamps = process_pos_timestamps
self.process_camera_sample_frame_count = process_camera_sample_frame_count
self.output_file = output_file
self.video_path = video_path
self.link_to_notes = self.metadata.get('link to notes', None)
Expand Down Expand Up @@ -286,7 +292,8 @@ def build(self):

self.camera_device_originator.make(nwb_content)

self.video_files_originator.make(nwb_content)
if self.process_video:
self.video_files_originator.make(nwb_content)

electrode_groups = self.electrode_group_originator.make(
nwb_content, probes, valid_map_dict['electrode_groups']
Expand All @@ -303,8 +310,8 @@ def build(self):
self.sample_count_timestamp_corespondence_originator.make(nwb_content)

self.task_originator.make(nwb_content)

self.camera_sample_frame_counts_originator.make(nwb_content)
if self.process_camera_sample_frame_count:
self.camera_sample_frame_counts_originator.make(nwb_content)

if self.process_dio:
self.dio_originator.make(nwb_content)
Expand Down
13 changes: 11 additions & 2 deletions rec_to_nwb/processing/builder/originators/mda_originator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,27 @@ def __init__(self, datasets, header, metadata):
self.datasets = datasets
self.header = header
self.metadata = metadata
self.number_of_channels = self.count_number_of_channels(header)

def make(self, nwb_content):
logger.info('MDA: Building')
fl_mda_manager = FlMdaManager(
nwb_content=nwb_content,
sampling_rate=float(self.header.configuration.hardware_configuration.sampling_rate),
datasets=self.datasets,
conversion=self.metadata['raw_data_to_volts']
conversion=self.metadata['raw_data_to_volts'],
number_of_channels=self.number_of_channels
)
fl_mda = fl_mda_manager.get_data()
logger.info('MDA: Injecting')
MdaInjector.inject_mda(
nwb_content=nwb_content,
electrical_series=ElectricalSeriesCreator.create_mda(fl_mda)
)
)

def count_number_of_channels(self, header):
spike_configuration = header.configuration.spike_configuration
counter = 0
for spike_n_trode in spike_configuration.spike_n_trodes:
counter += len(spike_n_trode.spike_channels)
return counter
2 changes: 1 addition & 1 deletion rec_to_nwb/processing/header/module/spike_n_trode.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ def __init__(self, element):
self.filter_on = self.tree.get('filterOn')
self.ref_on = self.tree.get('refOn')
self.module_data_on = self.tree.get('moduleDataOn')
self.ref_n_trode_id = self.tree.get('refNTrodeID')
self.ref_n_trode_id = self.tree.get('refNTrodeID', self.tree.get('refNTrode'))
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ def get_fl_header_device(self):

def __compare_global_configuration_with_default(self):
for single_key in self.default_configuration:
if single_key not in self.global_configuration.keys():
if single_key not in self.global_configuration.keys() or self.global_configuration[single_key] is None:
self.global_configuration[single_key] = self.default_configuration[single_key]
return self.global_configuration
17 changes: 9 additions & 8 deletions rec_to_nwb/processing/nwb/components/iterator/data_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@


class DataIterator(AbstractDataChunkIterator):
def __init__(self, data):
def __init__(self, data, number_of_channels):
self.data = data
self.current_number_of_rows = 0

self._current_index = 0
self.current_file = 0
self.current_dataset = 0

self.number_of_steps = self.data.get_number_of_datasets() * self.data.get_number_of_files_per_dataset()
self.dataset_file_length = self.data.get_file_lenghts_in_datasets()
self.number_of_rows = self.data.get_number_of_rows_per_file()
self.number_of_files_in_single_dataset = self.data.get_number_of_files_per_dataset()
self.shape = [self.data.get_final_data_shape()[1], self.data.get_final_data_shape()[0]]
self.shape = [self.data.get_final_data_shape()[1], number_of_channels]

def __iter__(self):
return self
Expand All @@ -25,12 +25,13 @@ def _get_selection(self):
(self.current_file * self.number_of_rows):
((self.current_file + 1) * self.number_of_rows)]

@staticmethod
def get_selection(number_of_threads, current_dataset, dataset_file_length, current_file, number_of_rows):
return np.s_[sum(dataset_file_length[0:current_dataset]):
def get_selection(self, current_dataset, dataset_file_length, number_of_new_rows):
selection = np.s_[sum(dataset_file_length[0:current_dataset]):
sum(dataset_file_length[0:current_dataset + 1]),
(current_file * number_of_rows):
((current_file + number_of_threads) * number_of_rows)]
self.current_number_of_rows: self.current_number_of_rows + number_of_new_rows]

self.current_number_of_rows += number_of_new_rows
return selection

def recommended_chunk_shape(self):
return None
Expand Down
46 changes: 46 additions & 0 deletions rec_to_nwb/processing/nwb/components/iterator/data_iterator_pos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import numpy as np
from hdmf.data_utils import AbstractDataChunkIterator


class DataIteratorPos(AbstractDataChunkIterator):
def __init__(self, data):
self.data = data

self._current_index = 0
self.current_file = 0
self.current_dataset = 0
self.number_of_rows = self.data.get_number_of_rows_per_file()
self.number_of_steps = self.data.get_number_of_datasets() * self.data.get_number_of_files_per_dataset()
self.dataset_file_length = self.data.get_file_lenghts_in_datasets()
self.number_of_files_in_single_dataset = self.data.get_number_of_files_per_dataset()
self.shape = [self.data.get_final_data_shape()[1], self.data.get_final_data_shape()[0]]

def __iter__(self):
return self

def _get_selection(self):
return np.s_[sum(self.dataset_file_length[0:self.current_dataset]):
sum(self.dataset_file_length[0:self.current_dataset + 1]),
(self.current_file * self.number_of_rows):
((self.current_file + 1) * self.number_of_rows)]

@staticmethod
def get_selection(number_of_threads, current_dataset, dataset_file_length, current_file, number_of_rows):
return np.s_[sum(dataset_file_length[0:current_dataset]):
sum(dataset_file_length[0:current_dataset + 1]),
(current_file * number_of_rows):
((current_file + number_of_threads) * number_of_rows)]

def recommended_chunk_shape(self):
return None

def recommended_data_shape(self):
return self.shape

@property
def dtype(self):
return np.dtype('int16')

@property
def maxshape(self):
return self.shape
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@


class MultiThreadDataIterator(DataIterator):
def __init__(self, data, number_of_threads=6):
DataIterator.__init__(self, data)
def __init__(self, data, number_of_channels, number_of_threads=6):
DataIterator.__init__(self, data, number_of_channels)
self.number_of_threads = number_of_threads

def __next__(self):
Expand All @@ -23,11 +23,10 @@ def __next__(self):
for thread in threads:
data_from_multiple_files += (thread.result(),)
stacked_data_from_multiple_files = np.hstack(data_from_multiple_files)
selection = self.get_selection(number_of_threads=number_of_threads_in_current_step,
current_dataset=self.current_dataset,
number_of_new_rows = stacked_data_from_multiple_files.shape[1]
selection = self.get_selection(current_dataset=self.current_dataset,
dataset_file_length=self.dataset_file_length,
current_file=self.current_file,
number_of_rows=self.number_of_rows)
number_of_new_rows=number_of_new_rows)
data_chunk = DataChunk(data=stacked_data_from_multiple_files, selection=selection)

self._current_index += number_of_threads_in_current_step
Expand All @@ -36,6 +35,7 @@ def __next__(self):
if self.current_file >= self.number_of_files_in_single_dataset:
self.current_dataset += 1
self.current_file = 0
self.current_number_of_rows = 0

del stacked_data_from_multiple_files
return data_chunk
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import concurrent.futures

import numpy as np
from hdmf.data_utils import DataChunk

from rec_to_nwb.processing.nwb.components.iterator.data_iterator_pos import DataIteratorPos


class MultiThreadDataIteratorPos(DataIteratorPos):
def __init__(self, data, number_of_threads=6):
DataIteratorPos.__init__(self, data)
self.number_of_threads = number_of_threads

def __next__(self):
if self._current_index < self.number_of_steps:
number_of_threads_in_current_step = min(self.number_of_threads,
self.number_of_files_in_single_dataset - self.current_file)
with concurrent.futures.ThreadPoolExecutor() as executor:
threads = [executor.submit(MultiThreadDataIteratorPos.get_data_from_file,
self.data, self.current_dataset, self.current_file + i)
for i in range(number_of_threads_in_current_step)]
data_from_multiple_files = ()
for thread in threads:
data_from_multiple_files += (thread.result(),)
stacked_data_from_multiple_files = np.hstack(data_from_multiple_files)
selection = self.get_selection(number_of_threads=number_of_threads_in_current_step,
current_dataset=self.current_dataset,
dataset_file_length=self.dataset_file_length,
current_file=self.current_file,
number_of_rows=self.number_of_rows)
data_chunk = DataChunk(data=stacked_data_from_multiple_files, selection=selection)

self._current_index += number_of_threads_in_current_step
self.current_file += number_of_threads_in_current_step

if self.current_file >= self.number_of_files_in_single_dataset:
self.current_dataset += 1
self.current_file = 0
self.current_number_of_rows = 0

del stacked_data_from_multiple_files
return data_chunk

raise StopIteration

next = __next__

@staticmethod
def get_data_from_file(data, current_dataset, current_file):
return np.transpose(data.read_data(current_dataset, current_file))

Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class MultiThreadTimestampIterator(TimestampIterator):

def __init__(self, data, number_of_threads=6):
def __init__(self, data, number_of_threads=1):
TimestampIterator.__init__(self, data)
self.number_of_threads = number_of_threads

Expand Down
5 changes: 3 additions & 2 deletions rec_to_nwb/processing/nwb/components/mda/fl_mda_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

class FlMdaExtractor:

def __init__(self, datasets):
def __init__(self, datasets, number_of_channels):
self.datasets = datasets
self.number_of_channels = number_of_channels

def get_data(self):
mda_data, timestamps, continuous_time = self.__extract_data()
Expand All @@ -18,7 +19,7 @@ def get_data(self):
continuous_time_directories=continuous_time
)
mda_data_manager = MdaDataManager(mda_data)
data_iterator = MultiThreadDataIterator(mda_data_manager)
data_iterator = MultiThreadDataIterator(mda_data_manager, self.number_of_channels)
timestamp_iterator = MultiThreadTimestampIterator(mda_timestamp_data_manager)

return MdaContent(data_iterator, timestamp_iterator)
Expand Down
4 changes: 2 additions & 2 deletions rec_to_nwb/processing/nwb/components/mda/fl_mda_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@


class FlMdaManager:
def __init__(self, nwb_content, sampling_rate, datasets, conversion):
def __init__(self, nwb_content, sampling_rate, datasets, conversion, number_of_channels):
self.__table_region_builder = TableRegionBuilder(nwb_content)
self.__fl_mda_extractor = FlMdaExtractor(datasets)
self.__fl_mda_extractor = FlMdaExtractor(datasets, number_of_channels)
self.__fl_mda_builder = FlMdaBuilder(sampling_rate, conversion)

def get_data(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from rec_to_nwb.processing.exceptions.missing_data_exception import MissingDataException
from rec_to_nwb.processing.nwb.components.iterator.multi_thread_data_iterator import MultiThreadDataIterator
from rec_to_nwb.processing.nwb.components.iterator.multi_thread_data_iterator_pos import MultiThreadDataIteratorPos
from rec_to_nwb.processing.nwb.components.iterator.multi_thread_timestamp_iterator import MultiThreadTimestampIterator
from rec_to_nwb.processing.nwb.components.position.pos_data_manager import PosDataManager
from rec_to_nwb.processing.nwb.components.position.pos_timestamp_manager import PosTimestampManager
Expand All @@ -26,6 +26,9 @@ def __extract_data(self):
'Incomplete data in dataset '
+ str(dataset.name)
+ 'missing continuous time file')
if len(data_from_current_dataset) == 0:
# otherwise get IndexError downstream (PosDataManager)
continue
all_pos.append(data_from_current_dataset)
continuous_time.append(dataset.get_continuous_time())
return all_pos, continuous_time
Expand All @@ -36,7 +39,7 @@ def get_positions(self):
for single_pos in self.all_pos
]
return [
MultiThreadDataIterator(pos_data)
MultiThreadDataIteratorPos(pos_data)
for pos_data in pos_datas
]

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import numpy as np
from rec_to_binaries.read_binaries import readTrodesExtractedDataFile

Expand All @@ -15,13 +17,21 @@ def extract_video_files(self):
video_files = self.video_files_metadata
extracted_video_files = []
for video_file in video_files:
new_fl_video_file = {
"name": video_file["name"],
"timestamps": self.convert_timestamps(readTrodesExtractedDataFile(
if Path(self.raw_data_path + "/" + video_file["name"][:-4] + "videoTimeStamps.cameraHWSync").is_file():
converted_timestamps = self.convert_timestamps(readTrodesExtractedDataFile(
self.raw_data_path + "/"
+ video_file["name"][:-4]
+ "videoTimeStamps.cameraHWSync"
)["data"]),
+ "videoTimeStamps.cameraHWFrameCount"
)["data"])
else:
converted_timestamps = readTrodesExtractedDataFile(
self.raw_data_path + "/"
+ video_file["name"][:-4]
+ "videoTimeStamps.cameraHWFrameCount"
)["data"]
new_fl_video_file = {
"name": video_file["name"],
"timestamps": converted_timestamps,
"device": video_file["camera_id"]
}
extracted_video_files.append(new_fl_video_file)
Expand Down
Loading