diff --git a/bc_qualityMetrics_pipeline.asv b/bc_qualityMetrics_pipeline.asv index 80c4ab20..b7eb3252 100644 --- a/bc_qualityMetrics_pipeline.asv +++ b/bc_qualityMetrics_pipeline.asv @@ -19,12 +19,15 @@ ephysRawDir = dir('/home/netshare/zaru/JF093/2023-03-06/ephys/site1/*ap*.*bin'); ephysMetaDir = dir('/home/netshare/zaru/JF093/2023-03-06/ephys/site1/*ap*.*meta'); % path to your .meta or .oebin meta file savePath = '/media/julie/ExtraHD/JF093/qMetrics'; % where you want to save the quality metrics decompressDataLocal = '/media/julie/ExtraHD/decompressedData'; % where to save raw decompressed ephys data -gain_to_uV = 0.195; % use this if you are not using spikeGLX or openEphys to record your data. You can then leave ephysMetaDir +gain_to_uV = 0.195; % use this if you are not using spikeGLX or openEphys to record your data. You then must leave the ephysMetaDir % empty(e.g. ephysMetaDir = '') %% check MATLAB version oldMATLAB = isMATLABReleaseOlderThan("R2019a"); -of +if oldMATLAB + error('This MATLAB version is older than 2019a - download a more recent version before continuing') +end + %% load data [spikeTimes_samples, spikeTemplates, templateWaveforms, templateAmplitudes, pcFeatures, ... pcFeatureIdx, channelPositions] = bc_loadEphysData(ephysKilosortPath); @@ -33,8 +36,9 @@ of rawFile = bc_manageDataCompression(ephysRawDir, decompressDataLocal); %% which quality metric parameters to extract and thresholds -param = bc_qualityParamValues(ephysMetaDir, rawFile, ephysKilosortPath, gain_to_uV); -% param = bc_qualityParamValuesForUnitMatch(ephysMetaDir, rawFile) % Run this if you want to use UnitMatch after +param = bc_qualityParamValues(ephysMetaDir, rawFile, ephysKilosortPath, gain_to_uV); %for unitmatch, run this: +% param = bc_qualityParamValuesForUnitMatch(ephysMetaDir, rawFile, ephysKilosortPath, gain_to_uV) + %% compute quality metrics rerun = 0; diff --git a/bc_qualityMetrics_pipeline.m b/bc_qualityMetrics_pipeline.m index b78d1728..00ccfb2d 100644 --- a/bc_qualityMetrics_pipeline.m +++ b/bc_qualityMetrics_pipeline.m @@ -14,10 +14,11 @@ %% set paths - EDIT THESE -ephysKilosortPath = '/home/netshare/zaru/JF093/2023-03-06/ephys/kilosort2/site1/';% path to your kilosort output files -ephysRawDir = dir('/home/netshare/zaru/JF093/2023-03-06/ephys/site1/*ap*.*bin'); % path to yourraw .bin or .dat data -ephysMetaDir = dir('/home/netshare/zaru/JF093/2023-03-06/ephys/site1/*ap*.*meta'); % path to your .meta or .oebin meta file -savePath = '/media/julie/ExtraHD/JF093/qMetrics'; % where you want to save the quality metrics +% '/home/netshare/zaru/JF093/2023-03-06/ephys/kilosort2/site1 +ephysKilosortPath = '/home/netshare/zaru/JF093/2023-03-08/ephys/pykilosort/site2/output';% path to your kilosort output files +ephysRawDir = dir('/home/netshare/zaru/JF093/2023-03-08/ephys/site2/*ap*.*bin'); % path to your raw .bin or .dat data +ephysMetaDir = dir('/home/netshare/zaru/JF093/2023-03-08/ephys/site2/*ap*.*meta'); % path to your .meta or .oebin meta file +savePath = '/media/julie/ExtraHD/JF093/2023-03-08/ephys/site2/qMetrics'; % where you want to save the quality metrics decompressDataLocal = '/media/julie/ExtraHD/decompressedData'; % where to save raw decompressed ephys data gain_to_uV = 0.195; % use this if you are not using spikeGLX or openEphys to record your data. You then must leave the ephysMetaDir % empty(e.g. ephysMetaDir = '') @@ -36,8 +37,9 @@ rawFile = bc_manageDataCompression(ephysRawDir, decompressDataLocal); %% which quality metric parameters to extract and thresholds -param = bc_qualityParamValues(ephysMetaDir, rawFile, ephysKilosortPath, gain_to_uV); -% param = bc_qualityParamValuesForUnitMatch(ephysMetaDir, rawFile) % Run this if you want to use UnitMatch after +param = bc_qualityParamValues(ephysMetaDir, rawFile, ephysKilosortPath, gain_to_uV); %for unitmatch, run this: +% param = bc_qualityParamValuesForUnitMatch(ephysMetaDir, rawFile, ephysKilosortPath, gain_to_uV) + %% compute quality metrics rerun = 0; diff --git a/loading/bc_loadMetricsForGUI.m b/loading/bc_loadMetricsForGUI.m index 81c77c52..38cc9712 100644 --- a/loading/bc_loadMetricsForGUI.m +++ b/loading/bc_loadMetricsForGUI.m @@ -43,6 +43,14 @@ rawWaveforms.average = readNPY([fullfile(savePath, 'templates._bc_rawWaveforms.npy')]); rawWaveforms.peakChan = readNPY([fullfile(savePath, 'templates._bc_rawWaveformPeakChannels.npy')]); +% remove any duplicate spikes +[uniqueTemplates, ~, ephysData.spike_times_samples, ephysData.spike_templates, ephysData.template_amplitudes, ... + ~, rawWaveforms.average, rawWaveforms.peakChan, signalToNoiseRatio] = ... + bc_removeDuplicateSpikes(ephysData.spike_times_samples, ephysData.spike_templates, ephysData.template_amplitudes,... + [], rawWaveforms.average, rawWaveforms.peakChan,[],... + qMetric.maxChannels, param.removeDuplicateSpikes, param.duplicateSpikeWindow_s, ... + param.ephys_sample_rate, param.saveSpikes_withoutDuplicates, savePath, param.recomputeDuplicateSpikes); + % load other gui stuffs if ~exist('forGUI', 'var') || ~isempty(dir([savePath, filesep, 'templates.qualityMetricDetailsforGUI.mat'])) load([savePath, filesep, 'templates.qualityMetricDetailsforGUI.mat']) diff --git a/personal_work_in_progress/bc_qualityMetricsPipeline_JF.m b/personal_work_in_progress/bc_qualityMetricsPipeline_JF.m index b09f71ec..fdf2e363 100644 --- a/personal_work_in_progress/bc_qualityMetricsPipeline_JF.m +++ b/personal_work_in_progress/bc_qualityMetricsPipeline_JF.m @@ -30,7 +30,7 @@ end %% run qmetrics -param = bc_qualityParamValues(ephysMetaDir, rawFile); +param = bc_qualityParamValues_JF(ephysMetaDir, rawFile); %param.computeDistanceMetrics = 1; %% compute quality metrics diff --git a/personal_work_in_progress/bc_qualityParamValues_JF.m b/personal_work_in_progress/bc_qualityParamValues_JF.m new file mode 100644 index 00000000..a494a634 --- /dev/null +++ b/personal_work_in_progress/bc_qualityParamValues_JF.m @@ -0,0 +1,6 @@ +function param = bc_qualityParamValues_JF(ephysMetaDir, rawFile, ephysKilosortPath, gain_to_uV) + +param = bc_qualityParamValues(ephysMetaDir, rawFile, ephysKilosortPath, gain_to_uV); +param.removeDuplicateSpikes = 0; + +end \ No newline at end of file diff --git a/qualityMetrics/bc_addMissingFieldsWithDefault.m b/qualityMetrics/bc_addMissingFieldsWithDefault.m new file mode 100644 index 00000000..ce6d9f85 --- /dev/null +++ b/qualityMetrics/bc_addMissingFieldsWithDefault.m @@ -0,0 +1,34 @@ +function [data, missingFields] = bc_addMissingFieldsWithDefault(data, defaultValues) +% JF, Check input structure has all necessary fields + add them with +% defualt values if not. +% ------ +% Inputs +% ------ + if ~isstruct(data) && ~istable(data) + error('Input must be a structure or table'); + end + + if ~isstruct(defaultValues) + error('Default values must be provided as a structure'); + end + + fieldnames = fields(defaultValues); + + if isstruct(data) + missingFields = fieldnames(~isfield(data, fieldnames)); + + for i = 1:length(missingFields) + fieldName = missingFields{i}; + data.(fieldName) = defaultValues.(fieldName); + end + else % data is a table + existingFields = data.Properties.VariableNames; + missingFields = setdiff(fieldnames, existingFields); + + for i = 1:length(missingFields) + fieldName = missingFields{i}; + data.(fieldName) = repmat(defaultValues.(fieldName), height(data), 1); + end + end +end + diff --git a/qualityMetrics/bc_checkParameterFields.m b/qualityMetrics/bc_checkParameterFields.m new file mode 100644 index 00000000..824fafa1 --- /dev/null +++ b/qualityMetrics/bc_checkParameterFields.m @@ -0,0 +1,41 @@ +function param_complete = bc_checkParameterFields(param) +% JF, Check input structure has all necessary fields + add them with +% defualt values if not. This is to ensure backcompatibility when any new +% paramaters are introduced. By default, any parameters not already present +% will be set so that the quality metrics are calculated in the same way as +% they were before these new parameters were introduced. +% ------ +% Inputs +% ------ + + + +%% Default values for fields +% duplicate spikes +defaultValues.removeDuplicateSpikes = 0; +defaultValues.duplicateSpikeWindow_s = 0.0001; +defaultValues.saveSpikes_withoutDuplicates = 1; +defaultValues.recomputeDuplicateSpikes = 0; + +% raw waveforms +defaultValues.detrendWaveforms = 0; +defaultValues.extractRaw = 1; + +% amplitude +defaultValues.gain_to_uV = NaN; + +% phy saving +defaultValues.saveAsTSV = 0; +defaultValues.unitType_for_phy = 0; + + +%% Check for missing fields and add them with default value +[param_complete, missingFields] = bc_addMissingFieldsWithDefault(param, defaultValues); + +%% Display result +if ~isempty(missingFields) + disp('Missing param fields filled in with default values'); + disp(missingFields); +end + +end diff --git a/qualityMetrics/bc_getQualityUnitType.m b/qualityMetrics/bc_getQualityUnitType.m index 5d3d73b5..104ed547 100644 --- a/qualityMetrics/bc_getQualityUnitType.m +++ b/qualityMetrics/bc_getQualityUnitType.m @@ -7,6 +7,10 @@ % ------ % Outputs % ------ + +% check paramaters +param = bc_checkParameterFields(param); + if nargin < 3 && param.unitType_for_phy == 1 savePath = pwd; warning('no save path specified. using current working directory') diff --git a/qualityMetrics/bc_qualityParamValues.m b/qualityMetrics/bc_qualityParamValues.m index 9b13f21f..f157bb46 100644 --- a/qualityMetrics/bc_qualityParamValues.m +++ b/qualityMetrics/bc_qualityParamValues.m @@ -39,6 +39,12 @@ end param.saveMatFileForGUI = 1; % save certain outputs at .mat file - useful for GUI +% duplicate spikes parameters +param.removeDuplicateSpikes = 1; +param.duplicateSpikeWindow_s = 0.00001; % in seconds +param.saveSpikes_withoutDuplicates = 1; +param.recomputeDuplicateSpikes = 0; + % amplitude / raw waveform parameters param.detrendWaveform = 1; % If this is set to 1, each raw extracted spike is % detrended (we remove the best straight-fit line from the spike) @@ -56,6 +62,7 @@ % For additional probe types, make a pull request with more % information. If your spikeGLX meta file contains information about your probe % type, or if you are using open ephys, this paramater wil be ignored. +param.detrendWaveforms = 0; % signal to noise ratio param.waveformBaselineNoiseWindow = 20; %time in samples at beginning of times diff --git a/qualityMetrics/bc_removeDuplicateSpikes.m b/qualityMetrics/bc_removeDuplicateSpikes.m new file mode 100644 index 00000000..64d5ba93 --- /dev/null +++ b/qualityMetrics/bc_removeDuplicateSpikes.m @@ -0,0 +1,160 @@ +function [nonEmptyUnits, duplicateSpikes_idx, spikeTimes_samples, spikeTemplates, templateAmplitudes, ... + pcFeatures, rawWaveformsFull, rawWaveformsPeakChan, signalToNoiseRatio, maxChannels] = ... + bc_removeDuplicateSpikes(spikeTimes_samples, spikeTemplates, templateAmplitudes, ... + pcFeatures, rawWaveformsFull, rawWaveformsPeakChan, signalToNoiseRatio, ... + maxChannels, removeDuplicateSpikes_flag, ... + duplicateSpikeWindow_s, ephys_sample_rate, saveSpikes_withoutDuplicates_flag, savePath, recompute) +% JF, Remove any duplicate spikes +% Some spike sorters (including kilosort) can sometimes count spikes twice +% if for instance the residuals are re-fitted. see https://github.com/MouseLand/Kilosort/issues/29 +% ------ +% Inputs +% ------ +% +% ------ +% Outputs +% ------ +% + +if removeDuplicateSpikes_flag + % Check if we need to extract duplicate spikes + if recompute || isempty(dir([savePath, filesep, 'spikes._bc_duplicateSpikes.npy'])) + % Parameters + duplicateSpikeWindow_samples = duplicateSpikeWindow_s * ephys_sample_rate; + batch_size = 10000; + overlap_size = 100; + numSpikes_full = length(spikeTimes_samples); + + % initialize and re-allocate + duplicateSpikes_idx = false(1, numSpikes_full); + + % Rename the spike templates according to the remaining templates + good_templates_idx = unique(spikeTemplates); + new_spike_idx = nan(max(spikeTemplates), 1); + new_spike_idx(good_templates_idx) = 1:length(good_templates_idx); + spikeTemplates_flat = new_spike_idx(spikeTemplates); + + % check for duplicate spikes in batches + for start_idx = 1:batch_size - overlap_size:numSpikes_full + end_idx = min(start_idx+batch_size-1, numSpikes_full); + batch_spikeTimes_samples = spikeTimes_samples(start_idx:end_idx); + batch_spikeTemplates = spikeTemplates(start_idx:end_idx); + batch_templateAmplitudes = templateAmplitudes(start_idx:end_idx); + + [~, ~, batch_removeIdx] = removeDuplicates(batch_spikeTimes_samples, ... + batch_spikeTemplates, batch_templateAmplitudes, duplicateSpikeWindow_samples, ... + maxChannels, spikeTemplates_flat); + + duplicateSpikes_idx(start_idx:end_idx) = batch_removeIdx; + + if end_idx == numSpikes_full + break; + end + end + % save data if required + if saveSpikes_withoutDuplicates_flag + writeNPY(duplicateSpikes_idx, [savePath, filesep, 'spikes._bc_duplicateSpikes.npy']) + end + + else + duplicateSpikes_idx = readNPY([savePath, filesep, 'spikes._bc_duplicateSpikes.npy']); + end + + % check if there are any empty units + unique_templates = unique(spikeTemplates); + nonEmptyUnits = unique(spikeTemplates(~duplicateSpikes_idx)); + emptyUnits_idx = ~ismember(unique_templates, nonEmptyUnits); + + % remove any empty units from ephys data + spikeTimes_samples = spikeTimes_samples(~duplicateSpikes_idx); + spikeTemplates = spikeTemplates(~duplicateSpikes_idx); + templateAmplitudes = templateAmplitudes(~duplicateSpikes_idx); + if ~isempty(pcFeatures) + pcFeatures = pcFeatures(~duplicateSpikes_idx, :, :); + end + if ~isempty(rawWaveformsFull) + rawWaveformsFull = rawWaveformsFull(~emptyUnits_idx, :, :); + rawWaveformsPeakChan = rawWaveformsPeakChan(~emptyUnits_idx); + end + + if ~isempty(signalToNoiseRatio) + signalToNoiseRatio = signalToNoiseRatio(~emptyUnits_idx); + end + + fprintf('\n Removed %.0f spike duplicates out of %.0f. \n', sum(duplicateSpikes_idx), length(duplicateSpikes_idx)) + +else + nonEmptyUnits = unique(spikeTemplates); + duplicateSpikes_idx = zeros(size(spikeTimes_samples, 1), 1); + +end + + + function [spikeTimes_samples, spikeTemplates, removeIdx] = removeDuplicates(spikeTimes_samples, ... + spikeTemplates, templateAmplitudes, duplicateSpikeWindow_samples, maxChannels, spikeTemplates_flat) + numSpikes = length(spikeTimes_samples); + removeIdx = false(1, numSpikes); + + % Intra-unit duplicate removal + for iSpike1 = 1:numSpikes + if removeIdx(iSpike1) + continue; + end + + for iSpike2 = iSpike1 + 1:numSpikes + if removeIdx(iSpike2) + continue; + end + if maxChannels(spikeTemplates_flat(iSpike2)) ~= maxChannels(spikeTemplates_flat(iSpike1)) % spikes are not on same channel + continue; + end + + if spikeTemplates(iSpike1) == spikeTemplates(iSpike2) + if abs(spikeTimes_samples(iSpike1)-spikeTimes_samples(iSpike2)) <= duplicateSpikeWindow_samples + if templateAmplitudes(iSpike1) < templateAmplitudes(iSpike2) + spikeTimes_samples(iSpike1) = NaN; + removeIdx(iSpike1) = true; + break; + else + spikeTimes_samples(iSpike2) = NaN; + removeIdx(iSpike2) = true; + end + end + end + end + end + + + % Inter-unit duplicate removal + unitSpikeCounts = accumarray(spikeTemplates, 1); + for iSpike1 = 1:length(spikeTimes_samples) + if removeIdx(iSpike1) + continue; + end + + for iSpike2 = iSpike1 + 1:length(spikeTimes_samples) + if removeIdx(iSpike2) + continue; + end + if maxChannels(spikeTemplates_flat(iSpike2)) ~= maxChannels(spikeTemplates_flat(iSpike1)) % spikes are not on same channel + continue; + end + + if spikeTemplates(iSpike1) ~= spikeTemplates(iSpike2) + if abs(spikeTimes_samples(iSpike1)-spikeTimes_samples(iSpike2)) <= duplicateSpikeWindow_samples + if unitSpikeCounts(spikeTemplates(iSpike1)) < unitSpikeCounts(spikeTemplates(iSpike2)) + spikeTimes_samples(iSpike1) = NaN; + removeIdx(iSpike1) = true; + break; + else + spikeTimes_samples(iSpike2) = NaN; + removeIdx(iSpike2) = true; + end + end + end + end + end + end + + +end diff --git a/qualityMetrics/bc_runAllQualityMetrics.asv b/qualityMetrics/bc_runAllQualityMetrics.asv deleted file mode 100644 index a1482974..00000000 --- a/qualityMetrics/bc_runAllQualityMetrics.asv +++ /dev/null @@ -1,238 +0,0 @@ -function [qMetric, unitType] = bc_runAllQualityMetrics(param, spikeTimes_samples, spikeTemplates, ... - templateWaveforms, templateAmplitudes, pcFeatures, pcFeatureIdx, channelPositions, savePath) -% JF -% ------ -% Inputs -% ------ -% param: parameter structure with fields: -% tauR = 0.0010; %refractory period time (s) -% tauC = 0.0002; %censored period time (s) -% maxPercSpikesMissing: maximum percent (eg 30) of estimated spikes below detection -% threshold to define timechunks in the recording on which to compute -% quality metrics for each unit. -% minNumSpikes: minimum number of spikes (eg 300) for unit to classify it as good -% maxNtroughsPeaks: maximum number of troughs and peaks (eg 3) to classify unit -% waveform as good -% isSomatic: boolean, whether to keep only somatic spikes -% maxRPVviolations: maximum estimated % (eg 20) of refractory period violations to classify unit as good -% minAmplitude: minimum amplitude of raw waveform in microVolts to -% classify unit as good -% plotThis: boolean, whether to plot figures for each metric and unit - ! -% this will create * a lot * of plots if run on all units - use just -% for debugging a particular issue / creating plots for one single -% unit -% rawFolder: string containing the location of the raw .dat or .bin file -% deltaTimeChunk: size of time chunks to cut the recording in, in seconds -% (eg 600 for 10 min time chunks or duration of recording if you don't -% want time chunks) -% ephys_sample_rate: recording sample rate (eg 30000) -% nChannels: number of recorded channels, including any sync channels (eg -% 385) -% nRawSpikesToExtract: number of spikes to extract from the raw data for -% each waveform (eg 100) -% nChannelsIsoDist: number of channels on which to compute the distance -% metrics (eg 4) -% computeDistanceMetrics: boolean, whether to compute distance metrics or not -% isoDmin: minimum isolation distance to classify unit as single-unit -% lratioMin: minimum l-ratio to classify unit as single-unit -% ssMin: silhouette score to classify unit as single-unit -% computeTimeChunks -% -% spikeTimes_samples: nSpikes × 1 uint64 vector giving each spike time in samples (*not* seconds) -% -% spikeTemplates: nSpikes × 1 uint32 vector giving the identity of each -% spike's matched template -% -% templateWaveforms: nTemplates × nTimePoints × nChannels single matrix of -% template waveforms for each template and channel -% -% templateAmplitudes: nSpikes × 1 double vector of the amplitude scaling factor -% that was applied to the template when extracting that spike -% -% pcFeatures: nSpikes × nFeaturesPerChannel × nPCFeatures single -% matrix giving the PC values for each spike -% -% pcFeatureIdx: nTemplates × nPCFeatures uint32 matrix specifying which -% channels contribute to each entry in dim 3 of the pc_features matrix -% -% channelPositions -% goodChannels -%------ -% Outputs -% ------ -% qMetric: structure with fields: -% percentageSpikesMissing : a gaussian is fit to the spike amplitudes with a -% 'cutoff' parameter below which there are no spikes to estimate the -% percentage of spikes below the spike-sorting detection threshold - will -% slightly underestimate in the case of 'bursty' cells with burst -% adaptation (eg see Fig 5B of Harris/Buzsaki 2000 DOI: 10.1152/jn.2000.84.1.401) -% fractionRefractoryPeriodViolations: percentage of false positives, ie spikes within the refractory period -% defined by param.tauR of another spike. This also excludes -% duplicated spikes that occur within param.tauC of another spike. -% useTheseTimes : param.computeTimeChunks, this defines the time chunks -% (deivding the recording in time of chunks of param.deltaTimeChunk size) -% where the percentage of spike missing and percentage of false positives -% is below param.maxPercSpikesMissing and param.maxRPVviolations -% nSpikes : number of spikes for each unit -% nPeaks : number of detected peaks in each units template waveform -% nTroughs : number of detected troughs in each units template waveform -% isSomatic : a unit is defined as Somatic of its trough precedes its main -% peak (see Deligkaris/Frey DOI: 10.3389/fnins.2016.00421) -% rawAmplitude : amplitude in uV of the units mean raw waveform at its peak -% channel. The peak channel is defined by the template waveform. -% spatialDecay : gets the minumum amplitude for each unit 5 channels from -% the peak channel and calculates the slope of this decrease in amplitude. -% isoD : isolation distance, a measure of how well a units spikes are seperate from -% other nearby units spikes -% Lratio : l-ratio, a similar measure to isolation distance. see -% Schmitzer-Torbert/Redish 2005 DOI: 10.1016/j.neuroscience.2004.09.066 -% for a comparison of l-ratio/isolation distance -% silhouetteScore : another measure similar ti isolation distance and -% l-ratio. See Rousseeuw 1987 DOI: 10.1016/0377-0427(87)90125-7) -% -% unitType: nUnits x 1 vector indicating whether each unit met the -% threshold criterion to be classified as a single unit (1), noise -% (0) or multi-unit (2) - -%% if some manual curation already performed, remove bad units - -%% prepare for quality metrics computations -% initialize structures -qMetric = struct; -forGUI = struct; - -% get unit max channels -maxChannels = bc_getWaveformMaxChannel(templateWaveforms); -qMetric.maxChannels = maxChannels; - -% get unique templates -uniqueTemplates = unique(spikeTemplates); - -% extract and save or load in raw waveforms - -if param.extractRaw - [rawWaveformsFull, rawWaveformsPeakChan, signalToNoiseRatio] = bc_extractRawWaveformsFast(param, ... - spikeTimes_samples, spikeTemplates, param.reextractRaw, savePath, param.verbose); % takes ~10' for -end -% an average dataset, the first time it is run, <1min after that - -% previous, slower method: -% [qMetric.rawWaveforms, qMetric.rawMemMap] = bc_extractRawWaveforms(param.rawFolder, param.nChannels, param.nRawSpikesToExtract, ... -% spikeTimes, spikeTemplates, usedChannels, verbose); - -% divide recording into time chunks -spikeTimes_seconds = spikeTimes_samples ./ param.ephys_sample_rate; %convert to seconds after using sample indices to extract raw waveforms -if param.computeTimeChunks - timeChunks = [min(spikeTimes_seconds):param.deltaTimeChunk:max(spikeTimes_seconds), max(spikeTimes_seconds)]; -else - timeChunks = [min(spikeTimes_seconds), max(spikeTimes_seconds)]; -end - -%% loop through units and get quality metrics -fprintf('\n Extracting quality metrics from %s ... \n', param.rawFile) - -for iUnit = 61:80%length(uniqueTemplates) - clearvars thisUnit theseSpikeTimes theseAmplis theseSpikeTemplates - - % get this unit's attributes - thisUnit = uniqueTemplates(iUnit); - qMetric.phy_clusterID(iUnit) = thisUnit - 1; % this is the cluster ID as it appears in phy - qMetric.clusterID(iUnit) = thisUnit; % this is the cluster ID as it appears in phy, 1-indexed (adding 1) - - theseSpikeTimes = spikeTimes_seconds(spikeTemplates == thisUnit); - theseAmplis = templateAmplitudes(spikeTemplates == thisUnit); - - %% remove duplicate spikes - - - %% percentage spikes missing (false negatives) - param.plotDetails=1 - [percentageSpikesMissing_gaussian, percentageSpikesMissing_symmetric, ksTest_pValue, ~, ~, ~] = ... - bc_percSpikesMissing(theseAmplis, theseSpikeTimes, timeChunks, param.plotDetails); - - %% fraction contamination (false positives) - tauR_window = param.tauR_valuesMin:param.tauR_valuesStep:param.tauR_valuesMax; - [fractionRPVs, ~, ~] = bc_fractionRPviolations(theseSpikeTimes, theseAmplis, ... - tauR_window, param.tauC, ... - timeChunks, param.plotDetails, NaN); - - %% define timechunks to keep: keep times with low percentage spikes missing and low fraction contamination - - [theseSpikeTimes, theseAmplis, theseSpikeTemplates, qMetric.useTheseTimesStart(iUnit), qMetric.useTheseTimesStop(iUnit),... - qMetric.RPV_tauR_estimate(iUnit)] = bc_defineTimechunksToKeep(... - percentageSpikesMissing_gaussian, fractionRPVs, param.maxPercSpikesMissing, ... - param.maxRPVviolations, theseAmplis, theseSpikeTimes, spikeTemplates, timeChunks); %QQ add kstest thing, symmetric ect -param.plotDetails=0 - %% re-compute percentage spikes missing and fraction contamination on timechunks - thisUnits_timesToUse = [qMetric.useTheseTimesStart(iUnit), qMetric.useTheseTimesStop(iUnit)]; - - [qMetric.percentageSpikesMissing_gaussian(iUnit), qMetric.percentageSpikesMissing_symmetric(iUnit), ... - qMetric.ksTest_pValue(iUnit), forGUI.ampliBinCenters{iUnit}, forGUI.ampliBinCounts{iUnit}, ... - forGUI.ampliGaussianFit{iUnit}] = bc_percSpikesMissing(theseAmplis, theseSpikeTimes, ... - thisUnits_timesToUse, param.plotDetails); - - [qMetric.fractionRPVs(iUnit,:), ~, ~] = bc_fractionRPviolations(theseSpikeTimes, theseAmplis, ... - tauR_window, param.tauC, thisUnits_timesToUse, param.plotDetails, qMetric.RPV_tauR_estimate(iUnit)); - - %% presence ratio (potential false negatives) - [qMetric.presenceRatio(iUnit)] = bc_presenceRatio(theseSpikeTimes, theseAmplis, param.presenceRatioBinSize, ... - qMetric.useTheseTimesStart(iUnit), qMetric.useTheseTimesStop(iUnit), param.plotDetails); - - %% maximum cumulative drift estimate - [qMetric.maxDriftEstimate(iUnit),qMetric.cumDriftEstimate(iUnit)] = bc_maxDriftEstimate(pcFeatures, pcFeatureIdx, theseSpikeTemplates, ... - theseSpikeTimes, channelPositions(:,2), thisUnit, param.driftBinSize, param.computeDrift, param.plotDetails); - - %% number spikes - qMetric.nSpikes(iUnit) = bc_numberSpikes(theseSpikeTimes); - - %% waveform - waveformBaselineWindow = [param.waveformBaselineWindowStart, param.waveformBaselineWindowStop]; - [qMetric.nPeaks(iUnit), qMetric.nTroughs(iUnit), qMetric.isSomatic(iUnit), forGUI.peakLocs{iUnit},... - forGUI.troughLocs{iUnit}, qMetric.waveformDuration_peakTrough(iUnit), ... - forGUI.spatialDecayPoints(iUnit,:), qMetric.spatialDecaySlope(iUnit), qMetric.waveformBaselineFlatness(iUnit), .... - forGUI.tempWv(iUnit,:)] = bc_waveformShape(templateWaveforms,thisUnit, qMetric.maxChannels(thisUnit),... - param.ephys_sample_rate, channelPositions, param.maxWvBaselineFraction, waveformBaselineWindow,... - param.minThreshDetectPeaksTroughs, param.plotDetails); %do we need tempWv ? - - %% amplitude - if param.extractRaw - qMetric.rawAmplitude(iUnit) = bc_getRawAmplitude(rawWaveformsFull(iUnit,rawWaveformsPeakChan(iUnit),:), ... - param.ephysMetaFile, param.probeType); - else - qMetric.rawAmplitude(iUnit) =NaN; - qMetric.signalToNoiseRatio(iUnit) = NaN; - end - - %% distance metrics - if param.computeDistanceMetrics - [qMetric.isoD(iUnit), qMetric.Lratio(iUnit), qMetric.silhouetteScore(iUnit), ... - forGUI.d2_mahal{iUnit}, forGUI.mahalobnis_Xplot{iUnit}, forGUI.mahalobnis_Yplot{iUnit}] = bc_getDistanceMetrics(pcFeatures, ... - pcFeatureIdx, thisUnit, sum(spikeTemplates == thisUnit), spikeTemplates == thisUnit, theseSpikeTemplates, ... - param.nChannelsIsoDist, param.plotDetails); %QQ - end - if ((mod(iUnit, 50) == 0) || iUnit == length(uniqueTemplates)) && param.verbose - fprintf(['\n Finished ', num2str(iUnit), ' / ', num2str(length(uniqueTemplates)), ' units.']); - end - -end - - -%% get unit types and save data -qMetric.maxChannels = qMetric.maxChannels(uniqueTemplates)'; -if param.extractRaw - qMetric.signalToNoiseRatio = signalToNoiseRatio'; -end - -fprintf('\n Finished extracting quality metrics from %s', param.rawFile) -try - qMetric = bc_saveQMetrics(param, qMetric, forGUI, savePath); - fprintf('\n Saved quality metrics from %s to %s \n', param.rawFile, savePath) - %% get some summary plots - -catch - warning('\n Warning, quality metrics from %s not saved! \n', param.rawFile) -end - -unitType = bc_getQualityUnitType(param, qMetric, savePath); -bc_plotGlobalQualityMetric(qMetric, param, unitType, uniqueTemplates, forGUI.tempWv); -end diff --git a/qualityMetrics/bc_runAllQualityMetrics.m b/qualityMetrics/bc_runAllQualityMetrics.m index 41e0592c..68bb7eb4 100644 --- a/qualityMetrics/bc_runAllQualityMetrics.m +++ b/qualityMetrics/bc_runAllQualityMetrics.m @@ -4,39 +4,8 @@ % ------ % Inputs % ------ -% param: parameter structure with fields: -% tauR = 0.0010; %refractory period time (s) -% tauC = 0.0002; %censored period time (s) -% maxPercSpikesMissing: maximum percent (eg 30) of estimated spikes below detection -% threshold to define timechunks in the recording on which to compute -% quality metrics for each unit. -% minNumSpikes: minimum number of spikes (eg 300) for unit to classify it as good -% maxNtroughsPeaks: maximum number of troughs and peaks (eg 3) to classify unit -% waveform as good -% isSomatic: boolean, whether to keep only somatic spikes -% maxRPVviolations: maximum estimated % (eg 20) of refractory period violations to classify unit as good -% minAmplitude: minimum amplitude of raw waveform in microVolts to -% classify unit as good -% plotThis: boolean, whether to plot figures for each metric and unit - ! -% this will create * a lot * of plots if run on all units - use just -% for debugging a particular issue / creating plots for one single -% unit -% rawFolder: string containing the location of the raw .dat or .bin file -% deltaTimeChunk: size of time chunks to cut the recording in, in seconds -% (eg 600 for 10 min time chunks or duration of recording if you don't -% want time chunks) -% ephys_sample_rate: recording sample rate (eg 30000) -% nChannels: number of recorded channels, including any sync channels (eg -% 385) -% nRawSpikesToExtract: number of spikes to extract from the raw data for -% each waveform (eg 100) -% nChannelsIsoDist: number of channels on which to compute the distance -% metrics (eg 4) -% computeDistanceMetrics: boolean, whether to compute distance metrics or not -% isoDmin: minimum isolation distance to classify unit as single-unit -% lratioMin: minimum l-ratio to classify unit as single-unit -% ssMin: silhouette score to classify unit as single-unit -% computeTimeChunks +% param: parameter structure. See bc_qualityParamValues for all fields +% anf information about them. % % spikeTimes_samples: nSpikes × 1 uint64 vector giving each spike time in samples (*not* seconds) % @@ -55,8 +24,11 @@ % pcFeatureIdx: nTemplates × nPCFeatures uint32 matrix specifying which % channels contribute to each entry in dim 3 of the pc_features matrix % -% channelPositions -% goodChannels +% channelPositions: nChannels x 2 double matrix corresponding to the x and +% z locations of each channel on the probe, in um +% +% savePath: sting defining the path where to save bombcell's output +% %------ % Outputs % ------ @@ -65,27 +37,27 @@ % 'cutoff' parameter below which there are no spikes to estimate the % percentage of spikes below the spike-sorting detection threshold - will % slightly underestimate in the case of 'bursty' cells with burst -% adaptation (eg see Fig 5B of Harris/Buzsaki 2000 DOI: 10.1152/jn.2000.84.1.401) +% adaptation (eg see Fig 5B of Harris/Buzsaki 2000 DOI: 10.1152/jn.2000.84.1.401) % fractionRefractoryPeriodViolations: percentage of false positives, ie spikes within the refractory period % defined by param.tauR of another spike. This also excludes -% duplicated spikes that occur within param.tauC of another spike. -% useTheseTimes : param.computeTimeChunks, this defines the time chunks +% duplicated spikes that occur within param.tauC of another spike. +% useTheseTimes : param.computeTimeChunks, this defines the time chunks % (deivding the recording in time of chunks of param.deltaTimeChunk size) % where the percentage of spike missing and percentage of false positives % is below param.maxPercSpikesMissing and param.maxRPVviolations -% nSpikes : number of spikes for each unit +% nSpikes : number of spikes for each unit % nPeaks : number of detected peaks in each units template waveform % nTroughs : number of detected troughs in each units template waveform % isSomatic : a unit is defined as Somatic of its trough precedes its main % peak (see Deligkaris/Frey DOI: 10.3389/fnins.2016.00421) % rawAmplitude : amplitude in uV of the units mean raw waveform at its peak -% channel. The peak channel is defined by the template waveform. +% channel. The peak channel is defined by the template waveform. % spatialDecay : gets the minumum amplitude for each unit 5 channels from % the peak channel and calculates the slope of this decrease in amplitude. % isoD : isolation distance, a measure of how well a units spikes are seperate from % other nearby units spikes % Lratio : l-ratio, a similar measure to isolation distance. see -% Schmitzer-Torbert/Redish 2005 DOI: 10.1016/j.neuroscience.2004.09.066 +% Schmitzer-Torbert/Redish 2005 DOI: 10.1016/j.neuroscience.2004.09.066 % for a comparison of l-ratio/isolation distance % silhouetteScore : another measure similar ti isolation distance and % l-ratio. See Rousseeuw 1987 DOI: 10.1016/0377-0427(87)90125-7) @@ -95,25 +67,32 @@ % (0) or multi-unit (2) %% prepare for quality metrics computations -% initialize structures +% initialize structures qMetric = struct; forGUI = struct; +% check parameter values +param = bc_checkParameterFields(param); + % get unit max channels maxChannels = bc_getWaveformMaxChannel(templateWaveforms); -qMetric.maxChannels = maxChannels; -% get unique templates -uniqueTemplates = unique(spikeTemplates); +% extract and save or load in raw waveforms +[rawWaveformsFull, rawWaveformsPeakChan, signalToNoiseRatio] = bc_extractRawWaveformsFast(param, ... + spikeTimes_samples, spikeTemplates, param.reextractRaw, savePath, param.verbose); % takes ~10' for +% an average dataset, the first time it is run, <1min after that -% extract and save or load in raw waveforms -if param.extractRaw - [rawWaveformsFull, rawWaveformsPeakChan, signalToNoiseRatio] = bc_extractRawWaveformsFast(param, ... - spikeTimes_samples, spikeTemplates, param.reextractRaw, savePath, param.verbose); % takes ~10' for - % an average dataset, the first time it is run, <1min after that -end -% divide recording into time chunks +% remove any duplicate spikes +[uniqueTemplates, ~, spikeTimes_samples, spikeTemplates, templateAmplitudes, ... + pcFeatures, rawWaveformsFull, rawWaveformsPeakChan, signalToNoiseRatio, ... + qMetric.maxChannels] = ... + bc_removeDuplicateSpikes(spikeTimes_samples, spikeTemplates, templateAmplitudes, ... + pcFeatures, rawWaveformsFull, rawWaveformsPeakChan, signalToNoiseRatio, ... + maxChannels, param.removeDuplicateSpikes, param.duplicateSpikeWindow_s, ... + param.ephys_sample_rate, param.saveSpikes_withoutDuplicates, savePath, param.recomputeDuplicateSpikes); + +% divide recording into time chunks spikeTimes_seconds = spikeTimes_samples ./ param.ephys_sample_rate; %convert to seconds after using sample indices to extract raw waveforms if param.computeTimeChunks timeChunks = [min(spikeTimes_seconds):param.deltaTimeChunk:max(spikeTimes_seconds), max(spikeTimes_seconds)]; @@ -124,10 +103,10 @@ %% loop through units and get quality metrics fprintf('\n Extracting quality metrics from %s ... \n', param.rawFile) -for iUnit = 1:length(uniqueTemplates) - clearvars thisUnit theseSpikeTimes theseAmplis theseSpikeTemplates +for iUnit = 1:size(uniqueTemplates, 1) - % get this unit's attributes + clearvars thisUnit theseSpikeTimes theseAmplis theseSpikeTemplates + % get this unit's attributes thisUnit = uniqueTemplates(iUnit); qMetric.phy_clusterID(iUnit) = thisUnit - 1; % this is the cluster ID as it appears in phy qMetric.clusterID(iUnit) = thisUnit; % this is the cluster ID as it appears in phy, 1-indexed (adding 1) @@ -135,9 +114,6 @@ theseSpikeTimes = spikeTimes_seconds(spikeTemplates == thisUnit); theseAmplis = templateAmplitudes(spikeTemplates == thisUnit); - %% remove duplicate spikes - - %% percentage spikes missing (false negatives) [percentageSpikesMissing_gaussian, percentageSpikesMissing_symmetric, ksTest_pValue, ~, ~, ~] = ... bc_percSpikesMissing(theseAmplis, theseSpikeTimes, timeChunks, param.plotDetails); @@ -147,51 +123,51 @@ [fractionRPVs, ~, ~] = bc_fractionRPviolations(theseSpikeTimes, theseAmplis, ... tauR_window, param.tauC, ... timeChunks, param.plotDetails, NaN); - + %% define timechunks to keep: keep times with low percentage spikes missing and low fraction contamination - [theseSpikeTimes, theseAmplis, theseSpikeTemplates, qMetric.useTheseTimesStart(iUnit), qMetric.useTheseTimesStop(iUnit),... - qMetric.RPV_tauR_estimate(iUnit)] = bc_defineTimechunksToKeep(... + [theseSpikeTimes, theseAmplis, theseSpikeTemplates, qMetric.useTheseTimesStart(iUnit), qMetric.useTheseTimesStop(iUnit), ... + qMetric.RPV_tauR_estimate(iUnit)] = bc_defineTimechunksToKeep( ... percentageSpikesMissing_gaussian, fractionRPVs, param.maxPercSpikesMissing, ... - param.maxRPVviolations, theseAmplis, theseSpikeTimes, spikeTemplates, timeChunks); %QQ add kstest thing, symmetric ect + param.maxRPVviolations, theseAmplis, theseSpikeTimes, spikeTemplates, timeChunks); %QQ add kstest thing, symmetric ect %% re-compute percentage spikes missing and fraction contamination on timechunks thisUnits_timesToUse = [qMetric.useTheseTimesStart(iUnit), qMetric.useTheseTimesStop(iUnit)]; - + [qMetric.percentageSpikesMissing_gaussian(iUnit), qMetric.percentageSpikesMissing_symmetric(iUnit), ... qMetric.ksTest_pValue(iUnit), forGUI.ampliBinCenters{iUnit}, forGUI.ampliBinCounts{iUnit}, ... forGUI.ampliGaussianFit{iUnit}] = bc_percSpikesMissing(theseAmplis, theseSpikeTimes, ... thisUnits_timesToUse, param.plotDetails); - [qMetric.fractionRPVs(iUnit,:), ~, ~] = bc_fractionRPviolations(theseSpikeTimes, theseAmplis, ... + [qMetric.fractionRPVs(iUnit, :), ~, ~] = bc_fractionRPviolations(theseSpikeTimes, theseAmplis, ... tauR_window, param.tauC, thisUnits_timesToUse, param.plotDetails, qMetric.RPV_tauR_estimate(iUnit)); - + %% presence ratio (potential false negatives) [qMetric.presenceRatio(iUnit)] = bc_presenceRatio(theseSpikeTimes, theseAmplis, param.presenceRatioBinSize, ... qMetric.useTheseTimesStart(iUnit), qMetric.useTheseTimesStop(iUnit), param.plotDetails); %% maximum cumulative drift estimate - [qMetric.maxDriftEstimate(iUnit),qMetric.cumDriftEstimate(iUnit)] = bc_maxDriftEstimate(pcFeatures, pcFeatureIdx, theseSpikeTemplates, ... - theseSpikeTimes, channelPositions(:,2), thisUnit, param.driftBinSize, param.computeDrift, param.plotDetails); - + [qMetric.maxDriftEstimate(iUnit), qMetric.cumDriftEstimate(iUnit)] = bc_maxDriftEstimate(pcFeatures, pcFeatureIdx, theseSpikeTemplates, ... + theseSpikeTimes, channelPositions(:, 2), thisUnit, param.driftBinSize, param.computeDrift, param.plotDetails); + %% number spikes qMetric.nSpikes(iUnit) = bc_numberSpikes(theseSpikeTimes); %% waveform waveformBaselineWindow = [param.waveformBaselineWindowStart, param.waveformBaselineWindowStop]; - [qMetric.nPeaks(iUnit), qMetric.nTroughs(iUnit), qMetric.isSomatic(iUnit), forGUI.peakLocs{iUnit},... + [qMetric.nPeaks(iUnit), qMetric.nTroughs(iUnit), qMetric.isSomatic(iUnit), forGUI.peakLocs{iUnit}, ... forGUI.troughLocs{iUnit}, qMetric.waveformDuration_peakTrough(iUnit), ... - forGUI.spatialDecayPoints(iUnit,:), qMetric.spatialDecaySlope(iUnit), qMetric.waveformBaselineFlatness(iUnit), .... - forGUI.tempWv(iUnit,:)] = bc_waveformShape(templateWaveforms,thisUnit, qMetric.maxChannels(thisUnit),... - param.ephys_sample_rate, channelPositions, param.maxWvBaselineFraction, waveformBaselineWindow,... - param.minThreshDetectPeaksTroughs, param.plotDetails); %do we need tempWv ? - + forGUI.spatialDecayPoints(iUnit, :), qMetric.spatialDecaySlope(iUnit), qMetric.waveformBaselineFlatness(iUnit), ... . + forGUI.tempWv(iUnit, :)] = bc_waveformShape(templateWaveforms, thisUnit, qMetric.maxChannels(thisUnit), ... + param.ephys_sample_rate, channelPositions, param.maxWvBaselineFraction, waveformBaselineWindow, ... + param.minThreshDetectPeaksTroughs, param.plotDetails); %do we need tempWv ? + %% amplitude if param.extractRaw - qMetric.rawAmplitude(iUnit) = bc_getRawAmplitude(rawWaveformsFull(iUnit,rawWaveformsPeakChan(iUnit),:), ... + qMetric.rawAmplitude(iUnit) = bc_getRawAmplitude(rawWaveformsFull(iUnit, rawWaveformsPeakChan(iUnit), :), ... param.ephysMetaFile, param.probeType, param.gain_to_uV); else - qMetric.rawAmplitude(iUnit) =NaN; - qMetric.signalToNoiseRatio(iUnit) = NaN; + qMetric.rawAmplitude(iUnit) = NaN; + qMetric.signalToNoiseRatio(iUnit) = NaN; end %% distance metrics @@ -202,16 +178,15 @@ param.nChannelsIsoDist, param.plotDetails); %QQ end if ((mod(iUnit, 50) == 0) || iUnit == length(uniqueTemplates)) && param.verbose - fprintf(['\n Finished ', num2str(iUnit), ' / ', num2str(length(uniqueTemplates)), ' units.']); + fprintf(['\n Finished ', num2str(iUnit), ' / ', num2str(length(uniqueTemplates)), ' units.']); end end - %% get unit types and save data -qMetric.maxChannels = qMetric.maxChannels(uniqueTemplates)'; +qMetric.maxChannels = qMetric.maxChannels(uniqueTemplates)'; if param.extractRaw - qMetric.signalToNoiseRatio = signalToNoiseRatio'; + qMetric.signalToNoiseRatio = signalToNoiseRatio'; end fprintf('\n Finished extracting quality metrics from %s', param.rawFile) diff --git a/qualityMetrics/helpers/bc_extractRawWaveformsFast.asv b/qualityMetrics/helpers/bc_extractRawWaveformsFast.asv deleted file mode 100644 index 0b8af910..00000000 --- a/qualityMetrics/helpers/bc_extractRawWaveformsFast.asv +++ /dev/null @@ -1,191 +0,0 @@ - -function [rawWaveformsFull, rawWaveformsPeakChan, signalToNoiseRatio] = bc_extractRawWaveformsFast(param, spikeTimes_samples, ... - spikeTemplates, reExtract, savePath, verbose) -% JF, Get raw waveforms for all templates -% ------ -% Inputs -% ------ -% param with: - % rawFile: string containing the location of the raw .bin or .dat file location - % nChannels: number of recorded channels (including sync), (eg 385) - % nSpikesToExtract: number of spikes to extract per template - % detrendWaveforms: boolean, whether to detrend spikes or not -% spikeTimes_samples: nSpikes × 1 uint64 vector giving each spike time in samples (*not* seconds) -% spikeTemplates: nSpikes × 1 uint32 vector giving the identity of each -% spike's matched template -% reextract: boolean, whether to reextract raw waveforms or not -% verbose: boolean, display progress bar or not -% savePath: where to save output data -% ------ -% Outputs -% ------ -% rawWaveformsFull: nUnits × nTimePoints × nChannels single matrix of -% mean raw waveforms for each unit and channel -% rawWaveformsPeakChan: nUnits x 1 vector of each unit's channel with the maximum -% amplitude -% signalToNoiseRatio: nUnits x 1 vector defining the absolute maximum -% value of the mean raw waveform for that value divided by the variance -% of the data before detected waveforms. implementation : Enny van Beest - -%% Check if data needs to be extracted -rawWaveformFolder = dir(fullfile(savePath, 'templates._bc_rawWaveforms.npy')); - -if ~isempty(rawWaveformFolder) && reExtract == 0 % no need to extract data, - % simply load it in - rawWaveformsFull = readNPY(fullfile(savePath, 'templates._bc_rawWaveforms.npy')); - rawWaveformsPeakChan = readNPY(fullfile(savePath, 'templates._bc_rawWaveformPeakChannels.npy')); - -else -%% Extract raw waveforms - %% Initialize parameters - nChannels = param.nChannels; % (385) - nSpikesToExtract = param.nRawSpikesToExtract; - spikeWidth = param.spikeWidth; - halfWidth = spikeWidth / 2; - dataTypeNBytes = numel(typecast(cast(0, 'uint16'), 'uint8')); - clustInds = unique(spikeTemplates); - nClust = numel(clustInds); - rawFileInfo = dir(param.rawFile); - BatchSize = 5000; - if param.saveMultipleRaw && ~isfolder(fullfile(savePath,'RawWaveforms')) - mkdir(fullfile(savePath,'RawWaveforms')) - end - - fprintf('\n Extracting raw waveforms from %s ...', param.rawFile) - % Get binary file name - fid = fopen(param.rawFile, 'r'); - - %% Interate over spike clusters and find spikes associated with them - % Initialize and pre-allocate variables - rawWaveforms = struct; - rawWaveformsFull = nan(nClust, nChannels-param.nSyncChannels, spikeWidth); - rawWaveformsPeakChan = nan(nClust, 1); - average_baseline = cell(1,nClust); - - % loop over spike clusters - for iCluster = 1:nClust - % Get cluster information - rawWaveforms(iCluster).clInd = clustInds(iCluster); - rawWaveforms(iCluster).spkInd = spikeTimes_samples(spikeTemplates == clustInds(iCluster)); - - % Determine # of spikes to extract - if numel(rawWaveforms(iCluster).spkInd) >= nSpikesToExtract - spksubi = round(linspace(1, numel(rawWaveforms(iCluster).spkInd), nSpikesToExtract))'; - rawWaveforms(iCluster).spkIndsub = rawWaveforms(iCluster).spkInd(spksubi); - else - rawWaveforms(iCluster).spkIndsub = rawWaveforms(iCluster).spkInd; - end - nSpkLocal = numel(rawWaveforms(iCluster).spkIndsub); - - % loop over spikes for this cluster - rawWaveforms(iCluster).spkMap = nan(nChannels-param.nSyncChannels, spikeWidth, nSpkLocal); - for iSpike = 1:nSpkLocal - thisSpikeIdx = rawWaveforms(iCluster).spkIndsub(iSpike); - - if ((thisSpikeIdx - halfWidth) * nChannels) * dataTypeNBytes > halfWidth &&... - (thisSpikeIdx + halfWidth) * nChannels * dataTypeNBytes < rawFileInfo.bytes % check that it's not out of bounds - % extract spike - bytei = ((thisSpikeIdx - halfWidth) * nChannels) * dataTypeNBytes; - fseek(fid, bytei, 'bof'); - data0 = fread(fid, nChannels*spikeWidth, 'int16=>int16'); % read individual waveform from binary file - frewind(fid); - data = reshape(data0, nChannels, []); - % if whitenBool - % [data, mu, invMat, whMat]=whiten(double(data)); - % end - % if size(data, 2) == spikeWidth - % rawWaveforms(iCluster).spkMap(:, :, iSpike) = data; - % end - - % detrend spike if required - if param.detrendWaveform - rawWaveforms(iCluster).spkMap(:, :, iSpike) = permute(detrend(double(permute(data(1:nChannels-param.nSyncChannels, :),[2,1]))), [2,1]); - else - rawWaveforms(iCluster).spkMap(:, :, iSpike) = data(1:nChannels-param.nSyncChannels, :); %remove sync channel - end - - end - - end - - % align raw spikes to each other (using the trough) - clearvars meanWaveform_temp peakChan_temp - meanWaveform_temp = nanmean(rawWaveforms(iCluster).spkMap,3); - [~, peakChan_temp] = max(max(meanWaveform_temp, [], 2) - min(meanWaveform_temp, [], 2)); % maximum channel per cluster - - [~, troughLocation] = min(squeeze(rawWaveforms(iCluster).spkMap(peakChan_temp, :, :))); - end - - % save waveforms for unitmatch - if param.saveMultipleRaw - tmpspkmap = permute(rawWaveforms(iCluster).spkMap,[2,1,3]); % Compatible with UnitMatch QQ - %Do smoothing in batches - nBatch = ceil(nSpkLocal./BatchSize); - for bid = 1:nBatch - spkId = (bid-1)*BatchSize+(1:BatchSize); - spkId(spkId>nSpkLocal) = []; - tmpspkmap(:,:,spkId) = smoothdata(tmpspkmap(:,:,spkId) - mean(tmpspkmap(1:param.waveformBaselineNoiseWindow,:,spkId),1),1,'gaussian',5); % Subtract baseline and smooth - end - % Save two averages for UnitMatch - tmpspkmap = arrayfun(@(X) nanmedian(tmpspkmap(:,:,(X-1)*floor(size(tmpspkmap,3)/2)+1:X*floor(size(tmpspkmap,3)/2)),3),1:2,'Uni',0); - tmpspkmap = cat(3,tmpspkmap{:}); - writeNPY(tmpspkmap, fullfile(savePath,'RawWaveforms',['Unit' num2str(clustInds(iCluster)-1) '_RawSpikes.npy'])) % Back to 0-indexed (same as Kilosort) - end - - % get average, baseline-subtracted and smoothed raw waveform - rawWaveforms(iCluster).spkMapMean = nanmean(rawWaveforms(iCluster).spkMap, 3); % initialize and pre-allocate - rawWaveformsFull(iCluster, :, :) = rawWaveforms(iCluster).spkMapMean - ... - mean(rawWaveforms(iCluster).spkMapMean(:, 1:param.waveformBaselineNoiseWindow), 2); % remove baseline - spkMapMean_sm = smoothdata(rawWaveforms(iCluster).spkMapMean, 1, 'gaussian', 5); % smooth - - [~, rawWaveformsPeakChan(iCluster)] = max(max(spkMapMean_sm, [], 2) - min(spkMapMean_sm, [], 2)); % maximum channel per cluster - average_baseline{iCluster} = squeeze(nanmean(rawWaveforms(iCluster).spkMap(rawWaveformsPeakChan(iCluster),... - 1:param.waveformBaselineNoiseWindow,:),3)); % waveform baseline (for signal-to-noise calculation) - - % delete current cluster raw spikes (for memory) - rawWaveforms(iCluster).spkMap = []; - - % display progress - if (mod(iCluster, 100) == 0 || iCluster == nClust) && verbose - fprintf(['\n Finished ', num2str(iCluster), ' / ', num2str(nClust), ' units.']); - end - - end - % close file - fclose(fid); - - % save extracted mean raw waveforms - if ~isfolder(savePath) - mkdir(savePath) - end - writeNPY(rawWaveformsFull, fullfile(savePath, 'templates._bc_rawWaveforms.npy')) - writeNPY(rawWaveformsPeakChan, fullfile(savePath, 'templates._bc_rawWaveformPeakChannels.npy')) - - % save mean raw waveform baseline average (for signal-to-noise - % calculation) - average_baseline_cat = cat(2, average_baseline{:})'; - average_baseline_idx = arrayfun(@(x) ones(param.waveformBaselineNoiseWindow,1)*x, 1:nClust, 'UniformOutput',false); - average_baseline_idx_cat = cat(1, average_baseline_idx{:}); - writeNPY(average_baseline_cat, fullfile(savePath, 'templates._bc_baselineNoiseAmplitude.npy')) - writeNPY(average_baseline_idx_cat, fullfile(savePath, 'templates._bc_baselineNoiseAmplitudeIndex.npy')) -end -%% estimate signal-to-noise ratio -clustInds = unique(spikeTemplates); -nClust = numel(clustInds); - -if ~isempty(fullfile(savePath, 'templates._bc_baselineNoiseAmplitude.npy')) - - average_baseline_cat = readNPY(fullfile(savePath, 'templates._bc_baselineNoiseAmplitude.npy')); - average_baseline_idx_cat = readNPY(fullfile(savePath, 'templates._bc_baselineNoiseAmplitudeIndex.npy')); - - signalToNoiseRatio = cell2mat(arrayfun(@(X) max(abs(squeeze(rawWaveformsFull(X,rawWaveformsPeakChan(X),:)))) ./... - var(average_baseline_cat(average_baseline_idx_cat==X)),1:nClust,'Uni',0))'; - - %signalToNoiseRatio = cell2mat(arrayfun(@(X) max(abs(squeeze(rawWaveformsFull(X,rawWaveformsPeakChan(X),:))),[],'omitnan') ./... - % var(average_baseline_cat(average_baseline_idx_cat==X)),1:nClust,'Uni',0), 'omitnan')'; - -else - fprintf('No saved waveform baseline file found, skipping signal to noise calculation') - signalToNoiseRatio = nan(nClust,1); -end - diff --git a/qualityMetrics/helpers/bc_extractRawWaveformsFast.m b/qualityMetrics/helpers/bc_extractRawWaveformsFast.m index 707abf84..1f68b09f 100644 --- a/qualityMetrics/helpers/bc_extractRawWaveformsFast.m +++ b/qualityMetrics/helpers/bc_extractRawWaveformsFast.m @@ -28,6 +28,7 @@ % of the data before detected waveforms. implementation : Enny van Beest %% Check if data needs to be extracted +if param.extractRaw rawWaveformFolder = dir(fullfile(savePath, 'templates._bc_rawWaveforms.npy')); if ~isempty(rawWaveformFolder) && reExtract == 0 % no need to extract data, @@ -90,12 +91,6 @@ data0 = fread(fid, nChannels*spikeWidth, 'int16=>int16'); % read individual waveform from binary file frewind(fid); data = reshape(data0, nChannels, []); - % if whitenBool - % [data, mu, invMat, whMat]=whiten(double(data)); - % end - % if size(data, 2) == spikeWidth - % rawWaveforms(iCluster).spkMap(:, :, iSpike) = data; - % end % detrend spike if required if param.detrendWaveform @@ -189,3 +184,8 @@ signalToNoiseRatio = nan(nClust,1); end +else + rawWaveformsFull = []; + rawWaveformsPeakChan = []; + signalToNoiseRatio = []; +end