Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 committed Oct 4, 2023
1 parent cb11247 commit 1a2b81a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 18 deletions.
23 changes: 12 additions & 11 deletions src/spikeanalysis/curated_spike_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,29 @@

from .spike_analysis import SpikeAnalysis


def read_responsive_neurons(folder_path) -> dict:
"""
Function for reading a response profile json file
and converting into the appropriate dictionary for curation
Parameters
----------
folder_path: str | Path
The path way to the directory containing the `response_profile.json`
Returns
-------
curation: dict
The curation file that has been previously generated for a dataset
"""

import json

file_path = Path(folder_path)
assert file_path.is_dir(), "please input the directory containing the response_profile json"

with open(file_path/'response_profile.json', 'r') as read_file:
with open(file_path / "response_profile.json", "r") as read_file:
response_dict = json.load(read_file)

for stim in response_dict.keys():
Expand All @@ -48,7 +50,7 @@ def __init__(self, curation: dict):
----------
curation: dict
The curation dictionary to be used for curated data
"""

self.curation = curation
Expand All @@ -69,12 +71,11 @@ def curate(
by_trial: Literal["all"] | bool = False,
trial_index: Optional[int] = None,
):

"""Function for loading the current curation
Parameters
----------
criteria: str | dict
by_stim: bool, default: False
Whether to analyze data by a particular stimulus
by_response: bool, default: False
Expand All @@ -83,7 +84,7 @@ def curate(
*****
trial_index: Optional[int], default: None
Must be given if by_trial=True, to indicate which specific trial to be used
"""
curation = self.curation

Expand All @@ -103,7 +104,7 @@ def curate(
), f"by_trial must be 'all' or boolean you entered {by_trial}"

if by_trial == "all":
if len(sub_curation.shape)==1:
if len(sub_curation.shape) == 1:
sub_curation = np.expand_dims(sub_curation, axis=1)
mask = np.all(sub_curation, axis=1)
self.cluster_ids = self.cluster_ids[mask]
Expand All @@ -114,7 +115,7 @@ def curate(
self.cluster_ids = self.cluster_ids[mask]

else:
if len(sub_curation.shape)==1:
if len(sub_curation.shape) == 1:
sub_curation = np.expand_dims(sub_curation, axis=1)
mask = np.any(sub_curation, axis=1)
self.cluster_ids = self.cluster_ids[mask]
Expand All @@ -132,7 +133,7 @@ def curate(
else:
mask_array = np.array(mask_list[0])

if len(mask_array.shape)==1:
if len(mask_array.shape) == 1:
mask_array = np.expand_dims(mask_array, axis=1)

if by_trial == "all":
Expand All @@ -155,7 +156,7 @@ def curate(
else:
mask_array = np.array(mask_list[0])

if len(mask_array.shape)==1:
if len(mask_array.shape) == 1:
mask_array = np.expand_dims(mask_array, axis=1)

if by_trial == "all":
Expand Down
16 changes: 9 additions & 7 deletions test/test_curated_spike_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,25 @@ def test_curation_both(csa):
csa.curate(criteria={"test": "activated"}, by_stim=True, by_response=True, by_trial=False)
assert len(csa.cluster_ids) == 1
csa.revert_curation()
assert len(csa.cluster_ids)==2
assert len(csa.cluster_ids) == 2


def test_curation_stim(csa):
csa.curate(criteria="test", by_stim=True, by_response=False, by_trial=False)
assert len(csa.cluster_ids)==1
assert len(csa.cluster_ids) == 1
csa.revert_curation()
assert len(csa.cluster_ids) == 2


def test_curation_response(csa):
csa.curate(criteria="activated", by_stim=False, by_response=True, by_trial=False)
assert len(csa.cluster_ids)==1
assert len(csa.cluster_ids) == 1
csa.revert_curation()
assert len(csa.cluster_ids) == 2


def test_curation_trial_all(csa):
csa.curate(criteria="test", by_stim=True, by_response=False, by_trial='all')
assert len(csa.cluster_ids)==1
csa.curate(criteria="test", by_stim=True, by_response=False, by_trial="all")
assert len(csa.cluster_ids) == 1
csa.revert_curation()
assert len(csa.cluster_ids) == 2
assert len(csa.cluster_ids) == 2

0 comments on commit 1a2b81a

Please sign in to comment.