Skip to content

Commit

Permalink
Better parameters handling
Browse files Browse the repository at this point in the history
  • Loading branch information
DradeAW committed Mar 20, 2024
1 parent d858d0c commit 3836169
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 29 deletions.
1 change: 1 addition & 0 deletions src/lussac/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .lussac_data import LussacData, MonoSortingData, MultiSortingsData
from .lussac_params import LussacParams
from .module import LussacModule, MonoSortingModule, MultiSortingsModule
from .module_factory import ModuleFactory
from .pipeline import LussacPipeline
Expand Down
79 changes: 79 additions & 0 deletions src/lussac/core/lussac_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import json
import pathlib
import platform

import jsmin

import lussac


class LussacParams:

@staticmethod
def load_from_string(params: str, params_folder: pathlib.Path | str | None = None):
"""
Loads the parameters from a string and returns them as a dict.
@param params: str
Lussac's parameters.
@param params_folder: str
Path to replace the "$PARAMS_FOLDER".
"""

if params_folder is not None:
params_folder = str(pathlib.Path(params_folder).absolute())
params = params.replace("$PARAMS_FOLDER", params_folder)
if platform.system() == "Windows": # pragma: no cover (OS specific).
params = params.replace("\\", "\\\\")

return json.loads(params)

@staticmethod
def load_from_json_file(filename: str, params_folder: pathlib.Path | str | None = None) -> dict:
"""
Loads the JSON parameters file and returns its content as a dict.
@param filename: str
Path to the file containing Lussac's parameters.
@param params_folder: Path | str | None
Path to replace the "$PARAMS_FOLDER".
If None (default), will use the parent folder of the filename.
@return params: dict
Lussac's parameters.
"""

if params_folder is None:
params_folder = str(pathlib.Path(filename).parent.absolute())
else:
params_folder = str(pathlib.Path(params_folder).absolute())

with open(filename) as json_file:
minified = jsmin.jsmin(json_file.read()) # Parses out comments.
return LussacParams.load_from_string(minified, params_folder)

@staticmethod
def load_default_params(name: str, folder: pathlib.Path | str) -> dict:
"""
Loads the default parameters from the "params_example" folder.
@param name: str
The name of the default params file to load.
@param folder: str
Path to the folder where to create the "lussac" folder.
@return params: dict
The default parameters.
"""

if not name.startswith("params_"):
name = f"params_{name}"
if not name.endswith(".json"):
name = f"{name}.json"

params_folder = pathlib.Path(lussac.__file__).parent.parent.parent / "params_examples"
file = params_folder / name

return LussacParams.load_from_json_file(str(file), folder)




28 changes: 2 additions & 26 deletions src/lussac/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import logging
import pathlib
import platform
import sys
import argparse
import json
import jsmin
from lussac.core import LussacData, LussacPipeline, LussacSpikeSorter
from lussac.core import LussacData, LussacParams, LussacPipeline, LussacSpikeSorter


def parse_arguments(args: list | None) -> str:
Expand All @@ -25,34 +21,14 @@ def parse_arguments(args: list | None) -> str:
return args.params_file


def load_json(filename: str) -> dict:
"""
Loads the JSON parameters file and returns its content.
@param filename: str
Path to the file containing Lussac's parameters.
@return params: dict
Lussac's parameters.
"""

folder = pathlib.Path(filename).parent
with open(filename) as json_file:
minified = jsmin.jsmin(json_file.read()) # Parses out comments.
minified = minified.replace("$PARAMS_FOLDER", str(folder.absolute()))
if platform.system() == "Windows": # pragma: no cover (OS specific).
minified = minified.replace("\\", "\\\\")

return json.loads(minified)


def main() -> None: # pragma: no cover
"""
The main function to execute Lussac.
"""

# STEP 0: Loading the parameters into the LussacData object.
params_file = parse_arguments(sys.argv[1:])
params = load_json(params_file)
params = LussacParams.load_from_json_file(params_file)
data = LussacData.create_from_params(params)

# STEP 1: Running the spike sorting.
Expand Down
5 changes: 2 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import shutil
import sys
import pytest
from lussac.core import LussacData, LussacPipeline, MonoSortingData, MultiSortingsData
import lussac.main
from lussac.core import LussacData, LussacParams, LussacPipeline, MonoSortingData, MultiSortingsData


sys.path.append(str(pathlib.Path(__file__).parent.parent.absolute())) # Otherwise the tests are not found properly with 'pytest'.
Expand All @@ -13,7 +12,7 @@

@pytest.fixture(scope="session")
def params() -> dict:
return lussac.main.load_json(str(params_path.resolve()))
return LussacParams.load_from_json_file(str(params_path.resolve()))


@pytest.fixture(scope="session")
Expand Down
7 changes: 7 additions & 0 deletions tests/core/test_lussac_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from lussac.core import LussacParams


def test_load_default_params() -> None:
params = LussacParams.load_default_params('params_synthetic', '/aze/')
print(params)
assert params['lussac']['tmp_folder'] == "/aze/lussac/tmp"

0 comments on commit 3836169

Please sign in to comment.