diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..7f73516 --- /dev/null +++ b/.flake8 @@ -0,0 +1,40 @@ +[flake8] +exclude = + .git, + __pycache__, + build, + dist, +--select = D,E,F,W +docstring-convention = numpy +max-line-length = 240 +# For PEP8 error codes see +# http://pep8.readthedocs.org/en/latest/intro.html#error-codes + # D100-D104: missing docstring + # D105: missing docstring in magic method + # D107: missing docstring in __init__ + # W504: line break after binary operator +per-file-ignores = + **/__init__.py: D104 +ignore = + BLK100, + D105 + D107, + E402, + E266, + E721, + E731, + E713, + E714, + E741, + F403, + F405, + E401, + F401, + F811, + F821, + W503, + + +# for compatibility with black +# https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html#flake8 +extend-ignore = E203 diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..5ab0ddf --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,9 @@ +--- +# Documentation +# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file +version: 2 +updates: +- package-ecosystem: github-actions + directory: / + schedule: + interval: monthly diff --git a/.github/workflows/run_precommit.yml b/.github/workflows/run_precommit.yml new file mode 100644 index 0000000..97100d0 --- /dev/null +++ b/.github/workflows/run_precommit.yml @@ -0,0 +1,15 @@ +--- +name: pre-commit + +on: + pull_request: + push: + branches: [main] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + - uses: pre-commit/action@v3.0.0 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b549776..cdacd85 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,88 +1,89 @@ +--- name: Build and test on: - push: - branches: - - main - tags: - - "*" - pull_request: - branches: - - main + push: + branches: + - main + tags: + - '*' + pull_request: + branches: + - main # Run weekly to avoid missing deprecations during low activity - schedule: - - cron: '0 0 * * 1' + schedule: + - cron: 0 0 * * 1 # Allow job to be triggered manually from GitHub interface - workflow_dispatch: + workflow_dispatch: defaults: - run: - shell: bash + run: + shell: bash # Force tox and pytest to use color env: - FORCE_COLOR: true + FORCE_COLOR: true concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true permissions: - contents: read + contents: read jobs: - test: + test: # Check each OS, all supported Python, minimum versions and latest releases - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: ['ubuntu-latest'] - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] - include: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + include: # Basic dependencies only - - os: ubuntu-latest - python-version: 3.8 - dependencies: 'min' + - os: ubuntu-latest + python-version: 3.8 + dependencies: min - env: - DEPENDS: ${{ matrix.dependencies }} - ARCH: ${{ !contains(fromJSON('["none", "min"]'), matrix.dependencies) && matrix.architecture }} + env: + DEPENDS: ${{ matrix.dependencies }} + ARCH: ${{ !contains(fromJSON('["none", "min"]'), matrix.dependencies) && matrix.architecture }} - steps: - - uses: actions/checkout@v4 - with: - submodules: recursive - fetch-depth: 0 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - architecture: ${{ matrix.architecture }} - allow-prereleases: true - - name: Display Python version - run: python -c "import sys; print(sys.version)" - - name: Install tox - run: | - python -m pip install --upgrade pip - python -m pip install tox tox-gh-actions - - name: Show tox config - run: tox c - - name: Run tox - run: tox -v --exit-and-dump-after 1200 + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + fetch-depth: 0 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + architecture: ${{ matrix.architecture }} + allow-prereleases: true + - name: Display Python version + run: python -c "import sys; print(sys.version)" + - name: Install tox + run: | + python -m pip install --upgrade pip + python -m pip install tox tox-gh-actions + - name: Show tox config + run: tox c + - name: Run tox + run: tox -v --exit-and-dump-after 1200 - publish: - runs-on: ubuntu-latest - environment: "Package deployment" - needs: [test] - permissions: + publish: + runs-on: ubuntu-latest + environment: Package deployment + needs: [test] + permissions: # Required for trusted publishing - id-token: write - if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') - steps: - - uses: actions/checkout@v4 - with: - submodules: recursive - fetch-depth: 0 - - run: pipx run build - - uses: pypa/gh-action-pypi-publish@release/v1 + id-token: write + if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + fetch-depth: 0 + - run: pipx run build + - uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/update_precommit_hooks.yml b/.github/workflows/update_precommit_hooks.yml new file mode 100644 index 0000000..a0cf4bf --- /dev/null +++ b/.github/workflows/update_precommit_hooks.yml @@ -0,0 +1,53 @@ +--- +name: Update precommit hooks + + +on: + +# Uses the cron schedule for github actions +# +# https://docs.github.com/en/free-pro-team@latest/actions/reference/events-that-trigger-workflows#scheduled-events +# +# ┌───────────── minute (0 - 59) +# │ ┌───────────── hour (0 - 23) +# │ │ ┌───────────── day of the month (1 - 31) +# │ │ │ ┌───────────── month (1 - 12 or JAN-DEC) +# │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT) +# │ │ │ │ │ +# │ │ │ │ │ +# │ │ │ │ │ +# * * * * * + schedule: + - cron: 0 0 * * 1 # every monday + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + update_precommit_hooks: + + # only run on upstream repo + if: github.repository_owner == 'SIMEXP' + + runs-on: ubuntu-latest + steps: + - name: Checkout repo + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + allow-prereleases: false + - name: Install pre-commit + run: pip install pre-commit + - name: Update pre-commit hooks + run: pre-commit autoupdate + - name: Create Pull Request + uses: peter-evans/create-pull-request@v5 + with: + commit-message: pre-commit hooks auto-update + base: main + token: ${{ secrets.GITHUB_TOKEN }} + delete-branch: true + title: '[BOT] update pre-commit hooks' + body: done via this [GitHub Action](https://github.com/${{ github.repository_owner }}/giga_connectome/blob/main/.github/workflows/update_precommit_hooks.yml) diff --git a/.gitignore b/.gitignore index f92e659..a41f38a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,19 +1,14 @@ **/.DS_Store + __pycache__ *.pyc *.cpython -sample_data/sub-0001_task-restingstate_acq-mb3_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz -sample_data/sub-0001_task-restingstate_acq-mb3_desc-confounds_regressors.tsv -sample_data/sub-0002_task-restingstate_acq-mb3_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz -sample_data/sub-0002_task-restingstate_acq-mb3_desc-confounds_regressors.tsv -sample_data/sub-0003_task-restingstate_acq-mb3_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz -sample_data/sub-0003_task-restingstate_acq-mb3_desc-confounds_regressors.tsv -sample_data/sub-0004_task-restingstate_acq-mb3_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz -sample_data/sub-0004_task-restingstate_acq-mb3_desc-confounds_regressors.tsv -sample_data/sub-0005_task-restingstate_acq-mb3_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz -sample_data/sub-0005_task-restingstate_acq-mb3_desc-confounds_regressors.tsv + +.vscode + +sample_data/ # build related pydfc.egg-info build -dist/ \ No newline at end of file +dist/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..9a7fe62 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,60 @@ +--- +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks + +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-ast + - id: check-case-conflict + - id: check-json + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: mixed-line-ending + - id: trailing-whitespace + +- repo: https://github.com/ikamensh/flynt/ + rev: 1.0.1 + hooks: + - id: flynt + +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + args: [--profile, black] + +- repo: https://github.com/psf/black-pre-commit-mirror + rev: 24.2.0 + hooks: + - id: black-jupyter + args: [--config, pyproject.toml] + +- repo: https://github.com/codespell-project/codespell + rev: v2.2.6 + hooks: + - id: codespell + args: [--toml, pyproject.toml] + additional_dependencies: [tomli] + +- repo: https://github.com/jumanjihouse/pre-commit-hook-yamlfmt + rev: 0.2.3 + hooks: + - id: yamlfmt + args: [--mapping, '4', --sequence, '4', --offset, '0'] + +- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks + rev: v2.12.0 + hooks: + - id: pretty-format-toml + args: [--autofix, --indent, '4'] + +- repo: https://github.com/pyCQA/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + args: [--config, .flake8, --verbose, pydfc, HCP_resting_state_analysis, task_dFC] + additional_dependencies: [flake8-use-fstring] diff --git a/HCP_resting_state_analysis/FCS_estimate.py b/HCP_resting_state_analysis/FCS_estimate.py index d167883..bafabf8 100644 --- a/HCP_resting_state_analysis/FCS_estimate.py +++ b/HCP_resting_state_analysis/FCS_estimate.py @@ -1,94 +1,117 @@ -from functions.dFC_funcs import * -import numpy as np +import os import time + import hdf5storage +import numpy as np import scipy.io as sio -import os -os.environ["MKL_NUM_THREADS"] = '64' -os.environ["NUMEXPR_NUM_THREADS"] = '64' -os.environ["OMP_NUM_THREADS"] = '64' +from functions.dFC_funcs import * + +os.environ["MKL_NUM_THREADS"] = "64" +os.environ["NUMEXPR_NUM_THREADS"] = "64" +os.environ["OMP_NUM_THREADS"] = "64" -print('################################# CODE started running ... #################################') +print( + "################################# CODE started running ... #################################" +) ################################# Parameters ################################# ###### DATA PARAMETERS ###### -output_root = './' +output_root = "./" # DATA_type is either 'sample' or 'Gordon' or 'simulated' or 'ICA' -params_data_load = { - 'DATA_type': 'Gordon', - 'SESSIONs':['Rest1_LR' , 'Rest1_RL', 'Rest2_LR', 'Rest2_RL'], - 'networks2include':['Auditory', 'CinguloOperc', 'Default', 'DorsalAttn', 'FrontoParietal', - 'MedialParietal', 'ParietoOccip', 'SMhand', 'SMmouth', - 'Salience', 'VentralAttn', 'Visual'], - - 'data_root_simul': './../../../../DATA/TVB data/', - 'data_root_sample': './sampleDATA/', - 'data_root_gordon': './../../../../DATA/HCP/HCP_Gordon/', - 'data_root_ica': './../../../../DATA/HCP/HCP_PTN1200/node_timeseries/3T_HCP1200_MSMAll_d50_ts2/' +params_data_load = { + "DATA_type": "Gordon", + "SESSIONs": ["Rest1_LR", "Rest1_RL", "Rest2_LR", "Rest2_RL"], + "networks2include": [ + "Auditory", + "CinguloOperc", + "Default", + "DorsalAttn", + "FrontoParietal", + "MedialParietal", + "ParietoOccip", + "SMhand", + "SMmouth", + "Salience", + "VentralAttn", + "Visual", + ], + "data_root_simul": "./../../../../DATA/TVB data/", + "data_root_sample": "./sampleDATA/", + "data_root_gordon": "./../../../../DATA/HCP/HCP_Gordon/", + "data_root_ica": "./../../../../DATA/HCP/HCP_PTN1200/node_timeseries/3T_HCP1200_MSMAll_d50_ts2/", } ###### MEASUREMENT PARAMETERS ###### # W is in sec -params_methods = { +params_methods = { # Sliding Parameters - 'W': 44, 'n_overlap': 0.5, 'sw_method':'pear_corr', 'tapered_window':True, + "W": 44, + "n_overlap": 0.5, + "sw_method": "pear_corr", + "tapered_window": True, # TIME_FREQ - 'TF_method':'WTC', + "TF_method": "WTC", # CLUSTERING AND DHMM - 'clstr_base_measure':'SlidingWindow', + "clstr_base_measure": "SlidingWindow", # HMM - 'hmm_iter': 30, 'dhmm_obs_state_ratio': 16/24, + "hmm_iter": 30, + "dhmm_obs_state_ratio": 16 / 24, # State Parameters - 'n_states': 12, 'n_subj_clstrs': 20, + "n_states": 12, + "n_subj_clstrs": 20, # Parallelization Parameters - 'n_jobs': 2, 'verbose': 0, 'backend': 'loky', + "n_jobs": 2, + "verbose": 0, + "backend": "loky", # SESSION - 'session': 'Rest1_LR', + "session": "Rest1_LR", # Hyper Parameters - 'normalization': True, - 'num_subj': 395, - 'num_select_nodes': 96, - 'num_time_point': 1200, - 'Fs_ratio': 1.00, - 'noise_ratio': 0.00, - 'num_realization': 1 + "normalization": True, + "num_subj": 395, + "num_select_nodes": 96, + "num_time_point": 1200, + "Fs_ratio": 1.00, + "noise_ratio": 0.00, + "num_realization": 1, } ###### HYPER PARAMETERS ALTERNATIVE ###### -MEASURES_name_lst = [ - 'SlidingWindow', - 'Time-Freq', - 'CAP', - 'ContinuousHMM', - 'Windowless', - 'Clustering', - 'DiscreteHMM' +MEASURES_name_lst = [ + "SlidingWindow", + "Time-Freq", + "CAP", + "ContinuousHMM", + "Windowless", + "Clustering", + "DiscreteHMM", ] -alter_hparams = { \ - 'session': ['Rest1_RL', 'Rest2_LR', 'Rest2_RL'], - # 'n_overlap': [0, 0.25, 0.75, 1], - # 'n_states': [6, 16], - # # 'normalization': [], - # 'num_subj': [50, 100, 200], - # 'num_select_nodes': [30, 50, 333], - # 'num_time_point': [800, 1000], - # 'Fs_ratio': [0.50, 0.75, 1.5], - # 'noise_ratio': [1.00, 2.00, 3.00], - # 'num_realization': [] - } +alter_hparams = { + "session": ["Rest1_RL", "Rest2_LR", "Rest2_RL"], + # 'n_overlap': [0, 0.25, 0.75, 1], + # 'n_states': [6, 16], + # # 'normalization': [], + # 'num_subj': [50, 100, 200], + # 'num_select_nodes': [30, 50, 333], + # 'num_time_point': [800, 1000], + # 'Fs_ratio': [0.50, 0.75, 1.5], + # 'noise_ratio': [1.00, 2.00, 3.00], + # 'num_realization': [] +} ###### dFC ANALYZER PARAMETERS ###### -params_dFC_analyzer = { +params_dFC_analyzer = { # Parallelization Parameters - 'n_jobs': None, 'verbose': 0, 'backend': 'loky' + "n_jobs": None, + "verbose": 0, + "backend": "loky", } @@ -105,24 +128,21 @@ ################################# Measures of dFC ################################# -dFC_analyzer = DFC_ANALYZER( - analysis_name='reproducibility assessment', - **params_dFC_analyzer +dFC_analyzer = DFC_ANALYZER( + analysis_name="reproducibility assessment", **params_dFC_analyzer ) -MEASURES_lst = dFC_analyzer.measures_initializer( - MEASURES_name_lst, - params_methods, - alter_hparams +MEASURES_lst = dFC_analyzer.measures_initializer( + MEASURES_name_lst, params_methods, alter_hparams ) tic = time.time() -print('Measurement Started ...') +print("Measurement Started ...") ################################# estimate FCS ################################# task_id = int(os.getenv("SGE_TASK_ID")) -MEASURE_id = task_id-1 # SGE_TASK_ID starts from 1 not 0 +MEASURE_id = task_id - 1 # SGE_TASK_ID starts from 1 not 0 if MEASURE_id >= len(MEASURES_lst): @@ -132,18 +152,18 @@ print("FCS estimation started...") - time_series = BOLD[measure.params['session']] + time_series = BOLD[measure.params["session"]] if measure.is_state_based: measure.estimate_FCS(time_series=time_series) - + # dFC_analyzer.estimate_group_FCS(time_series_dict=BOLD) print("FCS estimation done.") - print('Measurement required %0.3f seconds.' % (time.time() - tic, )) + print(f"Measurement required {time.time() - tic:0.3f} seconds.") # Save - np.save(output_root+'fitted_MEASURES/MEASURE_'+str(MEASURE_id)+'.npy', measure) - np.save(output_root+'dFC_analyzer.npy', dFC_analyzer) - np.save(output_root+'data_loader.npy', data_loader) + np.save(output_root + "fitted_MEASURES/MEASURE_" + str(MEASURE_id) + ".npy", measure) + np.save(output_root + "dFC_analyzer.npy", dFC_analyzer) + np.save(output_root + "data_loader.npy", data_loader) -################################################################################# \ No newline at end of file +################################################################################# diff --git a/HCP_resting_state_analysis/dFC_assessment.py b/HCP_resting_state_analysis/dFC_assessment.py index c1304e9..1df0f5c 100644 --- a/HCP_resting_state_analysis/dFC_assessment.py +++ b/HCP_resting_state_analysis/dFC_assessment.py @@ -1,57 +1,63 @@ -from functions.dFC_funcs import * -import numpy as np +import os import time + import hdf5storage +import numpy as np import scipy.io as sio -import os -os.environ["MKL_NUM_THREADS"] = '64' -os.environ["NUMEXPR_NUM_THREADS"] = '64' -os.environ["OMP_NUM_THREADS"] = '64' +from functions.dFC_funcs import * + +os.environ["MKL_NUM_THREADS"] = "64" +os.environ["NUMEXPR_NUM_THREADS"] = "64" +os.environ["OMP_NUM_THREADS"] = "64" -print('################################# subject-level dFC assessment CODE started running ... #################################') +print( + "################################# subject-level dFC assessment CODE started running ... #################################" +) ################################# Parameters ################################# ###### DATA PARAMETERS ###### -input_root = './' -output_root = './' +input_root = "./" +output_root = "./" ################################# LOAD ################################# -dFC_analyzer = np.load(input_root+'dFC_analyzer.npy',allow_pickle='TRUE').item() -data_loader = np.load(input_root+'data_loader.npy',allow_pickle='TRUE').item() +dFC_analyzer = np.load(input_root + "dFC_analyzer.npy", allow_pickle="TRUE").item() +data_loader = np.load(input_root + "data_loader.npy", allow_pickle="TRUE").item() ################################# LOAD FIT MEASURES ################################# -if dFC_analyzer.MEASURES_fit_lst==[]: - ALL_RECORDS = os.listdir(input_root+'fitted_MEASURES/') - ALL_RECORDS = [i for i in ALL_RECORDS if 'MEASURE' in i] +if dFC_analyzer.MEASURES_fit_lst == []: + ALL_RECORDS = os.listdir(input_root + "fitted_MEASURES/") + ALL_RECORDS = [i for i in ALL_RECORDS if "MEASURE" in i] ALL_RECORDS.sort() MEASURES_fit_lst = list() for s in ALL_RECORDS: - fit_measure = np.load(input_root+'fitted_MEASURES/'+s, allow_pickle='TRUE').item() + fit_measure = np.load( + input_root + "fitted_MEASURES/" + s, allow_pickle="TRUE" + ).item() MEASURES_fit_lst.append(fit_measure) dFC_analyzer.set_MEASURES_fit_lst(MEASURES_fit_lst) - print('fitted MEASURES loaded ...') - # np.save('./dFC_analyzer.npy', dFC_analyzer) + print("fitted MEASURES loaded ...") + # np.save('./dFC_analyzer.npy', dFC_analyzer) ################################# LOAD DATA ################################# task_id = int(os.getenv("SGE_TASK_ID")) -subj_id = data_loader.SUBJECTS[task_id-1] # SGE_TASK_ID starts from 1 not 0 +subj_id = data_loader.SUBJECTS[task_id - 1] # SGE_TASK_ID starts from 1 not 0 BOLD = data_loader.load(subj_id2load=subj_id) ################################# dFC ASSESSMENT ################################# tic = time.time() -print('Measurement Started ...') +print("Measurement Started ...") print("dFCM estimation started...") dFCM_dict = dFC_analyzer.subj_lvl_dFC_assess(time_series_dict=BOLD) print("dFCM estimation done.") -print('Measurement required %0.3f seconds.' % (time.time() - tic, )) +print(f"Measurement required {time.time() - tic:0.3f} seconds.") ################################# SAVE DATA ################################# @@ -60,26 +66,28 @@ # os.makedirs(folder) # for dFCM_id, dFCM in enumerate(dFCM_dict['dFCM_lst']): -# np.save(folder+'/dFCM_'+str(dFCM_id)+'.npy', dFCM) +# np.save(folder+'/dFCM_'+str(dFCM_id)+'.npy', dFCM) ################################# SIMILARITY MEASUREMENT ################################# -similarity_assessment = SIMILARITY_ASSESSMENT(dFCM_lst=dFCM_dict['dFCM_lst']) +similarity_assessment = SIMILARITY_ASSESSMENT(dFCM_lst=dFCM_dict["dFCM_lst"]) tic = time.time() -print('Measurement Started ...') +print("Measurement Started ...") print("Similarity measurement started...") -SUBJ_output = similarity_assessment.run(FILTERS=dFC_analyzer.hyper_param_info, downsampling_method='default') +SUBJ_output = similarity_assessment.run( + FILTERS=dFC_analyzer.hyper_param_info, downsampling_method="default" +) print("Similarity measurement done.") -print('Measurement required %0.3f seconds.' % (time.time() - tic, )) +print(f"Measurement required {time.time() - tic:0.3f} seconds.") # Save -folder = output_root+'similarity_measured' +folder = output_root + "similarity_measured" if not os.path.exists(folder): os.makedirs(folder) -np.save(folder+'/SUBJ_'+str(subj_id)+'_output.npy', SUBJ_output) +np.save(folder + "/SUBJ_" + str(subj_id) + "_output.npy", SUBJ_output) -################################################################################# \ No newline at end of file +################################################################################# diff --git a/HCP_resting_state_analysis/functions/dFC_funcs.py b/HCP_resting_state_analysis/functions/dFC_funcs.py index 635b2bf..5102972 100755 --- a/HCP_resting_state_analysis/functions/dFC_funcs.py +++ b/HCP_resting_state_analysis/functions/dFC_funcs.py @@ -6,24 +6,24 @@ @author: mte """ -import numpy as np -from scipy import signal -import scipy.spatial.distance as ssd -import scipy.cluster.hierarchy as shc +import os +import time from copy import deepcopy -import matplotlib.pyplot as plt + +import hdf5storage import matplotlib as mpl +import matplotlib.pyplot as plt +import networkx as nx +import numpy as np +import pandas as pd +import scipy.cluster.hierarchy as shc +import scipy.io as sio +import scipy.spatial.distance as ssd import seaborn as sns +from joblib import Parallel, delayed from nilearn.plotting import plot_markers -import networkx as nx +from scipy import signal, stats from scipy.spatial import distance -from scipy import stats -from joblib import Parallel, delayed -import os -import time -import hdf5storage -import scipy.io as sio -import pandas as pd # np.seterr(invalid='ignore') @@ -46,79 +46,83 @@ ################################# Parameters #################################### fig_dpi = 120 -fig_bbox_inches = 'tight' +fig_bbox_inches = "tight" fig_pad = 0.1 show_title = False -save_fig_format = 'png' +save_fig_format = "png" ################################# Other Functions #################################### + # test def zip_name(name): # zip measure names - if 'Clustering' in name: - new_name = 'SWC' - if 'CAP' in name: - new_name = 'CAP' - if 'ContinuousHMM' in name: - new_name = 'CHMM' - if 'Windowless' in name: - new_name = 'WL' - if 'DiscreteHMM' in name: - new_name = 'DHMM' - if 'Time-Freq' in name: - new_name = 'TF' - if 'SlidingWindow' in name: - new_name = 'SW' + if "Clustering" in name: + new_name = "SWC" + if "CAP" in name: + new_name = "CAP" + if "ContinuousHMM" in name: + new_name = "CHMM" + if "Windowless" in name: + new_name = "WL" + if "DiscreteHMM" in name: + new_name = "DHMM" + if "Time-Freq" in name: + new_name = "TF" + if "SlidingWindow" in name: + new_name = "SW" return new_name + # test # pear_corr problem def unzip_name(name): # unzip measure names - if 'SWC' in name: - new_name = 'Clustering' - elif 'CAP' in name: - new_name = 'CAP' - elif 'CHMM' in name: - new_name = 'ContinuousHMM' - elif 'WL' in name: - new_name = 'Windowless' - elif 'DHMM' in name: - new_name = 'DiscreteHMM' - elif 'TF' in name: - new_name = 'Time-Freq' - elif 'SW' in name: - new_name = 'SlidingWindow' + if "SWC" in name: + new_name = "Clustering" + elif "CAP" in name: + new_name = "CAP" + elif "CHMM" in name: + new_name = "ContinuousHMM" + elif "WL" in name: + new_name = "Windowless" + elif "DHMM" in name: + new_name = "DiscreteHMM" + elif "TF" in name: + new_name = "Time-Freq" + elif "SW" in name: + new_name = "SlidingWindow" return new_name + def find_new_order(old_list, new_list): - ''' + """ new_order is a list of indices old_list = ['E', 'B', 'A', 'C', 'D'] new_list = ['A', 'B', 'C', 'D', 'E'] - ''' - new_order = [old_list.index(a) for a in new_list] + """ + new_order = [old_list.index(a) for a in new_list] return new_order + def mat_reorder(A, new_order): - ''' + """ new_order must be a list of indices: old_list = ['E', 'B', 'A', 'C', 'D'] new_list = ['A', 'B', 'C', 'D', 'E'] new_order = find_new_order(old_list, new_list) A_sorted is a copy of A - ''' + """ assert ( - len(new_order)==A.shape[0] - and len(new_order)==A.shape[1] - ), 'dimension mismatch in reordering.' + len(new_order) == A.shape[0] and len(new_order) == A.shape[1] + ), "dimension mismatch in reordering." A_sorted = deepcopy(A) A_sorted = [[A_sorted[i][j] for j in new_order] for i in new_order] A_sorted = np.array(A_sorted) return A_sorted + # test def get_subj_ts_dict(time_series_dict, subjs_id): subj_ts_dict = {} @@ -126,94 +130,97 @@ def get_subj_ts_dict(time_series_dict, subjs_id): subj_ts_dict[session] = time_series_dict[session].get_subj_ts(subjs_id=subjs_id) return subj_ts_dict + # test def filter_dFCM_lst(dFCM_lst, **param_dict): dFCM_lst2check = list() for dFCM in dFCM_lst: if dFCM.measure.param_match(**param_dict): - dFCM_lst2check.append(dFCM) + dFCM_lst2check.append(dFCM) return dFCM_lst2check + def SW_downsample(data, Fs, W, n_overlap, tapered_window=False): - ''' + """ data = (n_time, ...) - the time samples will be picked after + the time samples will be picked after averaging over a window which slides W is in sec SWed_data = (n_time_new, ...) - ''' + """ SWed_data = list() L = data.shape[0] # change W to timepoints - W = int(W * Fs) - step = int((1-n_overlap)*W) + W = int(W * Fs) + step = int((1 - n_overlap) * W) if step == 0: step = 1 - window_taper = signal.windows.gaussian(W, std=3*W/22) + window_taper = signal.windows.gaussian(W, std=3 * W / 22) TR_array = list() - for l in range(0, L-W+1, step): + for l in range(0, L - W + 1, step): ######### creating a rectangel window ############ window = np.zeros((L)) - window[l:l+W] = 1 - + window[l : l + W] = 1 + ########### tapering the window ############## if tapered_window: - window = signal.convolve(window, window_taper, mode='same') / sum(window_taper) + window = signal.convolve(window, window_taper, mode="same") / sum( + window_taper + ) # int(l-W/2):int(l+3*W/2) is the nonzero interval after tapering SWed_data.append(np.average(data, weights=window, axis=0)) - - TR_array.append(int((l + (l+W)) / 2) ) - - + + TR_array.append(int((l + (l + W)) / 2)) + SWed_data = np.array(SWed_data) return SWed_data + def mutual_information(X, Y, N_bins=100): - """ Mutual information for joint histogram + """Mutual information for joint histogram https://matthew-brett.github.io/teaching/mutual_information.html#:~:text=Mutual%20information%20is%20a%20measure,signal%20intensity%20in%20the%20first. """ # 2D histogram - hist_2d, x_edges, y_edges = np.histogram2d( - X, - Y, - bins=N_bins) - + hist_2d, x_edges, y_edges = np.histogram2d(X, Y, bins=N_bins) + # Convert bins counts to probability values pxy = hist_2d / float(np.sum(hist_2d)) - px = np.sum(pxy, axis=1) # marginal for x over y - py = np.sum(pxy, axis=0) # marginal for y over x - px_py = px[:, None] * py[None, :] # Broadcast to multiply marginals + px = np.sum(pxy, axis=1) # marginal for x over y + py = np.sum(pxy, axis=0) # marginal for y over x + px_py = px[:, None] * py[None, :] # Broadcast to multiply marginals # Now we can do the calculation using the pxy, px_py 2D arrays - nzs = pxy > 0 # Only non-zero pxy values contribute to the sum + nzs = pxy > 0 # Only non-zero pxy values contribute to the sum return np.sum(pxy[nzs] * np.log(pxy[nzs] / px_py[nzs])) + def corr2distance(corr_mat, metric): - ''' + """ metric can be: MI, euclidean_distance, spearman, corr (pearson) - ''' - if metric=='MI': - if np.any(corr_mat>1): - print('MI values cannot be converted to distances.') + """ + if metric == "MI": + if np.any(corr_mat > 1): + print("MI values cannot be converted to distances.") # negative corr will be > 1.0 dist_mat = 1 - corr_mat - elif metric=='euclidean_distance': + elif metric == "euclidean_distance": dist_mat = corr_mat else: # negative corr will be > 1.0 dist_mat = 1 - corr_mat # dist_mat must be symmetric - dist_mat = 0.5*(dist_mat + dist_mat.T) + dist_mat = 0.5 * (dist_mat + dist_mat.T) # diagonal values of dist_mat must equal exactly zero np.fill_diagonal(dist_mat, 0) return dist_mat + # test def normalizeAdjacency(W): """ @@ -229,23 +236,31 @@ def normalizeAdjacency(W): """ W_norm = W - np.min(W) W_norm = np.divide(W_norm, np.max(W_norm)) - return W_norm + return W_norm + # test def normalized_euc_dist(x, y): # https://stats.stackexchange.com/questions/136232/definition-of-normalized-euclidean-distance#:~:text=The%20normalized%20squared%20euclidean%20distance,not%20related%20to%20Mahalanobis%20distance. - if np.linalg.norm(x-np.mean(x))**2==0 and np.linalg.norm(y-np.mean(y))**2==0: + if ( + np.linalg.norm(x - np.mean(x)) ** 2 == 0 + and np.linalg.norm(y - np.mean(y)) ** 2 == 0 + ): return 0 - return 0.5*((np.linalg.norm((x-np.mean(x)) - (y-np.mean(y)))**2)/(np.linalg.norm(x-np.mean(x))**2 + np.linalg.norm(y-np.mean(y))**2)) + return 0.5 * ( + (np.linalg.norm((x - np.mean(x)) - (y - np.mean(y))) ** 2) + / (np.linalg.norm(x - np.mean(x)) ** 2 + np.linalg.norm(y - np.mean(y)) ** 2) + ) + def calc_graph_propoerty(A, property, threshold=False, binarize=False): """ - calc_graph_propoerty: Computes Graph-based properties + calc_graph_propoerty: Computes Graph-based properties of adjacency matrix A A is converted to positive before calc property: - - ECM: Computes Eigenvector Centrality Mapping (ECM) + - ECM: Computes Eigenvector Centrality Mapping (ECM) - shortest_path - degree - clustering_coef @@ -258,32 +273,32 @@ def calc_graph_propoerty(A, property, threshold=False, binarize=False): graph-property (np.array): a vector """ - N_edges = 200 # number of edges to keep - if property=='shortest_path' or property=='clustering_coef': - threshold=True + N_edges = 200 # number of edges to keep + if property == "shortest_path" or property == "clustering_coef": + threshold = True - G = nx.from_numpy_matrix(np.abs(A)) + G = nx.from_numpy_matrix(np.abs(A)) G.remove_edges_from(nx.selfloop_edges(G)) # G = G.to_undirected() - # pruning edges + # pruning edges if threshold: labels = [d["weight"] for (u, v, d) in G.edges(data=True)] labels.sort() - threshold = labels[-1*N_edges] - ebunch = [(u, v) for u, v, d in G.edges(data=True) if d['weight'] dFC_mat_new = (n_region, n_region) - ''' + """ dFC_mat_copy = deepcopy(dFC_mat) flag_dim = False - if len(dFC_mat_copy.shape)<3: + if len(dFC_mat_copy.shape) < 3: dFC_mat_copy = np.expand_dims(dFC_mat_copy, axis=0) flag_dim = True - assert dFC_mat_copy.shape[1]==dFC_mat_copy.shape[2], \ - 'dimension mismatch.' + assert dFC_mat_copy.shape[1] == dFC_mat_copy.shape[2], "dimension mismatch." n_time = dFC_mat_copy.shape[0] n_region = dFC_mat_copy.shape[1] - dFC_vecs = dFC_mat2vec(dFC_mat_copy) # (n_time, (n_region*(n_region-1))/2) + dFC_vecs = dFC_mat2vec(dFC_mat_copy) # (n_time, (n_region*(n_region-1))/2) if global_norm: - dFC_vecs_flatten = dFC_vecs.flatten() # (n_time*(n_region*(n_region-1))/2,) + dFC_vecs_flatten = dFC_vecs.flatten() # (n_time*(n_region*(n_region-1))/2,) dFC_vecs_flatten_ranked = stats.rankdata(dFC_vecs_flatten) - dFC_vecs_ranked = dFC_vecs_flatten_ranked.reshape((n_time, -1)) # (n_time, (n_region*(n_region-1))/2) - dFC_mat_ranked = dFC_vec2mat(dFC_vecs_ranked, N=n_region) # (n_time, n_region, n_region) + dFC_vecs_ranked = dFC_vecs_flatten_ranked.reshape( + (n_time, -1) + ) # (n_time, (n_region*(n_region-1))/2) + dFC_mat_ranked = dFC_vec2mat( + dFC_vecs_ranked, N=n_region + ) # (n_time, n_region, n_region) dFC_mat_new = dFC_mat_ranked else: # normalize time point-wise dFC_vecs_new = list() for i, vec in enumerate(dFC_vecs): - vec_ranked = stats.rankdata(vec) # (n_region*(n_region-1))/2,) + vec_ranked = stats.rankdata(vec) # (n_region*(n_region-1))/2,) dFC_vecs_new.append(vec_ranked) - dFC_vecs_new = np.array(dFC_vecs_new) # (n_time, (n_region*(n_region-1))/2) - dFC_mat_new = dFC_vec2mat(dFC_vecs_new, N=n_region) # (n_time, n_region, n_region) + dFC_vecs_new = np.array(dFC_vecs_new) # (n_time, (n_region*(n_region-1))/2) + dFC_mat_new = dFC_vec2mat( + dFC_vecs_new, N=n_region + ) # (n_time, n_region, n_region) if flag_dim: - dFC_mat_new = np.squeeze(dFC_mat_new) # (n_region, n_region) + dFC_mat_new = np.squeeze(dFC_mat_new) # (n_region, n_region) return dFC_mat_new + def cat_data(X_t, N): - ''' + """ X_t = (time, roi, roi) X_t is preferable to be ranked prior to cat_data - ''' + """ X_t_new = list() for X in X_t: borders = np.linspace(1, np.max(X), N, endpoint=False) @@ -352,11 +374,12 @@ def cat_data(X_t, N): X_t_new = np.array(X_t_new) return X_t_new + def dFC_mask(dFC_mat, mask): - ''' + """ dFC_mat and mask will be vectorized using dFC_mat2vec mask = (roi, roi) - ''' + """ dFC_vecs = dFC_mat2vec(dFC_mat) mask_vec = dFC_mat2vec(mask) @@ -367,49 +390,50 @@ def dFC_mask(dFC_mat, mask): return dFC_vec_new -#test + +# test # toDo: use ssd.squareform def dFC_mat2vec(C_t): - ''' + """ C_t must be an array of matrices or a single matrix - diagonal values not included. if you want to include + diagonal values not included. if you want to include them set k=0 if C_t is a single matrix, F will be one dim changing F will not change C_t - ''' - if len(C_t.shape)==2: - assert C_t.shape[0]==C_t.shape[1],\ - 'C is not a square matrix' + """ + if len(C_t.shape) == 2: + assert C_t.shape[0] == C_t.shape[1], "C is not a square matrix" return C_t[np.triu_indices(C_t.shape[1], k=1)] F = list() for t in range(C_t.shape[0]): - C = C_t[t, : , :] - assert C.shape[0]==C.shape[1],\ - 'C is not a square matrix' + C = C_t[t, :, :] + assert C.shape[0] == C.shape[1], "C is not a square matrix" F.append(C[np.triu_indices(C_t.shape[1], k=1)]) F = np.array(F) return F -#test + +# test # toDo: use ssd.squareform def dFC_vec2mat(F, N): - ''' + """ diagonal values are set to 1.0 F shape is (observations, features) - ''' + """ C = list() iu = np.triu_indices(N, k=1) for i in range(F.shape[0]): K = np.zeros((N, N)) - K[iu] = F[i,:] + K[iu] = F[i, :] K = K + K.T K = K + np.eye(N) C.append(K) C = np.array(C) return C + # test def common_subj_lst(time_series_dict): SUBJECTs = None @@ -420,24 +444,26 @@ def common_subj_lst(time_series_dict): SUBJECTs = intersection(SUBJECTs, time_series_dict[session].subj_id_lst) return SUBJECTs -def intersection(lst1, lst2): # input is a list + +def intersection(lst1, lst2): # input is a list lst3 = [value for value in lst1 if value in lst2] return lst3 -def TR_intersection(dFCM_lst): # input is a list of dFCM objs + +def TR_intersection(dFCM_lst): # input is a list of dFCM objs TRs_lst_old = dFCM_lst[0].TR_array - common_Fs = dFCM_lst[0].TS_info['Fs'] + common_Fs = dFCM_lst[0].TS_info["Fs"] for dFCM in dFCM_lst: - assert dFCM.TS_info['Fs'] == common_Fs, \ - 'Fs mismatch. Cannot find the common TRs' - + assert dFCM.TS_info["Fs"] == common_Fs, "Fs mismatch. Cannot find the common TRs" + TRs_lst_new = intersection(TRs_lst_old, dFCM.TR_array) TRs_lst_old = TRs_lst_new TRs_lst_old.sort() - if len(TRs_lst_old)==0: - print('No TR intersection.') + if len(TRs_lst_old) == 0: + print("No TR intersection.") return TRs_lst_old + def dFC_dict_slice(data, idx_lst): data_sliced = {} for i, k in enumerate(data): @@ -445,39 +471,53 @@ def dFC_dict_slice(data, idx_lst): data_sliced[k] = data[k] return data_sliced + def node_info2network(nodes_info): node_networks = [] for info in nodes_info: - if info[3]=='Network': + if info[3] == "Network": continue - node_networks.append(info[3]) + node_networks.append(info[3]) return node_networks + def segment_FC(FC, node_networks): unique_node_networks = list(set(node_networks)) segmented = np.zeros_like(FC) for network_i in unique_node_networks: - node_id_i = [idx for idx, value in enumerate(node_networks) if value==network_i] + node_id_i = [idx for idx, value in enumerate(node_networks) if value == network_i] for network_j in unique_node_networks: - node_id_j = [idx for idx, value in enumerate(node_networks) if value==network_j] - segmented[node_id_i[0]:node_id_i[-1]+1, node_id_j[0]:node_id_j[-1]+1] = np.mean(FC[node_id_i[0]:node_id_i[-1]+1, node_id_j[0]:node_id_j[-1]+1]) + node_id_j = [ + idx for idx, value in enumerate(node_networks) if value == network_j + ] + segmented[ + node_id_i[0] : node_id_i[-1] + 1, node_id_j[0] : node_id_j[-1] + 1 + ] = np.mean( + FC[node_id_i[0] : node_id_i[-1] + 1, node_id_j[0] : node_id_j[-1] + 1] + ) return segmented - + + def segment_FC_dict(FC_dict, node_networks): segmented_dict = {} for key in FC_dict: segmented_dict[key] = segment_FC(FC_dict[key], node_networks) return segmented_dict - -def visualize_conn_mat(C, axis=None, title='', - cmap='seismic', - V_MIN=None, V_MAX=None, + + +def visualize_conn_mat( + C, + axis=None, + title="", + cmap="seismic", + V_MIN=None, + V_MAX=None, node_networks=None, - title_fontsize=18 - ): - ''' + title_fontsize=18, +): + """ C is (regions, regions) - ''' + """ if axis is None: fig, axis = plt.subplots(1, 1, figsize=(5, 5)) @@ -488,14 +528,20 @@ def visualize_conn_mat(C, axis=None, title='', if V_MAX is None: V_MAX = np.max(np.abs(C)) if V_MIN is None: - V_MIN = -1*V_MAX + V_MIN = -1 * V_MAX + + im = axis.imshow( + C, + interpolation="nearest", + aspect="equal", + cmap=cmap, # 'viridis' or 'jet' + vmin=V_MIN, + vmax=V_MAX, + ) - im = axis.imshow(C, interpolation='nearest', aspect='equal', cmap=cmap, # 'viridis' or 'jet' - vmin=V_MIN, vmax=V_MAX) - # cluster node networks if not node_networks is None: - + # finding unique network names wrt order network_names = [] for node in node_networks: @@ -503,62 +549,66 @@ def visualize_conn_mat(C, axis=None, title='', network_names.append(node) network_labels = [network_names.index(node) for node in node_networks] - network_borders = np.argwhere(np.diff(network_labels)!=0) + network_borders = np.argwhere(np.diff(network_labels) != 0) ticks_position = [] last_line_position = 0 for i in network_borders: # 0.5 is the visualization offset of imshow - line_position = i[0]+1-0.5 - axis.axvline(x=line_position, color='k', linewidth=1) - axis.axhline(y=line_position, color='k', linewidth=1) - ticks_position.append((line_position+last_line_position)/2) + line_position = i[0] + 1 - 0.5 + axis.axvline(x=line_position, color="k", linewidth=1) + axis.axhline(y=line_position, color="k", linewidth=1) + ticks_position.append((line_position + last_line_position) / 2) last_line_position = line_position - line_position = len(node_networks)+1-0.5 - ticks_position.append((line_position+last_line_position)/2) + line_position = len(node_networks) + 1 - 0.5 + ticks_position.append((line_position + last_line_position) / 2) axis.set_xticks(ticks_position) axis.set_yticks(ticks_position) axis.set_xticklabels(network_names, rotation=90, fontsize=13) axis.set_yticklabels(network_names, fontsize=13) - - axis.set_title(title, fontdict={'fontsize': title_fontsize, 'fontweight': 'bold'}) + + axis.set_title(title, fontdict={"fontsize": title_fontsize, "fontweight": "bold"}) return im -def visualize_conn_mat_dict(data, title='', - cmap='seismic', + +def visualize_conn_mat_dict( + data, + title="", + cmap="seismic", normalize=False, disp_diag=True, label_dict={}, - save_image=False, output_root=None, axes=None, fig=None, - fix_lim=True, center_0=True, - node_networks=None, segmented=False - ): - - ''' + save_image=False, + output_root=None, + axes=None, + fig=None, + fix_lim=True, + center_0=True, + node_networks=None, + segmented=False, +): + """ - data must be a dict of connectivity matrices sample: Suptitle1 - 0.00 0.31 0.76 - 0.31 0.00 0.43 - 0.76 0.43 0.00 + 0.00 0.31 0.76 + 0.31 0.00 0.43 + 0.76 0.43 0.00 Suptitle1 - 0.00 0.32 0.76 - 0.32 0.00 0.45 - 0.76 0.45 0.00 - ''' - - sns.set_context("paper", - font_scale=2.5, - rc={"lines.linewidth": 3.0} - ) + 0.00 0.32 0.76 + 0.32 0.00 0.45 + 0.76 0.45 0.00 + """ + + sns.set_context("paper", font_scale=2.5, rc={"lines.linewidth": 3.0}) - sns.set_style('white') + sns.set_style("white") if node_networks is None: - fig_width = 25*(len(data)/10) + fig_width = 25 * (len(data) / 10) else: - fig_width = 60*(len(data)/10) + fig_width = 60 * (len(data) / 10) fig_height = 5 fig_flag = True @@ -566,14 +616,15 @@ def visualize_conn_mat_dict(data, title='', fig_flag = False if not fig_flag: - fig, axes = plt.subplots(1, len(data), figsize=(fig_width, fig_height), \ - facecolor='w', edgecolor='k') + fig, axes = plt.subplots( + 1, len(data), figsize=(fig_width, fig_height), facecolor="w", edgecolor="k" + ) if not type(axes) is np.ndarray: axes = np.array([axes]) if show_title: - fig.suptitle(title, fontsize=20, y=0.98) #, fontsize=20, size=20 + fig.suptitle(title, fontsize=20, y=0.98) # , fontsize=20, size=20 axes = axes.ravel() @@ -581,21 +632,23 @@ def visualize_conn_mat_dict(data, title='', conn_mats = list() V_MAX_all = None for i, key in enumerate(data): - + if segmented: C = data[key] if not disp_diag: - C = np.multiply(C, 1-np.eye(len(C))) + C = np.multiply(C, 1 - np.eye(len(C))) C = C + np.mean(C.flatten()) * np.eye(len(C)) C = segment_FC(C, node_networks) else: C = data[key] if normalize: - C = dFC_mat_normalize(C[None,:,:], global_normalization=False, threshold=0.0)[0] + C = dFC_mat_normalize( + C[None, :, :], global_normalization=False, threshold=0.0 + )[0] if (not disp_diag) and (not segmented): - C = np.multiply(C, 1-np.eye(len(C))) + C = np.multiply(C, 1 - np.eye(len(C))) C = C + np.mean(C.flatten()) * np.eye(len(C)) if V_MAX_all is None: @@ -606,16 +659,16 @@ def visualize_conn_mat_dict(data, title='', conn_mats.append(C) conn_mats = np.array(conn_mats) - if np.any(conn_mats<0) or center_0: + if np.any(conn_mats < 0) or center_0: V_MIN = -1 V_MAX = 1 - else: + else: V_MIN = 0 V_MAX = 1 if not fix_lim: V_MAX = V_MAX_all - if np.any(conn_mats<0) or center_0: + if np.any(conn_mats < 0) or center_0: V_MIN = -1 * V_MAX_all else: V_MIN = 0 @@ -623,102 +676,114 @@ def visualize_conn_mat_dict(data, title='', # plot for i, key in enumerate(data): - C = conn_mats[i,:,:] + C = conn_mats[i, :, :] mat_title = key if key in label_dict: mat_title = label_dict[key] - im = visualize_conn_mat(C, axis=axes[i], title=mat_title, + im = visualize_conn_mat( + C, + axis=axes[i], + title=mat_title, cmap=cmap, - V_MIN=V_MIN, V_MAX=V_MAX, - node_networks=node_networks - ) + V_MIN=V_MIN, + V_MAX=V_MAX, + node_networks=node_networks, + ) if not fig_flag: fig.subplots_adjust( - bottom=0.1, \ - top=0.85, \ - left=0.1, \ + bottom=0.1, + top=0.85, + left=0.1, right=0.9, # wspace=0.02, \ # hspace=0.02\ ) if not node_networks is None: - fig.subplots_adjust( - wspace=0.55 - ) - + fig.subplots_adjust(wspace=0.55) + l, b, w, h = axes[-1].get_position().bounds if fig_flag: cb_ax = fig.add_axes([0.91, b, 0.007, h]) else: cb_ax = fig.add_axes([0.91, b, 0.01, h]) - cbar = fig.colorbar(im, cax=cb_ax, shrink=0.8) # shrink=0.8?? + fig.colorbar(im, cax=cb_ax, shrink=0.8) # shrink=0.8?? if save_image: - folder = output_root[:output_root.rfind('/')] + folder = output_root[: output_root.rfind("/")] if not os.path.exists(folder): os.makedirs(folder) - plt.savefig(output_root+title.replace(" ", "_")+'.'+save_fig_format, - dpi=fig_dpi, bbox_inches=fig_bbox_inches, pad_inches=fig_pad, format=save_fig_format - ) + plt.savefig( + output_root + title.replace(" ", "_") + "." + save_fig_format, + dpi=fig_dpi, + bbox_inches=fig_bbox_inches, + pad_inches=fig_pad, + format=save_fig_format, + ) plt.close() # else: # plt.show() -def visualize_conn_mat_2D_dict(data, title='', \ - cmap='seismic',\ - normalize=False,\ - disp_diag=True,\ - save_image=False, output_root=None, \ - fix_lim=True, center_0=True, \ - node_networks=None, segmented=False \ - ): - - ''' +def visualize_conn_mat_2D_dict( + data, + title="", + cmap="seismic", + normalize=False, + disp_diag=True, + save_image=False, + output_root=None, + fix_lim=True, + center_0=True, + node_networks=None, + segmented=False, +): + """ - data must be a 2D dict of connectivity matrices sample: ROW1 (method_1) COLUMN1 (method_1) data[method_1][method_1] - 0.00 0.31 0.76 - 0.31 0.00 0.43 - 0.76 0.43 0.00 + 0.00 0.31 0.76 + 0.31 0.00 0.43 + 0.76 0.43 0.00 COLUMN2 (method_2) data[method_1][method_2] - 0.00 0.31 0.76 - 0.31 0.00 0.43 - 0.76 0.43 0.00 + 0.00 0.31 0.76 + 0.31 0.00 0.43 + 0.76 0.43 0.00 ROW2 (method_2) COLUMN1 (method_1) data[method_2][method_1] - 0.00 0.31 0.76 - 0.31 0.00 0.43 - 0.76 0.43 0.00 - ''' + 0.00 0.31 0.76 + 0.31 0.00 0.43 + 0.76 0.43 0.00 + """ zip_measure_name = True - sns.set_context("paper", - font_scale=3.5, - rc={"lines.linewidth": 3.0} - ) + sns.set_context("paper", font_scale=3.5, rc={"lines.linewidth": 3.0}) - sns.set_style('white') + sns.set_style("white") if node_networks is None: - fig_width = 30*(len(data)/10) + fig_width = 30 * (len(data) / 10) else: - fig_width = 55*(len(data)/10) + 4 + fig_width = 55 * (len(data) / 10) + 4 fig_height = fig_width * 1.0 - fig, axs = plt.subplots(len(data), len(data), figsize=(fig_width, fig_height), \ - facecolor='w', edgecolor='k') + fig, axs = plt.subplots( + len(data), + len(data), + figsize=(fig_width, fig_height), + facecolor="w", + edgecolor="k", + ) if not type(axs) is np.ndarray: axs = np.array([axs]) if show_title: - fig.suptitle(title, fontsize=25, y=0.98) #, fontsize=20, size=20 + fig.suptitle(title, fontsize=25, y=0.98) # , fontsize=20, size=20 # axs = axs.ravel() @@ -727,18 +792,19 @@ def visualize_conn_mat_2D_dict(data, title='', \ V_MAX_all = None for i, key_i in enumerate(data): for j, key_j in enumerate(data[key_i]): - + if segmented: C = segment_FC(data[key_i][key_j], node_networks) else: C = data[key_i][key_j] - if normalize: - C = dFC_mat_normalize(C[None,:,:], global_normalization=False, threshold=0.0)[0] + C = dFC_mat_normalize( + C[None, :, :], global_normalization=False, threshold=0.0 + )[0] if not disp_diag: - C = np.multiply(C, 1-np.eye(len(C))) + C = np.multiply(C, 1 - np.eye(len(C))) C = C + np.mean(C.flatten()) * np.eye(len(C)) if V_MAX_all is None: @@ -749,20 +815,20 @@ def visualize_conn_mat_2D_dict(data, title='', \ conn_mats.append(C) conn_mats = np.array(conn_mats) - if np.any(conn_mats<0) or center_0: + if np.any(conn_mats < 0) or center_0: V_MIN = -1 V_MAX = 1 - else: + else: V_MIN = 0 V_MAX = 1 if not fix_lim: V_MAX = V_MAX_all - if np.any(conn_mats<0) or center_0: + if np.any(conn_mats < 0) or center_0: V_MIN = -1 * V_MAX_all else: V_MIN = 0 - + # plot axs_plotted = list() for i, key_i in enumerate(data): @@ -775,22 +841,28 @@ def visualize_conn_mat_2D_dict(data, title='', \ C = data[key_i][key_j] if normalize: - C = dFC_mat_normalize(C[None,:,:], global_normalization=False, threshold=0.0)[0] + C = dFC_mat_normalize( + C[None, :, :], global_normalization=False, threshold=0.0 + )[0] if not disp_diag: - C = np.multiply(C, 1-np.eye(len(C))) + C = np.multiply(C, 1 - np.eye(len(C))) C = C + np.mean(C.flatten()) * np.eye(len(C)) if zip_measure_name: - mat_title=zip_name(key_i)+'-'+zip_name(key_j) + mat_title = zip_name(key_i) + "-" + zip_name(key_j) else: - mat_title=key_i+' and '+key_j + mat_title = key_i + " and " + key_j - im = visualize_conn_mat(C, axis=axs[i][j], title=mat_title, + im = visualize_conn_mat( + C, + axis=axs[i][j], + title=mat_title, cmap=cmap, - V_MIN=V_MIN, V_MAX=V_MAX, + V_MIN=V_MIN, + V_MAX=V_MAX, node_networks=node_networks, - title_fontsize=25 + title_fontsize=25, ) axs_plotted.append(axs[i][j]) @@ -799,48 +871,42 @@ def visualize_conn_mat_2D_dict(data, title='', \ for ax in axs.ravel(): if not ax in axs_plotted: ax.set_axis_off() - ax.xaxis.set_tick_params(which='both', labelbottom=True) + ax.xaxis.set_tick_params(which="both", labelbottom=True) fig.subplots_adjust( - bottom=0.1, \ - top=0.95, \ - left=0.1, \ - right=0.9, - wspace=0.001, \ - hspace=0.4\ + bottom=0.1, top=0.95, left=0.1, right=0.9, wspace=0.001, hspace=0.4 ) if not node_networks is None: - fig.subplots_adjust( - wspace=0.45, - hspace=0.50 - ) - - l, b, w, h = axs[-1][-1].get_position().bounds + fig.subplots_adjust(wspace=0.45, hspace=0.50) + + _, _, _, h = axs[-1][-1].get_position().bounds if node_networks is None: - cb_ax = fig.add_axes([0.91, 0.5-h/2, 0.007, h]) + cb_ax = fig.add_axes([0.91, 0.5 - h / 2, 0.007, h]) else: - cb_ax = fig.add_axes([0.91, 0.5-h/2, 0.015, h]) - cbar = fig.colorbar(im, cax=cb_ax, shrink=0.8) # shrink=0.8?? + cb_ax = fig.add_axes([0.91, 0.5 - h / 2, 0.015, h]) + fig.colorbar(im, cax=cb_ax, shrink=0.8) # shrink=0.8?? if save_image: - folder = output_root[:output_root.rfind('/')] + folder = output_root[: output_root.rfind("/")] if not os.path.exists(folder): os.makedirs(folder) - plt.savefig(output_root+title.replace(" ", "_")+'.'+save_fig_format, - dpi=fig_dpi, bbox_inches=fig_bbox_inches, pad_inches=fig_pad, format=save_fig_format - ) + plt.savefig( + output_root + title.replace(" ", "_") + "." + save_fig_format, + dpi=fig_dpi, + bbox_inches=fig_bbox_inches, + pad_inches=fig_pad, + format=save_fig_format, + ) plt.close() else: plt.show() - + def visualize_FCS( - measure, - normalize=True, fix_lim=True, - save_image=False, output_root=None - ): - + measure, normalize=True, fix_lim=True, save_image=False, output_root=None +): + if measure.FCS == []: return @@ -849,74 +915,66 @@ def visualize_FCS( else: D = measure.FCS_dict - fig_width = 45*(len(D)/10) + fig_width = 45 * (len(D) / 10) fig_height = 8 - fig, axes = plt.subplots(2, len(D), figsize=(fig_width, fig_height), - facecolor='w', edgecolor='k') - - fig.subplots_adjust( - bottom=0.1, - top=0.85, - left=0.1, - right=0.9, - wspace=0.1, - hspace=0.6 + fig, axes = plt.subplots( + 2, len(D), figsize=(fig_width, fig_height), facecolor="w", edgecolor="k" ) + fig.subplots_adjust(bottom=0.1, top=0.85, left=0.1, right=0.9, wspace=0.1, hspace=0.6) + # plot mean activity for i, mean_act in enumerate(measure.mean_act): # setting vmin=-vmax to make 0 correspond to white color max_activity = np.max(np.abs(mean_act)) - min_activity = -1*max_activity + min_activity = -1 * max_activity plot_markers( - node_values=mean_act, - node_coords=measure.TS_info['nodes_locs'], - node_cmap='seismic', - node_vmin=min_activity, node_vmax=max_activity, - display_mode='z', - colorbar=False, axes=axes[1, i] + node_values=mean_act, + node_coords=measure.TS_info["nodes_locs"], + node_cmap="seismic", + node_vmin=min_activity, + node_vmax=max_activity, + display_mode="z", + colorbar=False, + axes=axes[1, i], ) # plot FC pattern - node_networks = node_info2network(measure.TS_info['nodes_info']) - - visualize_conn_mat_dict(data=D, \ - node_networks=node_networks, \ - title=measure.measure_name+' FCS', \ - save_image=save_image, \ - axes=axes[0, :], fig=fig, - output_root=output_root, \ - disp_diag=False, \ - fix_lim=fix_lim \ + node_networks = node_info2network(measure.TS_info["nodes_info"]) + + visualize_conn_mat_dict( + data=D, + node_networks=node_networks, + title=measure.measure_name + " FCS", + save_image=save_image, + axes=axes[0, :], + fig=fig, + output_root=output_root, + disp_diag=False, + fix_lim=fix_lim, ) - fig.subplots_adjust( - bottom=0.1, - top=0.85, - left=0.1, - right=0.9, - wspace=0.1, - hspace=1.0 - ) + fig.subplots_adjust(bottom=0.1, top=0.85, left=0.1, right=0.9, wspace=0.1, hspace=1.0) -''' + +""" ########## bundled brain graph visualizer ########## cvsopts = dict(plot_height=400, plot_width=400) def thresh_G(G, threshold): - + G_copy = deepcopy(G) - + if threshold > 1: labels = [d["weight"] for (u, v, d) in G_copy.edges(data=True)] labels.sort() threshold = labels[-1*threshold] - + ebunch = [(u, v) for u, v, d in G_copy.edges(data=True) if np.abs(d['weight']) < threshold] G_copy.remove_edges_from(ebunch) - + return G_copy def nodesplot(nodes, name=None, canvas=None, cat=None): @@ -924,23 +982,23 @@ def nodesplot(nodes, name=None, canvas=None, cat=None): # aggregator=None if cat is None else ds.count_cat(cat) # agg=canvas.points(nodes,'x','y',aggregator) aggc = canvas.points(nodes, 'x', 'y', ds.count_cat('cat')) #ds.by('cat', ds.count()) - + color_key = dict(cat_normal='#FF3333', cat_sig='#00FF00') - + return tf.spread(tf.shade(aggc, color_key=color_key), px=4, name=name) def edgesplot(edges, name=None, canvas=None): canvas = ds.Canvas(**cvsopts) if canvas is None else canvas return tf.shade(canvas.line(edges, 'x','y', agg=ds.count()), name=name) - + def graphplot(nodes, edges, name="", canvas=None, cat=None): - + if canvas is None: xr = nodes.x.min(), nodes.x.max() yr = nodes.y.min(), nodes.y.max() canvas = ds.Canvas(x_range=xr, y_range=yr, **cvsopts) - + np = nodesplot(nodes, name + " nodes", canvas, cat) ep = edgesplot(edges, name + " edges", canvas) return tf.stack(ep, np, how="over", name=name) @@ -951,19 +1009,19 @@ def ng(graph,name): def nx_layout(graph, view_degree=0, threshold=0): # layout = nx.circular_layout(graph) - + # Get node positions pos = nx.get_node_attributes(graph, 'pos') for key in pos: - if view_degree==0: + if view_degree==0: pos[key] = pos[key][:2] - if view_degree==1: + if view_degree==1: pos[key] = pos[key][1:3] - if view_degree==2: + if view_degree==2: pos[key] = pos[key][[0, 2]] - + # layout = pos - + cat = list() for key in graph.nodes(): cat.append('cat_normal') @@ -971,7 +1029,7 @@ def nx_layout(graph, view_degree=0, threshold=0): # cat.append( 'cat_sig') # else: # cat.append('cat_normal') - + data = [[node]+pos[node].tolist()+[cat[i]] for i, node in enumerate(graph.nodes)] nodes = pd.DataFrame(data, columns=['id', 'x', 'y','cat']) @@ -986,7 +1044,7 @@ def nx_layout(graph, view_degree=0, threshold=0): def nx_plot(graph, name="", view_degree=0, threshold=0): # print(graph.name, len(graph.edges)) nodes, edges = nx_layout(graph, view_degree=view_degree, threshold=threshold) - + direct = connect_edges(nodes, edges) bundled_bw005 = hammer_bundle(nodes, edges) bundled_bw030 = hammer_bundle(nodes, edges, initial_bandwidth=0.30) @@ -998,7 +1056,7 @@ def nx_plot(graph, name="", view_degree=0, threshold=0): graphplot(nodes, bundled_bw100, "Bundled bw=1.00", cat=None)] def batch_Adj2Net(FCS, nodes_info, is_digraph=False): - + np.fill_diagonal(FCS, 0) if is_digraph: G = nx.from_numpy_matrix(FCS, create_using=nx.DiGraph) @@ -1013,43 +1071,44 @@ def batch_Adj2Net(FCS, nodes_info, is_digraph=False): return G def set_locs_G(G, locs): - + G_copy = deepcopy(G) - - pos = nx.circular_layout(G_copy) + + pos = nx.circular_layout(G_copy) for i, key in enumerate(pos): pos[key] = locs[i] - - nx.set_node_attributes(G_copy, pos, "pos") - - - return G_copy + + nx.set_node_attributes(G_copy, pos, "pos") + + + return G_copy def visulize_brain_graph(FCS, nodes_info, locs, num_edges2show, \ title='', save_image=True, output_root=None \ ): - + # EXAMPLE: # visulize_brain_graph(measure.FCS_dict[FCS], measure.TS_info['nodes_info'], \ # measure.TS_info['nodes_locs'], num_edges2show=100, \ # title=FCS+'_'+measure.measure_name, save_image=save_image, output_root=output_root \ # ) - + G = batch_Adj2Net(FCS=FCS, nodes_info=nodes_info, is_digraph=False) - G = set_locs_G(G, locs=locs) + G = set_locs_G(G, locs=locs) plots = [nx_plot(ng(G, name="dFC"), view_degree=0, threshold=num_edges2show)] if save_image: - ds.utils.export_image(img=plots[0][2], filename=title+'_bundle_', - fmt=".png", background='black', + ds.utils.export_image(img=plots[0][2], filename=title+'_bundle_', + fmt=".png", background='black', export_path=output_root) - + # return plots[0][0] ############################## -''' +""" + def dFC_dict_normalize(D, global_normalization=False, threshold=0.0): @@ -1058,43 +1117,43 @@ def dFC_dict_normalize(D, global_normalization=False, threshold=0.0): C.append(D[key]) C = np.array(C) - C_z = dFC_mat_normalize(C, \ - global_normalization=global_normalization, \ - threshold=threshold \ + C_z = dFC_mat_normalize( + C, global_normalization=global_normalization, threshold=threshold ) D_z = {} for i, key in enumerate(D): - D_z[key] = C_z[i,:,:] + D_z[key] = C_z[i, :, :] return D_z + def dFC_mat_normalize(C_t, global_normalization=False, threshold=0.0): # threshold is ratio of connections wanted to be zero C_t_z = deepcopy(C_t) - if len(C_t_z.shape)<3: + if len(C_t_z.shape) < 3: C_t_z = np.expand_dims(C_t_z, axis=0) if global_normalization: - # transform the whole abs(dFC mat) to [0, 1] + # transform the whole abs(dFC mat) to [0, 1] signs = np.sign(C_t_z) C_t_z = np.abs(C_t_z) miN = list() for i in range(C_t_z.shape[0]): - slice = C_t_z[i,:,:] - slice_non_diag = slice[np.where(~np.eye(slice.shape[0],dtype=bool))] + slice = C_t_z[i, :, :] + slice_non_diag = slice[np.where(~np.eye(slice.shape[0], dtype=bool))] miN.append(np.min(slice_non_diag)) C_t_z = C_t_z - np.min(miN) maX = list() for i in range(C_t_z.shape[0]): - slice = C_t_z[i,:,:] - slice_non_diag = slice[np.where(~np.eye(slice.shape[0],dtype=bool))] + slice = C_t_z[i, :, :] + slice_non_diag = slice[np.where(~np.eye(slice.shape[0], dtype=bool))] maX.append(np.max(slice_non_diag)) if np.max(maX) != 0: @@ -1103,8 +1162,8 @@ def dFC_mat_normalize(C_t, global_normalization=False, threshold=0.0): # thresholding d = deepcopy(np.ravel(C_t_z)) d.sort() - new_threshold = d[int(threshold*len(d))] - C_t_z = np.multiply(C_t_z, (C_t_z>=new_threshold)) + new_threshold = d[int(threshold * len(d))] + C_t_z = np.multiply(C_t_z, (C_t_z >= new_threshold)) C_t_z = np.multiply(C_t_z, signs) else: @@ -1113,77 +1172,81 @@ def dFC_mat_normalize(C_t, global_normalization=False, threshold=0.0): signs = np.sign(C_t_z) C_t_z = np.abs(C_t_z) - + for i in range(C_t_z.shape[0]): - slice = C_t_z[i,:,:] - slice_non_diag = slice[np.where(~np.eye(slice.shape[0],dtype=bool))] + slice = C_t_z[i, :, :] + slice_non_diag = slice[np.where(~np.eye(slice.shape[0], dtype=bool))] slice = slice - np.min(slice_non_diag) - slice_non_diag = slice[np.where(~np.eye(slice.shape[0],dtype=bool))] + slice_non_diag = slice[np.where(~np.eye(slice.shape[0], dtype=bool))] if np.max(slice_non_diag) != 0: slice = np.divide(slice, np.max(slice_non_diag)) # thresholding d = deepcopy(np.ravel(slice)) d.sort() - new_threshold = d[int(threshold*len(d))] - slice = np.multiply(slice, (slice>=new_threshold)) + new_threshold = d[int(threshold * len(d))] + slice = np.multiply(slice, (slice >= new_threshold)) - C_t_z[i,:,:] = slice + C_t_z[i, :, :] = slice C_t_z = np.multiply(C_t_z, signs) # removing self connections for i in range(C_t_z.shape[1]): - C_t_z[:, i, i] = np.mean(C_t_z) # ????????????????? + C_t_z[:, i, i] = np.mean(C_t_z) # ????????????????? return C_t_z + def print_mat(mat, s=0): - if len(mat.shape)==1: + if len(mat.shape) == 1: mat = np.expand_dims(mat, axis=0) for i in mat: - print('\t'*s, end=" ") + print("\t" * s, end=" ") for j in i: - print("{:.2f}".format(j), end=" ") + print(f"{j:.2f}", end=" ") print() + def print_dict(t, s=0): - if not isinstance(t,dict) and not isinstance(t,list): - if isinstance(t,np.ndarray): + if not isinstance(t, dict) and not isinstance(t, list): + if isinstance(t, np.ndarray): print_mat(t, s) else: - if isinstance(t,float): - print('\t'*s+"{:.2f}".format(t)) + if isinstance(t, float): + print("\t" * s + f"{t:.2f}") else: - print('\t'*s+str(t)) + print("\t" * s + str(t)) else: for key in t: - print('\t'*s+str(key)) - if not isinstance(t,list): - print_dict(t[key],s+1) + print("\t" * s + str(key)) + if not isinstance(t, list): + print_dict(t[key], s + 1) + ############################# dFC Analyzer class ################################ """ todo: -- +- """ + class DFC_ANALYZER: # if self.n_jobs is None => no parallelization - def __init__(self, analysis_name='', **params): + def __init__(self, analysis_name="", **params): self.analysis_name = analysis_name - + self.params = params - if not 'n_jobs' in self.params: - self.params['n_jobs'] = -1 - if not 'verbose' in self.params: - self.params['verbose'] = 1 - if not 'backend' in self.params: - self.params['backend'] = 'loky' + if not "n_jobs" in self.params: + self.params["n_jobs"] = -1 + if not "verbose" in self.params: + self.params["verbose"] = 1 + if not "backend" in self.params: + self.params["backend"] = "loky" self.MEASURES_lst_ = None self.MEASURES_fit_lst_ = [] @@ -1194,8 +1257,7 @@ def __init__(self, analysis_name='', **params): @property def MEASURES_lst(self): - assert not self.MEASURES_lst_ is None, \ - 'first set the MEASURES_lst!' + assert not self.MEASURES_lst_ is None, "first set the MEASURES_lst!" return self.MEASURES_lst_ @property @@ -1209,10 +1271,9 @@ def set_MEASURES_fit_lst(self, MEASURES_fit_lst): self.MEASURES_fit_lst_ = MEASURES_fit_lst def measures_initializer(self, MEASURES_name_lst, params_methods, alter_hparams): - - ''' + """ - this will test values in hyper_params other than - values already in self.params. values in self.params + values already in self.params. values in self.params will be considered the reference sample: hyper_params = { \ @@ -1225,7 +1286,7 @@ def measures_initializer(self, MEASURES_name_lst, params_methods, alter_hparams) 'noise_ratio': [0.00, 0.50, 1.00], \ 'num_realization': [1, 2, 3], \ } - + MEASURES_name_lst = ( \ 'SlidingWindow', \ 'Time-Freq', \ @@ -1235,30 +1296,34 @@ def measures_initializer(self, MEASURES_name_lst, params_methods, alter_hparams) 'Clustering', \ 'DiscreteHMM' \ ) - ''' + """ self.MEASURES_name_lst = MEASURES_name_lst self.params_methods = params_methods self.alter_hparams = alter_hparams # a list of MEASURES with default parameter values - MEASURES_lst = self.create_measure_obj(MEASURES_name_lst=MEASURES_name_lst, **params_methods) + MEASURES_lst = self.create_measure_obj( + MEASURES_name_lst=MEASURES_name_lst, **params_methods + ) # adding MEASURES with alternative parameter values hyper_param_info = {} - hyper_param_info['default_values'] = params_methods + hyper_param_info["default_values"] = params_methods for hyper_param in alter_hparams: for value in alter_hparams[hyper_param]: params = deepcopy(params_methods) params[hyper_param] = value - hyper_param_info[hyper_param+'_'+str(value)] = deepcopy(params) - new_MEASURES = self.create_measure_obj(MEASURES_name_lst=MEASURES_name_lst, **params) + hyper_param_info[hyper_param + "_" + str(value)] = deepcopy(params) + new_MEASURES = self.create_measure_obj( + MEASURES_name_lst=MEASURES_name_lst, **params + ) for new_measure in new_MEASURES: - flag=0 + flag = 0 for MEASURE in MEASURES_lst: if new_measure.issame(MEASURE): - flag=1 - if flag==0: + flag = 1 + if flag == 0: MEASURES_lst.append(new_measure) self.hyper_param_info = hyper_param_info @@ -1271,45 +1336,45 @@ def create_measure_obj(self, MEASURES_name_lst, **params): for MEASURES_name in MEASURES_name_lst: ###### CAP ###### - if MEASURES_name=='CAP': + if MEASURES_name == "CAP": measure = CAP(**params) ###### CONTINUOUS HMM ###### - if MEASURES_name=='ContinuousHMM': + if MEASURES_name == "ContinuousHMM": measure = HMM_CONT(**params) ###### WINDOW_LESS ###### - if MEASURES_name=='Windowless': + if MEASURES_name == "Windowless": measure = WINDOWLESS(**params) ###### SLIDING WINDOW ###### - if MEASURES_name=='SlidingWindow': + if MEASURES_name == "SlidingWindow": measure = SLIDING_WINDOW(**params) ###### TIME FREQUENCY ###### - if MEASURES_name=='Time-Freq': + if MEASURES_name == "Time-Freq": measure = TIME_FREQ(**params) ###### SLIDING WINDOW + CLUSTERING ###### - if MEASURES_name=='Clustering': + if MEASURES_name == "Clustering": measure = SLIDING_WINDOW_CLUSTR(**params) ###### DISCRETE HMM ###### - if MEASURES_name=='DiscreteHMM': + if MEASURES_name == "DiscreteHMM": measure = HMM_DISC(**params) MEASURES_lst.append(measure) return MEASURES_lst - def SB_MEASURES_lst(self, MEASURES_lst): # returns state_based measures + def SB_MEASURES_lst(self, MEASURES_lst): # returns state_based measures SB_MEASURES = list() for measure in MEASURES_lst: if measure.is_state_based: SB_MEASURES.append(measure) return SB_MEASURES - def DD_MEASURES_lst(self, MEASURES_lst): # returns data_driven measures + def DD_MEASURES_lst(self, MEASURES_lst): # returns data_driven measures DD_MEASURES = list() for measure in MEASURES_lst: if not measure.is_state_based: @@ -1326,46 +1391,54 @@ def estimate_group_FCS(self, time_series_dict): time_series = time_series_dict[session] SB_MEASURES_lst = self.SB_MEASURES_lst(self.MEASURES_lst) - if self.params['n_jobs'] is None: + if self.params["n_jobs"] is None: SB_MEASURES_lst_NEW = list() for measure in SB_MEASURES_lst: - SB_MEASURES_lst_NEW.append( \ - measure.estimate_FCS(time_series=time_series) \ - ) + SB_MEASURES_lst_NEW.append( + measure.estimate_FCS(time_series=time_series) + ) else: - SB_MEASURES_lst_NEW = Parallel( \ - n_jobs=self.params['n_jobs'], verbose=self.params['verbose'], backend=self.params['backend'])( \ - delayed(measure.estimate_FCS)(time_series=time_series) \ - for measure in SB_MEASURES_lst) - self.MEASURES_fit_lst_[session] = self.DD_MEASURES_lst(self.MEASURES_lst) + SB_MEASURES_lst_NEW + SB_MEASURES_lst_NEW = Parallel( + n_jobs=self.params["n_jobs"], + verbose=self.params["verbose"], + backend=self.params["backend"], + )( + delayed(measure.estimate_FCS)(time_series=time_series) + for measure in SB_MEASURES_lst + ) + self.MEASURES_fit_lst_[session] = ( + self.DD_MEASURES_lst(self.MEASURES_lst) + SB_MEASURES_lst_NEW + ) ##################### dFCM ASSESSMENT ###################### def group_dFCM_assess(self, time_series_dict): # time_series_dict is a dict of time_series + SUBJECTs = common_subj_lst(time_series_dict) - SUBJ_s_dFCM_dict = {} - - SUBJECTs = common_subj_lst(time_series_dict) - - if self.params['n_jobs'] is None: + if self.params["n_jobs"] is None: OUT = list() for subject in SUBJECTs: - OUT.append( \ - self.subj_lvl_dFC_assess( \ - time_series_dict=get_subj_ts_dict(time_series_dict, subjs_id=subject), \ - )) + OUT.append( + self.subj_lvl_dFC_assess( + time_series_dict=get_subj_ts_dict( + time_series_dict, subjs_id=subject + ), + ) + ) else: - OUT = Parallel( \ - n_jobs=self.params['n_jobs'], \ - verbose=self.params['verbose'], \ - backend=self.params['backend'])( \ - delayed(self.subj_lvl_dFC_assess)( \ - time_series_dict=get_subj_ts_dict(time_series_dict, subjs_id=subject), \ - ) \ - for subject in SUBJECTs) - + OUT = Parallel( + n_jobs=self.params["n_jobs"], + verbose=self.params["verbose"], + backend=self.params["backend"], + )( + delayed(self.subj_lvl_dFC_assess)( + time_series_dict=get_subj_ts_dict(time_series_dict, subjs_id=subject), + ) + for subject in SUBJECTs + ) + return OUT def subj_lvl_dFC_assess(self, time_series_dict): @@ -1375,24 +1448,34 @@ def subj_lvl_dFC_assess(self, time_series_dict): dFCM_dict = {} # dFC_corr_assess_dict = {} - if self.params['n_jobs'] is None: + if self.params["n_jobs"] is None: dFCM_lst = list() for measure in self.MEASURES_fit_lst_: - dFCM_lst.append( \ - measure.estimate_dFCM(time_series=time_series_dict[measure.params['session']]) \ + dFCM_lst.append( + measure.estimate_dFCM( + time_series=time_series_dict[measure.params["session"]] + ) ) else: - dFCM_lst = Parallel( \ - n_jobs=self.params['n_jobs'], verbose=self.params['verbose'], backend=self.params['backend'])( \ - delayed(measure.estimate_dFCM)(time_series=time_series_dict[measure.params['session']]) \ - for measure in self.MEASURES_fit_lst_) + dFCM_lst = Parallel( + n_jobs=self.params["n_jobs"], + verbose=self.params["verbose"], + backend=self.params["backend"], + )( + delayed(measure.estimate_dFCM)( + time_series=time_series_dict[measure.params["session"]] + ) + for measure in self.MEASURES_fit_lst_ + ) - dFCM_dict['dFCM_lst'] = dFCM_lst + dFCM_dict["dFCM_lst"] = dFCM_lst return dFCM_dict + ################################# Similarity_Assessment class #################################### + class SIMILARITY_ASSESSMENT: def __init__(self, dFCM_lst): @@ -1411,15 +1494,15 @@ def FO_calc(self, dFCM_lst, common_TRs=None): FO_list = list() for dFCM in dFCM_lst: - + FO = {} if dFCM.measure.is_state_based: state_act_dict = dFCM.state_act_dict(TRs=common_TRs) - - for FCS_key in state_act_dict['state_TC']: - FO[FCS_key] = np.mean(state_act_dict['state_TC'][FCS_key]['act_TC']) + + for FCS_key in state_act_dict["state_TC"]: + FO[FCS_key] = np.mean(state_act_dict["state_TC"][FCS_key]["act_TC"]) FO_list.append(FO) @@ -1436,16 +1519,16 @@ def transition_stats(self, dFCM_lst, common_TRs=None): TRs_lst = list() for TR in common_TRs: - TRs_lst.append('TR'+str(TR)) + TRs_lst.append("TR" + str(TR)) output_lst = list() for dFCM in dFCM_lst: - + output_dict = {} if dFCM.measure.is_state_based: - # downsampled + # downsampled trans_freq = 0 dwell_time_lst = list() dwell_time = 0 @@ -1453,24 +1536,24 @@ def transition_stats(self, dFCM_lst, common_TRs=None): for TR in dFCM.FCS_idx: if TR in TRs_lst: if not last_TR is None: - if dFCM.FCS_idx[TR]!=dFCM.FCS_idx[last_TR]: + if dFCM.FCS_idx[TR] != dFCM.FCS_idx[last_TR]: dwell_time_lst.append(dwell_time) dwell_time = 0 trans_freq += 1 dwell_time += 1 last_TR = TR - output_dict['dwell_time'] = dwell_time_lst - output_dict['trans_freq'] = trans_freq + output_dict["dwell_time"] = dwell_time_lst + output_dict["trans_freq"] = trans_freq - # normalized (not downsampled) + # normalized (not downsampled) trans_norm = 0 dwell_time_lst = list() dwell_time = 0 last_TR = None for TR in dFCM.FCS_idx: if not last_TR is None: - if dFCM.FCS_idx[TR]!=dFCM.FCS_idx[last_TR]: + if dFCM.FCS_idx[TR] != dFCM.FCS_idx[last_TR]: dwell_time_lst.append(dwell_time / len(dFCM.FCS_idx)) dwell_time = 0 trans_norm += 1 @@ -1478,162 +1561,182 @@ def transition_stats(self, dFCM_lst, common_TRs=None): last_TR = TR trans_norm = trans_norm / len(dFCM.FCS_idx) - output_dict['dwell_time_norm'] = dwell_time_lst - output_dict['trans_norm'] = trans_norm + output_dict["dwell_time_norm"] = dwell_time_lst + output_dict["trans_norm"] = trans_norm output_lst.append(output_dict) - return output_lst + return output_lst def feature_all(self, dFC_mat): - vectorized_dFC = dFC_mat2vec(dFC_mat).flatten() # (time*connection, ) - vectorized_dFC = np.expand_dims(vectorized_dFC, axis=0) # (1, time*connection) + vectorized_dFC = dFC_mat2vec(dFC_mat).flatten() # (time*connection, ) + vectorized_dFC = np.expand_dims(vectorized_dFC, axis=0) # (1, time*connection) return vectorized_dFC def feature_spatial(self, dFC_mat): - conn_over_time = dFC_mat2vec(dFC_mat) # (time, connection) + conn_over_time = dFC_mat2vec(dFC_mat) # (time, connection) return conn_over_time def feature_temporal(self, dFC_mat): - conn_over_time = dFC_mat2vec(dFC_mat) # (time, connection) - time_over_conn = conn_over_time.T # (connection, time) + conn_over_time = dFC_mat2vec(dFC_mat) # (time, connection) + time_over_conn = conn_over_time.T # (connection, time) return time_over_conn def feature_inter_time_corr(self, dFC_mat): - ''' - returns correspondence of inter-time relation between results of dFC + """ + returns correspondence of inter-time relation between results of dFC measures in each subject - ''' - conn_over_time = dFC_mat2vec(dFC_mat) # (time, connection) - inter_time_corr = np.corrcoef(conn_over_time) # (time, time) + """ + conn_over_time = dFC_mat2vec(dFC_mat) # (time, connection) + inter_time_corr = np.corrcoef(conn_over_time) # (time, time) inter_time_corr = np.nan_to_num(inter_time_corr) - inter_time_corr = dFC_mat2vec(inter_time_corr) # (time*(time-1)/2, ) - inter_time_corr = np.expand_dims(inter_time_corr, axis=0) # (1, time*(time-1)/2) + inter_time_corr = dFC_mat2vec(inter_time_corr) # (time*(time-1)/2, ) + inter_time_corr = np.expand_dims(inter_time_corr, axis=0) # (1, time*(time-1)/2) return inter_time_corr def feature_inter_conn_corr(self, dFC_mat): - conn_over_time = dFC_mat2vec(dFC_mat) # (time, connection) - time_over_conn = conn_over_time.T # (connection, time) - inter_conn_corr = np.corrcoef(time_over_conn) # (connection, connection) + conn_over_time = dFC_mat2vec(dFC_mat) # (time, connection) + time_over_conn = conn_over_time.T # (connection, time) + inter_conn_corr = np.corrcoef(time_over_conn) # (connection, connection) inter_conn_corr = np.nan_to_num(inter_conn_corr) - inter_conn_corr = dFC_mat2vec(inter_conn_corr) # (connection*(connection-1)/2, ) - inter_conn_corr = np.expand_dims(inter_conn_corr, axis=0) # (1, connection*(connection-1)/2) + inter_conn_corr = dFC_mat2vec(inter_conn_corr) # (connection*(connection-1)/2, ) + inter_conn_corr = np.expand_dims( + inter_conn_corr, axis=0 + ) # (1, connection*(connection-1)/2) return inter_conn_corr def feature_dFC_avg(self, dFC_mat): - dFC_avg = np.mean(dFC_mat, axis=0) # (ROI, ROI) - vectorized_dFC_avg = dFC_mat2vec(dFC_avg) # (connection, ) - vectorized_dFC_avg = np.expand_dims(vectorized_dFC_avg, axis=0) # (1, connection) + dFC_avg = np.mean(dFC_mat, axis=0) # (ROI, ROI) + vectorized_dFC_avg = dFC_mat2vec(dFC_avg) # (connection, ) + vectorized_dFC_avg = np.expand_dims(vectorized_dFC_avg, axis=0) # (1, connection) return vectorized_dFC_avg def feature_dFC_var(self, dFC_mat): - dFC_var = np.var(dFC_mat, axis=0) # (ROI, ROI) - vectorized_dFC_var = dFC_mat2vec(dFC_var) # (connection, ) - vectorized_dFC_var = np.expand_dims(vectorized_dFC_var, axis=0) # (1, connection) + dFC_var = np.var(dFC_mat, axis=0) # (ROI, ROI) + vectorized_dFC_var = dFC_mat2vec(dFC_var) # (connection, ) + vectorized_dFC_var = np.expand_dims(vectorized_dFC_var, axis=0) # (1, connection) return vectorized_dFC_var def feature_graph_spatial(self, dFC_mat, graph_property): graph_feature_over_time = list() for FC_mat in dFC_mat: - graph_feature = calc_graph_propoerty(FC_mat, property=graph_property, threshold=False, binarize=False) + graph_feature = calc_graph_propoerty( + FC_mat, property=graph_property, threshold=False, binarize=False + ) graph_feature_over_time.append(graph_feature) - graph_feature_over_time = np.array(graph_feature_over_time) # (time, ROI) + graph_feature_over_time = np.array(graph_feature_over_time) # (time, ROI) return graph_feature_over_time def feature_graph_temporal(self, dFC_mat, graph_property): graph_feature_over_time = list() for FC_mat in dFC_mat: - graph_feature = calc_graph_propoerty(FC_mat, property=graph_property, threshold=False, binarize=False) + graph_feature = calc_graph_propoerty( + FC_mat, property=graph_property, threshold=False, binarize=False + ) graph_feature_over_time.append(graph_feature) - graph_feature_over_time = np.array(graph_feature_over_time) # (time, ROI) - graph_feature_over_node = graph_feature_over_time.T # (ROI, time) - graph_feature_avg = np.mean(graph_feature_over_node, axis=0) # (time, ) - graph_feature_avg = np.expand_dims(graph_feature_avg, axis=0) # (1, time) + graph_feature_over_time = np.array(graph_feature_over_time) # (time, ROI) + graph_feature_over_node = graph_feature_over_time.T # (ROI, time) + graph_feature_avg = np.mean(graph_feature_over_node, axis=0) # (time, ) + graph_feature_avg = np.expand_dims(graph_feature_avg, axis=0) # (1, time) return graph_feature_avg def extract_feature(self, dFC_mat, feature2extract, graph_property=None): - ''' + """ feature2extract_list = [ - 'all', - 'spatial', 'temporal', - 'inter_time_corr', 'inter_conn_corr', - 'dFC_avg', 'dFC_var', + 'all', + 'spatial', 'temporal', + 'inter_time_corr', 'inter_conn_corr', + 'dFC_avg', 'dFC_var', 'graph_spatial', 'graph_temporal' ] - ''' + """ feature = None - if feature2extract=='all': + if feature2extract == "all": feature = self.feature_all(dFC_mat) - if feature2extract=='spatial': + if feature2extract == "spatial": feature = self.feature_spatial(dFC_mat) - if feature2extract=='temporal': + if feature2extract == "temporal": feature = self.feature_temporal(dFC_mat) - if feature2extract=='inter_time_corr': + if feature2extract == "inter_time_corr": feature = self.feature_inter_time_corr(dFC_mat) - if feature2extract=='inter_conn_corr': + if feature2extract == "inter_conn_corr": feature = self.feature_inter_conn_corr(dFC_mat) - if feature2extract=='dFC_avg': + if feature2extract == "dFC_avg": feature = self.feature_dFC_avg(dFC_mat) - if feature2extract=='dFC_var': + if feature2extract == "dFC_var": feature = self.feature_dFC_var(dFC_mat) - if feature2extract=='graph_spatial': + if feature2extract == "graph_spatial": feature = self.feature_graph_spatial(dFC_mat, graph_property=graph_property) - if feature2extract=='graph_temporal': + if feature2extract == "graph_temporal": feature = self.feature_graph_temporal(dFC_mat, graph_property=graph_property) return feature - def dFC_mat_lst_similarity(self, dFC_mat_lst, feature2extract, metric, graph_property=None): - + def dFC_mat_lst_similarity( + self, dFC_mat_lst, feature2extract, metric, graph_property=None + ): + sim_mat_over_sample = None for i, dFC_mat_i in enumerate(dFC_mat_lst): for j, dFC_mat_j in enumerate(dFC_mat_lst): - if j<=i: + if j <= i: continue - assert dFC_mat_i.shape==dFC_mat_j.shape,\ - 'shape mismatch' + assert dFC_mat_i.shape == dFC_mat_j.shape, "shape mismatch" feature_i = self.extract_feature( - dFC_mat_i, - feature2extract=feature2extract, - graph_property=graph_property - ) # (samples, variables) + dFC_mat_i, + feature2extract=feature2extract, + graph_property=graph_property, + ) # (samples, variables) feature_j = self.extract_feature( - dFC_mat_j, - feature2extract=feature2extract, - graph_property=graph_property - ) # (samples, variables) + dFC_mat_j, + feature2extract=feature2extract, + graph_property=graph_property, + ) # (samples, variables) sim_over_sample = list() for sample in range(feature_i.shape[0]): - if np.var(feature_i[sample, :])==0 or np.var(feature_j[sample, :])==0: + if ( + np.var(feature_i[sample, :]) == 0 + or np.var(feature_j[sample, :]) == 0 + ): sim = 0 else: - if metric=='corr': - sim = np.corrcoef(feature_i[sample, :], feature_j[sample, :])[0,1] - elif metric=='spearman': - sim, p = stats.spearmanr(feature_i[sample, :], feature_j[sample, :]) - elif metric=='MI': - sim = mutual_information(X=feature_i[sample, :], Y=feature_j[sample, :], N_bins=100) - elif metric=='euclidean_distance': + if metric == "corr": + sim = np.corrcoef(feature_i[sample, :], feature_j[sample, :])[ + 0, 1 + ] + elif metric == "spearman": + sim, p = stats.spearmanr( + feature_i[sample, :], feature_j[sample, :] + ) + elif metric == "MI": + sim = mutual_information( + X=feature_i[sample, :], Y=feature_j[sample, :], N_bins=100 + ) + elif metric == "euclidean_distance": # normalized euclidean is used - sim = normalized_euc_dist(x=feature_i[sample, :], y=feature_j[sample, :]) + sim = normalized_euc_dist( + x=feature_i[sample, :], y=feature_j[sample, :] + ) sim_over_sample.append(sim) - + if sim_mat_over_sample is None: - sim_mat_over_sample = np.zeros((len(sim_over_sample), len(dFC_mat_lst), len(dFC_mat_lst))) + sim_mat_over_sample = np.zeros( + (len(sim_over_sample), len(dFC_mat_lst), len(dFC_mat_lst)) + ) sim_mat_over_sample[:, i, j] = np.array(sim_over_sample) sim_mat_over_sample[:, j, i] = sim_mat_over_sample[:, i, j] return sim_mat_over_sample - def assess_similarity(self, dFCM_lst, downsampling_method='default', **param_dict): - ''' - downsampling_method: 'default' picks FCs at common_TRs + def assess_similarity(self, dFCM_lst, downsampling_method="default", **param_dict): + """ + downsampling_method: 'default' picks FCs at common_TRs while 'SWed' uses a sliding window to downsample - ''' + """ methods_assess = {} # sort dFCM_lst according to methods names @@ -1652,141 +1755,139 @@ def assess_similarity(self, dFCM_lst, downsampling_method='default', **param_dic for dFCM in dFCM_lst: measure_lst.append(dFCM.measure) TS_info_lst.append(dFCM.TS_info) - if downsampling_method=='SWed': - dFC_mat_lst.append( \ - dFCM.SWed_dFC_mat( \ - W=param_dict['W'], \ - n_overlap=param_dict['n_overlap'], \ - tapered_window=param_dict['tapered_window'] \ + if downsampling_method == "SWed": + dFC_mat_lst.append( + dFCM.SWed_dFC_mat( + W=param_dict["W"], + n_overlap=param_dict["n_overlap"], + tapered_window=param_dict["tapered_window"], ) ) else: dFC_mat_lst.append(dFCM.get_dFC_mat(TRs=common_TRs)) - methods_assess['measure_lst'] = measure_lst - methods_assess['TS_info_lst'] = TS_info_lst - methods_assess['common_TRs'] = common_TRs + methods_assess["measure_lst"] = measure_lst + methods_assess["TS_info_lst"] = TS_info_lst + methods_assess["common_TRs"] = common_TRs ########## dFCM samples ########## dFCM_samples = {} for i, dFC_mat in enumerate(dFC_mat_lst): dFCM_samples[str(i)] = dFC_mat - methods_assess['dFCM_samples'] = dFCM_samples + methods_assess["dFCM_samples"] = dFCM_samples ########## time record ########## - + time_record_dict = {} for i, dFCM in enumerate(dFCM_lst): time_record = {} - time_record['FCS_fit'] = dFCM.measure.FCS_fit_time - time_record['dFC_assess'] = dFCM.measure.dFC_assess_time + time_record["FCS_fit"] = dFCM.measure.FCS_fit_time + time_record["dFC_assess"] = dFCM.measure.dFC_assess_time time_record_dict[str(i)] = time_record - methods_assess['time_record_dict'] = time_record_dict + methods_assess["time_record_dict"] = time_record_dict ########## subj_dFC_sim ########## - # returns correlation/MI/spearman corr/euclidean distance between results of dFC + # returns correlation/MI/spearman corr/euclidean distance between results of dFC # measures in a subject feature2extract_list = [ - # 'all', - 'spatial', 'temporal', - 'inter_time_corr', 'inter_conn_corr', - 'dFC_avg', 'dFC_var', + # 'all', + "spatial", + "temporal", + "inter_time_corr", + "inter_conn_corr", + "dFC_avg", + "dFC_var", # 'graph_spatial', 'graph_temporal' ] - metric_list = [ - 'corr', - 'spearman', - 'MI', - 'euclidean_distance' - ] - graph_property_list = [ - 'ECM', - 'shortest_path', - 'degree', - 'clustering_coef' - ] - methods_assess['all'] = {} + metric_list = ["corr", "spearman", "MI", "euclidean_distance"] + graph_property_list = ["ECM", "shortest_path", "degree", "clustering_coef"] + methods_assess["all"] = {} for metric in metric_list: - methods_assess['all'][metric] = self.dFC_mat_lst_similarity( - dFC_mat_lst, - feature2extract='all', - metric=metric - ) - methods_assess['feature_based'] = {} + methods_assess["all"][metric] = self.dFC_mat_lst_similarity( + dFC_mat_lst, feature2extract="all", metric=metric + ) + methods_assess["feature_based"] = {} for feature2extract in feature2extract_list: - methods_assess['feature_based'][feature2extract] = self.dFC_mat_lst_similarity( - dFC_mat_lst, - feature2extract=feature2extract, - metric='spearman' - ) - methods_assess['graph_based'] = {} - methods_assess['graph_based']['graph_spatial'] = {} - methods_assess['graph_based']['graph_temporal'] = {} + methods_assess["feature_based"][feature2extract] = ( + self.dFC_mat_lst_similarity( + dFC_mat_lst, feature2extract=feature2extract, metric="spearman" + ) + ) + methods_assess["graph_based"] = {} + methods_assess["graph_based"]["graph_spatial"] = {} + methods_assess["graph_based"]["graph_temporal"] = {} for graph_property in graph_property_list: - methods_assess['graph_based']['graph_spatial'][graph_property] = self.dFC_mat_lst_similarity( - dFC_mat_lst, - feature2extract='graph_spatial', - metric='spearman', - graph_property=graph_property - ) - methods_assess['graph_based']['graph_temporal'][graph_property] = self.dFC_mat_lst_similarity( - dFC_mat_lst, - feature2extract='graph_temporal', - metric='spearman', - graph_property=graph_property - ) + methods_assess["graph_based"]["graph_spatial"][graph_property] = ( + self.dFC_mat_lst_similarity( + dFC_mat_lst, + feature2extract="graph_spatial", + metric="spearman", + graph_property=graph_property, + ) + ) + methods_assess["graph_based"]["graph_temporal"][graph_property] = ( + self.dFC_mat_lst_similarity( + dFC_mat_lst, + feature2extract="graph_temporal", + metric="spearman", + graph_property=graph_property, + ) + ) # ########## dFC temporal average and variance ########## - methods_assess['dFC_avg'] = [self.feature_dFC_avg(dFC_mat) for dFC_mat in dFC_mat_lst] + methods_assess["dFC_avg"] = [ + self.feature_dFC_avg(dFC_mat) for dFC_mat in dFC_mat_lst + ] + + methods_assess["dFC_var"] = [ + self.feature_dFC_var(dFC_mat) for dFC_mat in dFC_mat_lst + ] - methods_assess['dFC_var'] = [self.feature_dFC_var(dFC_mat) for dFC_mat in dFC_mat_lst] - ########## Fractional Occupancy ########## - FO_lst = self.FO_calc(dFCM_lst, \ - common_TRs=common_TRs \ - ) - methods_assess['FO'] = FO_lst + FO_lst = self.FO_calc(dFCM_lst, common_TRs=common_TRs) + methods_assess["FO"] = FO_lst ########## transition frequency ########## - transition_stats_lst = self.transition_stats(dFCM_lst, \ - common_TRs=common_TRs \ - ) - methods_assess['transition_stats'] = transition_stats_lst - + transition_stats_lst = self.transition_stats(dFCM_lst, common_TRs=common_TRs) + methods_assess["transition_stats"] = transition_stats_lst + ############################################## return methods_assess - def run(self, FILTERS, downsampling_method='default'): - ''' - downsampling_method: 'default' picks FCs at common_TRs + def run(self, FILTERS, downsampling_method="default"): + """ + downsampling_method: 'default' picks FCs at common_TRs while 'SWed' uses a sliding window to downsample - ''' + """ parallelize = True output = {} if parallelize: - out_lst = Parallel( \ - n_jobs=4, verbose=0, backend='loky')( \ - delayed(self.assess_similarity)(dFCM_lst=filter_dFCM_lst(self.dFCM_lst, **FILTERS[filter]), \ - downsampling_method=downsampling_method, \ - **FILTERS[filter]) \ - for filter in FILTERS) + out_lst = Parallel(n_jobs=4, verbose=0, backend="loky")( + delayed(self.assess_similarity)( + dFCM_lst=filter_dFCM_lst(self.dFCM_lst, **FILTERS[filter]), + downsampling_method=downsampling_method, + **FILTERS[filter], + ) + for filter in FILTERS + ) for i, filter in enumerate(FILTERS): output[filter] = out_lst[i] else: for filter in FILTERS: param_dict = FILTERS[filter] dFCM_lst2check = filter_dFCM_lst(self.dFCM_lst, **param_dict) - output[filter] = self.assess_similarity( \ - dFCM_lst=dFCM_lst2check, \ - downsampling_method=downsampling_method, \ - **param_dict \ - ) + output[filter] = self.assess_similarity( + dFCM_lst=dFCM_lst2check, + downsampling_method=downsampling_method, + **param_dict, + ) return output + ################################# dFC class #################################### """ @@ -1794,25 +1895,21 @@ def run(self, FILTERS, downsampling_method='default'): - type annotation """ + class dFC: - TF_methods_name_lst = [ \ - 'CWT_mag', \ - 'CWT_phase_r', \ - 'CWT_phase_a', \ - 'WTC' \ - ] + TF_methods_name_lst = ["CWT_mag", "CWT_phase_r", "CWT_phase_a", "WTC"] - sw_methods_name_lst = [ \ - 'pear_corr', \ - 'MI', \ - 'GraphLasso', \ + sw_methods_name_lst = [ + "pear_corr", + "MI", + "GraphLasso", ] - base_methods_name_lst = ['SlidingWindow', 'Time-Freq'] + base_methods_name_lst = ["SlidingWindow", "Time-Freq"] def __init__(self): - self.measure_name = '' + self.measure_name = "" self.is_state_based = bool() self._stat = [] self.TPM = [] @@ -1820,7 +1917,7 @@ def __init__(self): self.TS_info_ = {} self.FCS_fit_time_ = None self.dFC_assess_time_ = None - self.logs_ = '' + self.logs_ = "" @property def FCS_fit_time(self): @@ -1837,7 +1934,7 @@ def TS_info(self): @property def is_state_based(self): - return self.params['is_state_based'] + return self.params["is_state_based"] @property def FCS(self): @@ -1854,8 +1951,8 @@ def FCS_dict(self): C_A = self.FCS FCSs = {} for k in range(C_A.shape[0]): - FCSs['FCS'+str(k+1)] = C_A[k,:,:] - + FCSs["FCS" + str(k + 1)] = C_A[k, :, :] + return FCSs @property @@ -1867,7 +1964,7 @@ def logs(self): print(self.logs_) def issame(self, dFC): - if type(self)==type(dFC): + if type(self) == type(dFC): for param_name in self.params: if self.params[param_name] != dFC.params[param_name]: return False @@ -1875,7 +1972,7 @@ def issame(self, dFC): return False return True - #test + # test def param_match(self, **param_dict): for param in param_dict: if param in self.params: @@ -1883,7 +1980,7 @@ def param_match(self, **param_dict): if not self.params[param] in param_dict[param]: return False else: - if self.params[param]!=param_dict[param]: + if self.params[param] != param_dict[param]: return False return True @@ -1896,15 +1993,17 @@ def set_dFC_assess_time(self, time): def set_mean_activity(self, time_series): # mean activity of regions at each state if self.is_state_based: - if 'sw_method' in self.params_name_lst: + if "sw_method" in self.params_name_lst: SUBJECTs = time_series.subj_id_lst TS_data = None for subject in SUBJECTs: subj_TS = time_series.get_subj_ts(subjs_id=subject).data - new_TS_data = SW_downsample(data=subj_TS.T, \ - Fs=time_series.Fs, W=self.params['W'], \ - n_overlap=self.params['n_overlap'], \ - tapered_window=self.params['tapered_window'] \ + new_TS_data = SW_downsample( + data=subj_TS.T, + Fs=time_series.Fs, + W=self.params["W"], + n_overlap=self.params["n_overlap"], + tapered_window=self.params["tapered_window"], ).T if TS_data is None: TS_data = new_TS_data @@ -1914,7 +2013,7 @@ def set_mean_activity(self, time_series): TS_data = time_series.data mean_act = list() for i in np.unique(self.Z): - ids = np.array([int(state==i) for state in self.Z]) + ids = np.array([int(state == i) for state in self.Z]) mean_act.append(np.average(TS_data, weights=ids, axis=1)) self.mean_act = np.array(mean_act) else: @@ -1931,18 +2030,22 @@ def manipulate_time_series4FCS(self, time_series): new_time_series = deepcopy(time_series) # SUBJECTs - new_time_series.select_subjs(num_subj=self.params['num_subj']) + new_time_series.select_subjs(num_subj=self.params["num_subj"]) # SPATIAL RESOLUTION - new_time_series.spatial_downsample(num_select_nodes=self.params['num_select_nodes'], rand_node_slct=False) + new_time_series.spatial_downsample( + num_select_nodes=self.params["num_select_nodes"], rand_node_slct=False + ) # TEMPORAL RESOLUTION - new_time_series.Fs_resample(Fs_ratio=self.params['Fs_ratio']) + new_time_series.Fs_resample(Fs_ratio=self.params["Fs_ratio"]) # NORMALIZE - if self.params['normalization']: + if self.params["normalization"]: new_time_series.normalize() # NOISE - new_time_series.add_noise(noise_ratio=self.params['noise_ratio'], mean_noise=0) + new_time_series.add_noise(noise_ratio=self.params["noise_ratio"], mean_noise=0) # NUMBER OF TIME POINTS - new_time_series.truncate(start_point=0, end_point=self.params['num_time_point']-1) + new_time_series.truncate( + start_point=0, end_point=self.params["num_time_point"] - 1 + ) self.TS_info_ = new_time_series.info_dict @@ -1953,37 +2056,41 @@ def manipulate_time_series4dFC(self, time_series): new_time_series = deepcopy(time_series) # SPATIAL RESOLUTION - new_time_series.spatial_downsample(num_select_nodes=self.params['num_select_nodes'], rand_node_slct=False) + new_time_series.spatial_downsample( + num_select_nodes=self.params["num_select_nodes"], rand_node_slct=False + ) # TEMPORAL RESOLUTION - new_time_series.Fs_resample(Fs_ratio=self.params['Fs_ratio']) + new_time_series.Fs_resample(Fs_ratio=self.params["Fs_ratio"]) # NORMALIZE - if self.params['normalization']: + if self.params["normalization"]: new_time_series.normalize() # NOISE - new_time_series.add_noise(noise_ratio=self.params['noise_ratio'], mean_noise=0) + new_time_series.add_noise(noise_ratio=self.params["noise_ratio"], mean_noise=0) # NUMBER OF TIME POINTS - new_time_series.truncate(start_point=0, end_point=self.params['num_time_point']-1) + new_time_series.truncate( + start_point=0, end_point=self.params["num_time_point"] - 1 + ) return new_time_series - + def visualize_states(self): pass # todo : use FCS_dict func in this func def visualize_FCS( - self, - normalize=True, fix_lim=True, - save_image=False, output_root=None - ): - + self, normalize=True, fix_lim=True, save_image=False, output_root=None + ): + visualize_FCS( self, - normalize=normalize, fix_lim=fix_lim, - save_image=save_image, output_root=output_root + normalize=normalize, + fix_lim=fix_lim, + save_image=save_image, + output_root=output_root, ) def visualize_TPM(self, normalize=True, save_image=False, output_root=None): - + if self.TPM == []: return if normalize: @@ -1992,17 +2099,21 @@ def visualize_TPM(self, normalize=True, save_image=False, output_root=None): C = np.expand_dims(self.TPM, axis=0) plt.figure(figsize=(5, 5)) - plt.imshow(np.squeeze(C), interpolation='nearest', aspect='equal', cmap='jet') - cb=plt.colorbar(shrink=0.8) - plt.title(self.measure_name + ' TPM') - + plt.imshow(np.squeeze(C), interpolation="nearest", aspect="equal", cmap="jet") + plt.colorbar(shrink=0.8) + plt.title(self.measure_name + " TPM") + if save_image: - folder = output_root[:output_root.rfind('/')] + folder = output_root[: output_root.rfind("/")] if not os.path.exists(folder): os.makedirs(folder) - plt.savefig(output_root+'.'+save_fig_format, \ - dpi=fig_dpi, bbox_inches=fig_bbox_inches, pad_inches=fig_pad, format=save_fig_format - ) + plt.savefig( + output_root + "." + save_fig_format, + dpi=fig_dpi, + bbox_inches=fig_bbox_inches, + pad_inches=fig_pad, + format=save_fig_format, + ) plt.close() else: plt.show() @@ -2010,7 +2121,7 @@ def visualize_TPM(self, normalize=True, save_image=False, output_root=None): ################################## NEW METHOD ################################## -''' +""" by : web link Reference: ## @@ -2043,10 +2154,10 @@ def __init__(self, **params): self.params['specific_param'] = value self.params['measure_name'] = 'method_name' self.params['is_state_based'] = True/False - + @property def measure_name(self): - return self.params['measure_name'] + return self.params['measure_name'] def estimate_FCS(self, time_series): @@ -2071,7 +2182,7 @@ def estimate_FCS(self, time_series): return self def estimate_dFCM(self, time_series): - + assert type(time_series) is TIME_SERIES, \ "time_series must be of TIME_SERIES class." @@ -2087,15 +2198,15 @@ def estimate_dFCM(self, time_series): # record time self.set_dFC_assess_time(time.time() - tic) - + dFCM = DFCM(measure=self) dFCM.set_dFC(FCSs=self.FCS_, FCS_idx=FCS_idx, TS_info=time_series.info_dict) return dFCM -''' +""" ################################## CAP ################################## -''' +""" by : web link Reference: ## @@ -2108,32 +2219,44 @@ def estimate_dFCM(self, time_series): Sample spacing. todo: -''' +""" from sklearn.cluster import KMeans + class CAP(dFC): def __init__(self, **params): - self.logs_ = '' + self.logs_ = "" self.FCS_ = [] self.mean_act = [] self.FCS_fit_time_ = None self.dFC_assess_time_ = None - self.params_name_lst = ['measure_name', 'is_state_based', 'n_states', \ - 'n_subj_clstrs', 'normalization', 'num_subj', 'num_select_nodes', 'num_time_point', \ - 'Fs_ratio', 'noise_ratio', 'num_realization', 'session'] + self.params_name_lst = [ + "measure_name", + "is_state_based", + "n_states", + "n_subj_clstrs", + "normalization", + "num_subj", + "num_select_nodes", + "num_time_point", + "Fs_ratio", + "noise_ratio", + "num_realization", + "session", + ] self.params = {} for params_name in self.params_name_lst: if params_name in params: self.params[params_name] = params[params_name] - self.params['measure_name'] = 'CAP' - self.params['is_state_based'] = True + self.params["measure_name"] = "CAP" + self.params["is_state_based"] = True @property def measure_name(self): - return self.params['measure_name'] + return self.params["measure_name"] def act_vec2FCS(self, act_vecs): FCS_ = list() @@ -2150,8 +2273,9 @@ def cluster_act_vec(self, act_vecs, n_clusters): def estimate_FCS(self, time_series): - assert type(time_series) is TIME_SERIES, \ - "time_series must be of TIME_SERIES class." + assert ( + type(time_series) is TIME_SERIES + ), "time_series must be of TIME_SERIES class." time_series = self.manipulate_time_series4FCS(time_series) @@ -2162,29 +2286,30 @@ def estimate_FCS(self, time_series): SUBJECTs = time_series.subj_id_lst act_center_1st_level = None for subject in SUBJECTs: - + act_vecs = time_series.get_subj_ts(subjs_id=subject).data.T # test - if act_vecs.shape[0]=periods)) + X_corrected = np.multiply(X, (coi >= periods)) return X_corrected def WT_dFC(self, Y1, Y2, Fs, J, s0, dj): - if self.params['TF_method']=='CWT_mag' or self.params['TF_method']=='CWT_phase_r' or self.params['TF_method']=='CWT_phase_a': + if ( + self.params["TF_method"] == "CWT_mag" + or self.params["TF_method"] == "CWT_phase_r" + or self.params["TF_method"] == "CWT_phase_a" + ): # Cross Wavelet Transform - WT_xy, coi, freqs, _ = wavelet.xwt(Y1, Y2, dt=1/Fs, dj=dj, s0=s0, J=J, - significance_level=0.95, wavelet='morlet', normalize=True) + WT_xy, coi, freqs, _ = wavelet.xwt( + Y1, + Y2, + dt=1 / Fs, + dj=dj, + s0=s0, + J=J, + significance_level=0.95, + wavelet="morlet", + normalize=True, + ) - if self.params['TF_method']=='CWT_mag': + if self.params["TF_method"] == "CWT_mag": WT_xy_corrected = self.coi_correct(WT_xy, coi, freqs) wt = np.abs(np.mean(WT_xy_corrected, axis=0)) - if self.params['TF_method']=='CWT_phase_r' or self.params['TF_method']=='CWT_phase_a': + if ( + self.params["TF_method"] == "CWT_phase_r" + or self.params["TF_method"] == "CWT_phase_a" + ): cosA = np.cos(np.angle(WT_xy)) sinA = np.sin(np.angle(WT_xy)) cosA_corrected = self.coi_correct(cosA, coi, freqs) sinA_corrected = self.coi_correct(sinA, coi, freqs) - A = (cosA_corrected + sinA_corrected * 1j) + A = cosA_corrected + sinA_corrected * 1j - if self.params['TF_method']=='CWT_phase_r': + if self.params["TF_method"] == "CWT_phase_r": wt = np.abs(np.mean(A, axis=0)) else: wt = np.angle(np.mean(A, axis=0)) - - if self.params['TF_method']=='WTC': + + if self.params["TF_method"] == "WTC": # Wavelet Transform Coherence - WT_xy, _, coi, freqs, _ = wavelet.wct(Y1, Y2, dt=1/Fs, dj=dj, s0=s0, J=J, - sig=False, significance_level=0.95, wavelet='morlet', normalize=True) + WT_xy, _, coi, freqs, _ = wavelet.wct( + Y1, + Y2, + dt=1 / Fs, + dj=dj, + s0=s0, + J=J, + sig=False, + significance_level=0.95, + wavelet="morlet", + normalize=True, + ) WT_xy_corrected = self.coi_correct(WT_xy, coi, freqs) wt = np.abs(np.mean(WT_xy_corrected, axis=0)) return wt def estimate_dFCM(self, time_series): - - ''' + """ we assume calc is applied on subjects separately - ''' - assert len(time_series.subj_id_lst)==1, \ - 'this function takes only one subject as input.' + """ + assert ( + len(time_series.subj_id_lst) == 1 + ), "this function takes only one subject as input." # params - J = 100 # -1 - s0 = 1 # -1 - dj = 1/12 # 1/12 + J = 100 # -1 + s0 = 1 # -1 + dj = 1 / 12 # 1/12 - assert type(time_series) is TIME_SERIES, \ - "time_series must be of TIME_SERIES class." + assert ( + type(time_series) is TIME_SERIES + ), "time_series must be of TIME_SERIES class." time_series = self.manipulate_time_series4dFC(time_series) # start timing tic = time.time() - WT = np.zeros((time_series.n_time, \ - time_series.n_regions, time_series.n_regions)) + WT = np.zeros((time_series.n_time, time_series.n_regions, time_series.n_regions)) for i in range(time_series.n_regions): - if self.params['n_jobs'] is None: + if self.params["n_jobs"] is None: Q = list() for j in range(time_series.n_regions): - Q.append(self.WT_dFC( \ - Y1=time_series.data[i, :], \ - Y2=time_series.data[j, :], \ - Fs=time_series.Fs, \ - J=J, s0=s0, dj=dj)) + Q.append( + self.WT_dFC( + Y1=time_series.data[i, :], + Y2=time_series.data[j, :], + Fs=time_series.Fs, + J=J, + s0=s0, + dj=dj, + ) + ) else: - Q = Parallel( \ - n_jobs=self.params['n_jobs'], verbose=self.params['verbose'], backend=self.params['backend'])( \ - delayed(self.WT_dFC)( \ - Y1=time_series.data[i, :], \ - Y2=time_series.data[j, :], \ - Fs=time_series.Fs, \ - J=J, s0=s0, dj=dj) \ - for j in range(time_series.n_regions) \ - ) + Q = Parallel( + n_jobs=self.params["n_jobs"], + verbose=self.params["verbose"], + backend=self.params["backend"], + )( + delayed(self.WT_dFC)( + Y1=time_series.data[i, :], + Y2=time_series.data[j, :], + Fs=time_series.Fs, + J=J, + s0=s0, + dj=dj, + ) + for j in range(time_series.n_regions) + ) WT[:, i, :] = np.array(Q).T # record time @@ -2605,6 +2820,7 @@ def estimate_dFCM(self, time_series): dFCM.set_dFC(FCSs=WT, TS_info=time_series.info_dict) return dFCM + ################################# Sliding-Window ################################# """ @@ -2621,59 +2837,71 @@ def estimate_dFCM(self, time_series): from sklearn.covariance import GraphicalLassoCV, graphical_lasso + class SLIDING_WINDOW(dFC): def __init__(self, **params): - self.logs_ = '' + self.logs_ = "" self.TPM = [] self.FCS_ = [] self.FCS_fit_time_ = None self.dFC_assess_time_ = None - self.params_name_lst = ['measure_name', 'is_state_based', 'sw_method', 'tapered_window', \ - 'W', 'n_overlap', 'normalization', \ - 'num_select_nodes', 'num_time_point', 'Fs_ratio', \ - 'noise_ratio', 'num_realization', 'session'] + self.params_name_lst = [ + "measure_name", + "is_state_based", + "sw_method", + "tapered_window", + "W", + "n_overlap", + "normalization", + "num_select_nodes", + "num_time_point", + "Fs_ratio", + "noise_ratio", + "num_realization", + "session", + ] self.params = {} for params_name in self.params_name_lst: if params_name in params: self.params[params_name] = params[params_name] - self.params['measure_name'] = 'SlidingWindow' - self.params['is_state_based'] = False + self.params["measure_name"] = "SlidingWindow" + self.params["is_state_based"] = False + + assert ( + self.params["sw_method"] in self.sw_methods_name_lst + ), "sw_method not recognized." - assert self.params['sw_method'] in self.sw_methods_name_lst, \ - "sw_method not recognized." - - @property def measure_name(self): - return self.params['measure_name'] #+ '_' + self.sw_method + return self.params["measure_name"] # + '_' + self.sw_method def shan_entropy(self, c): c_normalized = c / float(np.sum(c)) c_normalized = c_normalized[np.nonzero(c_normalized)] - H = -sum(c_normalized* np.log2(c_normalized)) + H = -sum(c_normalized * np.log2(c_normalized)) return H def calc_MI(self, X, Y): - + bins = 20 - - c_XY = np.histogram2d(X,Y,bins)[0] - c_X = np.histogram(X,bins)[0] - c_Y = np.histogram(Y,bins)[0] - + + c_XY = np.histogram2d(X, Y, bins)[0] + c_X = np.histogram(X, bins)[0] + c_Y = np.histogram(Y, bins)[0] + H_X = self.shan_entropy(c_X) H_Y = self.shan_entropy(c_Y) H_XY = self.shan_entropy(c_XY) - + MI = H_X + H_Y - H_XY return MI def FC(self, time_series): - - if self.params['sw_method']=='GraphLasso': + + if self.params["sw_method"] == "GraphLasso": model = GraphicalLassoCV() model.fit(time_series.T) C = model.covariance_ @@ -2681,69 +2909,75 @@ def FC(self, time_series): C = np.zeros((time_series.shape[0], time_series.shape[0])) for i in range(time_series.shape[0]): for j in range(i, time_series.shape[0]): - + X = time_series[i, :] Y = time_series[j, :] - if self.params['sw_method']=='MI': + if self.params["sw_method"] == "MI": ########### Mutual Information ############## C[j, i] = self.calc_MI(X, Y) else: ########### Pearson Correlation ############## - if np.var(X)==0 or np.var(Y)==0: + if np.var(X) == 0 or np.var(Y) == 0: C[j, i] = 0 else: C[j, i] = np.corrcoef(X, Y)[0, 1] - C[i, j] = C[j, i] - + C[i, j] = C[j, i] + return C def dFC(self, time_series, W=None, n_overlap=None, tapered_window=False): # W is in time samples - + L = time_series.shape[1] - step = int((1-n_overlap)*W) + step = int((1 - n_overlap) * W) if step == 0: step = 1 - window_taper = signal.windows.gaussian(W, std=3*W/22) + window_taper = signal.windows.gaussian(W, std=3 * W / 22) # C = DFCM(measure=self) FCSs = list() TR_array = list() - for l in range(0, L-W+1, step): + for l in range(0, L - W + 1, step): ######### creating a rectangel window ############ window = np.zeros((L)) - window[l:l+W] = 1 - + window[l : l + W] = 1 + ########### tapering the window ############## if tapered_window: - window = signal.convolve(window, window_taper, mode='same') / sum(window_taper) + window = signal.convolve(window, window_taper, mode="same") / sum( + window_taper + ) - window = np.repeat(np.expand_dims(window, axis=0), time_series.shape[0], axis=0) + window = np.repeat( + np.expand_dims(window, axis=0), time_series.shape[0], axis=0 + ) # int(l-W/2):int(l+3*W/2) is the nonzero interval after tapering - FCSs.append(self.FC( \ - np.multiply(time_series, window)[ \ - :,max(int(l-W/2),0):min(int(l+3*W/2),L) \ - ] \ - ) - ) - TR_array.append(int((l + (l+W)) / 2) ) + FCSs.append( + self.FC( + np.multiply(time_series, window)[ + :, max(int(l - W / 2), 0) : min(int(l + 3 * W / 2), L) + ] + ) + ) + TR_array.append(int((l + (l + W)) / 2)) return np.array(FCSs), np.array(TR_array) - + def estimate_dFCM(self, time_series): - - ''' + """ we assume calc is applied on subjects separately - ''' - assert len(time_series.subj_id_lst)==1, \ - 'this function takes only one subject as input.' + """ + assert ( + len(time_series.subj_id_lst) == 1 + ), "this function takes only one subject as input." - assert type(time_series) is TIME_SERIES, \ - "time_series must be of TIME_SERIES class." + assert ( + type(time_series) is TIME_SERIES + ), "time_series must be of TIME_SERIES class." time_series = self.manipulate_time_series4dFC(time_series) @@ -2751,11 +2985,12 @@ def estimate_dFCM(self, time_series): tic = time.time() # W is converted from sec to samples - FCSs, TR_array = self.dFC(time_series=time_series.data, \ - W=int(self.params['W'] * time_series.Fs) , \ - n_overlap=self.params['n_overlap'], \ - tapered_window=self.params['tapered_window'] \ - ) + FCSs, TR_array = self.dFC( + time_series=time_series.data, + W=int(self.params["W"] * time_series.Fs), + n_overlap=self.params["n_overlap"], + tapered_window=self.params["tapered_window"], + ) # record time self.set_dFC_assess_time(time.time() - tic) @@ -2769,12 +3004,12 @@ def estimate_dFCM(self, time_series): ########################### Sliding_Window + Clustering ############################ """ -- We used a tapered window as in Allen et al., created by convolving a rectangle (width = 22 TRs = 44s) +- We used a tapered window as in Allen et al., created by convolving a rectangle (width = 22 TRs = 44s) with a Gaussian (σ = 3 TRs) and slid in steps of 1 TR, resulting in W= 126 windows (Allen et al., 2014). - Kmeans Clustering is repeated 500 times to escape local minima (Allen et al., 2014) - for clustering, we have a 2-level kmeans clustering. First, we cluster FCSs of each subject. Then, we cluster all clustering centers from all subjects. the final estimate_dFCM is using the second kmeans - model (Allen et al., 2014; Ou et al., 2015). + model (Allen et al., 2014; Ou et al., 2015). Parameters ---------- @@ -2787,46 +3022,67 @@ def estimate_dFCM(self, time_series): - pyclustering(manhattan) has a problem when suing predict """ -from sklearn.cluster import KMeans -from pyclustering.cluster.kmeans import kmeans from pyclustering.cluster.center_initializer import kmeans_plusplus_initializer +from pyclustering.cluster.kmeans import kmeans from pyclustering.utils.metric import distance_metric, type_metric +from sklearn.cluster import KMeans + class SLIDING_WINDOW_CLUSTR(dFC): - def __init__(self, clstr_distance='euclidean', **params): + def __init__(self, clstr_distance="euclidean", **params): - assert clstr_distance=='euclidean' or clstr_distance=='manhattan', \ - "Clustering distance not recognized. It must be either \ + assert ( + clstr_distance == "euclidean" or clstr_distance == "manhattan" + ), "Clustering distance not recognized. It must be either \ euclidean or manhattan." - self.logs_ = '' + self.logs_ = "" self.TPM = [] self.FCS_ = [] self.mean_act = [] self.FCS_fit_time_ = None self.dFC_assess_time_ = None - self.params_name_lst = ['measure_name', 'is_state_based', 'clstr_base_measure', 'sw_method', 'tapered_window', \ - 'clstr_distance', 'coi_correction', \ - 'n_subj_clstrs', 'W', 'n_overlap', 'n_states', 'normalization', \ - 'n_jobs', 'verbose', 'backend', \ - 'num_subj', 'num_select_nodes', 'num_time_point', 'Fs_ratio', \ - 'noise_ratio', 'num_realization', 'session'] + self.params_name_lst = [ + "measure_name", + "is_state_based", + "clstr_base_measure", + "sw_method", + "tapered_window", + "clstr_distance", + "coi_correction", + "n_subj_clstrs", + "W", + "n_overlap", + "n_states", + "normalization", + "n_jobs", + "verbose", + "backend", + "num_subj", + "num_select_nodes", + "num_time_point", + "Fs_ratio", + "noise_ratio", + "num_realization", + "session", + ] self.params = {} for params_name in self.params_name_lst: if params_name in params: self.params[params_name] = params[params_name] - - self.params['measure_name'] = 'Clustering' - self.params['is_state_based'] = True - self.params['clstr_distance'] = clstr_distance - assert self.params['clstr_base_measure'] in self.base_methods_name_lst, \ - "Base method not recognized." + self.params["measure_name"] = "Clustering" + self.params["is_state_based"] = True + self.params["clstr_distance"] = clstr_distance + + assert ( + self.params["clstr_base_measure"] in self.base_methods_name_lst + ), "Base method not recognized." @property def measure_name(self): - return self.params['measure_name'] #+ '_' + self.base_method + return self.params["measure_name"] # + '_' + self.base_method def dFC_mat2vec(self, C_t): return dFC_mat2vec(C_t) @@ -2868,7 +3124,7 @@ def cluster_FC(self, FCS_raw, n_clusters, n_regions): F = self.dFC_mat2vec(FCS_raw) - if self.params['clstr_distance']=='manhattan': + if self.params["clstr_distance"] == "manhattan": pass # ########### Manhattan Clustering ############## # # Prepare initial centers using K-Means++ method. @@ -2889,11 +3145,11 @@ def cluster_FC(self, FCS_raw, n_clusters, n_regions): FCS_ = self.dFC_vec2mat(F_cent, N=n_regions) return FCS_, kmeans_ - def estimate_FCS(self, time_series): - assert type(time_series) is TIME_SERIES, \ - "time_series must be of TIME_SERIES class." + assert ( + type(time_series) is TIME_SERIES + ), "time_series must be of TIME_SERIES class." time_series = self.manipulate_time_series4FCS(time_series) @@ -2901,9 +3157,9 @@ def estimate_FCS(self, time_series): tic = time.time() base_dFC = None - if self.params['clstr_base_measure']=='Time-Freq': + if self.params["clstr_base_measure"] == "Time-Freq": base_dFC = TIME_FREQ(**self.params) - if self.params['clstr_base_measure']=='SlidingWindow': + if self.params["clstr_base_measure"] == "SlidingWindow": base_dFC = SLIDING_WINDOW(**self.params) # 1-level clustering @@ -2920,38 +3176,41 @@ def estimate_FCS(self, time_series): FCS_1st_level = None SW_dFC = None for subject in SUBJECTs: - - dFCM_raw = base_dFC.estimate_dFCM( \ - time_series=time_series.get_subj_ts(subjs_id=subject) \ - ) + + dFCM_raw = base_dFC.estimate_dFCM( + time_series=time_series.get_subj_ts(subjs_id=subject) + ) # test - if dFCM_raw.n_time