From 4c862f02264852c2e1adc4e376f1875467a1ab99 Mon Sep 17 00:00:00 2001 From: Remi Gau Date: Thu, 11 Jan 2024 14:15:09 +0100 Subject: [PATCH] set up config for precommit, black, flake8, codespell, isort and run --- .flake8 | 42 + .github/dependabot.yml | 9 + .github/workflows/run_precommit.yml | 15 + .github/workflows/update_precommit_hooks.yml | 53 + .gitignore | 12 +- .pre-commit-config.yaml | 60 + README.rst | 4 +- build/lib/pydfc/__init__.py | 29 - build/lib/pydfc/comparison/__init__.py | 10 - build/lib/pydfc/comparison/analytical.py | 316 -- build/lib/pydfc/comparison/plotting.py | 914 ------ .../pydfc/comparison/similarity_assessment.py | 418 --- build/lib/pydfc/data_loader.py | 310 -- build/lib/pydfc/dfc.py | 303 -- build/lib/pydfc/dfc_methods/__init__.py | 17 - .../lib/pydfc/dfc_methods/base_dfc_method.py | 312 -- build/lib/pydfc/dfc_methods/cap.py | 148 - build/lib/pydfc/dfc_methods/continuous_hmm.py | 117 - build/lib/pydfc/dfc_methods/discrete_hmm.py | 168 - build/lib/pydfc/dfc_methods/sliding_window.py | 177 - .../dfc_methods/sliding_window_clustr.py | 260 -- build/lib/pydfc/dfc_methods/time_freq.py | 198 -- build/lib/pydfc/dfc_methods/windowless.py | 120 - build/lib/pydfc/dfc_utils.py | 1181 ------- build/lib/pydfc/multi_analysis.py | 245 -- build/lib/pydfc/task_utils.py | 99 - build/lib/pydfc/time_series.py | 415 --- dFC_methods_demo.ipynb | 153 +- dist/pydfc-1.0.1-py3-none-any.whl | Bin 53105 -> 0 bytes dist/pydfc-1.0.1.tar.gz | Bin 42731 -> 0 bytes multi_analysis_demo.ipynb | 156 +- pydfc.egg-info/PKG-INFO | 15 - pydfc.egg-info/SOURCES.txt | 27 - pydfc.egg-info/dependency_links.txt | 1 - pydfc.egg-info/requires.txt | 15 - pydfc.egg-info/top_level.txt | 1 - pydfc/__init__.py | 21 +- pydfc/comparison/__init__.py | 7 +- pydfc/comparison/analytical.py | 160 +- pydfc/comparison/plotting.py | 924 +++--- pydfc/comparison/similarity_assessment.py | 385 +-- pydfc/data_loader.py | 278 +- pydfc/dfc.py | 195 +- pydfc/dfc_methods/__init__.py | 20 +- pydfc/dfc_methods/base_dfc_method.py | 139 +- pydfc/dfc_methods/cap.py | 99 +- pydfc/dfc_methods/continuous_hmm.py | 69 +- pydfc/dfc_methods/discrete_hmm.py | 134 +- pydfc/dfc_methods/sliding_window.py | 142 +- pydfc/dfc_methods/sliding_window_clustr.py | 167 +- pydfc/dfc_methods/time_freq.py | 186 +- pydfc/dfc_methods/windowless.py | 81 +- pydfc/dfc_utils.py | 838 ++--- pydfc/multi_analysis.py | 167 +- pydfc/task_utils.py | 73 +- pydfc/time_series.py | 259 +- pyproject.toml | 18 + rest_dFC/FCS_estimate.py | 165 +- rest_dFC/dFC_assessment.py | 65 +- rest_dFC/functions/dFC_funcs.py | 2852 +++++++++-------- rest_dFC/functions/post_analysis_funcs.py | 1184 ++++--- rest_dFC/main.py | 332 +- rest_dFC/post_analysis.py | 1343 ++++---- rest_dFC/similarity_measurement.py | 53 +- rest_dFC/test_dFC.py | 359 ++- rest_dFC/visualization.py | 1442 +++++---- setup.py | 67 +- task_dFC/FCS_estimate.py | 153 +- task_dFC/dFC_assessment.py | 73 +- task_dFC/nifti_to_roi_signal.py | 109 +- task_dFC/validation.py | 39 +- 71 files changed, 7459 insertions(+), 11459 deletions(-) create mode 100644 .flake8 create mode 100644 .github/dependabot.yml create mode 100644 .github/workflows/run_precommit.yml create mode 100644 .github/workflows/update_precommit_hooks.yml create mode 100644 .pre-commit-config.yaml delete mode 100644 build/lib/pydfc/__init__.py delete mode 100644 build/lib/pydfc/comparison/__init__.py delete mode 100644 build/lib/pydfc/comparison/analytical.py delete mode 100644 build/lib/pydfc/comparison/plotting.py delete mode 100644 build/lib/pydfc/comparison/similarity_assessment.py delete mode 100644 build/lib/pydfc/data_loader.py delete mode 100644 build/lib/pydfc/dfc.py delete mode 100644 build/lib/pydfc/dfc_methods/__init__.py delete mode 100644 build/lib/pydfc/dfc_methods/base_dfc_method.py delete mode 100644 build/lib/pydfc/dfc_methods/cap.py delete mode 100644 build/lib/pydfc/dfc_methods/continuous_hmm.py delete mode 100644 build/lib/pydfc/dfc_methods/discrete_hmm.py delete mode 100644 build/lib/pydfc/dfc_methods/sliding_window.py delete mode 100644 build/lib/pydfc/dfc_methods/sliding_window_clustr.py delete mode 100644 build/lib/pydfc/dfc_methods/time_freq.py delete mode 100644 build/lib/pydfc/dfc_methods/windowless.py delete mode 100644 build/lib/pydfc/dfc_utils.py delete mode 100644 build/lib/pydfc/multi_analysis.py delete mode 100644 build/lib/pydfc/task_utils.py delete mode 100644 build/lib/pydfc/time_series.py delete mode 100644 dist/pydfc-1.0.1-py3-none-any.whl delete mode 100644 dist/pydfc-1.0.1.tar.gz delete mode 100644 pydfc.egg-info/PKG-INFO delete mode 100644 pydfc.egg-info/SOURCES.txt delete mode 100644 pydfc.egg-info/dependency_links.txt delete mode 100644 pydfc.egg-info/requires.txt delete mode 100644 pydfc.egg-info/top_level.txt create mode 100644 pyproject.toml diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..dfe5f73 --- /dev/null +++ b/.flake8 @@ -0,0 +1,42 @@ +[flake8] +exclude = + .git, + __pycache__, + build, + dist, +--select = D,E,F,W +docstring-convention = numpy +max-line-length = 250 +# For PEP8 error codes see +# http://pep8.readthedocs.org/en/latest/intro.html#error-codes + # D100-D104: missing docstring + # D105: missing docstring in magic method + # D107: missing docstring in __init__ + # W504: line break after binary operator +per-file-ignores = + **/__init__.py: D104 +ignore = + BLK100, + D105 + D107, + E402, + E266, + E721, + E731, + E713, + E714, + E741, + F403, + F405, + E401, + F401, + F811, + F841, + F821, + FS001 + W503, + W504, + W605, +# for compatibility with black +# https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html#flake8 +extend-ignore = E203 diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..5ab0ddf --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,9 @@ +--- +# Documentation +# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file +version: 2 +updates: +- package-ecosystem: github-actions + directory: / + schedule: + interval: monthly diff --git a/.github/workflows/run_precommit.yml b/.github/workflows/run_precommit.yml new file mode 100644 index 0000000..97100d0 --- /dev/null +++ b/.github/workflows/run_precommit.yml @@ -0,0 +1,15 @@ +--- +name: pre-commit + +on: + pull_request: + push: + branches: [main] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + - uses: pre-commit/action@v3.0.0 diff --git a/.github/workflows/update_precommit_hooks.yml b/.github/workflows/update_precommit_hooks.yml new file mode 100644 index 0000000..a0cf4bf --- /dev/null +++ b/.github/workflows/update_precommit_hooks.yml @@ -0,0 +1,53 @@ +--- +name: Update precommit hooks + + +on: + +# Uses the cron schedule for github actions +# +# https://docs.github.com/en/free-pro-team@latest/actions/reference/events-that-trigger-workflows#scheduled-events +# +# ┌───────────── minute (0 - 59) +# │ ┌───────────── hour (0 - 23) +# │ │ ┌───────────── day of the month (1 - 31) +# │ │ │ ┌───────────── month (1 - 12 or JAN-DEC) +# │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT) +# │ │ │ │ │ +# │ │ │ │ │ +# │ │ │ │ │ +# * * * * * + schedule: + - cron: 0 0 * * 1 # every monday + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + update_precommit_hooks: + + # only run on upstream repo + if: github.repository_owner == 'SIMEXP' + + runs-on: ubuntu-latest + steps: + - name: Checkout repo + uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + allow-prereleases: false + - name: Install pre-commit + run: pip install pre-commit + - name: Update pre-commit hooks + run: pre-commit autoupdate + - name: Create Pull Request + uses: peter-evans/create-pull-request@v5 + with: + commit-message: pre-commit hooks auto-update + base: main + token: ${{ secrets.GITHUB_TOKEN }} + delete-branch: true + title: '[BOT] update pre-commit hooks' + body: done via this [GitHub Action](https://github.com/${{ github.repository_owner }}/giga_connectome/blob/main/.github/workflows/update_precommit_hooks.yml) diff --git a/.gitignore b/.gitignore index 044e1c4..f84f908 100644 --- a/.gitignore +++ b/.gitignore @@ -2,14 +2,4 @@ __pycache__ *.pyc *.cpython -sample_data/sub-0001_task-restingstate_acq-mb3_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz -sample_data/sub-0001_task-restingstate_acq-mb3_desc-confounds_regressors.tsv -sample_data/sub-0002_task-restingstate_acq-mb3_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz -sample_data/sub-0002_task-restingstate_acq-mb3_desc-confounds_regressors.tsv -sample_data/sub-0003_task-restingstate_acq-mb3_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz -sample_data/sub-0003_task-restingstate_acq-mb3_desc-confounds_regressors.tsv -sample_data/sub-0004_task-restingstate_acq-mb3_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz -sample_data/sub-0004_task-restingstate_acq-mb3_desc-confounds_regressors.tsv -sample_data/sub-0005_task-restingstate_acq-mb3_space-MNI152NLin2009cAsym_desc-preproc_bold.nii.gz -sample_data/sub-0005_task-restingstate_acq-mb3_desc-confounds_regressors.tsv - +sample_data/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..f2ba1e3 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,60 @@ +--- +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks + +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-ast + - id: check-case-conflict + - id: check-json + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: mixed-line-ending + - id: trailing-whitespace + +- repo: https://github.com/ikamensh/flynt/ + rev: 1.0.1 + hooks: + - id: flynt + +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + args: [--profile, black] + +- repo: https://github.com/psf/black-pre-commit-mirror + rev: 23.12.1 + hooks: + - id: black-jupyter + args: [--config, pyproject.toml] + +- repo: https://github.com/codespell-project/codespell + rev: v2.2.6 + hooks: + - id: codespell + args: [--toml, pyproject.toml] + additional_dependencies: [tomli] + +- repo: https://github.com/jumanjihouse/pre-commit-hook-yamlfmt + rev: 0.2.3 + hooks: + - id: yamlfmt + args: [--mapping, '4', --sequence, '4', --offset, '0'] + +- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks + rev: v2.12.0 + hooks: + - id: pretty-format-toml + args: [--autofix, --indent, '4'] + +- repo: https://github.com/pyCQA/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + args: [--config, .flake8, --verbose, pydfc, rest_dFC, task_dFC] + additional_dependencies: [flake8-use-fstring] diff --git a/README.rst b/README.rst index 3c48514..5080c62 100644 --- a/README.rst +++ b/README.rst @@ -1,6 +1,6 @@ .. image:: https://zenodo.org/badge/DOI/10.5281/zenodo.10211966.svg :target: https://doi.org/10.5281/zenodo.10211966 - + pydfc ======= An implementation of several well-known dynamic Functional Connectivity (dFC) assessment methods. @@ -16,4 +16,4 @@ The ``multi_analysis_demo.ipynb`` illustrates how to use the ``pydfc`` toolbox t For more details about the implemented methods and the comparison analysis see `our paper `_. - * Torabi M, Mitsis GD, Poline JB. On the variability of dynamic functional connectivity assessment methods. bioRxiv. 2023:2023-07. + * Torabi M, Mitsis GD, Poline JB. On the variability of dynamic functional connectivity assessment methods. bioRxiv. 2023:2023-07. diff --git a/build/lib/pydfc/__init__.py b/build/lib/pydfc/__init__.py deleted file mode 100644 index 0bca438..0000000 --- a/build/lib/pydfc/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -"""The pydFC toolbox. - -Submodules ---------- - -dfc_methods --- implementation of dFC methods -multi_analysis --- multi analysis class implementing - multiple dFC methods simultaneously -time_series --- time series class -dfc --- dfc class -data_loader --- load data -dfc_utils --- functions used for dFC analysis -comparison --- functions used for dFC results comparison - -""" - -from . import dfc_methods -from .multi_analysis import MultiAnalysis -from .time_series import TIME_SERIES -from .dfc import DFC - -__all__ = ['MultiAnalysis', - 'TIME_SERIES', - 'DFC', - 'data_loader', - 'dfc_methods', - 'dfc_utils', - 'comparison' - ] diff --git a/build/lib/pydfc/comparison/__init__.py b/build/lib/pydfc/comparison/__init__.py deleted file mode 100644 index 9bb7f49..0000000 --- a/build/lib/pydfc/comparison/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""The pydfc toolbox.""" - -from .similarity_assessment import SimilarityAssessment - - -__all__ = [ - 'SimilarityAssessment', - 'plotting', - 'analytical' -] diff --git a/build/lib/pydfc/comparison/analytical.py b/build/lib/pydfc/comparison/analytical.py deleted file mode 100644 index 8ec2a2c..0000000 --- a/build/lib/pydfc/comparison/analytical.py +++ /dev/null @@ -1,316 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Analytical functions for the comparison framework. - -Created on Jun 29 2023 -@author: Mohammad Torabi -""" - -import numpy as np -import statsmodels.api as sm -from scipy import stats -from statsmodels.formula.api import ols -import pandas as pd -from copy import deepcopy - -from ..dfc_utils import zip_name, mat_reorder, dFC_mat2vec - -################################# Analytical Functions #################################### - -############## STAT Functions ############## - -def make_sim_distribution(sim_mats_lst, name_lst, zip_names=True): - ''' - each sim_mat in the sim_mat_lst corresponds - to a subj. the name_lst must correspond to the - columns and rows of sim_mats - ''' - output = {} - for sim_mat in sim_mats_lst: - - for i, name_i in enumerate(name_lst): - for j, name_j in enumerate(name_lst): - - if j>=i: - continue - - if zip_names: - name_i_used = zip_name(name_i) - name_j_used = zip_name(name_j) - else: - name_i_used = name_i - name_j_used = name_j - - if not name_i_used in output: - output[name_i_used] = {} - if not name_j_used in output[name_i_used]: - output[name_i_used][name_j_used] = {'sim':list(), '':list()} - - output[name_i_used][name_j_used]['sim'].append(sim_mat[i, j]) - output[name_i_used][name_j_used][''].append('sim') - return output - -def two_way_anova(data): - ''' - perform two-way anova - target: sim - factor1: session - factor2: direction - ''' - df = pd.DataFrame(data) - - # Performing two-way ANOVA - model = ols('sim ~ C(session) + C(direction) +\ - C(session):C(direction)', - data=df).fit() - - return sm.stats.anova_lm(model, type=2) - -def convert_pvalue_to_asterisks(pvalue): - if pvalue <= 0.0001: - return "****" - elif pvalue <= 0.001: - return "***" - elif pvalue <= 0.01: - return "**" - elif pvalue <= 0.05: - return "*" - return "ns" - -############## Randomization Functions ############## - -def randomize_time(dFC_dict, N): - ''' - ''' - output = {} - for n in range(N): - - for i, measure_i_name in enumerate(dFC_dict): - - dFC_mat_i = dFC_dict[measure_i_name] - - # randomize the temporal order - n_time = dFC_mat_i.shape[0] - idx = np.random.choice(n_time, n_time, replace=False) - dFC_mat_i = dFC_mat_i[idx, :, :] - - dFC_mat_i_vec = dFC_mat2vec(dFC_mat_i) - - for j, measure_j_name in enumerate(dFC_dict): - - if j>i: - continue - if not measure_i_name in output: - output[measure_i_name] = {} - if not measure_j_name in output[measure_i_name]: - output[measure_i_name][measure_j_name] = {'sim':list(), '':list()} - - dFC_mat_j = dFC_dict[measure_j_name] - - # randomize the temporal order - n_time = dFC_mat_j.shape[0] - idx = np.random.choice(n_time, n_time, replace=False) - dFC_mat_j = dFC_mat_j[idx, :, :] - - dFC_mat_j_vec = dFC_mat2vec(dFC_mat_j) - - sim, p = stats.spearmanr(dFC_mat_i_vec.flatten(), dFC_mat_j_vec.flatten()) - output[measure_i_name][measure_j_name]['sim'].append(sim) - output[measure_i_name][measure_j_name][''].append('sim') - - return output - - -def suffle_dFC(dFC_mat, mode): - ''' - dFC_mat = ndarray(time, region, region) - mode can be 'temporal', 'spatial', - or 'all' - ''' - new_dFC_mat = deepcopy(dFC_mat) - if mode=='temporal': - n_time = new_dFC_mat.shape[0] - new_order = np.random.choice(n_time, n_time, replace=False) - new_dFC_mat = new_dFC_mat[new_order, :, :] - elif mode=='spatial': - n_region = new_dFC_mat.shape[1] - new_order = np.random.choice(n_region, n_region, replace=False) - for k, mat in enumerate(new_dFC_mat): - new_dFC_mat[k, :, :] = mat_reorder(new_dFC_mat[k, :, :], new_order) - elif mode=='all': - #spatial - n_region = new_dFC_mat.shape[1] - new_order_regions = np.random.choice(n_region, n_region, replace=False) - for k, mat in enumerate(new_dFC_mat): - new_dFC_mat[k, :, :] = mat_reorder(new_dFC_mat[k, :, :], new_order_regions) - #temporal - n_time = new_dFC_mat.shape[0] - new_order_time = np.random.choice(n_time, n_time, replace=False) - new_dFC_mat = new_dFC_mat[new_order_time, :, :] - - return new_dFC_mat - - -def randomized_dFC_sim(dFC_dict, N, mode): - ''' - mode can be 'temporal', 'spatial', - or 'all' - 'spatial': this will result in different methods having - different spatial/region orders but still the - same temporal order - 'temporal': this will result in different methods having - different temporal orders but still the - same spatial order - 'all': this will result in different methods having - different temporal orders AND different spatial order - ''' - output = {} - for n in range(N): - - for i, measure_i_name in enumerate(dFC_dict): - - dFC_mat_i = dFC_dict[measure_i_name] - - # randomize the spatial (regions) order - dFC_mat_i = suffle_dFC(dFC_mat_i, mode=mode) - - dFC_mat_i_vec = dFC_mat2vec(dFC_mat_i) - - for j, measure_j_name in enumerate(dFC_dict): - - if j>i: - continue - if not measure_i_name in output: - output[measure_i_name] = {} - if not measure_j_name in output[measure_i_name]: - output[measure_i_name][measure_j_name] = {'sim':list(), '':list()} - - dFC_mat_j = dFC_dict[measure_j_name] - - # randomize the temporal order - dFC_mat_j = suffle_dFC(dFC_mat_j, mode=mode) - - dFC_mat_j_vec = dFC_mat2vec(dFC_mat_j) - - sim, p = stats.spearmanr(dFC_mat_i_vec.flatten(), dFC_mat_j_vec.flatten()) - output[measure_i_name][measure_j_name]['sim'].append(sim) - output[measure_i_name][measure_j_name][''].append('sim') - - return output - - -def dFC_rand_generator(FCS, n_time): - ''' - generate a dFC mat of length n_time - using spatial FC patterns in FCS = (num_pattern, ROI, ROI) - ''' - dFC_rand = None - idx = np.random.choice(FCS.shape[0], n_time, replace=True) - dFC_rand = FCS[idx, :, :] - return dFC_rand - -def dFC_rand_sim(FCS_dict, n_time, N): - ''' - for random state TC similarity assessment - ''' - output = {} - for n in range(N): - - for i, measure_i_name in enumerate(FCS_dict): - - dFC_rand = dFC_rand_generator(FCS_dict[measure_i_name], n_time=n_time) - dFC_mat_i = dFC_rand - dFC_mat_i_vec = dFC_mat2vec(dFC_mat_i) - - for j, measure_j_name in enumerate(FCS_dict): - - if j>i: - continue - if not measure_i_name in output: - output[measure_i_name] = {} - if not measure_j_name in output[measure_i_name]: - output[measure_i_name][measure_j_name] = {'sim':list(), '':list()} - - dFC_rand = dFC_rand_generator(FCS_dict[measure_j_name], n_time=n_time) - dFC_mat_j = dFC_rand - dFC_mat_j_vec = dFC_mat2vec(dFC_mat_j) - - sim, p = stats.spearmanr(dFC_mat_i_vec.flatten(), dFC_mat_j_vec.flatten()) - output[measure_i_name][measure_j_name]['sim'].append(sim) - output[measure_i_name][measure_j_name][''].append('sim') - return output - -############## Hierarchical Clustering ############## - -def correct_order(s): - list = s.split('-') - list = [int(item) for item in list] - list.sort() - return '-'.join(str(x) for x in list) - -def open_trees(Z, num_leaf): - ''' - replace trees in Z by their leaves - ''' - Z_copy = deepcopy(Z) - Z_new = [] - for tree in Z_copy: - Z_new.append([tree[0], tree[1], tree[2], tree[3]]) - encode_dict = {} - counter = num_leaf - for tree in Z_new: - if tree[0]>=num_leaf: - tree[0] = encode_dict[tree[0]] - else: - tree[0] = str(int(tree[0])) - if tree[1]>=num_leaf: - tree[1] = encode_dict[tree[1]] - else: - tree[1] = str(int(tree[1])) - encode_dict[counter] = tree[0]+'-'+tree[1] - encode_dict[counter] = correct_order(encode_dict[counter]) - counter += 1 - return Z_new - -def is_trees_equal(trees_1, trees_2): - ''' - trees_2 is the reference - ''' - for tree in trees_1: - if (not [tree[0], tree[1]] in trees_2) \ - and (not [tree[1], tree[0]] in trees_2): - return False - return True - -def is_in_Z_clstrs(trees, Z_clstrs, trees_key): - for key in Z_clstrs: - if is_trees_equal(trees, Z_clstrs[key][trees_key]): - return key - return None - -def cluster_Z(Z_lst, num_leaf): - ''' - Z_lst is the list of linkages of samples - num_leaf is the number of objects in clustering - ''' - Z_clstrs = {} - counter = 0 - for Z in Z_lst: - # replace trees in Z by their leaves - Z_open = open_trees(Z, num_leaf) - trees = [[tree[0], tree[1]] for tree in Z_open] - distances = [tree[2] for tree in Z] - clstr_idx = is_in_Z_clstrs(trees, Z_clstrs, trees_key='trees') - - if clstr_idx is None: - Z_clstrs[counter] = { - 'Z': Z, - 'trees': trees, - 'freq': 1, - 'distance_lst': [distances] - } - counter += 1 - else: - Z_clstrs[clstr_idx]['freq'] += 1 - Z_clstrs[clstr_idx]['distance_lst'].append(distances) - return Z_clstrs \ No newline at end of file diff --git a/build/lib/pydfc/comparison/plotting.py b/build/lib/pydfc/comparison/plotting.py deleted file mode 100644 index a719126..0000000 --- a/build/lib/pydfc/comparison/plotting.py +++ /dev/null @@ -1,914 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Plotting function for visualizing comparison results. - -Created on Jun 29 2023 -@author: Mohammad Torabi -""" - -import warnings -import numpy as np -import math -import scipy.cluster.hierarchy as shc -import scipy.spatial.distance as ssd -from sklearn.manifold import TSNE - -import matplotlib.pyplot as plt -import matplotlib as mpl -from nilearn.plotting import plot_markers -from matplotlib.colors import ListedColormap -import seaborn as sns -import pandas as pd -import os - -from ..dfc_utils import visualize_conn_mat_dict - -################################# Parameters #################################### - -fig_dpi = 120 -fig_bbox_inches = 'tight' -fig_pad = 0.1 -show_title = False -save_fig_format = 'png' # pdf, png, - -################################# Plotting Functions #################################### - -def title2file_name(title): - ''' - change all spaces in the title to _ - the original string remains unchanged - ''' - return title.replace(" ", "_") - -def plot_sample_dFC(D, x, - title='', - cmap='seismic', - normalize=False, - disp_diag=True, - save_image=False, output_root=None, - fix_lim=True, center_0=True, - node_networks=None, segmented=False - ): - ''' - D is a dictionary of dFC samples. each - key is the name of a dFC matrix (e.g. method - used for assessing it), and D[key][x] contains the - the dFC matrix as a numpy ndarray - ''' - - num_dFC = len(D) - names_lst = [key for key in D] - num_time = len(D[names_lst[0]][x]) - - fig_width = 48*(num_time/10) - fig_height = 55*(num_dFC/10) - - fig, axes = plt.subplots(num_dFC, num_time, figsize=(fig_width, fig_height), \ - facecolor='w', edgecolor='k') - - fig.subplots_adjust( - bottom=0.1, - top=0.85, - left=0.1, - right=0.9, - wspace=0.5, - hspace=0.6 - ) - - for i, dFC_mat_name in enumerate(D): - visualize_conn_mat_dict(data=D[dFC_mat_name][x], - node_networks=node_networks, - title=dFC_mat_name, - cmap=cmap, center_0=center_0, - normalize=normalize, fix_lim=fix_lim, - disp_diag=disp_diag, - segmented=segmented, - save_image=False, output_root=output_root, - axes=axes[i, :], fig=fig, - ) - - fig.subplots_adjust( - bottom=0.1, - top=0.85, - left=0.1, - right=0.9, - wspace=0.5, - hspace=0.6 - ) - - # set row names - for i, dFC_mat_name in enumerate(D): - axes[i, 0].set_ylabel(dFC_mat_name, fontdict={'fontsize': 25, 'fontweight': 'bold'}, rotation=90) - - if save_image: - folder = output_root[:output_root.rfind('/')] - if not os.path.exists(folder): - os.makedirs(folder) - fig.savefig(output_root+title.replace(" ", "_")+'.'+save_fig_format, - dpi=fig_dpi, bbox_inches=fig_bbox_inches, pad_inches=fig_pad, format=save_fig_format - ) - plt.close() - else: - plt.show() - - -def pairwise_cat_plots(data=None, x=None, y=None, z=None, - title='', - label_dict={}, - save_image=False, output_root=None - ): - ''' - data is a dictionary with different vars as keys - if z is specidied, it will be used as a out of distribution - sample, e.g. actual similarity when plotting randomized - distribution. - ''' - - sns.set_context("paper", - font_scale=2.5, - rc={"lines.linewidth": 3.0} - ) - - row_keys = [key for key in data] - n_rows = len(row_keys) - column_keys = [key for key in data[row_keys[-1]]] - n_columns = len(column_keys) - - sns.set_style('darkgrid') - - fig_width = n_columns * 5 - fig_height = n_rows * 5 - fig, axs = plt.subplots(n_rows, n_columns, figsize=(fig_width, fig_height), \ - facecolor='w', edgecolor='k', sharex=True, sharey=True) - - axs_plotted = list() - for i, key_i in enumerate(data): - for j, key_j in enumerate(data[key_i]): - df = pd.DataFrame(data[key_i][key_j]) - - if not z is None: - sns.stripplot(ax=axs[i, j], data=df, x=x, y=z, color='red', jitter=False, size=10) - sns.violinplot(ax=axs[i, j], data=df, x=x, y=y) - - axs[i, j].set_title(key_i+'-'+key_j, fontdict={'fontsize': 25, 'fontweight': 'bold'}) - - ## set labels - ylabel = axs[i, j].get_ylabel() - if ylabel in label_dict: - ylabel = label_dict[ylabel] - axs[i, j].set_ylabel(ylabel, fontdict={'fontsize': 20, 'fontweight': 'bold'}) - xlabel = axs[i, j].get_xlabel() - if xlabel in label_dict: - xlabel = label_dict[xlabel] - axs[i, j].set_xlabel(xlabel, fontdict={'fontsize': 20, 'fontweight': 'bold'}) - # set font size of the tick labels and make them bold - tick_labels = axs[i, j].get_xticklabels() + axs[i, j].get_yticklabels() - for label in tick_labels: - label.set_fontweight('bold') - - axs_plotted.append(axs[i, j]) - - # remove extra subplots - for ax in axs.ravel(): - if not ax in axs_plotted: - ax.set_axis_off() - ax.xaxis.set_tick_params(which='both', labelbottom=True) - - if show_title: - plt.suptitle(title, fontsize=15, y=0.90) - - if save_image: - folder = output_root[:output_root.rfind('/')] - if not os.path.exists(folder): - os.makedirs(folder) - plt.savefig(output_root+title2file_name(title)+'.'+save_fig_format, \ - dpi=fig_dpi, bbox_inches=fig_bbox_inches, pad_inches=fig_pad, format=save_fig_format \ - ) - plt.close() - else: - plt.show() - -def joint_dist_plot(data, - title='', - label_dict={}, - save_image=False, output_root=None - ): - ''' - data is a dictionary including list of dFC values - of each dFC method - ''' - df = pd.DataFrame(data) - fig_width = 5*len(data) - fig_height = 5*len(data) - - sns.set_context("paper", - font_scale=2.5, - rc={"lines.linewidth": 3.0} - ) - - sns.set_style('darkgrid') - - g = sns.PairGrid(df) - - g.map_diag(sns.histplot) - g.map_offdiag(sns.histplot) - - g.fig.set_figwidth(fig_width) - g.fig.set_figheight(fig_height) - g.fig.subplots_adjust(top=0.95) - - for i in range(g.axes.shape[0]): - for j in range(g.axes.shape[1]): - - ## set labels - ylabel = g.axes[i, j].get_ylabel() - if ylabel in label_dict: - ylabel = label_dict[ylabel] - g.axes[i, j].set_ylabel(ylabel, fontdict={'fontsize': 25, 'fontweight': 'bold'}) - xlabel = g.axes[i, j].get_xlabel() - if xlabel in label_dict: - xlabel = label_dict[xlabel] - g.axes[i, j].set_xlabel(xlabel, fontdict={'fontsize': 25, 'fontweight': 'bold'}) - - # set font size of the tick labels and make them bold - tick_labels = g.axes[i, j].get_xticklabels() + g.axes[i, j].get_yticklabels() - for label in tick_labels: - label.set_fontweight('bold') - - if show_title: - plt.suptitle(title, fontsize=50, y=0.98) - - if save_image: - folder = output_root[:output_root.rfind('/')] - if not os.path.exists(folder): - os.makedirs(folder) - plt.savefig(output_root+title2file_name(title)+'.'+save_fig_format, - dpi=fig_dpi, bbox_inches=fig_bbox_inches, pad_inches=fig_pad, format=save_fig_format - ) - plt.close() - else: - plt.show() - -def pairwise_scatter_plots(data, x, y, - title='', hist=False, - label_dict={}, - equal_axis_lim=False, show_x_equal_y=False, - save_image=False, output_root=None - ): - ''' - data is a dictionary with different vars as keys - ''' - - sns.set_context("paper", - font_scale=2.5, - rc={"lines.linewidth": 3.0} - ) - - row_keys = [key for key in data] - n_rows = len(row_keys) - column_keys = [key for key in data[row_keys[-1]]] - n_columns = len(column_keys) - - sns.set_style('darkgrid') - - fig_width = n_columns * 5 - fig_height = n_rows * 5 - fig, axs = plt.subplots(n_rows, n_columns, figsize=(fig_width, fig_height), \ - facecolor='w', edgecolor='k', sharex=True, sharey=True) - - # equal x_lim and y_lim - if equal_axis_lim or show_x_equal_y: - min_lim = None - max_lim = None - for i, key_i in enumerate(data): - for j, key_j in enumerate(data[key_i]): - df = pd.DataFrame(data[key_i][key_j]) - m = np.minimum(df[x].min(), df[y].min()) - M = np.maximum(df[x].max(), df[y].max()) - if min_lim is None: - min_lim = m - max_lim = M - else: - min_lim = np.minimum(m, min_lim) - max_lim = np.maximum(M, max_lim) - - lim_L = max_lim - min_lim - min_lim = min_lim - lim_L*0.1 - max_lim = max_lim + lim_L*0.1 - - axs_plotted = list() - for i, key_i in enumerate(data): - for j, key_j in enumerate(data[key_i]): - df = pd.DataFrame(data[key_i][key_j]) - if hist: - g = sns.histplot(ax=axs[i, j], data=df, x=x, y=y, bins=50) - else: - g = sns.scatterplot(ax=axs[i, j], data=df, x=x, y=y, s=50) - axs[i, j].set_title(key_i+'-'+key_j, fontdict={'fontsize': 25, 'fontweight': 'bold'}) - - ## set labels and font sizes - ylabel = g.get_ylabel() - if ylabel in label_dict: - ylabel = label_dict[ylabel] - g.set_ylabel(ylabel, fontdict={'fontsize': 18, 'fontweight': 'bold'}) - xlabel = g.get_xlabel() - if xlabel in label_dict: - xlabel = label_dict[xlabel] - g.set_xlabel(xlabel, fontdict={'fontsize': 18, 'fontweight': 'bold'}) - g.tick_params(axis='x', which='major', labelsize=18) - g.tick_params(axis='y', which='major', labelsize=18) - tick_labels = g.get_xticklabels() + g.get_yticklabels() - for label in tick_labels: - label.set_fontweight('bold') - - # equal x_lim and y_lim - if equal_axis_lim: - axs[i, j].set_xlim(min_lim, max_lim) - axs[i, j].set_ylim(min_lim, max_lim) - - # y=x line - if show_x_equal_y: - X_plot = np.linspace(min_lim, max_lim, 100) - Y_plot = X_plot - axs[i, j].plot(X_plot, Y_plot, color='r') - - axs_plotted.append(axs[i, j]) - # remove extra subplots - for ax in axs.ravel(): - if not ax in axs_plotted: - ax.set_axis_off() - ax.xaxis.set_tick_params(which='both', labelbottom=True) - - if show_title: - plt.suptitle(title, fontsize=15, y=0.90) - - if save_image: - folder = output_root[:output_root.rfind('/')] - if not os.path.exists(folder): - os.makedirs(folder) - plt.savefig(output_root+title2file_name(title)+'.'+save_fig_format, - dpi=fig_dpi, bbox_inches=fig_bbox_inches, pad_inches=fig_pad, format=save_fig_format - ) - plt.close() - else: - plt.show() - -def scatter_plot(data, x, y, - labels=None, hue=None, - title='', hist=False, - label_dict={}, - equal_axis_lim=False, show_x_equal_y=False, - c=0.25, - save_image=False, output_root=None - ): - ''' - data is a dictionary with different vars as keys - c determines how far the annotation will be from dots - ''' - df = pd.DataFrame(data) - - sns.set_context("paper", - font_scale=2.5, - rc={"lines.linewidth": 3.0} - ) - - fig_width = 20 - fig_height = 20 - plt.figure(figsize=(fig_width, fig_height)) - sns.set_style('darkgrid') - if hist: - g = sns.histplot(data=df, x=x, y=y, hue=hue) - else: - g = sns.scatterplot(data=df, x=x, y=y, s=100, hue=hue) - - ## set labels and font sizes - ylabel = g.get_ylabel() - if ylabel in label_dict: - ylabel = label_dict[ylabel] - g.set_ylabel(ylabel, fontdict={'fontsize': 35, 'fontweight': 'bold'}) - xlabel = g.get_xlabel() - if xlabel in label_dict: - xlabel = label_dict[xlabel] - g.set_xlabel(xlabel, fontdict={'fontsize': 35, 'fontweight': 'bold'}) - g.tick_params(axis='x', which='major', labelsize=30) - g.tick_params(axis='y', which='major', labelsize=30) - tick_labels = g.get_xticklabels() + g.get_yticklabels() - for label in tick_labels: - label.set_fontweight('bold') - - # equal x_lim and y_lim - if equal_axis_lim: - min_lim = np.minimum(df[x].min(), df[y].min()) - max_lim = np.maximum(df[x].max(), df[y].max()) - lim_L = max_lim - min_lim - min_lim = min_lim - lim_L*0.1 - max_lim = max_lim + lim_L*0.1 - g.set_xlim(min_lim, max_lim) - g.set_ylim(min_lim, max_lim) - - # y=x line - if show_x_equal_y: - min_lim = np.minimum(df[x].min(), df[y].min()) - max_lim = np.maximum(df[x].max(), df[y].max()) - lim_L = max_lim - min_lim - min_lim = min_lim - lim_L*0.1 - max_lim = max_lim + lim_L*0.1 - X_plot = np.linspace(min_lim, max_lim, 100) - Y_plot = X_plot - plt.plot(X_plot, Y_plot, color='r') - - if (not labels is None) and (not hist): - # the labels are located smartly - # the direction will be away from mean - # the distance will be inverse proportional to - # distance from mean - mid_x = (np.max(df[x]) + np.min(df[x]))/2 - mid_y = (np.max(df[y]) + np.min(df[y]))/2 - x_range = max(np.max(df[x])-mid_x, mid_x-np.min(df[x])) - y_range = max(np.max(df[y])-mid_y, mid_y-np.min(df[y])) - distance_from_mean_range = math.sqrt(x_range**2+y_range**2) - for i in range(len(df[x])): - distance_from_mean = math.sqrt((df[x][i]-mid_x)**2+(df[y][i]-mid_y)**2) - text_x = df[x][i]+c*np.sign(df[x][i]-mid_x)*np.abs(df[x][i]-mid_x)*(distance_from_mean_range-distance_from_mean)/distance_from_mean - text_y = df[y][i]+c*np.sign(df[y][i]-mid_y)*np.abs(df[y][i]-mid_y)*(distance_from_mean_range-distance_from_mean)/distance_from_mean - plt.text( - x=text_x, - y=text_y, - s=df[labels][i], - fontdict=dict(color='black', size=14, weight='bold'), - ) - plt.plot( - [df[x][i], text_x], [df[y][i], text_y], - 'k', linewidth=0.5 - ) - - if show_title: - plt.title(title, fontsize=15) - - if save_image: - folder = output_root[:output_root.rfind('/')] - if not os.path.exists(folder): - os.makedirs(folder) - plt.savefig(output_root+title2file_name(title)+'.'+save_fig_format, - dpi=fig_dpi, bbox_inches=fig_bbox_inches, pad_inches=fig_pad, format=save_fig_format - ) - plt.close() - else: - plt.show() - -def cat_plot(data, x, y, - kind='bar', - scale_dist=False, - log=False, - title='', - label_dict={}, - y_lim=None, - save_image=False, output_root=None - ): - ''' - data is a dictionary with different vars as keys - kind can be = box or violin or bar - scale_dist is only for kind=='violin' - ''' - - sns.set_context("paper", - font_scale=1.0, - rc={"lines.linewidth": 1.0} - ) - - sns.set_style('darkgrid') - - df = pd.DataFrame(data) - - fig_width = 2*len(np.unique(data[x])) - fig_height = 5 - - if kind=='violin' and scale_dist: - g = sns.catplot(data=df, x=x, y=y, kind=kind, - scale='width' - # errorbar=("pi", 95) - ) - elif kind=='bar': - g = sns.catplot(data=df, x=x, y=y, kind=kind, - width=0.25 - # errorbar=("pi", 95) - ) - elif kind=='box': - g = sns.catplot(data=df, x=x, y=y, kind=kind, - width=0.25, fliersize=1.0 - # errorbar=("pi", 95) - ) - else: - g = sns.catplot(data=df, x=x, y=y, kind=kind, - # errorbar=("pi", 95) - ) - - if log: - plt.yscale('log') - - g.fig.set_figwidth(fig_width) - g.fig.set_figheight(fig_height) - - ## set labels - ylabel = g.ax.get_ylabel() - if ylabel in label_dict: - ylabel = label_dict[ylabel] - g.ax.set_ylabel(ylabel, fontdict={'fontsize': 13, 'fontweight': 'bold'}) - xlabel = g.ax.get_xlabel() - if xlabel in label_dict: - xlabel = label_dict[xlabel] - g.ax.set_xlabel(xlabel, fontdict={'fontsize': 13, 'fontweight': 'bold'}) - # set font size of the tick labels and make them bold - tick_labels = g.ax.get_xticklabels() + g.ax.get_yticklabels() - for label in tick_labels: - label.set_fontweight('bold') - if not y_lim is None: - g.ax.set_ylim(y_lim) - - if show_title: - plt.title(title, fontsize=15) - if save_image: - folder = output_root[:output_root.rfind('/')] - if not os.path.exists(folder): - os.makedirs(folder) - plt.savefig(output_root+title2file_name(title)+'.'+save_fig_format, - dpi=fig_dpi, bbox_inches=fig_bbox_inches, pad_inches=fig_pad, format=save_fig_format - ) - plt.close() - else: - plt.show() - -def visualize_sim_mat(data, mat_key, title='', - name_lst_key=None, - cmap='viridis', - annot=True, fmt=2, - label_dict={}, - show_diag=False, show_sig=False, no_color=False, - save_image=False, output_root=None, axes=None, fig=None, - ): - - ''' - - name_lst_key is the key to list of names - - data must be a dict of correlation/connectivity matrices - - masks the nan values - sample: - Suptitle1 - corr_mat - 0.00 0.31 0.76 - 0.31 0.00 0.43 - 0.76 0.43 0.00 - measure_lst - ContinuousHMM - Windowless - Clustering_pear_corr - Suptitle1 - corr_mat - 0.00 0.32 0.76 - 0.32 0.00 0.45 - 0.76 0.45 0.00 - measure_lst - ContinuousHMM - Windowless - Clustering_pear_corr - ''' - - sns.set_context("paper", - font_scale=1.0, - rc={"lines.linewidth": 1.0} - ) - - sns.set_style('white') - - if no_color: - cmap = ListedColormap(['white']) - - if name_lst_key is None: - fig_width = int(25*(len(data)/10)) - else: - fig_width = int(60*(len(data)/10) + 1) - fig_height = 5 - - fig_flag = True - if axes is None or fig is None: - fig_flag = False - - if not fig_flag: - fig, axes = plt.subplots(1, len(data), figsize=(fig_width, fig_height), - facecolor='w', edgecolor='k', sharey=False - ) - - if not type(axes) is np.ndarray: - axes = np.array([axes]) - - if show_title: - fig.suptitle(title, fontsize=20, y=0.98) #, fontsize=20, size=20 - - axes = axes.ravel() - - # normalizing and scale - sim_mats = list() - for i, key in enumerate(data): - sim_mats.append(data[key][mat_key]) - sim_mats = np.array(sim_mats) - - # plot - for i, key in enumerate(data): - - C = sim_mats[i,:,:] - - name_lst = None - if not name_lst_key is None: - name_lst = data[key][name_lst_key] - - cbar_flag = False - # if i==(len(data)-1): - # cbar_flag = True - - if annot: - C_forlabels = C.copy() - if not show_diag: - np.fill_diagonal(C_forlabels, np.nan) - df = pd.DataFrame(C_forlabels) - if show_sig: - annot_labels = df.applymap(lambda v: '' if np.isnan(v) else str(round(v, fmt))+''.join(['*' for t in [.05, .01, .001] if v<=t])) - else: - annot_labels = df.applymap(lambda v: '' if np.isnan(v) else str(round(v, fmt))) - else: - annot_labels = False - - # borderlines color - if no_color: - linecolor = 'black' - annot_kws={'weight': 'bold'} - else: - linecolor = 'w' - annot_kws={'weight': 'bold'} - - im = sns.heatmap(C, - annot=annot_labels, annot_kws=annot_kws, - fmt='', cmap=cmap, - xticklabels=name_lst, yticklabels=name_lst, - ax=axes[i], cbar=cbar_flag, - square=True, linewidth=2, linecolor=linecolor - ) - axis_title = key - if key in label_dict: - axis_title = label_dict[key] - axes[i].set_title(axis_title, fontdict= {'fontsize': 18, 'fontweight':'bold'}) - im.set_xticklabels(im.get_xticklabels(), fontdict= {'fontsize': 14, 'fontweight':'bold'}, rotation=90) - im.set_yticklabels(im.get_yticklabels(), fontdict= {'fontsize': 14, 'fontweight':'bold'}, rotation=0) - - if not fig_flag: - - fig.subplots_adjust( - bottom=0.1, - top=0.85, - left=0.1, - right=0.9, - ) - - if not name_lst is None: - fig.subplots_adjust( - wspace=0.5 - ) - - if save_image: - folder = output_root[:output_root.rfind('/')] - if not os.path.exists(folder): - os.makedirs(folder) - plt.savefig(output_root+title2file_name(title)+'.'+save_fig_format, - dpi=fig_dpi, bbox_inches=fig_bbox_inches, pad_inches=fig_pad, format=save_fig_format - ) - plt.close() - else: - plt.show() - -def distance2Z(dist_mat, method='ward'): - # convert the redundant n*n square matrix form into a condensed nC2 array - distArray = ssd.squareform(dist_mat) - Z = shc.linkage(distArray, method=method) - return Z - -def dist_mat_dendo(Z, labels, - distances_CI=None, - title='', \ - save_image=False, output_root=None, \ - ): - ''' - if distances_CI is provided, confidence intervals (CI) - of the distances will be shown. the order should be the same as - Z - ''' - - sns.set_context("paper", - font_scale=3.5, - rc={"lines.linewidth": 3.0, - 'font.weight': 'bold' - } - ) - - sns.set_style('darkgrid') - - width = int(2.5*len(Z)) - fig = plt.figure(figsize=(width, 5)) - ax = fig.add_subplot(1, 1, 1) - with mpl.rc_context({'lines.linewidth': 3}): - - dend = shc.dendrogram(Z, distance_sort='ascending', no_plot=False, labels=labels) - - # show confidence interval of distances - if not distances_CI is None: - - max_y_lim = None - for i, d in zip(dend['icoord'], dend['dcoord']): - - # we have to match the distances in dcoord - # with those in Z, because the orders are not - # the same - count = 0 - for idx, clstr in enumerate(Z): - if np.isclose(Z[idx][2], d[1]): - count += 1 - Z_CI = distances_CI[idx] - - if count > 1 or count==0: - warnings.warn( - 'Error in finding std of linkage.', - UserWarning - ) - - x = 0.5 * sum(i[1:3]) - y = d[1] - ci_line_y = np.linspace(y-Z_CI, y+Z_CI, 100) - ci_line_x = x * np.ones(100) - # cut start and the end for better - # visualization - ci_line_y = ci_line_y[5:-5] - ci_line_x = ci_line_x[5:-5] - - plt.plot(ci_line_x, ci_line_y, 'black') - plt.plot(x, y-Z_CI, 'k_', markersize=15, linewidth=15) - plt.plot(x, y+Z_CI, 'k_', markersize=15, linewidth=15) - plt.plot(x, y, 'wo', markersize=5, mec='k') - # plt.annotate("%.2g" % y, (x, y), xytext=(15, 13), - # fontsize = 11, - # fontweight= 'bold', - # textcoords='offset points', - # va='top', ha='center') - if max_y_lim is None: - max_y_lim = y+Z_CI - else: - max_y_lim = max(y+Z_CI, max_y_lim) - plt.ylim(0, max_y_lim*1.1) - - if show_title: - plt.title(title, fontsize=15) - - # set font size of the tick labels and make them bold - ax.tick_params(axis='x', which='major', labelsize=15) - ax.tick_params(axis='y', which='major', labelsize=15) - tick_labels = ax.get_xticklabels() + ax.get_yticklabels() - for label in tick_labels: - label.set_fontweight('bold') - - # save figure - if save_image: - folder = output_root[:output_root.rfind('/')] - if not os.path.exists(folder): - os.makedirs(folder) - plt.savefig(output_root+title2file_name(title)+'.'+save_fig_format, - dpi=fig_dpi, bbox_inches=fig_bbox_inches, pad_inches=fig_pad, format=save_fig_format - ) - plt.close() - -def plot_TSNE( - dist_mat, - sample_measure_lst, - color_dict, - projection='2d', - title='', - save_image=False, output_root=None, - ): - - sns.set_context("paper", - font_scale=2.5, - rc={ - "lines.linewidth": 3.0, - "lines.markersize": 10.0 - } - ) - - sns.set_style('darkgrid') - - fig_width = 20 - fig_height = 20 - - if projection=='2d': - X_embedded = TSNE( - n_components=2, - learning_rate='auto', - init='random', perplexity=30, - metric='precomputed' - ).fit_transform(dist_mat) - - # 2D plot - plt.figure(figsize=(fig_width, fig_height)) - sns.scatterplot( - x=X_embedded[:, 0], y=X_embedded[:, 1], - hue=sample_measure_lst, - palette=color_dict, - alpha=0.7 - ) - elif projection=='3d': - X_embedded = TSNE( - n_components=3, - learning_rate='auto', - init='random', perplexity=30, - metric='precomputed' - ).fit_transform(dist_mat) - - measures_lst = list(set(sample_measure_lst)) - measures_lst.sort() - - # 3D plot - fig = plt.figure(figsize=(fig_width, fig_height)) - ax = fig.add_subplot(projection='3d') - sample_measure_array = np.array(sample_measure_lst) - for measure in measures_lst: - scatter = ax.scatter( - X_embedded[sample_measure_array==measure, 0], - X_embedded[sample_measure_array==measure, 1], - X_embedded[sample_measure_array==measure, 2], - c=color_dict[measure], - label=measure - ) - ax.legend() - - if show_title: - plt.title(title, fontsize=15) - - # save figure - if save_image: - folder = output_root[:output_root.rfind('/')] - if not os.path.exists(folder): - os.makedirs(folder) - plt.savefig(output_root+title2file_name(title)+'.'+save_fig_format, - dpi=fig_dpi, bbox_inches=fig_bbox_inches, pad_inches=fig_pad, format=save_fig_format - ) - plt.close() - -def plot_brain_act(act_vec, locs, axes, - title='', save_image=False, output_root='' - ): - - plot_markers( - node_values=act_vec, node_coords=locs, - node_cmap='hot', - display_mode='z', - colorbar=False, axes=axes - ) - - if save_image: - folder = output_root[:output_root.rfind('/')] - if not os.path.exists(folder): - os.makedirs(folder) - plt.savefig(output_root+title2file_name(title)+'.'+save_fig_format, - dpi=fig_dpi, bbox_inches=fig_bbox_inches, pad_inches=fig_pad, format=save_fig_format - ) - plt.close() - -def visualize_state_TC(TC_lst, \ - TRs, \ - state_lst, \ - TC_name_lst, \ - title='', \ - save_image=None, output_root=None\ - ): - - color_lst = ['k', 'b', 'g', 'r'] - - if 'on' in state_lst and 'off' in state_lst: - ticks = range(2) - else: - ticks = range(1, len(state_lst)+1) - - plt.figure(figsize=(25, 5)) - for i, TC in enumerate(TC_lst): - plt.plot(TRs, TC, color_lst[i], linewidth=2) - plt.xlabel('TR') - plt.yticks(ticks=ticks, labels=state_lst) - plt.legend(TC_name_lst) - if show_title: - plt.title(title) - if save_image: - folder = output_root[:output_root.rfind('/')] - if not os.path.exists(folder): - os.makedirs(folder) - plt.savefig(output_root+title2file_name(title)+'.'+save_fig_format, - dpi=fig_dpi, bbox_inches=fig_bbox_inches, pad_inches=fig_pad, format=save_fig_format - ) - plt.close() - # else: - # plt.show() - - return \ No newline at end of file diff --git a/build/lib/pydfc/comparison/similarity_assessment.py b/build/lib/pydfc/comparison/similarity_assessment.py deleted file mode 100644 index fcbec47..0000000 --- a/build/lib/pydfc/comparison/similarity_assessment.py +++ /dev/null @@ -1,418 +0,0 @@ -# -*- coding: utf-8 -*- -""" -SimilarityAssessment class -functions to assess similarity between dFC results - -Created on Jun 29 2023 -@author: Mohammad Torabi -""" - -import numpy as np -from scipy import stats -from joblib import Parallel, delayed -from copy import deepcopy - -from ..dfc_utils import ( - calc_graph_propoerty, TR_intersection, - dFC_mat2vec, normalized_euc_dist, - mutual_information, find_new_order, - filter_dFC_lst -) - -################################# SimilarityAssessment class #################################### - -class SimilarityAssessment: - - def __init__(self, dFC_lst): - self.dFC_lst = dFC_lst - - ##################### dFC FEATURES ###################### - - def FO_calc(self, dFC_lst, common_TRs=None): - - # returns, for each state the Fractional Occupancy (FO) - # see Vidaurre et al., 2017 - # it only considers TRs in common_TRs - - if common_TRs is None: - common_TRs = TR_intersection(dFC_lst) - - FO_list = list() - for dFC in dFC_lst: - - FO = {} - - if dFC.measure.is_state_based: - - state_act_dict = dFC.state_act_dict(TRs=common_TRs) - - for FCS_key in state_act_dict['state_TC']: - FO[FCS_key] = np.mean(state_act_dict['state_TC'][FCS_key]['act_TC']) - - FO_list.append(FO) - - return FO_list - - def transition_stats(self, dFC_lst, common_TRs=None): - # returns the number of total state transition within common_TRs -> trans_freq - # and the number of total state transitions regardless of common_TRs - # but normalized by total number of TRs -> trans_norm - # and a list of all dwell times - - if common_TRs is None: - common_TRs = TR_intersection(dFC_lst) - - TRs_lst = list() - for TR in common_TRs: - TRs_lst.append('TR'+str(TR)) - - output_lst = list() - for dFC in dFC_lst: - - output_dict = {} - - if dFC.measure.is_state_based: - - # downsampled - trans_freq = 0 - dwell_time_lst = list() - dwell_time = 0 - last_TR = None - for TR in dFC.FCS_idx: - if TR in TRs_lst: - if not last_TR is None: - if dFC.FCS_idx[TR]!=dFC.FCS_idx[last_TR]: - dwell_time_lst.append(dwell_time) - dwell_time = 0 - trans_freq += 1 - dwell_time += 1 - last_TR = TR - - output_dict['dwell_time'] = dwell_time_lst - output_dict['trans_freq'] = trans_freq - - # normalized (not downsampled) - trans_norm = 0 - dwell_time_lst = list() - dwell_time = 0 - last_TR = None - for TR in dFC.FCS_idx: - if not last_TR is None: - if dFC.FCS_idx[TR]!=dFC.FCS_idx[last_TR]: - dwell_time_lst.append(dwell_time / len(dFC.FCS_idx)) - dwell_time = 0 - trans_norm += 1 - dwell_time += 1 - last_TR = TR - trans_norm = trans_norm / len(dFC.FCS_idx) - - output_dict['dwell_time_norm'] = dwell_time_lst - output_dict['trans_norm'] = trans_norm - - output_lst.append(output_dict) - - return output_lst - - def feature_all(self, dFC_mat): - vectorized_dFC = dFC_mat2vec(dFC_mat).flatten() # (time*connection, ) - vectorized_dFC = np.expand_dims(vectorized_dFC, axis=0) # (1, time*connection) - return vectorized_dFC - - def feature_spatial(self, dFC_mat): - conn_over_time = dFC_mat2vec(dFC_mat) # (time, connection) - return conn_over_time - - def feature_temporal(self, dFC_mat): - conn_over_time = dFC_mat2vec(dFC_mat) # (time, connection) - time_over_conn = conn_over_time.T # (connection, time) - return time_over_conn - - def feature_inter_time_corr(self, dFC_mat): - ''' - returns correspondence of inter-time relation between results of dFC - measures in each subject - ''' - conn_over_time = dFC_mat2vec(dFC_mat) # (time, connection) - inter_time_corr = np.corrcoef(conn_over_time) # (time, time) - inter_time_corr = np.nan_to_num(inter_time_corr) - inter_time_corr = dFC_mat2vec(inter_time_corr) # (time*(time-1)/2, ) - inter_time_corr = np.expand_dims(inter_time_corr, axis=0) # (1, time*(time-1)/2) - return inter_time_corr - - def feature_inter_conn_corr(self, dFC_mat): - conn_over_time = dFC_mat2vec(dFC_mat) # (time, connection) - time_over_conn = conn_over_time.T # (connection, time) - inter_conn_corr = np.corrcoef(time_over_conn) # (connection, connection) - inter_conn_corr = np.nan_to_num(inter_conn_corr) - inter_conn_corr = dFC_mat2vec(inter_conn_corr) # (connection*(connection-1)/2, ) - inter_conn_corr = np.expand_dims(inter_conn_corr, axis=0) # (1, connection*(connection-1)/2) - return inter_conn_corr - - def feature_dFC_avg(self, dFC_mat): - dFC_avg = np.mean(dFC_mat, axis=0) # (ROI, ROI) - vectorized_dFC_avg = dFC_mat2vec(dFC_avg) # (connection, ) - vectorized_dFC_avg = np.expand_dims(vectorized_dFC_avg, axis=0) # (1, connection) - return vectorized_dFC_avg - - def feature_dFC_var(self, dFC_mat): - dFC_var = np.var(dFC_mat, axis=0) # (ROI, ROI) - vectorized_dFC_var = dFC_mat2vec(dFC_var) # (connection, ) - vectorized_dFC_var = np.expand_dims(vectorized_dFC_var, axis=0) # (1, connection) - return vectorized_dFC_var - - def feature_graph_spatial(self, dFC_mat, graph_property): - graph_feature_over_time = list() - for FC_mat in dFC_mat: - graph_feature = calc_graph_propoerty(FC_mat, property=graph_property, threshold=False, binarize=False) - graph_feature_over_time.append(graph_feature) - graph_feature_over_time = np.array(graph_feature_over_time) # (time, ROI) - return graph_feature_over_time - - def feature_graph_temporal(self, dFC_mat, graph_property): - graph_feature_over_time = list() - for FC_mat in dFC_mat: - graph_feature = calc_graph_propoerty(FC_mat, property=graph_property, threshold=False, binarize=False) - graph_feature_over_time.append(graph_feature) - graph_feature_over_time = np.array(graph_feature_over_time) # (time, ROI) - graph_feature_over_node = graph_feature_over_time.T # (ROI, time) - graph_feature_avg = np.mean(graph_feature_over_node, axis=0) # (time, ) - graph_feature_avg = np.expand_dims(graph_feature_avg, axis=0) # (1, time) - return graph_feature_avg - - def extract_feature(self, dFC_mat, feature2extract, graph_property=None): - ''' - feature2extract_list = [ - 'all', - 'spatial', 'temporal', - 'inter_time_corr', 'inter_conn_corr', - 'dFC_avg', 'dFC_var', - 'graph_spatial', 'graph_temporal' - ] - ''' - feature = None - if feature2extract=='all': - feature = self.feature_all(dFC_mat) - if feature2extract=='spatial': - feature = self.feature_spatial(dFC_mat) - if feature2extract=='temporal': - feature = self.feature_temporal(dFC_mat) - if feature2extract=='inter_time_corr': - feature = self.feature_inter_time_corr(dFC_mat) - if feature2extract=='inter_conn_corr': - feature = self.feature_inter_conn_corr(dFC_mat) - if feature2extract=='dFC_avg': - feature = self.feature_dFC_avg(dFC_mat) - if feature2extract=='dFC_var': - feature = self.feature_dFC_var(dFC_mat) - if feature2extract=='graph_spatial': - feature = self.feature_graph_spatial(dFC_mat, graph_property=graph_property) - if feature2extract=='graph_temporal': - feature = self.feature_graph_temporal(dFC_mat, graph_property=graph_property) - - return feature - - def dFC_mat_lst_similarity(self, dFC_mat_lst, feature2extract, metric, graph_property=None): - - sim_mat_over_sample = None - for i, dFC_mat_i in enumerate(dFC_mat_lst): - for j, dFC_mat_j in enumerate(dFC_mat_lst): - - if j<=i: - continue - - assert dFC_mat_i.shape==dFC_mat_j.shape,\ - 'shape mismatch' - - feature_i = self.extract_feature( - dFC_mat_i, - feature2extract=feature2extract, - graph_property=graph_property - ) # (samples, variables) - feature_j = self.extract_feature( - dFC_mat_j, - feature2extract=feature2extract, - graph_property=graph_property - ) # (samples, variables) - - sim_over_sample = list() - for sample in range(feature_i.shape[0]): - if np.var(feature_i[sample, :])==0 or np.var(feature_j[sample, :])==0: - sim = 0 - else: - if metric=='corr': - sim = np.corrcoef(feature_i[sample, :], feature_j[sample, :])[0,1] - elif metric=='spearman': - sim, p = stats.spearmanr(feature_i[sample, :], feature_j[sample, :]) - elif metric=='MI': - sim = mutual_information(X=feature_i[sample, :], Y=feature_j[sample, :], N_bins=100) - elif metric=='euclidean_distance': - # normalized euclidean is used - sim = normalized_euc_dist(x=feature_i[sample, :], y=feature_j[sample, :]) - sim_over_sample.append(sim) - - if sim_mat_over_sample is None: - sim_mat_over_sample = np.zeros((len(sim_over_sample), len(dFC_mat_lst), len(dFC_mat_lst))) - sim_mat_over_sample[:, i, j] = np.array(sim_over_sample) - sim_mat_over_sample[:, j, i] = sim_mat_over_sample[:, i, j] - - return sim_mat_over_sample - - def assess_similarity(self, dFC_lst, downsampling_method='default', **param_dict): - ''' - downsampling_method: 'default' picks FCs at common_TRs - while 'SWed' uses a sliding window to downsample - ''' - methods_assess = {} - - # sort dFC_lst according to methods names - old_list = [dFC.measure.measure_name for dFC in dFC_lst] - new_list = deepcopy(old_list) - new_list.sort() - - new_order = find_new_order(old_list, new_list) - dFC_lst = [dFC_lst[i] for i in new_order] - - common_TRs = TR_intersection(dFC_lst) - - measure_lst = list() - TS_info_lst = list() - dFC_mat_lst = list() - for dFC in dFC_lst: - measure_lst.append(dFC.measure) - TS_info_lst.append(dFC.TS_info) - if downsampling_method=='SWed': - dFC_mat_lst.append( \ - dFC.SWed_dFC_mat( \ - W=param_dict['W'], \ - n_overlap=param_dict['n_overlap'], \ - tapered_window=param_dict['tapered_window'] \ - ) - ) - else: - dFC_mat_lst.append(dFC.get_dFC_mat(TRs=common_TRs)) - - methods_assess['measure_lst'] = measure_lst - methods_assess['TS_info_lst'] = TS_info_lst - methods_assess['common_TRs'] = common_TRs - - ########## dFC samples ########## - - dFC_samples = {} - for i, dFC_mat in enumerate(dFC_mat_lst): - dFC_samples[str(i)] = dFC_mat - methods_assess['dFC_samples'] = dFC_samples - - ########## time record ########## - - time_record_dict = {} - for i, dFC in enumerate(dFC_lst): - time_record = {} - time_record['FCS_fit'] = dFC.measure.FCS_fit_time - time_record['dFC_assess'] = dFC.measure.dFC_assess_time - time_record_dict[str(i)] = time_record - methods_assess['time_record_dict'] = time_record_dict - - ########## subj_dFC_sim ########## - # returns correlation/MI/spearman corr/euclidean distance between results of dFC - # measures in a subject - feature2extract_list = [ - # 'all', - 'spatial', 'temporal', - 'inter_time_corr', 'inter_conn_corr', - 'dFC_avg', 'dFC_var', - # 'graph_spatial', 'graph_temporal' - ] - metric_list = [ - 'corr', - 'spearman', - 'MI', - 'euclidean_distance' - ] - graph_property_list = [ - 'ECM', - 'shortest_path', - 'degree', - 'clustering_coef' - ] - methods_assess['all'] = {} - for metric in metric_list: - methods_assess['all'][metric] = self.dFC_mat_lst_similarity( - dFC_mat_lst, - feature2extract='all', - metric=metric - ) - methods_assess['feature_based'] = {} - for feature2extract in feature2extract_list: - methods_assess['feature_based'][feature2extract] = self.dFC_mat_lst_similarity( - dFC_mat_lst, - feature2extract=feature2extract, - metric='spearman' - ) - methods_assess['graph_based'] = {} - methods_assess['graph_based']['graph_spatial'] = {} - methods_assess['graph_based']['graph_temporal'] = {} - for graph_property in graph_property_list: - methods_assess['graph_based']['graph_spatial'][graph_property] = self.dFC_mat_lst_similarity( - dFC_mat_lst, - feature2extract='graph_spatial', - metric='spearman', - graph_property=graph_property - ) - methods_assess['graph_based']['graph_temporal'][graph_property] = self.dFC_mat_lst_similarity( - dFC_mat_lst, - feature2extract='graph_temporal', - metric='spearman', - graph_property=graph_property - ) - # ########## dFC temporal average and variance ########## - - methods_assess['dFC_avg'] = [self.feature_dFC_avg(dFC_mat) for dFC_mat in dFC_mat_lst] - - methods_assess['dFC_var'] = [self.feature_dFC_var(dFC_mat) for dFC_mat in dFC_mat_lst] - - ########## Fractional Occupancy ########## - - FO_lst = self.FO_calc(dFC_lst, \ - common_TRs=common_TRs \ - ) - methods_assess['FO'] = FO_lst - - ########## transition frequency ########## - - transition_stats_lst = self.transition_stats(dFC_lst, \ - common_TRs=common_TRs \ - ) - methods_assess['transition_stats'] = transition_stats_lst - - ############################################## - return methods_assess - - def run(self, FILTERS, downsampling_method='default'): - ''' - downsampling_method: 'default' picks FCs at common_TRs - while 'SWed' uses a sliding window to downsample - ''' - parallelize = True - output = {} - if parallelize: - out_lst = Parallel( \ - n_jobs=4, verbose=0, backend='loky')( \ - delayed(self.assess_similarity)(dFC_lst=filter_dFC_lst(self.dFC_lst, **FILTERS[filter]), \ - downsampling_method=downsampling_method, \ - **FILTERS[filter]) \ - for filter in FILTERS) - for i, filter in enumerate(FILTERS): - output[filter] = out_lst[i] - else: - for filter in FILTERS: - param_dict = FILTERS[filter] - dFC_lst2check = filter_dFC_lst(self.dFC_lst, **param_dict) - output[filter] = self.assess_similarity( \ - dFC_lst=dFC_lst2check, \ - downsampling_method=downsampling_method, \ - **param_dict \ - ) - - return output - -################################################################################################# \ No newline at end of file diff --git a/build/lib/pydfc/data_loader.py b/build/lib/pydfc/data_loader.py deleted file mode 100644 index b373f20..0000000 --- a/build/lib/pydfc/data_loader.py +++ /dev/null @@ -1,310 +0,0 @@ - -""" -Implementation of dFC methods. - -Created on Jun 29 2023 -@author: Mohammad Torabi -""" - -from re import S -from tkinter import N -import numpy as np -import hdf5storage -import scipy.io as sio -import os - -from .dfc_utils import intersection, label2network -from .time_series import TIME_SERIES - -################################# DATA_LOADER functions ###################################### - -def find_subj_list(data_root, sessions): - ''' - find the list of subjects in data_root - the files must follow the format: subjectID_sessionID - only these files should be in the data_root - ''' - ALL_FILES = os.listdir(data_root) - FOLDERS = [item for item in ALL_FILES if os.path.isdir(data_root+item)] - - FOLDERS.sort() - SUBJECTS = list() - for s in FOLDERS: - num = s[:s.find('_')] - SUBJECTS.append(num) - # the subjects might be repeated because of different sessions - SUBJECTS = list(set(SUBJECTS)) - SUBJECTS.sort() - - print( str(len(SUBJECTS)) + ' subjects were found. ') - - failed_subjs = [] - kept_subjs = [] - for subj in SUBJECTS: - kept_subjs.append(subj) - for session in sessions: - if not os.path.exists(data_root+subj+'_'+session): - failed_subjs.append(subj) - kept_subjs.remove(subj) - break - - print( str(len(failed_subjs)) + ' subjects had missing sessions. ' + str(len(kept_subjs)) + ' subjects were kept. ') - - return kept_subjs - -def load_from_array(subj_id2load=None, **params): - ''' - load fMRI data from numpy or mat files - input time_series.shape must be (time, roi) - returns a dictionary of TIME_SERIES objects - each corresponding to a session - - - if the file_name is a .mat file, it will be loaded using hdf5storage - if the file_name is a .npy file, it will be loaded using np.load - - - the roi locations should be in the same folder and a .npy file - with the name: params['roi_locs_file'] - it must - - - and the roi labels should be in the same folder and a .npy file - with the name: params['roi_labels_file'] - it must be a list of strings - - - labels should be in the format: Hemisphere_Network_ID - ow, the network2include will not work properly - ''' - - SESSIONs = params['SESSIONs'] # list of sessions - if subj_id2load is None: - SUBJECTS = find_subj_list(params['data_root'], sessions=SESSIONs) - else: - SUBJECTS = [subj_id2load] - - # LOAD Region Location DATA - locs = np.load(params['data_root']+params['roi_locs_file'], allow_pickle='True').item() - locs = locs['locs'] - - # LOAD Region Labels DATA - labels = np.load(params['data_root']+params['roi_labels_file'], allow_pickle='True').item() - labels = labels['labels'] - - assert type(locs) is np.ndarray, 'locs must be a numpy array' - assert type(labels) is list, 'labels must be a list' - assert locs.shape[0] == len(labels), 'locs and labels must have the same length' - assert locs.shape[1] == 3, 'locs must have 3 columns' - - # apply networks2include - # if params['networks2include'] is None, all the regions will be included - if not params['networks2include'] is None: - nodes2include = [i for i, x in enumerate(labels) if label2network(x) in params['networks2include']] - else: - nodes2include = [i for i, x in enumerate(labels)] - locs = locs[nodes2include, :] - labels = [x for node, x in enumerate(labels) if node in nodes2include] - - - BOLD = {} - for session in SESSIONs: - BOLD[session] = None - for subject in SUBJECTS: - - subj_fldr = subject + '_' + session - - # LOAD BOLD Data - - if params['file_name'][params['file_name'].find('.'):] == '.mat': - DATA = hdf5storage.loadmat(params['data_root']+subj_fldr+'/'+params['file_name']) - elif params['file_name'][params['file_name'].find('.'):] == '.npy': - DATA = np.load(params['data_root']+subj_fldr+'/'+params['file_name'], allow_pickle='True').item() - time_series = DATA['ROI_data'] # time_series.shape = (time, roi) - - # change time_series.shape to (roi, time) - time_series = time_series.T - - # apply networks2include - time_series = time_series[nodes2include, :] - - if BOLD[session] is None: - BOLD[session] = TIME_SERIES(data=time_series, subj_id=subject, - Fs=params['Fs'], - locs=locs, node_labels=labels, - TS_name='BOLD Real', session_name=session - ) - else: - BOLD[session].append_ts(new_time_series=time_series, subj_id=subject) - - print( '*** Session ' + session + ': ' ) - print( 'number of regions= '+str(BOLD[session].n_regions) + ', number of time points= ' + str(BOLD[session].n_time) ) - - return BOLD - - -def nifti2array(nifti_file, - confound_strategy='none', standardize=False, - n_rois=100 - ): - ''' - this function uses nilearn maskers to extract - BOLD signals from nifti files - For now it only works with schaefer atlas, - but you can set the number of rois to extract - {100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} - - returns a numpy array of shape (time, roi) - and labels and locs of rois - - confound_strategy: - 'none': no confounds are used - 'no_motion': motion parameters are used - 'no_motion_no_gsr': motion parameters are used - and global signal regression - is applied. - ''' - from nilearn.maskers import NiftiLabelsMasker - from nilearn import datasets - from nilearn.plotting import find_parcellation_cut_coords - from nilearn.interfaces.fmriprep import load_confounds - - parc = datasets.fetch_atlas_schaefer_2018(n_rois=n_rois) - atlas_filename = parc.maps - labels = parc.labels - # The list of labels does not contain ‘Background’ by default. - # To have proper indexing, you should either manually add ‘Background’ to the list of labels: - # Prepend background label - labels = np.insert(labels, 0, 'Background') - - # extract locs - # test! - # check if order is the same as labels - locs, labels_ = find_parcellation_cut_coords( - atlas_filename, - background_label=0, - return_label_names=True - ) - - # create the masker for extracting time series - masker = NiftiLabelsMasker( - labels_img=atlas_filename, - labels=labels, - resampling_target='data', - standardize=standardize, - ) - - labels = np.delete(labels, 0) # remove the background label - labels = [label.decode() for label in labels] - - ### extract the timeseries - if confound_strategy=='none': - time_series = masker.fit_transform(nifti_file) - elif confound_strategy=='no_motion': - confounds_simple, sample_mask = load_confounds( - nifti_file, - strategy=["high_pass", "motion", "wm_csf"], - motion="basic", wm_csf="basic" - ) - time_series = masker.fit_transform( - nifti_file, - confounds=confounds_simple, - sample_mask=sample_mask - ) - elif confound_strategy=='no_motion_no_gsr': - confounds_simple, sample_mask = load_confounds( - nifti_file, - strategy=["high_pass", "motion", "wm_csf", "global_signal"], - motion="basic", wm_csf="basic", global_signal="basic" - ) - time_series = masker.fit_transform( - nifti_file, - confounds=confounds_simple, - sample_mask=sample_mask - ) - - return time_series, labels, locs - - -def nifti2timeseries( - nifti_file, - n_rois, Fs, - subj_id, - confound_strategy='none', - standardize=False, - TS_name=None, - session=None, - ): - ''' - this function is only for single subject and single session data loading - it uses nilearn maskers to extract ROI signals from nifti files - and returns a TIME_SERIES object - - For now it only works with schaefer atlas, - but you can set the number of rois to extract - {100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} - ''' - time_series, labels, locs = nifti2array( - nifti_file=nifti_file, - confound_strategy=confound_strategy, - standardize=standardize, - n_rois=n_rois - ) - - assert type(locs) is np.ndarray, 'locs must be a numpy array' - assert type(labels) is list, 'labels must be a list' - assert locs.shape[0] == len(labels), 'locs and labels must have the same length' - assert locs.shape[1] == 3, 'locs must have 3 columns' - - # change time_series.shape to (roi, time) - time_series = time_series.T - - if TS_name is None: - TS_name = subj_id + ' time series' - - BOLD = TIME_SERIES( - data=time_series, subj_id=subj_id, - Fs=Fs, - locs=locs, node_labels=labels, - TS_name=TS_name, session_name=session - ) - - return BOLD - - -def multi_nifti2timeseries( - nifti_files_list, - subj_id_list, - n_rois, Fs, - confound_strategy='none', - standardize=False, - TS_name=None, - session=None, -): - ''' - loading data of multiple subjects, but single session, from their nifti files - ''' - BOLD_multi = None - for subj_id, nifti_file in zip(subj_id_list, nifti_files_list): - if BOLD_multi is None: - BOLD_multi = nifti2timeseries( - nifti_file=nifti_file, - n_rois=n_rois, Fs=Fs, - subj_id=subj_id, - confound_strategy=confound_strategy, - standardize=standardize, - TS_name=TS_name, - session=session, - ) - else: - BOLD_multi.concat_ts( - nifti2timeseries( - nifti_file=nifti_file, - n_rois=n_rois, Fs=Fs, - subj_id=subj_id, - confound_strategy=confound_strategy, - standardize=standardize, - TS_name=TS_name, - session=session, - ) - ) - return BOLD_multi - - -#################################################################################################################################### \ No newline at end of file diff --git a/build/lib/pydfc/dfc.py b/build/lib/pydfc/dfc.py deleted file mode 100644 index 08cfb50..0000000 --- a/build/lib/pydfc/dfc.py +++ /dev/null @@ -1,303 +0,0 @@ - -""" -dFC class - -Created on Jun 29 2023 -@author: Mohammad Torabi -""" - -import numpy as np - -from .dfc_utils import node_info2network, node_labels2networks, rank_norm_dFC_dict, visualize_conn_mat_dict, SW_downsample - -################################# DFC class ###################################### - -""" -Parameters - ---------- - TR_array : an array labeling - timepoints by their TRs - starts from 0 - -Variables - ---------- - FCSs : Functional Connecitivity - States patterns - FCS_idx : the index of the - FCS that corresponds to each - timepoint - - -todo: -- -""" - -class DFC(): - def __init__(self, measure=None): - - # assert not measure is None, \ - # "measure arg must be provided." - self.measure_ = measure - self.FCSs_ = None # is a dict - self.FCS_idx_ = None # is a dict - # info of the time series used for dFC estimation - self.TS_info_ = None - self.TR_array_ = None - self.n_regions_ = None - self.n_time_ = -1 - - @classmethod - def from_numpy(cls, array=None): - pass - - @property - def measure(self): - return self.measure_ - - @property - def TR_array(self): - return self.TR_array_.astype(int) - - @property - def TR_keys(self): - TRs_lst = list() - for TR in self.TR_array: - TRs_lst.append('TR'+str(TR)) - return TRs_lst - - @property - def n_regions(self): - return self.n_regions_ - - @property - def n_time(self): - return self.n_time_ - - # test this - @property - def FCSs(self): - return self.FCSs_ - - # test this - @property - def FCS_idx(self): - return self.FCS_idx_ - - # test this - @property - def FCS_idx_array(self): - return np.array([int(self.FCS_idx[TR][self.FCS_idx[TR].find('S')+1:])-1 for TR in self.FCS_idx]) - - @property - def TS_info(self): - # info of the time series used for dFC estimation - return self.TS_info_ - - - # test - def state_TC(self, TRs=None, \ - state_match=False, state_match_dict=None \ - ): - # returns a np array of state indices over TRs in TRs - - if TRs is None: - TRs = self.TR_array - - if not type(TRs[0]) is str: - TRs_lst = list() - for TR in TRs: - TRs_lst.append('TR'+str(TR)) - else: - TRs_lst = TRs - - state_TC = list() - for key in self.FCS_idx: - if key in TRs_lst: - state = self.FCS_idx[key] - if state_match: - match = state_match_dict['FCS_match'][state]['match'] - state_TC.append(int(match[match.find('FCS')+3:])) - else: - state_TC.append(int(state[state.find('FCS')+3:])) - - state_TC = np.array(state_TC) - return state_TC - - # test - def state_act_dict(self, TRs=None): - # returns a dict including each FCS and its activation times - # the TRs arg can be used to set a common set of TRs - - if TRs is None: - TRs = self.TR_array - - TRs_lst = list() - for TR in TRs: - TRs_lst.append('TR'+str(TR)) - - state_act_dict = {} - state_act_dict['state_TC'] = {} - state_act_dict['TR_array'] = TRs - for FCS_key in self.FCSs: - state_act_dict['state_TC'][FCS_key] = {} - state_act_dict['state_TC'][FCS_key]['FCS'] = self.FCSs[FCS_key] - state_act_dict['state_TC'][FCS_key]['act_TC'] = np.zeros((len(TRs),)) - t=0 - for TR in self.FCS_idx: - if TR in TRs_lst: - state_act_dict['state_TC'][self.FCS_idx[TR]]['act_TC'][t] = 1 - t=t+1 - assert t==len(TRs), 'error!' - - return state_act_dict - - # test - def dFC2dict(self, TRs=None): - # return dFC samples as a dictionary - if TRs is None: - TRs = self.TR_array - if type(TRs) is list: - TRs = np.array(TRs) - TRs = TRs.astype(int) - dFC_mat = self.get_dFC_mat(TRs=TRs) - dFC_dict = {} - for k, TR in enumerate(TRs): - dFC_dict['TR'+str(TR)] = dFC_mat[k, :, :] - return dFC_dict - - # test this - def get_dFC_mat(self, TRs=None, num_samples=None): - ''' - get dFC matrices corresponding to - the specified TRs - TRs should be list/ndarray not necessarily in order ? - if num_samples specified, it will downsample - TRs to reach that number of samples and will also - return picked TRs - if num_samples > len(TRs) -> picks all TRs - ''' - - if TRs is None: - TRs = self.TR_array - - if type(TRs) is np.int32 or type(TRs) is np.int64 or type(TRs) is int: - TRs = [TRs] - - if not num_samples is None: - if num_samples < len(TRs): - TRs = TRs[np.linspace(0, len(TRs), num_samples, endpoint=False, dtype=int)] - - dFC_mat = list() - for TR in TRs: - dFC_mat.append(self.FCSs[self.FCS_idx['TR'+str(TR)]]) - - dFC_mat = np.array(dFC_mat) - - if num_samples is None: - return dFC_mat - else: - return dFC_mat, TRs - - def SWed_dFC_mat(self, W=None, n_overlap=None, tapered_window=False): - ''' - the time samples will be picked after - averaging over a window which slides - W is in sec - ''' - dFC_mat = self.get_dFC_mat() - - # method not applicable to SW-based methods - if 'sw_method' in self.measure.info: - return dFC_mat - - dFC_mat_new = SW_downsample(data=dFC_mat, \ - Fs=self.TS_info['Fs'], W=W, n_overlap=n_overlap, tapered_window=tapered_window \ - ) - - return dFC_mat_new - - - def set_dFC(self, FCSs, FCS_idx=None, TS_info=None, TR_array=None): - - if len(FCSs.shape)==2: - FCSs = np.expand_dims(FCSs, axis=0) - - if FCS_idx is None: - # usually for state-free methods like sliding window when we don't have FCSs - # we consider each FC a FCS - FCS_idx = np.arange(start=0, stop=FCSs.shape[0], step=1, dtype=int) - - if type(FCS_idx) is list: - FCS_idx = np.array(FCS_idx) - - if len(FCS_idx.shape)>1: - FCS_idx = np.squeeze(FCS_idx) - - assert FCSs.shape[1] == FCSs.shape[2], \ - "FC matrices must be square." - - assert self.n_time==-1, \ - 'why n_time is not -1 ? Are you adding a dFC to an existing dFC ?' - - if TR_array is None: - # self.n_time is -1 at first. if it is not -1, it means that a dFC is already set and - # we are adding a new dFC to it. - TR_array = np.arange(start=self.n_time+1, stop=self.n_time+len(FCS_idx)+1, step=1, dtype=int) - - assert np.sum(np.abs(np.sort(TR_array)-TR_array))==0.0, \ - 'TRs not sorted !' - - # the input FCS_idx is ranged from 0 to len(FCS)-1 but we shift it to 1 to len(FCS) - self.FCSs_ = {} - for i, FCS in enumerate(FCSs): - self.FCSs_['FCS'+str(i+1)] = FCS - - self.FCS_idx_ = {} - for i, idx in enumerate(FCS_idx): - self.FCS_idx_['TR'+str(TR_array[i])] = 'FCS'+str(idx+1) - - self.TS_info_ = TS_info - self.n_regions_ = FCSs.shape[1] - self.n_time_ = len(self.FCS_idx_) - self.TR_array_ = TR_array - - - def visualize_dFC(self, TRs=None, normalize=False, - show_networks=False, - rank_norm=False, - threshold=0.0, - fix_lim=False, - save_image=False, fig_name=None, - ): - - assert not self.measure is None, \ - 'Measure is not provided.' - - if TRs is None: - TRs = self.TR_array - - if show_networks: - if 'nodes_info' in self.TS_info: - node_networks = node_info2network(self.TS_info['nodes_info']) - elif 'node_labels' in self.TS_info: - node_networks = node_labels2networks(self.TS_info['node_labels']) - else: - node_networks = None - - if rank_norm: - dFC_dict = rank_norm_dFC_dict(self.dFC2dict(TRs=TRs)) - cmap = 'plasma' - center_0 = False - else: - dFC_dict = self.dFC2dict(TRs=TRs) - cmap = 'seismic' - center_0 = True - - visualize_conn_mat_dict(data=dFC_dict, - title=self.measure.measure_name+' dFC', - fix_lim=fix_lim, normalize=normalize, - node_networks=node_networks, - cmap=cmap, center_0=center_0, - save_image=save_image, - output_root=fig_name, - ) diff --git a/build/lib/pydfc/dfc_methods/__init__.py b/build/lib/pydfc/dfc_methods/__init__.py deleted file mode 100644 index 22adf1b..0000000 --- a/build/lib/pydfc/dfc_methods/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -"""The :mod:`pydfc.dfc_methods` contains dFC methods objects.""" - -from .base_dfc_method import BaseDFCMethod -from .cap import CAP -from .sliding_window_clustr import SLIDING_WINDOW_CLUSTR -from .discrete_hmm import HMM_DISC -from .continuous_hmm import HMM_CONT -from .sliding_window import SLIDING_WINDOW -from .time_freq import TIME_FREQ -from .windowless import WINDOWLESS - -__all__ = ['BaseDFCMethod', - 'CAP', 'SLIDING_WINDOW_CLUSTR', - 'HMM_CONT', 'HMM_DISC', - 'SLIDING_WINDOW', 'TIME_FREQ', - 'WINDOWLESS' - ] diff --git a/build/lib/pydfc/dfc_methods/base_dfc_method.py b/build/lib/pydfc/dfc_methods/base_dfc_method.py deleted file mode 100644 index 12e10a2..0000000 --- a/build/lib/pydfc/dfc_methods/base_dfc_method.py +++ /dev/null @@ -1,312 +0,0 @@ -""" -Implementation of dFC methods. -the parent dFC method class - -Created on Jun 29 2023 -@author: Mohammad Torabi -""" - -import numpy as np -from copy import deepcopy - -from ..dfc_utils import SW_downsample, visualize_FCS - -################################# BaseDFCMethod class #################################### - -""" -todo: -- type annotation -""" - -class BaseDFCMethod: - - TF_methods_name_lst = [ \ - 'CWT_mag', \ - 'CWT_phase_r', \ - 'CWT_phase_a', \ - 'WTC' \ - ] - - sw_methods_name_lst = [ \ - 'pear_corr', \ - 'MI', \ - 'GraphLasso', \ - ] - - base_methods_name_lst = ['SlidingWindow', 'Time-Freq'] - - def __init__(self): - self.measure_name = '' - self.is_state_based = bool() - self._stat = [] - self.TPM = [] - self.params = {} - self.TS_info_ = {} - self.FCS_fit_time_ = None - self.dFC_assess_time_ = None - self.logs_ = '' - - @property - def FCS_fit_time(self): - return self.FCS_fit_time_ - - @property - def dFC_assess_time(self): - return self.dFC_assess_time_ - - @property - def TS_info(self): - # info of the time series used to train/estimate FCSs - return self.TS_info_ - - @property - def is_state_based(self): - return self.params['is_state_based'] - - @property - def FCS(self): - return self.FCS_ - - # test - @property - def FCS_dict(self): - # returns a dict including FCS matrices - - if not self.is_state_based: - return None - - C_A = self.FCS - FCSs = {} - for k in range(C_A.shape[0]): - FCSs['FCS'+str(k+1)] = C_A[k,:,:] - - return FCSs - - @property - def info(self): - return self.params - - @property - def logs(self): - print(self.logs_) - - def issame(self, dFC): - if type(self)==type(dFC): - for param_name in self.params: - if self.params[param_name] != dFC.params[param_name]: - return False - else: - return False - return True - - #test - def param_match(self, **param_dict): - for param in param_dict: - if param in self.params: - if type(param_dict[param]) is list: - if not self.params[param] in param_dict[param]: - return False - else: - if self.params[param]!=param_dict[param]: - return False - return True - - def set_FCS_fit_time(self, time): - self.FCS_fit_time_ = time - - def set_dFC_assess_time(self, time): - self.dFC_assess_time_ = time - - def set_mean_activity(self, time_series): - # mean activity of regions at each state - if self.is_state_based: - if 'sw_method' in self.params_name_lst: - SUBJECTs = time_series.subj_id_lst - TS_data = None - for subject in SUBJECTs: - subj_TS = time_series.get_subj_ts(subjs_id=subject).data - new_TS_data = SW_downsample(data=subj_TS.T, \ - Fs=time_series.Fs, W=self.params['W'], \ - n_overlap=self.params['n_overlap'], \ - tapered_window=self.params['tapered_window'] \ - ).T - if TS_data is None: - TS_data = new_TS_data - else: - TS_data = np.concatenate((TS_data, new_TS_data), axis=1) - else: - TS_data = time_series.data - mean_act = list() - for i in np.unique(self.Z): - ids = np.array([int(state==i) for state in self.Z]) - mean_act.append(np.average(TS_data, weights=ids, axis=1)) - self.mean_act = np.array(mean_act) - else: - self.mean_act = None - - def estimate_FCS(self, time_series=None): - pass - - def estimate_dFC(self, time_series=None): - pass - - def manipulate_time_series4FCS(self, time_series): - ''' - passing None to params will not change the time series - num_realization is not implemented yet - ''' - - new_time_series = deepcopy(time_series) - - # SUBJECTs - if not self.params['num_subj'] is None: - new_time_series.select_subjs(num_subj=self.params['num_subj']) - # SPATIAL RESOLUTION - if not self.params['num_select_nodes'] is None: - new_time_series.spatial_downsample(num_select_nodes=self.params['num_select_nodes'], rand_node_slct=False) - # TEMPORAL RESOLUTION - if not self.params['Fs_ratio'] is None: - new_time_series.Fs_resample(Fs_ratio=self.params['Fs_ratio']) - # NORMALIZE - if self.params['normalization']: - new_time_series.normalize() - # NOISE - if not self.params['noise_ratio'] is None: - new_time_series.add_noise(noise_ratio=self.params['noise_ratio'], mean_noise=0) - # NUMBER OF TIME POINTS - if not self.params['num_time_point'] is None: - new_time_series.truncate(start_point=0, end_point=self.params['num_time_point']-1) - - self.TS_info_ = new_time_series.info_dict - - return new_time_series - - def manipulate_time_series4dFC(self, time_series): - ''' - passing None to params will not change the time series - num_realization is not implemented yet - ''' - - new_time_series = deepcopy(time_series) - - # SPATIAL RESOLUTION - if not self.params['num_select_nodes'] is None: - new_time_series.spatial_downsample(num_select_nodes=self.params['num_select_nodes'], rand_node_slct=False) - # TEMPORAL RESOLUTION - if not self.params['Fs_ratio'] is None: - new_time_series.Fs_resample(Fs_ratio=self.params['Fs_ratio']) - # NORMALIZE - if self.params['normalization']: - new_time_series.normalize() - # NOISE - if not self.params['noise_ratio'] is None: - new_time_series.add_noise(noise_ratio=self.params['noise_ratio'], mean_noise=0) - # NUMBER OF TIME POINTS - if not self.params['num_time_point'] is None: - new_time_series.truncate(start_point=0, end_point=self.params['num_time_point']-1) - - return new_time_series - - def visualize_states(self): - pass - - # todo : use FCS_dict func in this func - def visualize_FCS( - self, - normalize=True, fix_lim=True, - save_image=False, output_root=None - ): - - visualize_FCS( - self, - normalize=normalize, fix_lim=fix_lim, - save_image=save_image, output_root=output_root - ) - - -################################## NEW METHOD ################################## - -''' -by : web link - -Reference: ## - -Parameters - ---------- - y1, y2 : numpy.ndarray, list - Input signals. - dt : float - Sample spacing. - -todo: - -import needed_toolbox - -class method_name(dFC): - - def __init__(self, **params): - self.FCS_ = [] - self.logs_ = '' - - self.params_name_lst = ['measure_name', 'is_state_based', 'n_states', - 'normalization', 'num_subj', 'num_select_nodes', 'num_time_point', - 'Fs_ratio', 'noise_ratio', 'num_realization', 'session'] - self.params = {} - for params_name in self.params_name_lst: - if params_name in params: - self.params[params_name] = params[params_name] - else: - self.params[params_name] = None - - self.params['specific_param'] = value - self.params['measure_name'] = 'method_name' - self.params['is_state_based'] = True/False - - @property - def measure_name(self): - return self.params['measure_name'] - - def estimate_FCS(self, time_series): - - assert type(time_series) is TIME_SERIES, \ - "time_series must be of TIME_SERIES class." - - time_series = self.manipulate_time_series4FCS(time_series) - - # start timing - tic = time.time() - - # calc FCSs - - # calc self.Z - - # mean activation of states - self.set_mean_activity(time_series) - - # record time - self.set_FCS_fit_time(time.time() - tic) - - return self - - def estimate_dFC(self, time_series): - - assert type(time_series) is TIME_SERIES, \ - "time_series must be of TIME_SERIES class." - - assert len(time_series.subj_id_lst)==1, \ - 'this function takes only one subject as input.' - - time_series = self.manipulate_time_series4dFC(time_series) - - # start timing - tic = time.time() - - # calc FCSs and FCS_idx - - # record time - self.set_dFC_assess_time(time.time() - tic) - - dFC = DFC(measure=self) - dFC.set_dFC(FCSs=self.FCS_, FCS_idx=FCS_idx, TS_info=time_series.info_dict) - return dFC -''' diff --git a/build/lib/pydfc/dfc_methods/cap.py b/build/lib/pydfc/dfc_methods/cap.py deleted file mode 100644 index 6480a97..0000000 --- a/build/lib/pydfc/dfc_methods/cap.py +++ /dev/null @@ -1,148 +0,0 @@ -""" -Implementation of dFC methods. - -Created on Jun 29 2023 -@author: Mohammad Torabi -""" - -import numpy as np -import time -from sklearn.cluster import KMeans - -from .base_dfc_method import BaseDFCMethod -from ..time_series import TIME_SERIES -from ..dfc import DFC - -################################## CAP ################################## - -''' -by : web link - -Reference: ## - -Parameters - ---------- - y1, y2 : numpy.ndarray, list - Input signals. - dt : float - Sample spacing. - -todo: -''' - - -class CAP(BaseDFCMethod): - - def __init__(self, **params): - self.logs_ = '' - self.FCS_ = [] - self.mean_act = [] - self.FCS_fit_time_ = None - self.dFC_assess_time_ = None - - self.params_name_lst = ['measure_name', 'is_state_based', 'n_states', - 'n_subj_clstrs', 'normalization', 'num_subj', 'num_select_nodes', 'num_time_point', - 'Fs_ratio', 'noise_ratio', 'num_realization', 'session'] - self.params = {} - for params_name in self.params_name_lst: - if params_name in params: - self.params[params_name] = params[params_name] - else: - self.params[params_name] = None - - self.params['measure_name'] = 'CAP' - self.params['is_state_based'] = True - - @property - def measure_name(self): - return self.params['measure_name'] - - def act_vec2FCS(self, act_vecs): - FCS_ = list() - for act_vec in act_vecs: - FCS_.append(np.multiply(act_vec[:, np.newaxis], act_vec[np.newaxis, :])) - return np.array(FCS_) - - def cluster_act_vec(self, act_vecs, n_clusters): - - kmeans_ = KMeans(n_clusters=n_clusters, n_init=500).fit(act_vecs) - act_centroids = kmeans_.cluster_centers_ - - return act_centroids, kmeans_ - - def estimate_FCS(self, time_series): - - assert type(time_series) is TIME_SERIES, \ - "time_series must be of TIME_SERIES class." - - time_series = self.manipulate_time_series4FCS(time_series) - - # start timing - tic = time.time() - - # 2-level clustering - SUBJECTs = time_series.subj_id_lst - act_center_1st_level = None - for subject in SUBJECTs: - - act_vecs = time_series.get_subj_ts(subjs_id=subject).data.T - - # test - if act_vecs.shape[0]