diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..ea51302 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,13 @@ +repos: + - repo: https://github.com/snakemake/snakefmt + rev: v0.10.2 # Replace by any tag/version ≥0.2.4 : https://github.com/snakemake/snakefmt/releases + hooks: + - id: snakefmt + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.4 + hooks: + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format diff --git a/benchmarking_pipeline/launch_pipeline.py b/benchmarking_pipeline/launch_pipeline.py index 7732d48..7e4f2b1 100644 --- a/benchmarking_pipeline/launch_pipeline.py +++ b/benchmarking_pipeline/launch_pipeline.py @@ -1,33 +1,37 @@ import argparse from pathlib import Path import subprocess -import tempfile + def parse_args(): - parser = argparse.ArgumentParser(description="set a seed and run the pipeline") - parser.add_argument("--seed", type=int, help="seed") - parser.add_argument("--config_template",type = str, help = "Provide path to template config file") + parser = argparse.ArgumentParser(description='set a seed and run the pipeline') + parser.add_argument('--seed', type=int, help='seed') + parser.add_argument( + '--config_template', type=str, help='Provide path to template config file' + ) return parser.parse_args() + def main(args): with open(args.config_template, 'r') as file: config_content = file.read() + config_content = config_content.replace('__seed__', str(args.seed)) - config_content = config_content.replace("__seed__", str(args.seed)) - - - with open(Path(args.config_template).parent / f"config_{args.seed}.yaml", 'w') as file: + with open( + Path(args.config_template).parent / f'config_{args.seed}.yaml', 'w' + ) as file: file.write(config_content) command = ( - "source $(conda info --base)/etc/profile.d/conda.sh && conda activate snakemake; " - f"sbatch --time=24:00:00 --wrap=\"snakemake --cores 10 --configfile config/config_{args.seed}.yaml " - "--software-deployment-method conda --rerun-incomplete -p --keep-going --profile profiles/slurm/\"" + 'source $(conda info --base)/etc/profile.d/conda.sh && conda activate snakemake; ' + f'sbatch --time=24:00:00 --wrap="snakemake --cores 10 --configfile config/config_{args.seed}.yaml ' + '--software-deployment-method conda --rerun-incomplete -p --keep-going --profile profiles/slurm/"' ) subprocess.run(command, shell=True, check=True, executable='/bin/bash') + if __name__ == '__main__': args = parse_args() main(args) diff --git a/benchmarking_pipeline/tests/test_common.py b/benchmarking_pipeline/tests/test_common.py index 66b416f..b632635 100644 --- a/benchmarking_pipeline/tests/test_common.py +++ b/benchmarking_pipeline/tests/test_common.py @@ -4,8 +4,8 @@ script_directory = Path(__file__).resolve().parent -path = (script_directory.parent / "workflow" / "rules" / "common.smk").as_posix() -spec = importlib.util.spec_from_file_location("common", path) +path = (script_directory.parent / 'workflow' / 'rules' / 'common.smk').as_posix() +spec = importlib.util.spec_from_file_location('common', path) target = importlib.util.module_from_spec(spec) spec.loader.exec_module(target) get_split_files = target.get_split_files @@ -15,14 +15,29 @@ class TestCommon(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestCommon, self).__init__(*args, **kwargs) - self.sample_list = script_directory / "sample_list_test.txt" - + self.sample_list = script_directory / 'sample_list_test.txt' + def test_get_split_files(self): split_files = get_split_files(self.sample_list) - self.assertEqual(split_files, [self.sample_list.parent / "loom1_split1.loom", self.sample_list.parent / "loom1_split2.loom", self.sample_list.parent / "loom2_split1.loom",self.sample_list.parent / "loom2_split2.loom"]) + self.assertEqual( + split_files, + [ + self.sample_list.parent / 'loom1_split1.loom', + self.sample_list.parent / 'loom1_split2.loom', + self.sample_list.parent / 'loom2_split1.loom', + self.sample_list.parent / 'loom2_split2.loom', + ], + ) def test_determine_number_of_different_donors(self): - self.assertEqual(target.determine_number_of_different_donors(script_directory / 'pools_test.txt'), "(0.0)", 4) + self.assertEqual( + target.determine_number_of_different_donors( + script_directory / 'pools_test.txt' + ), + '(0.0)', + 4, + ) + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/benchmarking_pipeline/tests/test_create_demultiplexing_scheme.py b/benchmarking_pipeline/tests/test_create_demultiplexing_scheme.py index 830c37f..70d4b11 100644 --- a/benchmarking_pipeline/tests/test_create_demultiplexing_scheme.py +++ b/benchmarking_pipeline/tests/test_create_demultiplexing_scheme.py @@ -4,32 +4,85 @@ script_directory = Path(__file__).resolve().parent -path = (script_directory.parent / "workflow" / "scripts" / "create_demultiplexing_scheme.py").as_posix() -spec = importlib.util.spec_from_file_location("create_demultiplexing_scheme", path) +path = ( + script_directory.parent / 'workflow' / 'scripts' / 'create_demultiplexing_scheme.py' +).as_posix() +spec = importlib.util.spec_from_file_location('create_demultiplexing_scheme', path) target = importlib.util.module_from_spec(spec) spec.loader.exec_module(target) multiplexing_scheme_format2pool_format = target.multiplexing_scheme_format2pool_format select_samples_for_pooling = target.select_samples_for_pooling -define_demultiplexing_scheme_optimal_case = target.define_demultiplexing_scheme_optimal_case +define_demultiplexing_scheme_optimal_case = ( + target.define_demultiplexing_scheme_optimal_case +) + class TestCreateDemultiplexingScheme(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestCreateDemultiplexingScheme, self).__init__(*args, **kwargs) - self.multiplexing_scheme = {1: (0, 1, 0), 2: (0, 2, 0), 3: (1, 2, 0), 4: (0, 0, 0), 5: (1, 1, 0), 6: (2, 2, 0), 7: (0, 1, 1), 8: (0, 2, 1), 9: (1, 2, 1), 10: (0, 0, 1), 11: (1, 1, 1), 12: (2, 2, 1)} - self.multiplexing_scheme_pool_format = {(0, 0): [1,2, 4,-4], (1, 0): [-1,3,5,-5], (2, 0): [-2, -3, 6, -6], (0, 1): [7, 8, 10, -10], (1,1): [-7, 9, 11, -11], (2, 1): [-8, -9, 12, -12]} - self.mock_samples = ['/test/folder/loom1.loom', '/test/folder/loom2.loom', '/test/folder/loom3.loom'] + self.multiplexing_scheme = { + 1: (0, 1, 0), + 2: (0, 2, 0), + 3: (1, 2, 0), + 4: (0, 0, 0), + 5: (1, 1, 0), + 6: (2, 2, 0), + 7: (0, 1, 1), + 8: (0, 2, 1), + 9: (1, 2, 1), + 10: (0, 0, 1), + 11: (1, 1, 1), + 12: (2, 2, 1), + } + self.multiplexing_scheme_pool_format = { + (0, 0): [1, 2, 4, -4], + (1, 0): [-1, 3, 5, -5], + (2, 0): [-2, -3, 6, -6], + (0, 1): [7, 8, 10, -10], + (1, 1): [-7, 9, 11, -11], + (2, 1): [-8, -9, 12, -12], + } + self.mock_samples = [ + '/test/folder/loom1.loom', + '/test/folder/loom2.loom', + '/test/folder/loom3.loom', + ] self.mock_path = '/another/test/folder' - self.another_pool_scheme = {(0,0): [1,-1,3], (1,0): [2,-2,-3]} + self.another_pool_scheme = {(0, 0): [1, -1, 3], (1, 0): [2, -2, -3]} def test_multiplexing_scheme_format2pool_format(self): - self.assertEqual(multiplexing_scheme_format2pool_format(self.multiplexing_scheme), self.multiplexing_scheme_pool_format) - + self.assertEqual( + multiplexing_scheme_format2pool_format(self.multiplexing_scheme), + self.multiplexing_scheme_pool_format, + ) + def test_select_samples_for_pooling(self): - self.assertEqual(select_samples_for_pooling(self.another_pool_scheme, self.mock_path, self.mock_samples), {'(0.0)': ['/another/test/folder/loom1_split1.loom', '/another/test/folder/loom1_split2.loom', '/another/test/folder/loom3_split1.loom'], - '(1.0)': ['/another/test/folder/loom2_split1.loom', '/another/test/folder/loom2_split2.loom', '/another/test/folder/loom3_split2.loom']}) + self.assertEqual( + select_samples_for_pooling( + self.another_pool_scheme, self.mock_path, self.mock_samples + ), + { + '(0.0)': [ + '/another/test/folder/loom1_split1.loom', + '/another/test/folder/loom1_split2.loom', + '/another/test/folder/loom3_split1.loom', + ], + '(1.0)': [ + '/another/test/folder/loom2_split1.loom', + '/another/test/folder/loom2_split2.loom', + '/another/test/folder/loom3_split2.loom', + ], + }, + ) def test_define_demultiplexing_scheme_optimal_case(self): - self.assertEqual(define_demultiplexing_scheme_optimal_case(maximal_number_of_samples = 3, maximal_pool_size = 2, n_samples = 3), {1: (0, 1, 0), 2: (0, 0, 0), 3: (1, 1, 0)}) + self.assertEqual( + define_demultiplexing_scheme_optimal_case( + maximal_number_of_samples=3, maximal_pool_size=2, n_samples=3 + ), + {1: (0, 1, 0), 2: (0, 0, 0), 3: (1, 1, 0)}, + ) + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/benchmarking_pipeline/tests/test_split_loom_files.py b/benchmarking_pipeline/tests/test_split_loom_files.py index afdc70a..146299d 100644 --- a/benchmarking_pipeline/tests/test_split_loom_files.py +++ b/benchmarking_pipeline/tests/test_split_loom_files.py @@ -6,12 +6,19 @@ import numpy as np import loompy -logging.basicConfig(format="{asctime} - {levelname} - {message}", style="{", datefmt="%Y-%m-%d %H:%M",level=logging.DEBUG) +logging.basicConfig( + format='{asctime} - {levelname} - {message}', + style='{', + datefmt='%Y-%m-%d %H:%M', + level=logging.DEBUG, +) script_directory = Path(__file__).resolve().parent -split_loom_file_path = (script_directory.parent / "workflow" / "scripts" / "split_loom_files.py").as_posix() -spec = importlib.util.spec_from_file_location("split_loom_file", split_loom_file_path) +split_loom_file_path = ( + script_directory.parent / 'workflow' / 'scripts' / 'split_loom_files.py' +).as_posix() +spec = importlib.util.spec_from_file_location('split_loom_file', split_loom_file_path) target = importlib.util.module_from_spec(spec) spec.loader.exec_module(target) split_loom_file = target.split_loom_file @@ -20,36 +27,42 @@ class TestSplitLoomFiles(unittest.TestCase): def __init__(self, *args, **kwargs): super(TestSplitLoomFiles, self).__init__(*args, **kwargs) - self.filename = "split_loom_data_test.loom" - #matrix = np.arange(4).reshape(2,2) - #row_attrs = { "SomeRowAttr": np.arange(2), "OtherRowAttr": ['A','B'] } - #col_attrs = { "SomeColAttr": np.arange(2), "OtherColAttr": ['C','D'] } - #other_layer = np.arange(4, 8).reshape(2,2) - #loompy.create((script_directory / self.filename).as_posix(), layers = {"":matrix, "other": other_layer}, row_attrs = row_attrs, col_attrs = col_attrs) + self.filename = 'split_loom_data_test.loom' + # matrix = np.arange(4).reshape(2,2) + # row_attrs = { "SomeRowAttr": np.arange(2), "OtherRowAttr": ['A','B'] } + # col_attrs = { "SomeColAttr": np.arange(2), "OtherColAttr": ['C','D'] } + # other_layer = np.arange(4, 8).reshape(2,2) + # loompy.create((script_directory / self.filename).as_posix(), layers = {"":matrix, "other": other_layer}, row_attrs = row_attrs, col_attrs = col_attrs) def test_split_loom_files(self): loom_file = script_directory / self.filename - split_loom_file(0.5, script_directory / self.filename, loom_file.parent, seed = 42) - with loompy.connect(script_directory / 'temp' / f'{Path(self.filename).stem}_split1.loom') as ds: + split_loom_file( + 0.5, script_directory / self.filename, loom_file.parent, seed=42 + ) + with loompy.connect( + script_directory / 'temp' / f'{Path(self.filename).stem}_split1.loom' + ) as ds: self.assertEqual(ds.shape[1], 1) self.assertEqual(ds.shape[0], 2) - self.assertTrue(np.all(ds[:,:] == np.array([[0], [2]]))) - self.assertTrue(np.all(ds.layers["other"][:,:] == np.array([[4], [6]]))) - self.assertTrue(np.all(ds.ca["SomeColAttr"] == np.array([0]))) - self.assertTrue(np.all(ds.ca["OtherColAttr"] == np.array(['C']))) - self.assertTrue(np.all(ds.ra["SomeRowAttr"] == np.array([0, 1]))) - self.assertTrue(np.all(ds.ra["OtherRowAttr"] == np.array(['A', 'B']))) - - with loompy.connect(script_directory / 'temp' / f'{Path(self.filename).stem}_split2.loom') as ds: + self.assertTrue(np.all(ds[:, :] == np.array([[0], [2]]))) + self.assertTrue(np.all(ds.layers['other'][:, :] == np.array([[4], [6]]))) + self.assertTrue(np.all(ds.ca['SomeColAttr'] == np.array([0]))) + self.assertTrue(np.all(ds.ca['OtherColAttr'] == np.array(['C']))) + self.assertTrue(np.all(ds.ra['SomeRowAttr'] == np.array([0, 1]))) + self.assertTrue(np.all(ds.ra['OtherRowAttr'] == np.array(['A', 'B']))) + + with loompy.connect( + script_directory / 'temp' / f'{Path(self.filename).stem}_split2.loom' + ) as ds: self.assertEqual(ds.shape[1], 1) self.assertEqual(ds.shape[0], 2) - self.assertTrue(np.all(ds[:,:] == np.array([[1], [3]]))) - self.assertTrue(np.all(ds.layers["other"][:,:] == np.array([[5], [7]]))) - self.assertTrue(np.all(ds.ca["SomeColAttr"] == np.array([1]))) - self.assertTrue(np.all(ds.ca["OtherColAttr"] == np.array(['D']))) - self.assertTrue(np.all(ds.ra["SomeRowAttr"] == np.array([0, 1]))) - self.assertTrue(np.all(ds.ra["OtherRowAttr"] == np.array(['A', 'B']))) + self.assertTrue(np.all(ds[:, :] == np.array([[1], [3]]))) + self.assertTrue(np.all(ds.layers['other'][:, :] == np.array([[5], [7]]))) + self.assertTrue(np.all(ds.ca['SomeColAttr'] == np.array([1]))) + self.assertTrue(np.all(ds.ca['OtherColAttr'] == np.array(['D']))) + self.assertTrue(np.all(ds.ra['SomeRowAttr'] == np.array([0, 1]))) + self.assertTrue(np.all(ds.ra['OtherRowAttr'] == np.array(['A', 'B']))) + if __name__ == '__main__': unittest.main() - diff --git a/benchmarking_pipeline/workflow/Snakefile b/benchmarking_pipeline/workflow/Snakefile index 2c584b7..d0bc779 100644 --- a/benchmarking_pipeline/workflow/Snakefile +++ b/benchmarking_pipeline/workflow/Snakefile @@ -1,19 +1,22 @@ include: Path("rules/common.smk") + rule all: input: - output_files - + output_files, + checkpoint get_input_samples: output: - samples_file = output_folder / "seed_{seed}" / "sample_list_{pool_size}.txt" + samples_file=output_folder / "seed_{seed}" / "sample_list_{pool_size}.txt", params: - number_of_samples = lambda wildcards: int(int(wildcards.pool_size)*(int(wildcards.pool_size) + 1)/2), - loom_file_path = loom_file_path, - seed = lambda w: w.seed, + number_of_samples=lambda wildcards: int( + int(wildcards.pool_size) * (int(wildcards.pool_size) + 1) / 2 + ), + loom_file_path=loom_file_path, + seed=lambda w: w.seed, log: - output_folder / 'log' / 'generate_input_samples_{seed}_{pool_size}.log' + output_folder / "log" / "generate_input_samples_{seed}_{pool_size}.log", shell: """ python {workflow.basedir}/scripts/get_input_samples.py \ @@ -24,19 +27,20 @@ checkpoint get_input_samples: &> {log} """ -rule split_files: ### This part of the pipeline is still not fully reproducible, because the files are split randomly. If I use the seed though, every file will be split the same way, which is not what I want. + +rule split_files: ### This part of the pipeline is still not fully reproducible, because the files are split randomly. If I use the seed though, every file will be split the same way, which is not what I want. input: - loom_file_path / '{loom_file}.loom', + loom_file_path / "{loom_file}.loom", output: - split1 = output_folder / "seed_{seed}" / "loom" / '{loom_file}_split1.loom', - split2 = output_folder / "seed_{seed}" / "loom" / '{loom_file}_split2.loom', + split1=output_folder / "seed_{seed}" / "loom" / "{loom_file}_split1.loom", + split2=output_folder / "seed_{seed}" / "loom" / "{loom_file}_split2.loom", conda: "envs/loom.yml" resources: mem_mb_per_cpu=3000, runtime=90, log: - output_folder / "seed_{seed}" / 'log' / 'split_{loom_file}.log', + output_folder / "seed_{seed}" / "log" / "split_{loom_file}.log", shell: """ python {workflow.basedir}/scripts/split_loom_files.py \ @@ -49,14 +53,19 @@ rule split_files: ### This part of the pipeline is still not fully reproducible, checkpoint create_demultiplexing_scheme: input: - sample_file = output_folder / "seed_{seed}" / 'sample_list_{pool_size}.txt' + sample_file=output_folder / "seed_{seed}" / "sample_list_{pool_size}.txt", output: - pools = output_folder / "seed_{seed}" / "pools_{pool_size}_robust{robust}.txt" + pools=output_folder / "seed_{seed}" / "pools_{pool_size}_robust{robust}.txt", params: - number_of_samples = lambda wildcards: int(int(wildcards.pool_size)*(int(wildcards.pool_size) + 1)/2), - input_dir = lambda w: output_folder / f"seed_{w.seed}" / "loom" + number_of_samples=lambda wildcards: int( + int(wildcards.pool_size) * (int(wildcards.pool_size) + 1) / 2 + ), + input_dir=lambda w: output_folder / f"seed_{w.seed}" / "loom", log: - output_folder / "seed_{seed}" / 'log' / 'create_demultiplexing_scheme_{seed}_{pool_size}_robust{robust}.log' + output_folder + / "seed_{seed}" + / "log" + / "create_demultiplexing_scheme_{seed}_{pool_size}_robust{robust}.log", shell: """ python {workflow.basedir}/scripts/create_demultiplexing_scheme.py \ @@ -70,16 +79,22 @@ checkpoint create_demultiplexing_scheme: """ -rule pooling: ##Need to introduce proper downsampling of cells +rule pooling: ##Need to introduce proper downsampling of cells input: - split_files = lambda w: yaml.safe_load( + split_files=lambda w: yaml.safe_load( open( - checkpoints.create_demultiplexing_scheme.get(seed=w.seed, pool_size = w.pool_size, robust = w.robust).output.pools - ) - )[f"({w.pool_ID})"], - pool_data = output_folder / "seed_{seed}" / "pools_{pool_size}_robust{robust}.txt", + checkpoints.create_demultiplexing_scheme.get( + seed=w.seed, pool_size=w.pool_size, robust=w.robust + ).output.pools + ) + )[f"({w.pool_ID})"], + pool_data=output_folder / "seed_{seed}" / "pools_{pool_size}_robust{robust}.txt", output: - output_folder / "seed_{seed}" / paramspace.wildcard_pattern / "pools" / "pool_{pool_ID}_{seed}.csv" + output_folder + / "seed_{seed}" + / paramspace.wildcard_pattern + / "pools" + / "pool_{pool_ID}_{seed}.csv", params: minGQ=config.get("mosaic", {}).get("minGQ", 30), minDP=config.get("mosaic", {}).get("minDP", 10), @@ -91,15 +106,19 @@ rule pooling: ##Need to introduce proper downsampling of cells minHomVAF=config.get("mosaic", {}).get("minHomVAF", 0.95), minHetVAF=config.get("mosaic", {}).get("minHetVAF", 0.35), proximity=config.get("mosaic", {}).get("proximity", "25 50 100 200"), - ratios = lambda w, input: sample_mixing_ratios(w.seed, len(input.split_files)), + ratios=lambda w, input: sample_mixing_ratios(w.seed, len(input.split_files)), resources: mem_mb_per_cpu=10000, runtime=90, conda: "envs/mosaic.yml" - group: "demultiplexing_simulation" + group: + "demultiplexing_simulation" log: - output_folder / "seed_{seed}" / 'log' / 'pooling_{pool_ID}_{seed}_{doublet_rate}_{cell_count}_{pool_size}_robust{robust}.log' + output_folder + / "seed_{seed}" + / "log" + / "pooling_{pool_ID}_{seed}_{doublet_rate}_{cell_count}_{pool_size}_robust{robust}.log", shell: """ python {workflow.basedir}/scripts/mosaic_processing.py \ @@ -124,29 +143,50 @@ rule pooling: ##Need to introduce proper downsampling of cells rule demultiplexing_demoTape: - input: - variants = output_folder / "seed_{seed}" / paramspace.wildcard_pattern / "pools" / "pool_{pool_ID}_{seed}.csv", - pools = output_folder / "seed_{seed}" / "pools_{pool_size}_robust{robust}.txt", - output: - profiles = output_folder / "seed_{seed}" / paramspace.wildcard_pattern / 'demultiplexed' / 'pool_{pool_ID}_{seed}_demultiplexed.profiles.tsv', - assignments = output_folder / "seed_{seed}" / paramspace.wildcard_pattern / 'demultiplexed' / 'pool_{pool_ID}_{seed}_demultiplexed.assignments.tsv' - params: - cluster_number= lambda w, input: determine_number_of_different_donors( - checkpoints.create_demultiplexing_scheme.get(seed=w.seed, pool_size = w.pool_size, robust = w.robust).output.pools, - "({})".format(w.pool_ID) - ), - outbase = lambda w: output_folder / "seed_{}".format(w.seed) / paramspace.wildcard_pattern / 'demultiplexed' / "pool_{}_{}_demultiplexed".format(w.pool_ID, w.seed), - output_folder = output_folder - resources: - mem_mb_per_cpu=4096, - runtime=90, - conda: - "envs/sample_assignment.yml" - group: "demultiplexing_simulation" - log: - output_folder / 'log' / 'demultiplexing_{pool_ID}_{seed}_{doublet_rate}_{cell_count}_{pool_size}_robust{robust}.log', - shell: - """ + input: + variants=output_folder + / "seed_{seed}" + / paramspace.wildcard_pattern + / "pools" + / "pool_{pool_ID}_{seed}.csv", + pools=output_folder / "seed_{seed}" / "pools_{pool_size}_robust{robust}.txt", + output: + profiles=output_folder + / "seed_{seed}" + / paramspace.wildcard_pattern + / "demultiplexed" + / "pool_{pool_ID}_{seed}_demultiplexed.profiles.tsv", + assignments=output_folder + / "seed_{seed}" + / paramspace.wildcard_pattern + / "demultiplexed" + / "pool_{pool_ID}_{seed}_demultiplexed.assignments.tsv", + params: + cluster_number=lambda w, input: determine_number_of_different_donors( + checkpoints.create_demultiplexing_scheme.get( + seed=w.seed, pool_size=w.pool_size, robust=w.robust + ).output.pools, + "({})".format(w.pool_ID), + ), + outbase=lambda w: output_folder + / "seed_{}".format(w.seed) + / paramspace.wildcard_pattern + / "demultiplexed" + / "pool_{}_{}_demultiplexed".format(w.pool_ID, w.seed), + output_folder=output_folder, + resources: + mem_mb_per_cpu=4096, + runtime=90, + conda: + "envs/sample_assignment.yml" + group: + "demultiplexing_simulation" + log: + output_folder + / "log" + / "demultiplexing_{pool_ID}_{seed}_{doublet_rate}_{cell_count}_{pool_size}_robust{robust}.log", + shell: + """ python {workflow.basedir}/scripts/demultiplex_distance.py \ --input {input.variants} \ -n {params.cluster_number} \ @@ -156,20 +196,28 @@ rule demultiplexing_demoTape: rule label_samples: - input: - demultiplexing_assignments = lambda wildcards: get_all_demultiplexed_assignments(wildcards, paramspace.wildcard_pattern), - output: - output_folder / "seed_{seed}" / paramspace.wildcard_pattern / 'sample_identity.yaml', - resources: - mem_mb_per_cpu=4096, - runtime=90, - conda: - "envs/sample_assignment.yml" - group: "assess_simulation_results" - log: - output_folder / 'log' / 'label_samples.{seed}_{doublet_rate}_{cell_count}_{pool_size}_robust{robust}.log', - shell: - """ + input: + demultiplexing_assignments=lambda wildcards: get_all_demultiplexed_assignments( + wildcards, paramspace.wildcard_pattern + ), + output: + output_folder + / "seed_{seed}" + / paramspace.wildcard_pattern + / "sample_identity.yaml", + resources: + mem_mb_per_cpu=4096, + runtime=90, + conda: + "envs/sample_assignment.yml" + group: + "assess_simulation_results" + log: + output_folder + / "log" + / "label_samples.{seed}_{doublet_rate}_{cell_count}_{pool_size}_robust{robust}.log", + shell: + """ python {workflow.basedir}/scripts/assign_subsample_simulation.py \ --demultiplexing_assignment {input.demultiplexing_assignments} \ --output {output} \ @@ -178,17 +226,28 @@ rule label_samples: rule assign_samples: - input: - demultiplexing_genotypes = lambda wildcards: get_all_demultiplexed_profiles(wildcards, paramspace.wildcard_pattern), - pools = output_folder / "seed_{seed}" / "pools_{pool_size}_robust{robust}.txt", + input: + demultiplexing_genotypes=lambda wildcards: get_all_demultiplexed_profiles( + wildcards, paramspace.wildcard_pattern + ), + pools=output_folder / "seed_{seed}" / "pools_{pool_size}_robust{robust}.txt", output: - sample_assignment = output_folder / "seed_{seed}" / paramspace.wildcard_pattern / 'sample_assignment.yaml', - heatmap = output_folder / "seed_{seed}" / paramspace.wildcard_pattern / 'sample_assignment_heatmap.png', + sample_assignment=output_folder + / "seed_{seed}" + / paramspace.wildcard_pattern + / "sample_assignment.yaml", + heatmap=output_folder + / "seed_{seed}" + / paramspace.wildcard_pattern + / "sample_assignment_heatmap.png", conda: "envs/sample_assignment.yml" - group: "assess_simulation_results" + group: + "assess_simulation_results" log: - output_folder / 'log' / 'assign_samples.{seed}_{doublet_rate}_{cell_count}_{pool_size}_robust{robust}.log', + output_folder + / "log" + / "assign_samples.{seed}_{doublet_rate}_{cell_count}_{pool_size}_robust{robust}.log", shell: """ python {workflow.basedir}/scripts/compare_distances.py \ @@ -198,4 +257,4 @@ rule assign_samples: --output {output.sample_assignment} \ --robust {wildcards.robust} \ &> {log} - """ \ No newline at end of file + """ diff --git a/benchmarking_pipeline/workflow/rules/common.smk b/benchmarking_pipeline/workflow/rules/common.smk index 685aa17..27a75e5 100644 --- a/benchmarking_pipeline/workflow/rules/common.smk +++ b/benchmarking_pipeline/workflow/rules/common.smk @@ -7,12 +7,14 @@ import numpy as np import pandas as pd from snakemake.utils import Paramspace -paramspace = Paramspace(pd.read_csv(Path(workflow.basedir) / 'sandbox' / 'parameter_space.tsv', sep='\t')) +paramspace = Paramspace( + pd.read_csv(Path(workflow.basedir) / "sandbox" / "parameter_space.tsv", sep="\t") +) output_folder = Path(config["output_folder"]) loom_file_path = Path(config["loom_files"]) -#number_of_samples = config.get("n_samples", 10) -seed=config["seed"] +# number_of_samples = config.get("n_samples", 10) +seed = config["seed"] loom_files = [] for p in loom_file_path.iterdir(): @@ -20,11 +22,9 @@ for p in loom_file_path.iterdir(): loom_files.append(p) - - def get_variant_list(pool_ID): file = output_folder / f"pool_{pool_ID}.txt" - with open(file, 'r') as f: + with open(file, "r") as f: return [line.strip() for line in f] @@ -32,7 +32,7 @@ def get_split_files(path): if isinstance(path, str): path = Path(path) split_files = [] - with open(path, 'r') as f: + with open(path, "r") as f: for line in f: path2 = Path(line.strip()) split_files.append(path.parent / "loom" / f"{path2.stem}_split1.loom") @@ -41,23 +41,28 @@ def get_split_files(path): return split_files - def sample_mixing_ratios(seed, max_pool_size): - concentrations = [40/max_pool_size] * max_pool_size + concentrations = [40 / max_pool_size] * max_pool_size rng = np.random.default_rng(seed=int(seed)) ratios = rng.dirichlet(concentrations) - + return " ".join(str(item) for item in ratios) def get_demultiplexed_samples(wildcards): - pools = yaml.safe_load(open(checkpoints.create_demultiplexing_scheme.get(seed=wildcard.seed).output, 'r')) + pools = yaml.safe_load( + open( + checkpoints.create_demultiplexing_scheme.get(seed=wildcard.seed).output, "r" + ) + ) pools = list(pools.keys()) - demultiplexed_samples = [f"pool_{pool}_{wildcard.seed}_demultiplexed.assignments.tsv" for pool in pools] + demultiplexed_samples = [ + f"pool_{pool}_{wildcard.seed}_demultiplexed.assignments.tsv" for pool in pools + ] + - def determine_number_of_different_donors(pool_file, pool_ID): - samples = yaml.safe_load(open(pool_file, 'r'))[pool_ID] + samples = yaml.safe_load(open(pool_file, "r"))[pool_ID] samples = [Path(sample) for sample in samples] unique_patterns = set() for sample in samples: @@ -66,35 +71,66 @@ def determine_number_of_different_donors(pool_file, pool_ID): unique_patterns.add(sample.stem.replace("_split1", "")) elif "_split2" in sample.stem: unique_patterns.add(sample.stem.replace("_split2", "")) - + return len(unique_patterns) def get_all_demultiplexed_assignments(wildcards, wildcard_pattern): - demultiplexing_scheme = yaml.safe_load(open(checkpoints.create_demultiplexing_scheme.get(seed=wildcards.seed, pool_size = wildcards.pool_size, robust = wildcards.robust).output.pools, 'r')) + demultiplexing_scheme = yaml.safe_load( + open( + checkpoints.create_demultiplexing_scheme.get( + seed=wildcards.seed, + pool_size=wildcards.pool_size, + robust=wildcards.robust, + ).output.pools, + "r", + ) + ) pool_names = list(demultiplexing_scheme.keys()) pool_names = [pool_name.strip("()") for pool_name in pool_names] - + demultiplexed_files = [] - file_path = output_folder / f"seed_{wildcards.seed}" / wildcard_pattern / 'demultiplexed' + file_path = ( + output_folder / f"seed_{wildcards.seed}" / wildcard_pattern / "demultiplexed" + ) for pool in pool_names: - demultiplexed_files.append(file_path / f"pool_{pool}_{wildcards.seed}_demultiplexed.assignments.tsv") + demultiplexed_files.append( + file_path / f"pool_{pool}_{wildcards.seed}_demultiplexed.assignments.tsv" + ) return demultiplexed_files def get_all_demultiplexed_profiles(wildcards, wildcard_pattern): - demultiplexing_scheme = yaml.safe_load(open(checkpoints.create_demultiplexing_scheme.get(seed=wildcards.seed, pool_size = wildcards.pool_size, robust = wildcards.robust).output.pools, 'r')) + demultiplexing_scheme = yaml.safe_load( + open( + checkpoints.create_demultiplexing_scheme.get( + seed=wildcards.seed, + pool_size=wildcards.pool_size, + robust=wildcards.robust, + ).output.pools, + "r", + ) + ) pool_names = list(demultiplexing_scheme.keys()) pool_names = [pool_name.strip("()") for pool_name in pool_names] - + demultiplexed_files = [] - file_path = output_folder / f"seed_{wildcards.seed}" / wildcard_pattern / 'demultiplexed' + file_path = ( + output_folder / f"seed_{wildcards.seed}" / wildcard_pattern / "demultiplexed" + ) for pool in pool_names: - demultiplexed_files.append(file_path / f"pool_{pool}_{wildcards.seed}_demultiplexed.profiles.tsv") + demultiplexed_files.append( + file_path / f"pool_{pool}_{wildcards.seed}_demultiplexed.profiles.tsv" + ) return demultiplexed_files - -output_files = expand(output_folder / f"seed_{seed}" / "{params}" / 'sample_assignment.yaml', params=paramspace.instance_patterns) + expand(output_folder / f"seed_{seed}" / "{params}" / 'sample_identity.yaml', params=paramspace.instance_patterns) +output_files = expand( + output_folder / f"seed_{seed}" / "{params}" / "sample_assignment.yaml", + params=paramspace.instance_patterns, +) + expand( + output_folder / f"seed_{seed}" / "{params}" / "sample_identity.yaml", + params=paramspace.instance_patterns, +) diff --git a/benchmarking_pipeline/workflow/sandbox/compute_performance_score.py b/benchmarking_pipeline/workflow/sandbox/compute_performance_score.py index 9aa4227..806a899 100644 --- a/benchmarking_pipeline/workflow/sandbox/compute_performance_score.py +++ b/benchmarking_pipeline/workflow/sandbox/compute_performance_score.py @@ -6,7 +6,12 @@ from sklearn.metrics import v_measure_score -logging.basicConfig(format="{asctime} - {levelname} - {message}", style="{", datefmt="%Y-%m-%d %H:%M",level=logging.INFO) +logging.basicConfig( + format='{asctime} - {levelname} - {message}', + style='{', + datefmt='%Y-%m-%d %H:%M', + level=logging.INFO, +) def compute_score(labels, estimate): @@ -18,10 +23,8 @@ def compute_score(labels, estimate): for key2, sample_assignment in assignments: if sample_assignment != labels[str(key1)][str(key2)]: mismatches += 1 - - return mismatches / total_counts - + return mismatches / total_counts def load_data(path): @@ -32,7 +35,7 @@ def load_data(path): with open(path / 'sample_assignment.yaml', 'r') as file: sample_assignment = yaml.safe_load(file) - + return sample_identity, sample_assignment @@ -47,7 +50,9 @@ def preprocess_labels(sample_identity): value_dict[sub_key] = int(value) if len(values) != len(set(values)): - logging.error(f"Duplicate values found in sample_identity for key {key}: {values}") + logging.error( + f'Duplicate values found in sample_identity for key {key}: {values}' + ) pathological_sample_identity = True return pathological_sample_identity, sample_identity @@ -60,9 +65,11 @@ def preprocess_sample_assignment(sample_assignment): def process_simulation_run(path): sample_identity, sample_assignment = load_data(path) - sample_identity_is_pathological, sample_identity = preprocess_labels(sample_identity) + sample_identity_is_pathological, sample_identity = preprocess_labels( + sample_identity + ) preprocess_sample_assignment(sample_assignment) - + logging.debug(sample_identity_is_pathological) logging.debug(sample_assignment) logging.debug(sample_identity) @@ -72,17 +79,22 @@ def process_simulation_run(path): return sample_identity_is_pathological, score - def compute_clustering_performance_score(subfolder): demultiplexed_folder = subfolder / 'demultiplexed' v_measures = [] for assignment_file in demultiplexed_folder.glob('*assignments.tsv'): df = pd.read_csv(assignment_file, sep='\t') - + # Extract ground truth labels and cluster assignments - ground_truth_labels = df.columns[1:].map(lambda col: 'doublet' if '+' in col else col.split('_')[1].split('-')[1]).tolist() + ground_truth_labels = ( + df.columns[1:] + .map( + lambda col: 'doublet' if '+' in col else col.split('_')[1].split('-')[1] + ) + .tolist() + ) cluster_assignments = df.iloc[0, 1:].tolist() # Compute ARI @@ -100,31 +112,46 @@ def main(input_dir): for subfolder in subfolder_0.glob('*0'): sample_identity_path = subfolder / 'sample_identity.yaml' sample_assignment_path = subfolder / 'sample_assignment.yaml' - + if sample_identity_path.exists() and sample_assignment_path.exists(): - sample_identity_is_pathological, score = process_simulation_run(subfolder) + sample_identity_is_pathological, score = process_simulation_run( + subfolder + ) mean_v_score = compute_clustering_performance_score(subfolder) seed_number = int(seed_folder.name.split('_')[1]) - - results.append({ - 'seed': seed_number, - 'doublet_rate': float(subfolder_0.name), - 'cell_count': int(subfolder.name), - 'score': score, - 'sample_identity_is_pathological': sample_identity_is_pathological, - 'v_score': mean_v_score - }) + + results.append( + { + 'seed': seed_number, + 'doublet_rate': float(subfolder_0.name), + 'cell_count': int(subfolder.name), + 'score': score, + 'sample_identity_is_pathological': sample_identity_is_pathological, + 'v_score': mean_v_score, + } + ) # Periodically save results to a file to avoid memory burden if len(results) % 100 == 0: temp_df = pd.DataFrame(results) - temp_df.to_csv('more_intermediate_results.csv', mode='a', header=False, index=False) + temp_df.to_csv( + 'more_intermediate_results.csv', + mode='a', + header=False, + index=False, + ) results.clear() if results: df = pd.DataFrame(results) - df.to_csv('more_intermediate_results.csv', mode='a', header=not Path('more_intermediate_results.csv').exists(), index=False) + df.to_csv( + 'more_intermediate_results.csv', + mode='a', + header=not Path('more_intermediate_results.csv').exists(), + index=False, + ) + """ def main(input_dir): if not isinstance(input_dir, Path): @@ -168,7 +195,9 @@ def main(input_dir): """ if __name__ == '__main__': - input_dir = Path('/cluster/work/bewi/members/jgawron/projects/Demultiplexing/AML_data/output')# = args.input - - #process_simulation_run('/cluster/work/bewi/members/jgawron/projects/Demultiplexing/AML_data/output/seed_4/robust~True/pool_size~4/doublet_rate~0.05/cell_count~1000') + input_dir = Path( + '/cluster/work/bewi/members/jgawron/projects/Demultiplexing/AML_data/output' + ) # = args.input + + # process_simulation_run('/cluster/work/bewi/members/jgawron/projects/Demultiplexing/AML_data/output/seed_4/robust~True/pool_size~4/doublet_rate~0.05/cell_count~1000') main(input_dir) diff --git a/benchmarking_pipeline/workflow/sandbox/figure3.py b/benchmarking_pipeline/workflow/sandbox/figure3.py index e10c36a..cad4bf6 100644 --- a/benchmarking_pipeline/workflow/sandbox/figure3.py +++ b/benchmarking_pipeline/workflow/sandbox/figure3.py @@ -4,48 +4,80 @@ import matplotlib.pyplot as plt import seaborn as sns -logging.basicConfig(format="{asctime} - {levelname} - {message}", style="{", datefmt="%Y-%m-%d %H:%M",level=logging.INFO) - -data = pd.read_csv('intermediate_results.csv', header = None) -data.columns = ['seed', 'doublet_rate', 'cell_count', 'score', 'sample_identity_is_pathological', 'v_score'] -sns.set_theme(style="whitegrid") +logging.basicConfig( + format='{asctime} - {levelname} - {message}', + style='{', + datefmt='%Y-%m-%d %H:%M', + level=logging.INFO, +) + +data = pd.read_csv('intermediate_results.csv', header=None) +data.columns = [ + 'seed', + 'doublet_rate', + 'cell_count', + 'score', + 'sample_identity_is_pathological', + 'v_score', +] +sns.set_theme(style='whitegrid') data = data.loc[data['cell_count'].isin([1000, 2000, 3000, 4000])] -data['score'] = 1-data['score'] +data['score'] = 1 - data['score'] # Customize the appearance of the plots -plt.rcParams.update({ - 'axes.titlesize': 26, - 'axes.labelsize': 26, - 'xtick.labelsize': 20, - 'ytick.labelsize': 26, - 'legend.fontsize': 16, - 'figure.titlesize': 18 -}) +plt.rcParams.update( + { + 'axes.titlesize': 26, + 'axes.labelsize': 26, + 'xtick.labelsize': 20, + 'ytick.labelsize': 26, + 'legend.fontsize': 16, + 'figure.titlesize': 18, + } +) fig, axes = plt.subplots(1, 5, figsize=(30, 6), sharey=True) doublet_rates = [0, 0.02, 0.04, 0.06, 0.08] for ax, rate in zip(axes, doublet_rates): - sns.boxplot(x='cell_count', y='score', data=data.loc[(~data['sample_identity_is_pathological']) & (data['doublet_rate'] == rate)], ax=ax, color='#C97B84',flierprops=dict(marker='o', color='red', markersize=5, markerfacecolor='black')) + sns.boxplot( + x='cell_count', + y='score', + data=data.loc[ + (~data['sample_identity_is_pathological']) & (data['doublet_rate'] == rate) + ], + ax=ax, + color='#C97B84', + flierprops=dict(marker='o', color='red', markersize=5, markerfacecolor='black'), + ) ax.set_title(f'Doublet Rate: {rate}') ax.set_xlabel('Cell Count') axes[0].set_ylabel('Sample assignment score') plt.tight_layout() -logging.info("Saving figure4a.png") -plt.savefig("figure4a.png", dpi = 300) +logging.info('Saving figure4a.png') +plt.savefig('figure4a.png', dpi=300) plt.close() -data = pd.read_csv('results_for_robust.csv', header = 0) -data.columns = ['seed', 'doublet_rate', 'cell_count', 'score', 'sample_identity_is_pathological', 'v_score'] -data['score'] = 1-data['score'] -data = data[data['score'] != 1] ### This outlier datapoint goes back to a mistake in the simulation! +data = pd.read_csv('results_for_robust.csv', header=0) +data.columns = [ + 'seed', + 'doublet_rate', + 'cell_count', + 'score', + 'sample_identity_is_pathological', + 'v_score', +] +data['score'] = 1 - data['score'] +data = data[ + data['score'] != 1 +] ### This outlier datapoint goes back to a mistake in the simulation! plt.figure(figsize=(10, 8)) -plt.scatter(x='v_score', y='score', data = data, color = '#C97B84') +plt.scatter(x='v_score', y='score', data=data, color='#C97B84') # Customize the plot plt.grid(True) @@ -60,13 +92,19 @@ plt.gca().tick_params(axis='x', colors='#444444') plt.gca().tick_params(axis='y', colors='#444444') -logging.info("Saving figure4b.png") -plt.savefig('figure4b.png', dpi = 300) +logging.info('Saving figure4b.png') +plt.savefig('figure4b.png', dpi=300) plt.close() plt.figure(figsize=(10, 8)) -sns.boxplot(x='sample_identity_is_pathological', y='score', data=data, color='#C97B84', flierprops=dict(marker='o', color='red', markersize=5, markerfacecolor='black')) +sns.boxplot( + x='sample_identity_is_pathological', + y='score', + data=data, + color='#C97B84', + flierprops=dict(marker='o', color='red', markersize=5, markerfacecolor='black'), +) # Customize the plot plt.grid(True) @@ -81,8 +119,7 @@ plt.gca().tick_params(axis='x', colors='#444444') plt.gca().tick_params(axis='y', colors='#444444') -logging.info("Saving figure4c.png") -plt.savefig('figure4c.png', dpi = 300) +logging.info('Saving figure4c.png') +plt.savefig('figure4c.png', dpi=300) plt.close() # Show the plot - diff --git a/benchmarking_pipeline/workflow/sandbox/figure4.py b/benchmarking_pipeline/workflow/sandbox/figure4.py index e258ed8..eee9d30 100644 --- a/benchmarking_pipeline/workflow/sandbox/figure4.py +++ b/benchmarking_pipeline/workflow/sandbox/figure4.py @@ -2,27 +2,45 @@ import matplotlib.pyplot as plt import seaborn as sns -data = pd.read_csv('intermediate_results.csv', header = None) -data.columns = ['seed', 'doublet_rate', 'cell_count', 'score', 'sample_identity_is_pathological', 'v_score'] -sns.set_theme(style="whitegrid") +data = pd.read_csv('intermediate_results.csv', header=None) +data.columns = [ + 'seed', + 'doublet_rate', + 'cell_count', + 'score', + 'sample_identity_is_pathological', + 'v_score', +] +sns.set_theme(style='whitegrid') # Customize the appearance of the plots -plt.rcParams.update({ - 'axes.titlesize': 16, - 'axes.labelsize': 14, - 'xtick.labelsize': 12, - 'ytick.labelsize': 12, - 'legend.fontsize': 12, - 'figure.titlesize': 18 -}) +plt.rcParams.update( + { + 'axes.titlesize': 16, + 'axes.labelsize': 14, + 'xtick.labelsize': 12, + 'ytick.labelsize': 12, + 'legend.fontsize': 12, + 'figure.titlesize': 18, + } +) fig, axes = plt.subplots(1, 6, figsize=(30, 6), sharey=True) doublet_rates = [0, 0.02, 0.04, 0.06, 0.08] for ax, rate in zip(axes, doublet_rates): - sns.boxplot(x='cell_count', y='score', data=data.loc[(~data['sample_identity_is_pathological']) & (data['doublet_rate'] == rate)], ax=ax, color='#C97B84',flierprops=dict(marker='o', color='red', markersize=5, markerfacecolor='black')) + sns.boxplot( + x='cell_count', + y='score', + data=data.loc[ + (~data['sample_identity_is_pathological']) & (data['doublet_rate'] == rate) + ], + ax=ax, + color='#C97B84', + flierprops=dict(marker='o', color='red', markersize=5, markerfacecolor='black'), + ) ax.set_title(f'Doublet Rate: {rate}') ax.set_xlabel('Cell Count') axes[0].set_ylabel('Score') plt.tight_layout() -plt.savefig("figure3.png") +plt.savefig('figure3.png') diff --git a/benchmarking_pipeline/workflow/sandbox/generate_parameter_space.py b/benchmarking_pipeline/workflow/sandbox/generate_parameter_space.py index aa59ff1..7b1e6da 100644 --- a/benchmarking_pipeline/workflow/sandbox/generate_parameter_space.py +++ b/benchmarking_pipeline/workflow/sandbox/generate_parameter_space.py @@ -8,7 +8,9 @@ robust_values = [True, False] # Generate all combinations of parameters -combinations = list(itertools.product(robust_values, pool_sizes, doublet_rates, cell_counts)) +combinations = list( + itertools.product(robust_values, pool_sizes, doublet_rates, cell_counts) +) # Define the output file path output_file = 'parameter_space.tsv' @@ -22,4 +24,4 @@ for combination in combinations: writer.writerow(combination) -print(f"Parameter space written to {output_file}") \ No newline at end of file +print(f'Parameter space written to {output_file}') diff --git a/benchmarking_pipeline/workflow/sandbox/get_read_count_distribution_of_samples.py b/benchmarking_pipeline/workflow/sandbox/get_read_count_distribution_of_samples.py deleted file mode 100644 index f02abea..0000000 --- a/benchmarking_pipeline/workflow/sandbox/get_read_count_distribution_of_samples.py +++ /dev/null @@ -1,38 +0,0 @@ -from pathlib import Path - -import loompy -import pandas as pd -import numpy as np - -def process_loom_files(folder_path): - folder = Path(folder_path) - files = [] - cell_counts = [] - for filename in Path(folder).rglob("*.loom"): - with loompy.connect(str(filename)) as ds: - cell_counts.append(ds.shape[1]) - files.append(filename) - data = pd.DataFrame({'cell_count': cell_counts}, index = files) - import matplotlib.pyplot as plt - - # Plot the histogram of cell counts - plt.hist(cell_counts, bins=30, density=True, alpha=0.6, color='g', label='Cell count histogram') - - # Plot the Poisson distribution density - - # Fit a Normal distribution to the cell counts - mean_count = np.mean(cell_counts) - std_dev = np.std(cell_counts) - - x = np.arange(0, max(cell_counts) + 1) - from scipy.stats import norm - plt.plot(x, norm.pdf(x, mean_count, std_dev), 'r-', label='Normal fit') - - plt.xlabel('Cell count') - plt.ylabel('Density') - plt.title('Histogram of Cell Counts with Poisson Distribution Fit') - plt.legend() - plt.savefig(folder / 'cell_count_distribution.png') - -folder_path = '/cluster/work/bewi/members/jgawron/projects/Demultiplexing/AML_data/loom_files_complete' -process_loom_files(folder_path) \ No newline at end of file diff --git a/benchmarking_pipeline/workflow/scripts/assign_subsample_simulation.py b/benchmarking_pipeline/workflow/scripts/assign_subsample_simulation.py index 211a878..f9f3e29 100644 --- a/benchmarking_pipeline/workflow/scripts/assign_subsample_simulation.py +++ b/benchmarking_pipeline/workflow/scripts/assign_subsample_simulation.py @@ -8,27 +8,44 @@ logging.basicConfig(level=logging.INFO) + def parse_args(): parser = argparse.ArgumentParser(description='Assign subsample simulation') - parser.add_argument('--demultiplexing_assignment', nargs = '+', type=str, help='A tsv file that specifies which cell is assigned to which cluster') - parser.add_argument('--output', type=str, help='Output file for the sample assignment') + parser.add_argument( + '--demultiplexing_assignment', + nargs='+', + type=str, + help='A tsv file that specifies which cell is assigned to which cluster', + ) + parser.add_argument( + '--output', type=str, help='Output file for the sample assignment' + ) return parser.parse_args() def analyse_demultiplexing_assignment(assignment_file, pool_name): - data = pd.read_csv(assignment_file, sep='\t', index_col = 0, header = None) + data = pd.read_csv(assignment_file, sep='\t', index_col=0, header=None) - data.iloc[0, :] = data.iloc[0, :].map(lambda x: "+".join([part.split('_')[1].split('.')[0].split("-")[1] for part in x.split('+')]) if '+' in x else x.split('_')[1].split('.')[0].split("-")[1]) + data.iloc[0, :] = data.iloc[0, :].map( + lambda x: '+'.join( + [part.split('_')[1].split('.')[0].split('-')[1] for part in x.split('+')] + ) + if '+' in x + else x.split('_')[1].split('.')[0].split('-')[1] + ) - unique_values = data.iloc[1,:].unique() + unique_values = data.iloc[1, :].unique() colors = plt.cm.tab20.colors - color_map = {label: colors[i % len(colors)] for i, label in enumerate(data.iloc[0, :].unique())} - + color_map = { + label: colors[i % len(colors)] + for i, label in enumerate(data.iloc[0, :].unique()) + } + fig, axes = plt.subplots(1, len(unique_values), figsize=(15, 5)) sample_assignment = {} - + for ax, value in zip(axes, unique_values): subset = data.loc[:, data.iloc[1, :] == value] counts = subset.iloc[0, :].value_counts() @@ -36,15 +53,23 @@ def analyse_demultiplexing_assignment(assignment_file, pool_name): major_label = counts.idxmax() sample_assignment[value] = major_label if not any(counts > 0.75 * counts.sum()): - logging.warning(f'Demultiplexing results in contaminated cell clusters.') - counts.plot.pie(autopct='%1.1f%%', startangle=90, ax=ax, colors=[color_map[label] for label in counts.index]) + logging.warning('Demultiplexing results in contaminated cell clusters.') + counts.plot.pie( + autopct='%1.1f%%', + startangle=90, + ax=ax, + colors=[color_map[label] for label in counts.index], + ) ax.set_title(f'Abundancies for {value}') ax.set_ylabel('') plt.tight_layout() - plt.savefig(Path(assignment_file).parent / f'pool_{pool_name}_demultiplexing_assignment.png') + plt.savefig( + Path(assignment_file).parent / f'pool_{pool_name}_demultiplexing_assignment.png' + ) return sample_assignment + def main(args): sample_assignments = {} if isinstance(args.demultiplexing_assignment, str): @@ -53,13 +78,16 @@ def main(args): assignment_files = args.demultiplexing_assignment for assignment_file in assignment_files: pool_name = str(Path(assignment_file).stem).split('_')[1] - sample_assignment = analyse_demultiplexing_assignment(assignment_file, pool_name) - sample_assignments[f"({pool_name})"] = sample_assignment + sample_assignment = analyse_demultiplexing_assignment( + assignment_file, pool_name + ) + sample_assignments[f'({pool_name})'] = sample_assignment with open(args.output, 'w') as outfile: yaml.dump(sample_assignments, outfile) logging.info('Success.') + if __name__ == '__main__': args = parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/benchmarking_pipeline/workflow/scripts/compare_distances.py b/benchmarking_pipeline/workflow/scripts/compare_distances.py index da39ccc..cacb546 100644 --- a/benchmarking_pipeline/workflow/scripts/compare_distances.py +++ b/benchmarking_pipeline/workflow/scripts/compare_distances.py @@ -1,5 +1,4 @@ import argparse -import logging import pandas as pd import seaborn as sns @@ -9,22 +8,25 @@ from pathlib import Path - def pool_format2multiplexing_scheme(pool_scheme): demultiplexing_scheme = {} for pool, samples in pool_scheme.items(): for sample in samples: - if "_split1.loom" in sample: + if '_split1.loom' in sample: sample_name = str(Path(sample).stem) sample_name = sample_name.replace('_split1', '') - number_of_iterations =pool.split('.')[1][:-1] + number_of_iterations = pool.split('.')[1][:-1] pool_ID = pool.split('.')[0][1:] if sample_name not in demultiplexing_scheme.keys(): - demultiplexing_scheme[sample_name] = [int(pool_ID), np.inf, int(number_of_iterations)] + demultiplexing_scheme[sample_name] = [ + int(pool_ID), + np.inf, + int(number_of_iterations), + ] else: demultiplexing_scheme[sample_name][0] = int(pool_ID) demultiplexing_scheme[sample_name][2] = int(number_of_iterations) - if "_split2.loom" in sample: + if '_split2.loom' in sample: sample_name = str(Path(sample).stem) sample_name = sample_name.replace('_split2', '') pool_ID = pool.split('.')[0][1:] @@ -33,25 +35,32 @@ def pool_format2multiplexing_scheme(pool_scheme): else: demultiplexing_scheme[sample_name][1] = int(pool_ID) # Swap keys and items in the dictionary - swapped_demultiplexing_scheme = {tuple(v): k for k, v in demultiplexing_scheme.items()} - - return swapped_demultiplexing_scheme - + swapped_demultiplexing_scheme = { + tuple(v): k for k, v in demultiplexing_scheme.items() + } + return swapped_demultiplexing_scheme # Create a custom colormap -cmap = sns.color_palette("viridis", as_cmap=True) +cmap = sns.color_palette('viridis', as_cmap=True) cmap.set_bad(color='grey') def parse_args(): - parser = argparse.ArgumentParser(description="Compare distances between samples") - parser.add_argument("--tsv_files", nargs='+' , type=str, help="Path to TSV files containing the genotypes of the clusters. Must be in the same order as the pools in the pooling scheme.") - parser.add_argument("--pool_scheme", type=str, help="Path to the pool scheme file") - parser.add_argument("--output_plot", type=str, help="Output heatmap") - parser.add_argument("--output", type=str, help="sample assignment") - parser.add_argument("--robust", type=bool, required=False, default = False, help="Robust assignment") + parser = argparse.ArgumentParser(description='Compare distances between samples') + parser.add_argument( + '--tsv_files', + nargs='+', + type=str, + help='Path to TSV files containing the genotypes of the clusters. Must be in the same order as the pools in the pooling scheme.', + ) + parser.add_argument('--pool_scheme', type=str, help='Path to the pool scheme file') + parser.add_argument('--output_plot', type=str, help='Output heatmap') + parser.add_argument('--output', type=str, help='sample assignment') + parser.add_argument( + '--robust', type=bool, required=False, default=False, help='Robust assignment' + ) return parser.parse_args() @@ -72,19 +81,30 @@ def compute_ratio(matrix): def permute_tsv_files(tsv_files, pool_scheme): pools = [f"({Path(file).name.split("_")[1]})" for file in tsv_files] - + pool_scheme_keys = list(pool_scheme.keys()) pool_permutation = [pools.index(pool) for pool in pool_scheme_keys] - - if not all((pool1 == pool2 for pool1, pool2 in zip([pools[permuted_idx] for permuted_idx in pool_permutation], pool_scheme_keys))): - raise ValueError("The pools in the pool scheme do not match the pools in the TSV files") + + if not all( + ( + pool1 == pool2 + for pool1, pool2 in zip( + [pools[permuted_idx] for permuted_idx in pool_permutation], + pool_scheme_keys, + ) + ) + ): + raise ValueError( + 'The pools in the pool scheme do not match the pools in the TSV files' + ) return pool_permutation + args = parse_args() -def load_data(args): +def load_data(args): with open(args.pool_scheme, 'r') as file: pooling_scheme = yaml.safe_load(file) @@ -94,11 +114,12 @@ def load_data(args): tsv_files = [args.tsv_files] tsv_files = list(args.tsv_files) tsv_files_permuted = [tsv_files[i] for i in permutation_of_pools] - - dfs = [pd.read_csv(f, sep='\t', index_col = 0) for f in tsv_files_permuted] + + dfs = [pd.read_csv(f, sep='\t', index_col=0) for f in tsv_files_permuted] return dfs, pooling_scheme + dfs, pooling_scheme = load_data(args) # Get a set of all column names @@ -109,14 +130,14 @@ def load_data(args): print(all_columns) -for idx,df in enumerate(dfs): +for idx, df in enumerate(dfs): for col in all_columns: if col not in df.columns: dfs[idx][col] = np.nan -for idx,df in enumerate(dfs): +for idx, df in enumerate(dfs): dfs[idx] = df[sorted(all_columns)] - + # Concatenate all DataFrames concatenated_df = pd.concat(dfs, ignore_index=True)[sorted(all_columns)] @@ -135,7 +156,7 @@ def load_data(args): concatenated_df.drop(columns=cols_to_remove, inplace=True) concatenated_df.drop(columns=cols_to_remove_45_55, inplace=True) -for idx,df in enumerate(dfs): +for idx, df in enumerate(dfs): dfs[idx] = df.drop(columns=cols_to_remove, inplace=False) dfs[idx] = df.drop(columns=cols_to_remove_45_55, inplace=False) @@ -151,7 +172,7 @@ def load_data(args): # Compute the distance matrix considering only non-NaN entries and normalizing - # Compute the distance matrix for all pairs of DataFrames +# Compute the distance matrix for all pairs of DataFrames distance_matrices = np.empty((len(dfs), len(dfs)), dtype=object) for data1, df1 in enumerate(dfs): @@ -159,7 +180,9 @@ def load_data(args): distance_matrix = np.zeros((df1.shape[0], df2.shape[0])) for i in range(df1.shape[0]): for j in range(df2.shape[0]): - distance_matrix[i, j] = custom_distance(df1.iloc[i].values, df2.iloc[j].values) + distance_matrix[i, j] = custom_distance( + df1.iloc[i].values, df2.iloc[j].values + ) distance_matrices[data1, data2] = distance_matrix @@ -200,7 +223,6 @@ def load_data(args): plt.savefig(heatmap_plot) - fig, axes = plt.subplots(nrows=1, ncols=len(dfs), figsize=(15, 5)) for i, df in enumerate(dfs): @@ -209,9 +231,6 @@ def load_data(args): plt.savefig(genotype_plot) - - - ratios = [] for i in range(len(dfs)): for j in range(len(dfs)): @@ -224,7 +243,7 @@ def load_data(args): lowest_value_pairs = [] -used_samples = [[]*len(dfs) for _ in range(len(dfs))] +used_samples = [[] * len(dfs) for _ in range(len(dfs))] """ if remove is not None: for j in range(len(dfs)): if j != remove: @@ -272,27 +291,31 @@ def load_data(args): # Sort the distance matrices by the recomputed ratio, from largest to smallest sorted_ratios = sorted(ratios, key=lambda x: x[2], reverse=True) else: - print(f"skipping assignment in matrix {i},{j}") + print(f'skipping assignment in matrix {i},{j}') # Print the pairs with the lowest values for i, j, row, col in lowest_value_pairs: - print(f'Lowest value in Distance Matrix DF{i} vs DF{j}: Row = {row}, Column = {col}') + print( + f'Lowest value in Distance Matrix DF{i} vs DF{j}: Row = {row}, Column = {col}' + ) - -demultiplexing_scheme = pool_format2multiplexing_scheme(pooling_scheme) +demultiplexing_scheme = pool_format2multiplexing_scheme(pooling_scheme) for i, j, row, col in lowest_value_pairs: if i > j: i, j = j, i row, col = col, row - print(f'Sample {row} from pool {i} and sample {col} from pool {j}: {demultiplexing_scheme[(i,j,0)]}') + print( + f'Sample {row} from pool {i} and sample {col} from pool {j}: {demultiplexing_scheme[(i,j,0)]}' + ) -if args.robust == False: +if args.robust: for i in range(len(used_samples)): - unused_sample = next(sample for sample in range(len(dfs[i])) if sample not in used_samples[i]) - print(f"Sample {unused_sample} from pool {i}: {demultiplexing_scheme[(i,i,0)]}") - + unused_sample = next( + sample for sample in range(len(dfs[i])) if sample not in used_samples[i] + ) + print(f'Sample {unused_sample} from pool {i}: {demultiplexing_scheme[(i,i,0)]}') pooling_scheme_keys = list(pooling_scheme.keys()) @@ -302,21 +325,35 @@ def load_data(args): i, j = j, i row, col = col, row if pooling_scheme_keys[i] not in sample_assignment.keys(): - sample_assignment[pooling_scheme_keys[i]] = {row: demultiplexing_scheme[(i, j, 0)]} + sample_assignment[pooling_scheme_keys[i]] = { + row: demultiplexing_scheme[(i, j, 0)] + } else: - sample_assignment[pooling_scheme_keys[i]][row] = demultiplexing_scheme[(i, j, 0)] + sample_assignment[pooling_scheme_keys[i]][row] = demultiplexing_scheme[ + (i, j, 0) + ] if pooling_scheme_keys[j] not in sample_assignment.keys(): - sample_assignment[pooling_scheme_keys[j]] = {col: demultiplexing_scheme[(i, j, 0)]} + sample_assignment[pooling_scheme_keys[j]] = { + col: demultiplexing_scheme[(i, j, 0)] + } else: - sample_assignment[pooling_scheme_keys[j]][col] = demultiplexing_scheme[(i, j, 0)] + sample_assignment[pooling_scheme_keys[j]][col] = demultiplexing_scheme[ + (i, j, 0) + ] -if args.robust == False: +if args.robust: for i in range(len(used_samples)): - unused_sample = next(sample for sample in range(len(dfs[i])) if sample not in used_samples[i]) + unused_sample = next( + sample for sample in range(len(dfs[i])) if sample not in used_samples[i] + ) if pooling_scheme_keys[i] not in sample_assignment.keys(): - sample_assignment[pooling_scheme_keys[i]] = {unused_sample: demultiplexing_scheme[(i, i, 0)]} + sample_assignment[pooling_scheme_keys[i]] = { + unused_sample: demultiplexing_scheme[(i, i, 0)] + } else: - sample_assignment[pooling_scheme_keys[i]][unused_sample] = demultiplexing_scheme[(i, i, 0)] + sample_assignment[pooling_scheme_keys[i]][unused_sample] = ( + demultiplexing_scheme[(i, i, 0)] + ) with open(args.output, 'w') as yaml_file: yaml.dump(sample_assignment, yaml_file) diff --git a/benchmarking_pipeline/workflow/scripts/create_demultiplexing_scheme.py b/benchmarking_pipeline/workflow/scripts/create_demultiplexing_scheme.py index 331816d..5e16c61 100644 --- a/benchmarking_pipeline/workflow/scripts/create_demultiplexing_scheme.py +++ b/benchmarking_pipeline/workflow/scripts/create_demultiplexing_scheme.py @@ -6,39 +6,56 @@ import numpy as np -logging.basicConfig(format="{asctime} - {levelname} - {message}", style="{", datefmt="%Y-%m-%d %H:%M",level=logging.INFO) +logging.basicConfig( + format='{asctime} - {levelname} - {message}', + style='{', + datefmt='%Y-%m-%d %H:%M', + level=logging.INFO, +) -def define_demultiplexing_scheme_optimal_case(maximal_number_of_samples, maximal_pool_size, n_samples, robust): +def define_demultiplexing_scheme_optimal_case( + maximal_number_of_samples, maximal_pool_size, n_samples, robust +): if n_samples % maximal_number_of_samples != 0: - raise ValueError("Number of samples must be a multiple of maximal number of samples to run this function!") + raise ValueError( + 'Number of samples must be a multiple of maximal number of samples to run this function!' + ) demultiplexing_scheme = {} number_of_iterations = int(n_samples / maximal_number_of_samples) if not robust: - unordered_unique_pairs = list(itertools.combinations(range(maximal_pool_size), 2)) + unordered_unique_pairs = list( + itertools.combinations(range(maximal_pool_size), 2) + ) diagonal = list(zip(range(maximal_pool_size), range(maximal_pool_size))) unordered_pairs = unordered_unique_pairs + diagonal else: - unordered_pairs = list(itertools.combinations(range(maximal_pool_size+1), 2)) + unordered_pairs = list(itertools.combinations(range(maximal_pool_size + 1), 2)) for idx1 in range(number_of_iterations): for idx2, pair in enumerate(unordered_pairs): - demultiplexing_scheme[int(idx1 * maximal_number_of_samples + idx2 + 1)] = pair + (idx1,) # naming of samples starts with 1 + demultiplexing_scheme[int(idx1 * maximal_number_of_samples + idx2 + 1)] = ( + pair + (idx1,) + ) # naming of samples starts with 1 return demultiplexing_scheme - def find_demultiplexing_scheme(maximal_pool_size, n_samples, robust): - maximal_number_of_samples = (maximal_pool_size * (maximal_pool_size+1))/2 + maximal_number_of_samples = (maximal_pool_size * (maximal_pool_size + 1)) / 2 if n_samples % maximal_number_of_samples != 0: - logging.error("So far, only defined for certain cohort sizes") + logging.error('So far, only defined for certain cohort sizes') raise NotImplementedError else: - logging.info("Creating multiplexing scheme") - demultiplexing_scheme = define_demultiplexing_scheme_optimal_case(maximal_number_of_samples = maximal_number_of_samples, maximal_pool_size = maximal_pool_size, n_samples = n_samples, robust = robust) - + logging.info('Creating multiplexing scheme') + demultiplexing_scheme = define_demultiplexing_scheme_optimal_case( + maximal_number_of_samples=maximal_number_of_samples, + maximal_pool_size=maximal_pool_size, + n_samples=n_samples, + robust=robust, + ) + logging.info(f'Demultiplexing scheme: {demultiplexing_scheme}') return demultiplexing_scheme @@ -46,91 +63,135 @@ def find_demultiplexing_scheme(maximal_pool_size, n_samples, robust): def multiplexing_scheme_format2pool_format(demultiplexing_scheme): pool_scheme = {} - total_no_pools_per_repetition = np.max([pool[:-1] for pool in list(demultiplexing_scheme.values())]) - total_no_of_repetitions = np.max([pool[-1] for pool in list(demultiplexing_scheme.values())]) - - pools = list(itertools.product(range(total_no_pools_per_repetition + 1), range(total_no_of_repetitions + 1))) + total_no_pools_per_repetition = np.max( + [pool[:-1] for pool in list(demultiplexing_scheme.values())] + ) + total_no_of_repetitions = np.max( + [pool[-1] for pool in list(demultiplexing_scheme.values())] + ) + + pools = list( + itertools.product( + range(total_no_pools_per_repetition + 1), range(total_no_of_repetitions + 1) + ) + ) for pool in pools: pool_scheme[pool] = [] for key, value in demultiplexing_scheme.items(): - pool_scheme[value[0],value[-1]].append(int(key)) - pool_scheme[value[1],value[-1]].append(-int(key)) + pool_scheme[value[0], value[-1]].append(int(key)) + pool_scheme[value[1], value[-1]].append(-int(key)) # The pool scheme has pool identifier as keys and split libraries as values, where the first split is encoded as a positive integer and the second split as a negative integer return pool_scheme - def select_samples_for_pooling(pool_scheme, input_dir, sample_list): if isinstance(sample_list[0], str): - for idx,sample in enumerate(sample_list): + for idx, sample in enumerate(sample_list): sample_list[idx] = Path(sample) if isinstance(input_dir, str): input_dir = Path(input_dir) logging.info('Selecting samples for pooling') pools_summary = {} for pool, samples in pool_scheme.items(): - - logging.info(f"Retrieving list of samples to be pooled for pool {pool}") - logging.debug(f"Samples: {samples}") - logging.debug(f"Sample list: {sample_list}") + logging.info(f'Retrieving list of samples to be pooled for pool {pool}') + logging.debug(f'Samples: {samples}') + logging.debug(f'Sample list: {sample_list}') loom_files = [] for sample in samples: if sample > 0: - loom_files.append((input_dir / f"{sample_list[sample-1].stem}_split1.loom").as_posix()) + loom_files.append( + (input_dir / f'{sample_list[sample-1].stem}_split1.loom').as_posix() + ) else: - loom_files.append((input_dir / f"{sample_list[-sample-1].stem}_split2.loom").as_posix()) - - logging.info(f"Done retrieving list of samples to be pooled for pool {pool}") + loom_files.append( + ( + input_dir / f'{sample_list[-sample-1].stem}_split2.loom' + ).as_posix() + ) - pools_summary[f"({pool[0]}.{pool[1]})"] = loom_files + logging.info(f'Done retrieving list of samples to be pooled for pool {pool}') + + pools_summary[f'({pool[0]}.{pool[1]})'] = loom_files return pools_summary def write_output(pools_summary, output): - with open(output, "w") as f: + with open(output, 'w') as f: yaml.dump(pools_summary, f, default_flow_style=False) def load_input_samples(input_sample_file): - with open(input_sample_file, "r") as f: + with open(input_sample_file, 'r') as f: input_samples = f.readlines() - + return input_samples def parse_args(): - parser = argparse.ArgumentParser(description="Output a multiplexing scheme under pool size constraint for a predefined number of samples") - parser.add_argument("--robust", type=bool, required=False, default = False, help="Input list of loom files") - parser.add_argument("-k", "--maximal_pool_size", type=int, required=False, help="The maximal amount of samples to be multiplexed in a pool") - parser.add_argument("--n_samples", type=int, required=True, help="The number of samples to be sequenced in the cohort") - parser.add_argument("--output", type=str, required=True, help="Where to store the output file") - parser.add_argument("--input_dir", type=str, required=True, help="Directory to the loom files") - parser.add_argument("--input_sample_file", type=str, required=True, help="Path to the file containing the list of loom files") - - logging.info("Parsing arguments") + parser = argparse.ArgumentParser( + description='Output a multiplexing scheme under pool size constraint for a predefined number of samples' + ) + parser.add_argument( + '--robust', + type=bool, + required=False, + default=False, + help='Input list of loom files', + ) + parser.add_argument( + '-k', + '--maximal_pool_size', + type=int, + required=False, + help='The maximal amount of samples to be multiplexed in a pool', + ) + parser.add_argument( + '--n_samples', + type=int, + required=True, + help='The number of samples to be sequenced in the cohort', + ) + parser.add_argument( + '--output', type=str, required=True, help='Where to store the output file' + ) + parser.add_argument( + '--input_dir', type=str, required=True, help='Directory to the loom files' + ) + parser.add_argument( + '--input_sample_file', + type=str, + required=True, + help='Path to the file containing the list of loom files', + ) + + logging.info('Parsing arguments') args = parser.parse_args() - + if args.n_samples == 0: - logging.error("Number of samples must be greater than 0") + logging.error('Number of samples must be greater than 0') raise ValueError - + return args def main(args): input_samples = load_input_samples(args.input_sample_file) n_samples = len(input_samples) - demultiplexing_scheme = find_demultiplexing_scheme(args.maximal_pool_size, n_samples, args.robust) + demultiplexing_scheme = find_demultiplexing_scheme( + args.maximal_pool_size, n_samples, args.robust + ) pool_scheme = multiplexing_scheme_format2pool_format(demultiplexing_scheme) - pools_summary = select_samples_for_pooling(pool_scheme, args.input_dir, input_samples) - logging.debug(f"Output: {pools_summary.keys()}") - logging.info(f"Writing output to {args.output}") + pools_summary = select_samples_for_pooling( + pool_scheme, args.input_dir, input_samples + ) + logging.debug(f'Output: {pools_summary.keys()}') + logging.info(f'Writing output to {args.output}') write_output(pools_summary, args.output) - logging.info("Success.") + logging.info('Success.') -if __name__ == "__main__": +if __name__ == '__main__': args = parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/benchmarking_pipeline/workflow/scripts/create_heatmap.py b/benchmarking_pipeline/workflow/scripts/create_heatmap.py deleted file mode 100644 index 15c857c..0000000 --- a/benchmarking_pipeline/workflow/scripts/create_heatmap.py +++ /dev/null @@ -1,293 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import copy -import os -import re -import tkinter as tk - -from matplotlib import pyplot as plt -import numpy as np -import pandas as pd -from scipy.spatial.distance import pdist, cityblock -from scipy.cluster.hierarchy import linkage -import seaborn as sns - -VCF_COLS = [ - 'CHROM', - 'POS', - 'ID', - 'REF', - 'ALT', - 'QUAL', - 'FILTER', - 'INFO', - 'FORMAT', - 'sample', -] -GT_MAP = {'0/0': 0, '0/1': 1, '1/1': 2, './.': np.nan} -DEF_DIST = 'Euclidean' -PID2SID = { - 'PIDasdf': 'SIDjklö' -} ### These can be adapted manually to have more control on how the sampled should be named in the output -SID2PAPER = {'SIDjklö': 'S1'} - -FILE_EXT = '.png' -FONTSIZE = 30 -DPI = 300 -sns.set_style('whitegrid') # darkgrid, whitegrid, dark, white, ticks -sns.set_context( - 'paper', - rc={ - 'lines.linewidth': 2, - 'axes.axisbelow': True, - 'font.size': FONTSIZE, - 'axes.labelsize': 'medium', - 'axes.titlesize': 'large', - 'xtick.labelsize': 'medium', - 'ytick.labelsize': 'medium', - 'legend.fontsize': 'medium', - 'legend.title_fontsize': 'large', - 'axes.labelticksize': 50, - }, -) -plt.rcParams['xtick.major.size'] = 10 -plt.rcParams['xtick.major.width'] = 3 -plt.rcParams['xtick.bottom'] = True - - -def dist_nan(u, v): - valid = ~np.isnan(u) & ~np.isnan(v) - # return euclidean(u[valid], v[valid]) / np.sqrt(np.sum(2**2 * valid.sum())) # Euclidean - # return (u[valid].round() != v[valid].round()).mean() # Hamming - return cityblock(u[valid], v[valid]) / (2 * valid.sum()) # Manhattan - - -def get_clone_profiles(in_file): - df = pd.read_csv(in_file) - df['REF'].fillna('*', inplace=True) - df['ALT'].fillna('*', inplace=True) - idx_str = ( - df['CHR'].astype(str) - + '_' - + df['POS'].astype(str) - + '_' - + df['REF'] - + '_' - + df['ALT'] - ) - idx_str = idx_str.str.replace('chr', '') - gt = df.iloc[:, 7:].map(lambda x: int(x.split(':')[-1])) - gt.replace(3, np.nan, inplace=True) - - df = pd.DataFrame(gt.mean(axis=1).values, index=idx_str, columns=['Cluster']) - return df - - -def plot_heatmap(data, out_file): - cmap = plt.get_cmap('YlOrRd', 100) - dist_row = pdist(data.values, dist_nan) - dist_col = pdist(data.values.T, dist_nan) - Z_row = linkage(np.nan_to_num(dist_row, 10), 'ward') - Z_col = linkage(np.nan_to_num(dist_col, 10), 'ward') - try: - cm = sns.clustermap( - data, - row_linkage=Z_row, - col_linkage=Z_col, - row_cluster=True, - col_cluster=True, - col_colors=None, - row_colors=None, - vmin=0, - vmax=2, - cmap=cmap, - dendrogram_ratio=(0.1, 0.1), - figsize=(25, 15), - cbar_kws={ - 'ticks': [0, 1, 2], - }, - # cbar_pos=(0, 0, 0.0, 0.0), - tree_kws=dict(linewidths=2), - ) - except tk.TclError: - import matplotlib - - matplotlib.use('Agg') - cm = sns.clustermap( - data, - row_linkage=Z_row, - col_linkage=Z_col, - row_cluster=True, - col_cluster=True, - col_colors=None, - row_colors=None, - vmin=0, - vmax=2, - cmap=cmap, - dendrogram_ratio=(0.1, 0.1), - figsize=(25, 15), - cbar_kws={ - 'ticks': [0, 1, 2], - }, - # cbar_pos=(0, 0, 0.0, 0.0), - tree_kws=dict(linewidths=2), - ) - - cm.ax_heatmap.set_facecolor('#EAEAEA') - cm.ax_heatmap.set_ylabel('\nProfiles') - cm.ax_heatmap.set_xlabel('SNPs') - labels_pretty = [ - 'chr{}:{} {}>{}'.format(*i.get_text().split('_')) - for i in cm.ax_heatmap.get_xticklabels() - ] - cm.ax_heatmap.set_xticklabels(labels_pretty, rotation=45, ha='right', va='top') - - cm.ax_cbar.set_title('Genotype') - cm.ax_cbar.set_yticklabels(['0|0', '0|1', '1|1']) - - cm.fig.tight_layout() - if out_file: - cm.fig.savefig(out_file, dpi=DPI) - else: - plt.show() - - -def plot_distance(df, out_file): - col_no = 1 - row_no = 1 - fig, ax = plt.subplots( - nrows=row_no, ncols=col_no, figsize=(col_no * 12, row_no * 12) - ) - - sns.set(font_scale=2.5) - font_size = 30 - - df.sort_index(ascending=False, inplace=True) - df = df[sorted(df.columns)] - - df.columns = [i.split('_')[0] for i in df.columns] - df.index = [i.split('_')[0] for i in df.index] - - sns.heatmap( - df, - annot=True, - square=True, - cmap='viridis_r', - cbar_kws={'shrink': 0.5, 'label': 'Distance'}, - linewidths=0, - ax=ax, - ) - ax.set_xticklabels(ax.get_xmajorticklabels(), fontsize=font_size) - ax.set_yticklabels(ax.get_ymajorticklabels(), fontsize=font_size) - ax.set_xlabel(df.index.name, fontsize=font_size + 10) - ax.set_ylabel(df.columns.name, fontsize=font_size + 10) - - fig.tight_layout() - if out_file: - fig.savefig(out_file, dpi=300) - else: - plt.show() - - -def assign_clusters(df_in, assignment_output): - df = copy.deepcopy(df_in) - # import pdb;pdb.set_trace() - for i in range(df_in.shape[0]): - cl_idx, other_idx = np.where(df == df.min().min()) - cl = df.index[cl_idx[0]] - other = df.columns[other_idx[0]] - print(f'Assigning: {cl} -> {other}') - with open(assignment_output, 'a') as file: - file.write(f'{cl}: {other}\n') - - df.drop(cl, axis=0, inplace=True) - df.drop(other, axis=1, inplace=True) - - -def main(args): - if not args.outfile: - out_base = os.path.splitext(args.input[0])[0] - out_end = FILE_EXT - else: - out_base, out_end = os.path.splitext(args.outfile) - - names = [] - for i, in_file in enumerate(args.input): - - cluster = re.compile(r'([0-9]*)_variants.csv').search(in_file).group(1) - name = f'{cluster}_DNA' - - names.append(name) - df_new = get_clone_profiles(in_file) - df_new.rename({'Cluster': name}, inplace=True, axis=1) - - match_df = df_new.merge(SNP_df, left_index=True, right_index=True, how='inner') - - if match_df.size == 0: - print('No overlap in RNA and DNA data.') - exit() - try: - df = df.merge( - match_df[name], left_index=True, right_index=True, how='outer' - ) - except NameError: - df = match_df - - df = df[sorted(df.columns)] - # remove SNPs that are similar in all DNA clusters - df = df[(df[names] > 1.95).sum(axis=1) != len(names)] - df = df[(df[names] < 0.05).sum(axis=1) != len(names)] - df = df[ - ((df[names] > 0.95).sum(axis=1) != len(names)) - | ((df[names] < 1.05).sum(axis=1) != len(names)) - ] - # Remove SNPS not called in the scDNA-seq - df = df[(df[names] >= 0).sum(axis=1) > 0] - - dists = [] - for name in names: - dists.append( - np.apply_along_axis( - dist_nan, 1, df[SNP_df.columns].values.T, df[name].values.T - ) - ) - - dist_df = pd.DataFrame(dists, index=names, columns=SNP_df.columns) - - assign_clusters(dist_df, args.assignment_output) - print(f'\n{dist_df}\n') - - dist_df.index.name = 'scDNA-seq' - dist_df.columns.name = label - - dist_out = f'{out_base}_distance{out_end}' - print(f'Generating plot: {dist_out}') - plot_distance(dist_df, dist_out) - - hm_out = f'{out_base}_heatmap{out_end}' - print(f'Generating plot: {hm_out}') - plot_heatmap(df.T, hm_out) - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - '-i', '--input', type=str, nargs='+', help='mutation_profiles of the demultiplexed samples' - ) - parser.add_argument( - '-o', '--outfile', type=str, default='', help='Output file for heatmap.' - ) - parser.add_argument( - '-ao', - '--assignment_output', - type=str, - help='Output yaml-file in which the automated sample matching is saved.', - ) - - return parser.parse_args() - - -if __name__ == '__main__': - args = parse_args() - main(args) \ No newline at end of file diff --git a/benchmarking_pipeline/workflow/scripts/demo_tape.py b/benchmarking_pipeline/workflow/scripts/demo_tape.py index 7b83023..e051d29 100644 --- a/benchmarking_pipeline/workflow/scripts/demo_tape.py +++ b/benchmarking_pipeline/workflow/scripts/demo_tape.py @@ -17,20 +17,57 @@ EPSILON = np.finfo(np.float64).resolution -COLORS = ['#e41a1c', '#377eb8', '#4daf4a', '#E4761A' , '#2FBF85'] # '#A4CA55' -COLORS = ['#a6cee3', '#1f78b4', '#b2df8a', '#33a02c', '#fb9a99', '#e31a1c', - '#fdbf6f', '#ff7f00', '#cab2d6', '#6a3d9a', '#ffff99', '#b15928'] - -COLORS = ['#e6194b', '#3cb44b', '#ffe119', '#4363d8', '#f58231', '#911eb4', '#46f0f0', '#f032e6', '#bcf60c', '#fabebe', '#008080', '#e6beff', '#9a6324', '#fffac8', '#800000', '#aaffc3', '#808000', '#ffd8b1', '#000075', '#808080', '#ffffff', '#000000'] +COLORS = ['#e41a1c', '#377eb8', '#4daf4a', '#E4761A', '#2FBF85'] # '#A4CA55' +COLORS = [ + '#a6cee3', + '#1f78b4', + '#b2df8a', + '#33a02c', + '#fb9a99', + '#e31a1c', + '#fdbf6f', + '#ff7f00', + '#cab2d6', + '#6a3d9a', + '#ffff99', + '#b15928', +] + +COLORS = [ + '#e6194b', + '#3cb44b', + '#ffe119', + '#4363d8', + '#f58231', + '#911eb4', + '#46f0f0', + '#f032e6', + '#bcf60c', + '#fabebe', + '#008080', + '#e6beff', + '#9a6324', + '#fffac8', + '#800000', + '#aaffc3', + '#808000', + '#ffd8b1', + '#000075', + '#808080', + '#ffffff', + '#000000', +] COLORS_STR = ['red', 'blue', 'green', 'orange', 'mint'] FILE_EXT = 'png' FONTSIZE = 30 DPI = 300 -sns.set_style('whitegrid') #darkgrid, whitegrid, dark, white, ticks -sns.set_context('paper', - rc={'lines.linewidth': 2, +sns.set_style('whitegrid') # darkgrid, whitegrid, dark, white, ticks +sns.set_context( + 'paper', + rc={ + 'lines.linewidth': 2, 'axes.axisbelow': True, 'font.size': FONTSIZE, 'axes.labelsize': 'medium', @@ -39,8 +76,9 @@ 'ytick.labelsize': 'medium', 'legend.fontsize': 'medium', 'legend.title_fontsize': 'large', - 'axes.labelticksize': 50 -}) + 'axes.labelticksize': 50, + }, +) plt.rcParams['xtick.major.size'] = 5 plt.rcParams['xtick.major.width'] = 1.5 plt.rcParams['xtick.bottom'] = True @@ -54,8 +92,8 @@ def __init__(self, in_file, cl_no): # Get full data df = pd.read_csv(in_file, index_col=[0, 1], dtype={'CHR': str}) self.SNPs = df.apply( - lambda x: f'chr{x.name[0]}:{x.name[1]} {x["REF"]}>{x["ALT"]}', - axis=1) + lambda x: f'chr{x.name[0]}:{x.name[1]} {x["REF"]}>{x["ALT"]}', axis=1 + ) df.drop(['REF', 'ALT', 'REGION', 'NAME', 'FREQ'], axis=1, inplace=True) self.cells = df.columns.values @@ -78,11 +116,13 @@ def __init__(self, in_file, cl_no): self.ref = df.applymap(lambda x: int(x.split(':')[0])).values.T self.alt = df.applymap(lambda x: int(x.split(':')[1])).values.T self.dp = self.ref + self.alt - self.VAF = np.clip(np.where(self.dp > 0, self.alt / self.dp, np.nan), - EPSILON, 1 - EPSILON) + self.VAF = np.clip( + np.where(self.dp > 0, self.alt / self.dp, np.nan), EPSILON, 1 - EPSILON + ) self.RAF = 1 - self.VAF - self.norm_const = np.arange(self.dp.max() * 2 + 1) \ - * np.log(np.arange(self.dp.max() * 2 + 1)) + self.norm_const = np.arange(self.dp.max() * 2 + 1) * np.log( + np.arange(self.dp.max() * 2 + 1) + ) self.reads = np.hstack([self.ref, self.alt, self.dp]) self.metric = self.dist_reads @@ -95,25 +135,25 @@ def __init__(self, in_file, cl_no): self.dbt_ids = np.array([]) self.dbt_map = {} - def __str__(self): - out_str = '\ndemoTape:\n' \ - f'\tFile: {self.in_file}\n' \ - f'\t# Samples: {self.cl_no}\n' \ - f'\tCells: {self.cells}:\n' \ - f'\tSNPs: {self.SNPs}\n' \ + out_str = ( + '\ndemoTape:\n' + f'\tFile: {self.in_file}\n' + f'\t# Samples: {self.cl_no}\n' + f'\tCells: {self.cells}:\n' + f'\tSNPs: {self.SNPs}\n' f'\tDistance metric: reads' + ) return out_str - @staticmethod def dist_reads(c1, c2): r1 = np.reshape(c1, (3, -1)) r2 = np.reshape(c2, (3, -1)) valid = (r1[2] > 0) & (r2[2] > 0) - r1 = r1[:,valid] - r2 = r2[:,valid] + r1 = r1[:, valid] + r2 = r2[:, valid] p1 = np.clip(r1[0] / r1[2], EPSILON, 1 - EPSILON) p2 = np.clip(r2[0] / r2[2], EPSILON, 1 - EPSILON) @@ -121,49 +161,59 @@ def dist_reads(c1, c2): p12 = np.clip((r1[0] + r2[0]) / (dp_total), EPSILON, 1 - EPSILON) p12_inv = 1 - p12 - logl = r1[0] * np.log(p1 / p12) + r1[1] * np.log((1 - p1) / p12_inv) \ - + r2[0] * np.log(p2 / p12) + r2[1] * np.log((1 - p2) / p12_inv) + logl = ( + r1[0] * np.log(p1 / p12) + + r1[1] * np.log((1 - p1) / p12_inv) + + r2[0] * np.log(p2 / p12) + + r2[1] * np.log((1 - p2) / p12_inv) + ) - norm = np.log(dp_total) * (dp_total) \ - - r1[2] * np.log(r1[2]) - r2[2] * np.log(r2[2]) + norm = ( + np.log(dp_total) * (dp_total) + - r1[2] * np.log(r1[2]) + - r2[2] * np.log(r2[2]) + ) return np.sum(logl / norm) / valid.sum() - def get_pairwise_dists(self): dist = [] for i in np.arange(self.cells.size - 1): - valid = (self.dp[i] > 0) & (self.dp[i+1:] > 0) - dp_total = self.dp[i] + self.dp[i+1:] - p12 = np.clip((self.alt[i] + self.alt[i+1:]) / dp_total, EPSILON, 1 - EPSILON) + valid = (self.dp[i] > 0) & (self.dp[i + 1 :] > 0) + dp_total = self.dp[i] + self.dp[i + 1 :] + p12 = np.clip( + (self.alt[i] + self.alt[i + 1 :]) / dp_total, EPSILON, 1 - EPSILON + ) p12_inv = 1 - p12 - logl = self.alt[i] * np.log(self.VAF[i] / p12) \ - + self.ref[i] * np.log(self.RAF[i] / p12_inv) \ - + self.alt[i+1:] * np.log(self.VAF[i+1:] / p12) \ - + self.ref[i+1:] * np.log(self.RAF[i+1:] / p12_inv) + logl = ( + self.alt[i] * np.log(self.VAF[i] / p12) + + self.ref[i] * np.log(self.RAF[i] / p12_inv) + + self.alt[i + 1 :] * np.log(self.VAF[i + 1 :] / p12) + + self.ref[i + 1 :] * np.log(self.RAF[i + 1 :] / p12_inv) + ) - norm = self.norm_const[dp_total] \ - - self.norm_const[self.dp[i]] - self.norm_const[self.dp[i+1:]] + norm = ( + self.norm_const[dp_total] + - self.norm_const[self.dp[i]] + - self.norm_const[self.dp[i + 1 :]] + ) dist.append(np.nansum(logl / norm, axis=1) / valid.sum(axis=1)) return np.concatenate(dist) - def demultiplex(self): self.init_dendrogram() self.identify_doublets() self.merge_surplus_singlets() - -# -------------------------------- DENDROGRAM ---------------------------------- + # -------------------------------- DENDROGRAM ---------------------------------- def init_dendrogram(self): dist = self.get_pairwise_dists() self.Z = linkage(dist, method='ward') - -# ---------------------------- IDENTIFY DOUBLETS ------------------------------- + # ---------------------------- IDENTIFY DOUBLETS ------------------------------- def identify_doublets(self): cl_no = self.sgt_cl_no + self.dbt_cl_no @@ -175,14 +225,13 @@ def identify_doublets(self): if self.dbt_ids.size == self.dbt_cl_no: break elif cl_no == (self.sgt_cl_no + self.dbt_cl_no) * 3: - print(f'Could not identify all doublets.') + print('Could not identify all doublets.') self.set_assigment(self.sgt_cl_no + self.dbt_cl_no) self.set_dbt_ids() break cl_no += 1 print(f'Increasing cuttree clusters to {cl_no}') - def set_assigment(self, cl_no): self.assignment = cut_tree(self.Z, n_clusters=cl_no).flatten() clusters = np.unique(self.assignment) @@ -190,19 +239,16 @@ def set_assigment(self, cl_no): for cl_id, cl in enumerate(clusters): self.profiles[cl_id] = self.get_profile(cl) - def get_profile(self, cl): cells = np.isin(self.assignment, cl) - p = np.average(np.nan_to_num(self.VAF[cells]), weights=self.dp[cells], - axis=0) + p = np.average(np.nan_to_num(self.VAF[cells]), weights=self.dp[cells], axis=0) dp = np.mean(self.dp[cells], axis=0).round() alt = (dp * p).round() ref = dp - alt return np.hstack([alt, ref, dp]) - def set_dbt_ids(self): - cl_size = np.unique(self.assignment, return_counts=True)[1] / self.cells.size + cl_size = np.unique(self.assignment, return_counts=True)[1] / self.cells.size cl_no = cl_size.size dbt_map = {} @@ -217,8 +263,9 @@ def set_dbt_ids(self): df_dbt = pd.DataFrame([], columns=dbt_combs, index=range(cl_no)) for i in range(cl_no): - df_dbt.loc[i] = np.apply_along_axis(self.metric, 1, dbt_profiles, - self.profiles[i]) + df_dbt.loc[i] = np.apply_along_axis( + self.metric, 1, dbt_profiles, self.profiles[i] + ) for j in dbt_combs: # set doublet combo dist to np.nan if: # 1. Singlet is included in Dbt cluster @@ -254,26 +301,28 @@ def set_dbt_ids(self): df_dbt.drop(rm_rows, axis=0, inplace=True, errors='ignore') # Remove doublet rows including cluster rm_cols = [dbt_id] + [i for i in dbt_combs if cl_id in i] - df_dbt.drop(rm_cols, axis=1, inplace=True, errors='ignore') # Ignore error on already dropped labels + df_dbt.drop( + rm_cols, axis=1, inplace=True, errors='ignore' + ) # Ignore error on already dropped labels - self.sgt_ids = np.array([i for i in np.unique(self.assignment) \ - if not i in dbt_ids]) + self.sgt_ids = np.array( + [i for i in np.unique(self.assignment) if i not in dbt_ids] + ) self.dbt_ids = np.array(dbt_ids) self.dbt_map = dbt_map # self.plot_heatmap() # import pdb; pdb.set_trace() - def check_hom_match(self, scl, dcl): - rs = np.reshape(self.profiles[scl], (3, -1)) + rs = np.reshape(self.profiles[scl], (3, -1)) rd1 = np.reshape(self.profiles[dcl[0]], (3, -1)) rd2 = np.reshape(self.profiles[dcl[1]], (3, -1)) valid = (rs[2] > 0) & (rd1[2] > 0) & (rd2[2] > 0) - rs = rs[:,valid] - rd1 = rd1[:,valid] - rd2 = rd2[:,valid] + rs = rs[:, valid] + rd1 = rd1[:, valid] + rd2 = rd2[:, valid] ps = rs[0] / rs[2] pd1 = rd1[0] / rd1[2] @@ -318,7 +367,6 @@ def check_hom_match(self, scl, dcl): return (~hom_match).sum() - def get_cl_map(self): if self.sgt_ids.size == 0: cl_map = {i: str(i) for i in np.unique(self.assignment)} @@ -334,8 +382,7 @@ def get_cl_map(self): assignment_str[i] = cl_map[j] return cl_map, assignment_str - -# ------------------------- MERGE SURPLUS SINGLETS ----------------------------- + # ------------------------- MERGE SURPLUS SINGLETS ----------------------------- def merge_surplus_singlets(self): # Merge surplus single clusters @@ -391,8 +438,9 @@ def merge_surplus_singlets(self): self.assignment[self.assignment == rem_dbt_cl] = cl1 self.assignment[self.assignment > rem_dbt_cl] -= 1 - self.dbt_ids = np.delete(self.dbt_ids, - np.argwhere(self.dbt_ids == rem_dbt_cl)) + self.dbt_ids = np.delete( + self.dbt_ids, np.argwhere(self.dbt_ids == rem_dbt_cl) + ) self.sgt_ids[self.sgt_ids > rem_dbt_cl] -= 1 self.dbt_ids[self.dbt_ids > rem_dbt_cl] -= 1 @@ -400,13 +448,17 @@ def merge_surplus_singlets(self): if len(self.dbt_map.values()) != len(set(self.dbt_map.values())): print('Removing 1 equal doublet cluster') - eq_dbt = [i for i,j in self.dbt_map.items() \ - if list(self.dbt_map.values()).count(j) > 1] + eq_dbt = [ + i + for i, j in self.dbt_map.items() + if list(self.dbt_map.values()).count(j) > 1 + ] self.assignment[self.assignment == eq_dbt[1]] = eq_dbt[0] self.assignment[self.assignment > eq_dbt[1]] -= 1 - self.dbt_ids = np.delete(self.dbt_ids, - np.argwhere(self.dbt_ids == eq_dbt[1])) + self.dbt_ids = np.delete( + self.dbt_ids, np.argwhere(self.dbt_ids == eq_dbt[1]) + ) self.sgt_ids[self.sgt_ids > eq_dbt[1]] -= 1 self.dbt_ids[self.dbt_ids > eq_dbt[1]] -= 1 @@ -418,16 +470,15 @@ def merge_surplus_singlets(self): # Update profile self.profiles[cl1] = self.get_profile(cl1) - def get_sgt_dist_matrix(self): mat = np.zeros((self.sgt_ids.size, self.sgt_ids.size)) for i, j in enumerate(self.sgt_ids): - mat[i] = np.apply_along_axis(self.metric, 1, - self.profiles[self.sgt_ids], self.profiles[j]) + mat[i] = np.apply_along_axis( + self.metric, 1, self.profiles[self.sgt_ids], self.profiles[j] + ) mat[np.tril_indices(self.sgt_ids.size)] = np.nan return mat - def update_dbt_map(self, new_cl, del_cl): dbt_map_new = {} consistent = True @@ -454,42 +505,39 @@ def update_dbt_map(self, new_cl, del_cl): return dbt_map_new, consistent - # ------------------------ HEATMAP GENERATION ------------------------------ @staticmethod def get_cmap(): # Edit this gradient at https://eltos.github.io/gradient/ - cmap = LinearSegmentedColormap.from_list('my_gradient', ( - (0.000, (1.000, 1.000, 1.000)), - (0.167, (1.000, 0.882, 0.710)), - (0.333, (1.000, 0.812, 0.525)), - (0.500, (1.000, 0.616, 0.000)), - (0.667, (1.000, 0.765, 0.518)), - (0.833, (1.000, 0.525, 0.494)), - (1.000, (1.000, 0.000, 0.000))) + cmap = LinearSegmentedColormap.from_list( + 'my_gradient', + ( + (0.000, (1.000, 1.000, 1.000)), + (0.167, (1.000, 0.882, 0.710)), + (0.333, (1.000, 0.812, 0.525)), + (0.500, (1.000, 0.616, 0.000)), + (0.667, (1.000, 0.765, 0.518)), + (0.833, (1.000, 0.525, 0.494)), + (1.000, (1.000, 0.000, 0.000)), + ), ) return cmap - def get_hm_data(self): - df = np.nan_to_num(self.VAF, nan=-1) + df = np.nan_to_num(self.VAF, nan=-1) mask = np.zeros(self.VAF.shape, dtype=bool) mask[self.dp == 0] = True return df, mask - @staticmethod def get_hm_specifics(): - return {'vmin': 0, 'vmax': 1, - 'cbar_kws': {'ticks': [0, 1], 'label': 'VAF'}} - + return {'vmin': 0, 'vmax': 1, 'cbar_kws': {'ticks': [0, 1], 'label': 'VAF'}} @staticmethod def apply_cm_specifics(cm): pass - def plot_heatmap(self, out_file='', cluster=True): cmap = self.get_cmap() @@ -536,10 +584,11 @@ def get_row_cols(assignment): cmap=cmap, figsize=(25, 10), xticklabels=self.SNPs, - cbar_kws={'ticks': [0, 1], 'label': 'VAF'} + cbar_kws={'ticks': [0, 1], 'label': 'VAF'}, ) except tk.TclError: import matplotlib + matplotlib.use('Agg') cm = clustermap( df_plot, @@ -547,12 +596,13 @@ def get_row_cols(assignment): row_cluster=cluster, col_cluster=cluster, mask=mask, - vmin=0, vmax=1, + vmin=0, + vmax=1, row_colors=r_colors, cmap=cmap, figsize=(25, 10), xticklabels=self.SNPs, - cbar_kws={'ticks': [0, 1], 'label': 'VAF'} + cbar_kws={'ticks': [0, 1], 'label': 'VAF'}, ) cm.ax_heatmap.set_facecolor('#5B566C') @@ -561,8 +611,13 @@ def get_row_cols(assignment): cm.ax_heatmap.set_yticks([]) cm.ax_heatmap.set_xlabel('SNPs') - cm.ax_heatmap.set_xticklabels(cm.ax_heatmap.get_xticklabels(), - rotation=45, fontsize=5, ha='right', va='top') + cm.ax_heatmap.set_xticklabels( + cm.ax_heatmap.get_xticklabels(), + rotation=45, + fontsize=5, + ha='right', + va='top', + ) cm.ax_col_dendrogram.set_visible(False) self.apply_cm_specifics(cm) @@ -574,7 +629,6 @@ def get_row_cols(assignment): else: plt.show() - # -------------------------- GENERATE OUTPUT ------------------------------- def safe_results(self, output): @@ -590,11 +644,11 @@ def safe_results(self, output): # Safe SNV profiles to identy patients VAF_df = self.get_VAF_profile(self.sgt_ids) - VAF_df.index = [f'{cl_map[i]} ({COLORS_STR[int(cl_map[i])]})' \ - for i in self.sgt_ids] + VAF_df.index = [ + f'{cl_map[i]} ({COLORS_STR[int(cl_map[i])]})' for i in self.sgt_ids + ] VAF_df.round(2).to_csv(f'{output}.profiles.tsv', sep='\t') - def get_VAF_profile(self, cl): reads_all = np.reshape(self.profiles[cl], (cl.size, 3, self.SNPs.size)) VAF_raw = [] @@ -602,7 +656,6 @@ def get_VAF_profile(self, cl): VAF_raw.append(reads_cl[0] / reads_cl[2]) return pd.DataFrame(VAF_raw, index=cl, columns=self.SNPs) - def print_summary(self, cl_map): GTs = {'WT': (0, 0.35), 'HET': (0.35, 0.95), 'HOM': (0.95, 1)} for cl_id, cl_size in zip(*np.unique(self.assignment, return_counts=True)): @@ -612,8 +665,10 @@ def print_summary(self, cl_map): cl_color = f'{COLORS_STR[cl1]}+{COLORS_STR[cl2]}' else: cl_color = COLORS_STR[int(cl_name)] - print(f'Cluster {cl_name} ({cl_color}): {cl_size: >4} cells ' \ - f'({cl_size / self.cells.size * 100: >2.0f}%)') + print( + f'Cluster {cl_name} ({cl_color}): {cl_size: >4} cells ' + f'({cl_size / self.cells.size * 100: >2.0f}%)' + ) VAF_cl = self.VAF[self.assignment == cl_id] VAF_cl_called = (VAF_cl >= EPSILON).sum(axis=0) @@ -651,22 +706,37 @@ def main(args): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('-i', '--input', type=str, nargs='+', - help='Input _variants.csv file(s).') - parser.add_argument('-o', '--output', type=str, default='', - help='Output base name: <DIR>/<> .') - parser.add_argument('-n', '--clusters', type=int, default=1, - help='Number of clusters to define. Default = 1.') + parser.add_argument( + '-i', '--input', type=str, nargs='+', help='Input _variants.csv file(s).' + ) + parser.add_argument( + '-o', '--output', type=str, default='', help='Output base name: <DIR>/<> .' + ) + parser.add_argument( + '-n', + '--clusters', + type=int, + default=1, + help='Number of clusters to define. Default = 1.', + ) plotting = parser.add_argument_group('plotting') - plotting.add_argument('-op', '--output_plot', action='store_true', - help='Output file for heatmap with dendrogram to "<INPUT>.hm.png".') - plotting.add_argument('-sp', '--show_plot', action='store_true', - help='Show heatmap with dendrogram at stdout.') + plotting.add_argument( + '-op', + '--output_plot', + action='store_true', + help='Output file for heatmap with dendrogram to "<INPUT>.hm.png".', + ) + plotting.add_argument( + '-sp', + '--show_plot', + action='store_true', + help='Show heatmap with dendrogram at stdout.', + ) return parser.parse_args() if __name__ == '__main__': args = parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/benchmarking_pipeline/workflow/scripts/demultiplex_distance.py b/benchmarking_pipeline/workflow/scripts/demultiplex_distance.py index d4fcd27..6b45ed7 100644 --- a/benchmarking_pipeline/workflow/scripts/demultiplex_distance.py +++ b/benchmarking_pipeline/workflow/scripts/demultiplex_distance.py @@ -3,7 +3,6 @@ import argparse from itertools import combinations import os -import re import numpy as np from matplotlib import pyplot as plt @@ -104,30 +103,30 @@ def __init__(self, in_file, cl_no, distance='manhattan'): partial_order = {} for cell in df.columns: if '+' in cell: - s1 = cell.split('+')[0].split("_")[1] - s2 = cell.split('+')[1].split("_")[1] + s1 = cell.split('+')[0].split('_')[1] + s2 = cell.split('+')[1].split('_')[1] # Create a dictionary to store partial orders if s2 in partial_order.keys(): - if s1 in partial_order[s2]: ## This means s2<s1 - s2,s1 = s1,s2 - elif not s1 in partial_order.keys(): - partial_order[s1] = [s2] - elif not s2 in partial_order[s1]: + if s1 in partial_order[s2]: ## This means s2<s1 + s2, s1 = s1, s2 + elif s1 not in partial_order.keys(): + partial_order[s1] = [s2] + elif s2 not in partial_order[s1]: partial_order[s1].append(s2) true_cl.append('+'.join([s1, s2])) else: - s1 = cell.split("_")[1] + s1 = cell.split('_')[1] true_cl.append(s1) - #if 'pat' in cell: + # if 'pat' in cell: # s1 = int(re.split(r'[\.,]', cell.split('+')[0])[-1][3:]) - #else: + # else: # s1 = 0 - #if '+' in cell: + # if '+' in cell: # s2 = int(re.split(r'[\.,]', cell.split('+')[1])[-1][3:]) # true_cl.append('+'.join([str(j) for j in sorted([s1, s2])])) - #else: + # else: # true_cl.append(s1) self.true_cl = np.array(true_cl) @@ -169,13 +168,11 @@ def check_hom_match(self, scl, dcl): def demultiplex(self): self.init_dendrogram() - #self.identify_doublets() + # self.identify_doublets() self.set_assignment(self.sgt_cl_no) - self.sgt_ids = np.array( - [i for i in np.unique(self.assignment)] - ) + self.sgt_ids = np.array([i for i in np.unique(self.assignment)]) self.delete_garbage_cluster() - #self.merge_surplus_singlets() + # self.merge_surplus_singlets() # -------------------------------- DENDROGRAM ---------------------------------- @@ -970,7 +967,7 @@ def main(args): for in_file in in_files: if args.metric == 'reads': - dt = demoTape_reads(in_file, args.clusters+1) + dt = demoTape_reads(in_file, args.clusters + 1) else: dt = demoTape_gt(in_file, args.clusters, args.metric) @@ -1028,4 +1025,4 @@ def parse_args(): if __name__ == '__main__': args = parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/benchmarking_pipeline/workflow/scripts/get_input_samples.py b/benchmarking_pipeline/workflow/scripts/get_input_samples.py index be79f2b..5fa89a5 100644 --- a/benchmarking_pipeline/workflow/scripts/get_input_samples.py +++ b/benchmarking_pipeline/workflow/scripts/get_input_samples.py @@ -3,20 +3,44 @@ from pathlib import Path import logging -logging.basicConfig(format="{asctime} - {levelname} - {message}", style="{", datefmt="%Y-%m-%d %H:%M",level=logging.DEBUG) +logging.basicConfig( + format='{asctime} - {levelname} - {message}', + style='{', + datefmt='%Y-%m-%d %H:%M', + level=logging.DEBUG, +) + def generate_input_samples(loom_files, number_of_samples): return random.sample(loom_files, number_of_samples) - - def parse_args(): - parser = argparse.ArgumentParser(description="Get a random list of samples to be pooled") - parser.add_argument("--loom_file_path", type=str, help="Path to a folder containing loom files") - parser.add_argument("--output", type=str, default=None, help="path to a directory where the sample list will be stored") - parser.add_argument("--seed", type=int, required=True, default=None, help="random seed for reproducibility of sampling") - parser.add_argument("--number_of_samples", type=int, default=1, help="number of samples to be selected") + parser = argparse.ArgumentParser( + description='Get a random list of samples to be pooled' + ) + parser.add_argument( + '--loom_file_path', type=str, help='Path to a folder containing loom files' + ) + parser.add_argument( + '--output', + type=str, + default=None, + help='path to a directory where the sample list will be stored', + ) + parser.add_argument( + '--seed', + type=int, + required=True, + default=None, + help='random seed for reproducibility of sampling', + ) + parser.add_argument( + '--number_of_samples', + type=int, + default=1, + help='number of samples to be selected', + ) return parser.parse_args() @@ -24,23 +48,24 @@ def parse_args(): def main(args): loom_files = [] for p in Path(args.loom_file_path).iterdir(): - if p.is_file() and p.suffix == ".loom": + if p.is_file() and p.suffix == '.loom': loom_files.append(p) - + number_of_samples = args.number_of_samples if args.number_of_samples > len(loom_files): - logging.error(f"Number of samples ({number_of_samples}) is greater than the number of loom files ({len(loom_files)}). Continuing with maximal number of samples.") + logging.error( + f'Number of samples ({args.number_of_samples}) is greater than the number of loom files ({len(loom_files)}). Continuing with maximal number of samples.' + ) input_samples = loom_files number_of_samples = len(loom_files) - else: - random.seed(args.seed) - input_samples = random.sample(loom_files, args.number_of_samples) - - with open(args.output, "w") as f: + random.seed(args.seed) + input_samples = random.sample(loom_files, number_of_samples) + + with open(args.output, 'w') as f: for sample in input_samples: - f.write(f"{sample}\n") + f.write(f'{sample}\n') -if __name__ == "__main__": +if __name__ == '__main__': args = parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/benchmarking_pipeline/workflow/scripts/mosaic_processing.py b/benchmarking_pipeline/workflow/scripts/mosaic_processing.py index 9db9a73..afd6d3c 100644 --- a/benchmarking_pipeline/workflow/scripts/mosaic_processing.py +++ b/benchmarking_pipeline/workflow/scripts/mosaic_processing.py @@ -1,4 +1,4 @@ -#Authors: +# Authors: # 1. Nico Borgsmüller # 2. Johannes Gawron @@ -19,7 +19,12 @@ import pandas as pd -logging.basicConfig(format="{asctime} - {levelname} - {message}", style="{", datefmt="%Y-%m-%d %H:%M",level=logging.DEBUG) +logging.basicConfig( + format='{asctime} - {levelname} - {message}', + style='{', + datefmt='%Y-%m-%d %H:%M', + level=logging.DEBUG, +) EPSILON = np.finfo(np.float64).resolution @@ -99,7 +104,7 @@ def main(args): logging.warning( "Computation of the VAFs based on all cells' genotype call at a position. There is no correction of the VAFs for tumor purity and ploidy." ) - + """ logging.error( 'Unexpected behaviour in the computation of variant allele frequencies. Values may exceed 1.' @@ -395,15 +400,15 @@ def merge_gt(gt_in): return 1 -def subsample_cells(no_cells, ratios, total_cell_count = np.inf): +def subsample_cells(no_cells, ratios, total_cell_count=np.inf): if not isinstance(ratios, np.ndarray): - raise TypeError("Ratios need to be a numpy array!") + raise TypeError('Ratios need to be a numpy array!') smallest_sample = np.argmin(no_cells) if total_cell_count < np.inf: samples_total = (args.cell_no * ratios).astype(int) else: - samples_total = (no_cells.min() * ratios/ratios[smallest_sample]).astype(int) - for i,size in enumerate(samples_total): + samples_total = (no_cells.min() * ratios / ratios[smallest_sample]).astype(int) + for i, size in enumerate(samples_total): samples_total[i] = min(size, no_cells[i]) return samples_total @@ -415,7 +420,7 @@ def multiplex_looms(args): 'No. input files has to be the same as no. ratios. ' f'({no_samples} != {len(args.ratio)})' ) - #assert np.sum(args.ratio) <= 1, 'Ratios cannot sum up to >1.' + # assert np.sum(args.ratio) <= 1, 'Ratios cannot sum up to >1.' no_cells = np.zeros(no_samples, dtype=int) with tempfile.TemporaryDirectory() as temp_dir: @@ -424,10 +429,10 @@ def multiplex_looms(args): shutil.copy2(in_file, temp_in_file) with loompy.connect(temp_in_file) as ds: no_cells[i] = ds.shape[1] - - ##To obtain an unbiased cell count and doublet rate, we initially start with the number of cells and add (number of cells) * doublet_rate many samples. Later, pairs of cells will be merged to a doublet, which will reduce the total cell count again accordingly. + + ##To obtain an unbiased cell count and doublet rate, we initially start with the number of cells and add (number of cells) * doublet_rate many samples. Later, pairs of cells will be merged to a doublet, which will reduce the total cell count again accordingly. samples_size = subsample_cells(no_cells, np.array(args.ratio), args.cell_no) - + samples = {} for i, size in enumerate(samples_size): size = min(size, no_cells[i]) @@ -439,10 +444,10 @@ def multiplex_looms(args): print(f'Reading: {in_file}') # Open loom file and read data with loompy.connect(temp_in_file) as ds: - index = concat_str_arrays( + index = concat_str_arrays( [ds.ra['CHROM'], ds.ra['POS'], ds.ra['REF'], ds.ra['ALT']] ) - + cols = samples[i]['idx'] + ((i + 1) / 10) df_new = pd.DataFrame( ds[:, samples[i]['idx']], index=index, columns=cols @@ -480,7 +485,7 @@ def multiplex_looms(args): AD_new = AD_new.join(pd.DataFrame(view.layers['AD'][:, :], index=index, columns=cols)) RO_new = RO_new.join(pd.DataFrame(view.layers['RO'][:, :], index=index, columns=cols)) """ - + try: barcodes = np.char.add( ds.col_attrs['barcode'][samples[i]['idx']], f'.pat{i:.0f}' @@ -489,22 +494,25 @@ def multiplex_looms(args): samples[i]['name'] = barcodes.astype(f'<U{2*barcode_length+1}') except (AttributeError, TypeError) as error: logging.error(error) - logging.error("No barcode information found in loom file. Continuing with generic cell names.") + logging.error( + 'No barcode information found in loom file. Continuing with generic cell names.' + ) in_file_stripped = str(Path(in_file).stem) match = re.search(r'_split([12])$', in_file_stripped) in_file_stripped = re.sub(r'_split[12]$', '', in_file_stripped) if not match: - raise ValueError("The file names must end with '_split1' or '_split2'.") + raise ValueError( + "The file names must end with '_split1' or '_split2'." + ) split_number = match.group(1) matched_pattern = f'split{split_number}' - base_name = f"_{in_file_stripped}.{matched_pattern}" + base_name = f'_{in_file_stripped}.{matched_pattern}' counts = np.arange(samples[i]['idx'].size) + 1 counts_string = np.char.mod('%d', counts) samples[i]['name'] = np.char.add(counts_string, base_name) - - # First sample, nothing to merge + # First sample, nothing to merge if i == 0: df = df_new ampl = ampl_new @@ -514,7 +522,7 @@ def multiplex_looms(args): RO = RO_new continue - # Merge amplicons (keep all) + # Merge amplicons (keep all) ampl = ampl.combine_first(ampl_new) df = df.merge(df_new, left_index=True, right_index=True) DP = DP.merge(DP_new, left_index=True, right_index=True) @@ -530,23 +538,22 @@ def multiplex_looms(args): del RO_new gc.collect() - # Add doublets (if doublets arg specified) print('Generating doublets') samples['dbt'] = {'name': []} - s_probs = samples_size/np.sum(samples_size) + s_probs = samples_size / np.sum(samples_size) drop_cells = [] dbt_gt = {} dbt_DP = {} dbt_GQ = {} dbt_AD = {} dbt_RO = {} - - #We have sample_size + doublet_rate * sample_size many cells in total. - #The number of doublets is doublet_rate* sample_size = doublet_rate/(1+doublet_rate) * total_cell_count - - dbt_total = int(args.doublets/(1 + args.doublets) * np.sum(samples_size)) + + # We have sample_size + doublet_rate * sample_size many cells in total. + # The number of doublets is doublet_rate* sample_size = doublet_rate/(1+doublet_rate) * total_cell_count + + dbt_total = int(args.doublets / (1 + args.doublets) * np.sum(samples_size)) for i in range(dbt_total): s1, s2 = np.random.choice(no_samples, size=2, replace=False, p=s_probs) c1_idx = np.random.choice(samples[s1]['idx'].size) @@ -556,10 +563,10 @@ def multiplex_looms(args): c2 = samples[s2]['idx'][c2_idx] + ((s2 + 1) / 10) new_id = f'{c1}+{c2}' - new_name_sample1 = samples[s1]["name"][c1_idx] - new_name_sample2 = samples[s2]["name"][c2_idx] + new_name_sample1 = samples[s1]['name'][c1_idx] + new_name_sample2 = samples[s2]['name'][c2_idx] new_name = f'{new_name_sample1}+{new_name_sample2}' - + dbt_gt[new_id] = np.apply_along_axis(merge_gt, axis=1, arr=df[[c1, c2]]) dbt_GQ[new_id] = GQ[[c1, c2]].mean(axis=1) dbt_DP[new_id] = (DP[[c1, c2]].mean(axis=1).round()).astype(int) @@ -664,17 +671,16 @@ def multiplex_looms(args): } cell_data = np.empty((len(variants_info['CHR']), len(cells)), dtype='<U11') - cell_data = np.char.add(np.char.add(RO.astype(str), ':'), np.char.add(AD.astype(str), ':')) + cell_data = np.char.add( + np.char.add(RO.astype(str), ':'), np.char.add(AD.astype(str), ':') + ) cell_data = np.char.add(cell_data, gt.astype(str)) for col, cell in enumerate(cells): name = cell_names[col] variants_info[name] = cell_data[:, col].tolist() - - return pd.DataFrame(variants_info), gt, VAF - - + return pd.DataFrame(variants_info), gt, VAF def compute_pseudobulk_VAFs(df1, gt1): @@ -888,9 +894,9 @@ def parse_args(): if __name__ == '__main__': - logging.info("Parsing arguments.") + logging.info('Parsing arguments.') args = parse_args() if args.whitelist: update_whitelist(args.whitelist) - logging.info("Running main program.") - main(args) \ No newline at end of file + logging.info('Running main program.') + main(args) diff --git a/benchmarking_pipeline/workflow/scripts/split_loom_files.py b/benchmarking_pipeline/workflow/scripts/split_loom_files.py index d7f5dca..3c2ace1 100644 --- a/benchmarking_pipeline/workflow/scripts/split_loom_files.py +++ b/benchmarking_pipeline/workflow/scripts/split_loom_files.py @@ -13,7 +13,12 @@ import loompy -logging.basicConfig(format="{asctime} - {levelname} - {message}", style="{", datefmt="%Y-%m-%d %H:%M",level=logging.DEBUG) +logging.basicConfig( + format='{asctime} - {levelname} - {message}', + style='{', + datefmt='%Y-%m-%d %H:%M', + level=logging.DEBUG, +) def batched(iterable, chunk_size): @@ -26,77 +31,83 @@ class FancyLoomConnection(loompy.LoomConnection): def __init__(self, filename: str, mode: str = 'r+', *, validate: bool = True): super().__init__(filename, mode, validate=validate) - def add_columns_batched(self, loom_view_manager, col_indices, batch_size = 30): - logging.info("subselecting cells in batched mode") + def add_columns_batched(self, loom_view_manager, col_indices, batch_size=30): + logging.info('subselecting cells in batched mode') batches = list(batched(col_indices, batch_size)) for idx, batch in enumerate(batches): - logging.info(f"Adding batch {idx+1} of {len(batches)}") + logging.info(f'Adding batch {idx+1} of {len(batches)}') ds_batch = loom_view_manager[:, np.array(batch)] - self.add_columns(ds_batch.layers, col_attrs = ds_batch.ca, row_attrs = ds_batch.ra) + self.add_columns( + ds_batch.layers, col_attrs=ds_batch.ca, row_attrs=ds_batch.ra + ) -def fancy_loompy_connect(filename: str, mode: str = 'r+', *, validate: bool = True) -> FancyLoomConnection: +def fancy_loompy_connect( + filename: str, mode: str = 'r+', *, validate: bool = True +) -> FancyLoomConnection: """ - Establish a connection to a .loom file. + Establish a connection to a .loom file. - Args: - filename: Path to the Loom file to open - mode: Read/write mode, 'r+' (read/write) or 'r' (read-only), defaults to 'r+' - validate: Validate the file structure against the Loom file format specification - Returns: - A LoomConnection instance. + Args: + filename: Path to the Loom file to open + mode: Read/write mode, 'r+' (read/write) or 'r' (read-only), defaults to 'r+' + validate: Validate the file structure against the Loom file format specification + Returns: + A LoomConnection instance. - Remarks: - This function should typically be used as a context manager (i.e. inside a ``with``-block): + Remarks: + This function should typically be used as a context manager (i.e. inside a ``with``-block): - .. highlight:: python - .. code-block:: python + .. highlight:: python + .. code-block:: python - import loompy - with loompy.connect("mydata.loom") as ds: - print(ds.ca.keys()) + import loompy + with loompy.connect("mydata.loom") as ds: + print(ds.ca.keys()) - This ensures that the file will be closed automatically when the context block ends + This ensures that the file will be closed automatically when the context block ends - Note: if validation is requested, an exception is raised if validation fails. - """ + Note: if validation is requested, an exception is raised if validation fails. + """ return FancyLoomConnection(filename, mode, validate=validate) -def fancy_loompy_new(filename: str, *, file_attrs: Optional[Dict[str, str]] = None) -> FancyLoomConnection: - """ - Create an empty Loom file, and return it as a context manager. - """ - if filename.startswith("~/"): - filename = os.path.expanduser(filename) - if file_attrs is None: - file_attrs = {} - - # Create the file (empty). - # Yes, this might cause an exception, which we prefer to send to the caller - f = h5py.File(name=filename, mode='w') - f.create_group('/attrs') # v3.0.0 - f.create_group('/layers') - f.create_group('/row_attrs') - f.create_group('/col_attrs') - f.create_group('/row_graphs') - f.create_group('/col_graphs') - f.flush() - f.close() - - ds = fancy_loompy_connect(filename, validate=False) - for vals in file_attrs: - if file_attrs[vals] is None: - ds.attrs[vals] = "None" - else: - ds.attrs[vals] = file_attrs[vals] - # store creation date - ds.attrs['CreationDate'] = loompy.timestamp() - ds.attrs["LOOM_SPEC_VERSION"] = loompy.loom_spec_version - return ds - - -def sample_split(alpha, seed = None): +def fancy_loompy_new( + filename: str, *, file_attrs: Optional[Dict[str, str]] = None +) -> FancyLoomConnection: + """ + Create an empty Loom file, and return it as a context manager. + """ + if filename.startswith('~/'): + filename = os.path.expanduser(filename) + if file_attrs is None: + file_attrs = {} + + # Create the file (empty). + # Yes, this might cause an exception, which we prefer to send to the caller + f = h5py.File(name=filename, mode='w') + f.create_group('/attrs') # v3.0.0 + f.create_group('/layers') + f.create_group('/row_attrs') + f.create_group('/col_attrs') + f.create_group('/row_graphs') + f.create_group('/col_graphs') + f.flush() + f.close() + + ds = fancy_loompy_connect(filename, validate=False) + for vals in file_attrs: + if file_attrs[vals] is None: + ds.attrs[vals] = 'None' + else: + ds.attrs[vals] = file_attrs[vals] + # store creation date + ds.attrs['CreationDate'] = loompy.timestamp() + ds.attrs['LOOM_SPEC_VERSION'] = loompy.loom_spec_version + return ds + + +def sample_split(alpha, seed=None): if seed: rng = np.random.default_rng(seed=seed) else: @@ -106,8 +117,7 @@ def sample_split(alpha, seed = None): return splitting_ratio - -def split_loom_file(splitting_ratio, loom_file, output_dir, seed = None): +def split_loom_file(splitting_ratio, loom_file, output_dir, seed=None): if seed: rng = np.random.default_rng(seed=seed) else: @@ -116,12 +126,13 @@ def split_loom_file(splitting_ratio, loom_file, output_dir, seed = None): with tempfile.TemporaryDirectory() as temp_dir: temp_loom_file = Path(temp_dir) / Path(loom_file).name shutil.copy2(loom_file, temp_loom_file) - with loompy.connect(temp_loom_file) as ds: n_cells = ds.shape[1] - split_vector = rng.choice([0,1], size=n_cells, p=[1-splitting_ratio, splitting_ratio]) - + split_vector = rng.choice( + [0, 1], size=n_cells, p=[1 - splitting_ratio, splitting_ratio] + ) + indices_first_sample = np.where(split_vector == 1)[0] indices_second_sample = np.where(split_vector == 0)[0] @@ -130,45 +141,62 @@ def split_loom_file(splitting_ratio, loom_file, output_dir, seed = None): output_1 = (output_dir / f'{loom_file.stem}_split1.loom').as_posix() with fancy_loompy_new(output_1) as ds1: - logging.info(f"Connection to {output_1} established") - ds1.add_columns_batched(ds_view, indices_first_sample, batch_size = 30) + logging.info(f'Connection to {output_1} established') + ds1.add_columns_batched(ds_view, indices_first_sample, batch_size=30) number_of_cells1 = ds1.shape[1] - logging.info("Writing file 1 was successful!") + logging.info('Writing file 1 was successful!') output_2 = (output_dir / f'{loom_file.stem}_split2.loom').as_posix() with fancy_loompy_new(output_2) as ds2: - logging.info(f"Connection to {output_2} established") - ds2.add_columns_batched(ds_view, indices_second_sample, batch_size = 30) + logging.info(f'Connection to {output_2} established') + ds2.add_columns_batched(ds_view, indices_second_sample, batch_size=30) number_of_cells2 = ds2.shape[1] - logging.info("Writing file 2 was successful!") + logging.info('Writing file 2 was successful!') if number_of_cells1 + number_of_cells2 == n_cells: - logging.info("Splitting was successful!") + logging.info('Splitting was successful!') else: - logging.error("The number of cells in the split files does not match the number of cells in the original file") - raise ValueError(f"Split 1 has {number_of_cells1} cells and split 2 has {number_of_cells2} cells, but the original file has {n_cells} cells") + logging.error( + 'The number of cells in the split files does not match the number of cells in the original file' + ) + raise ValueError( + f'Split 1 has {number_of_cells1} cells and split 2 has {number_of_cells2} cells, but the original file has {n_cells} cells' + ) return None def parse_args(): - parser = argparse.ArgumentParser(description="Simulate multiplexed single-cell (or multi-sample) mutation data") - parser.add_argument("--input", type=str, required=True, help="Input list of loom files") - parser.add_argument("--alpha", type=float, required=True, help="Dirichlet distribution parameter") - parser.add_argument("--output", type=str, required=True, help="One of the output files. The other output file is saves to the same directory") - - logging.info("Parsing arguments") + parser = argparse.ArgumentParser( + description='Simulate multiplexed single-cell (or multi-sample) mutation data' + ) + parser.add_argument( + '--input', type=str, required=True, help='Input list of loom files' + ) + parser.add_argument( + '--alpha', type=float, required=True, help='Dirichlet distribution parameter' + ) + parser.add_argument( + '--output', + type=str, + required=True, + help='One of the output files. The other output file is saves to the same directory', + ) + + logging.info('Parsing arguments') args = parser.parse_args() return args def main(args): - logging.info(f"Splitting file {args.input}") - loom_file = args.input - splitting_ratio = sample_split(alpha=args.alpha) - output_dir = Path(args.output).parent - split_loom_file(splitting_ratio = splitting_ratio, loom_file = loom_file, output_dir = output_dir) - logging.info(f"Done splitting file {args.input}") + logging.info(f'Splitting file {args.input}') + loom_file = args.input + splitting_ratio = sample_split(alpha=args.alpha) + output_dir = Path(args.output).parent + split_loom_file( + splitting_ratio=splitting_ratio, loom_file=loom_file, output_dir=output_dir + ) + logging.info(f'Done splitting file {args.input}') -if __name__ == "__main__": +if __name__ == '__main__': args = parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/mixmax/create_demultiplexing_scheme.py b/mixmax/create_demultiplexing_scheme.py index 331816d..5e16c61 100644 --- a/mixmax/create_demultiplexing_scheme.py +++ b/mixmax/create_demultiplexing_scheme.py @@ -6,39 +6,56 @@ import numpy as np -logging.basicConfig(format="{asctime} - {levelname} - {message}", style="{", datefmt="%Y-%m-%d %H:%M",level=logging.INFO) +logging.basicConfig( + format='{asctime} - {levelname} - {message}', + style='{', + datefmt='%Y-%m-%d %H:%M', + level=logging.INFO, +) -def define_demultiplexing_scheme_optimal_case(maximal_number_of_samples, maximal_pool_size, n_samples, robust): +def define_demultiplexing_scheme_optimal_case( + maximal_number_of_samples, maximal_pool_size, n_samples, robust +): if n_samples % maximal_number_of_samples != 0: - raise ValueError("Number of samples must be a multiple of maximal number of samples to run this function!") + raise ValueError( + 'Number of samples must be a multiple of maximal number of samples to run this function!' + ) demultiplexing_scheme = {} number_of_iterations = int(n_samples / maximal_number_of_samples) if not robust: - unordered_unique_pairs = list(itertools.combinations(range(maximal_pool_size), 2)) + unordered_unique_pairs = list( + itertools.combinations(range(maximal_pool_size), 2) + ) diagonal = list(zip(range(maximal_pool_size), range(maximal_pool_size))) unordered_pairs = unordered_unique_pairs + diagonal else: - unordered_pairs = list(itertools.combinations(range(maximal_pool_size+1), 2)) + unordered_pairs = list(itertools.combinations(range(maximal_pool_size + 1), 2)) for idx1 in range(number_of_iterations): for idx2, pair in enumerate(unordered_pairs): - demultiplexing_scheme[int(idx1 * maximal_number_of_samples + idx2 + 1)] = pair + (idx1,) # naming of samples starts with 1 + demultiplexing_scheme[int(idx1 * maximal_number_of_samples + idx2 + 1)] = ( + pair + (idx1,) + ) # naming of samples starts with 1 return demultiplexing_scheme - def find_demultiplexing_scheme(maximal_pool_size, n_samples, robust): - maximal_number_of_samples = (maximal_pool_size * (maximal_pool_size+1))/2 + maximal_number_of_samples = (maximal_pool_size * (maximal_pool_size + 1)) / 2 if n_samples % maximal_number_of_samples != 0: - logging.error("So far, only defined for certain cohort sizes") + logging.error('So far, only defined for certain cohort sizes') raise NotImplementedError else: - logging.info("Creating multiplexing scheme") - demultiplexing_scheme = define_demultiplexing_scheme_optimal_case(maximal_number_of_samples = maximal_number_of_samples, maximal_pool_size = maximal_pool_size, n_samples = n_samples, robust = robust) - + logging.info('Creating multiplexing scheme') + demultiplexing_scheme = define_demultiplexing_scheme_optimal_case( + maximal_number_of_samples=maximal_number_of_samples, + maximal_pool_size=maximal_pool_size, + n_samples=n_samples, + robust=robust, + ) + logging.info(f'Demultiplexing scheme: {demultiplexing_scheme}') return demultiplexing_scheme @@ -46,91 +63,135 @@ def find_demultiplexing_scheme(maximal_pool_size, n_samples, robust): def multiplexing_scheme_format2pool_format(demultiplexing_scheme): pool_scheme = {} - total_no_pools_per_repetition = np.max([pool[:-1] for pool in list(demultiplexing_scheme.values())]) - total_no_of_repetitions = np.max([pool[-1] for pool in list(demultiplexing_scheme.values())]) - - pools = list(itertools.product(range(total_no_pools_per_repetition + 1), range(total_no_of_repetitions + 1))) + total_no_pools_per_repetition = np.max( + [pool[:-1] for pool in list(demultiplexing_scheme.values())] + ) + total_no_of_repetitions = np.max( + [pool[-1] for pool in list(demultiplexing_scheme.values())] + ) + + pools = list( + itertools.product( + range(total_no_pools_per_repetition + 1), range(total_no_of_repetitions + 1) + ) + ) for pool in pools: pool_scheme[pool] = [] for key, value in demultiplexing_scheme.items(): - pool_scheme[value[0],value[-1]].append(int(key)) - pool_scheme[value[1],value[-1]].append(-int(key)) + pool_scheme[value[0], value[-1]].append(int(key)) + pool_scheme[value[1], value[-1]].append(-int(key)) # The pool scheme has pool identifier as keys and split libraries as values, where the first split is encoded as a positive integer and the second split as a negative integer return pool_scheme - def select_samples_for_pooling(pool_scheme, input_dir, sample_list): if isinstance(sample_list[0], str): - for idx,sample in enumerate(sample_list): + for idx, sample in enumerate(sample_list): sample_list[idx] = Path(sample) if isinstance(input_dir, str): input_dir = Path(input_dir) logging.info('Selecting samples for pooling') pools_summary = {} for pool, samples in pool_scheme.items(): - - logging.info(f"Retrieving list of samples to be pooled for pool {pool}") - logging.debug(f"Samples: {samples}") - logging.debug(f"Sample list: {sample_list}") + logging.info(f'Retrieving list of samples to be pooled for pool {pool}') + logging.debug(f'Samples: {samples}') + logging.debug(f'Sample list: {sample_list}') loom_files = [] for sample in samples: if sample > 0: - loom_files.append((input_dir / f"{sample_list[sample-1].stem}_split1.loom").as_posix()) + loom_files.append( + (input_dir / f'{sample_list[sample-1].stem}_split1.loom').as_posix() + ) else: - loom_files.append((input_dir / f"{sample_list[-sample-1].stem}_split2.loom").as_posix()) - - logging.info(f"Done retrieving list of samples to be pooled for pool {pool}") + loom_files.append( + ( + input_dir / f'{sample_list[-sample-1].stem}_split2.loom' + ).as_posix() + ) - pools_summary[f"({pool[0]}.{pool[1]})"] = loom_files + logging.info(f'Done retrieving list of samples to be pooled for pool {pool}') + + pools_summary[f'({pool[0]}.{pool[1]})'] = loom_files return pools_summary def write_output(pools_summary, output): - with open(output, "w") as f: + with open(output, 'w') as f: yaml.dump(pools_summary, f, default_flow_style=False) def load_input_samples(input_sample_file): - with open(input_sample_file, "r") as f: + with open(input_sample_file, 'r') as f: input_samples = f.readlines() - + return input_samples def parse_args(): - parser = argparse.ArgumentParser(description="Output a multiplexing scheme under pool size constraint for a predefined number of samples") - parser.add_argument("--robust", type=bool, required=False, default = False, help="Input list of loom files") - parser.add_argument("-k", "--maximal_pool_size", type=int, required=False, help="The maximal amount of samples to be multiplexed in a pool") - parser.add_argument("--n_samples", type=int, required=True, help="The number of samples to be sequenced in the cohort") - parser.add_argument("--output", type=str, required=True, help="Where to store the output file") - parser.add_argument("--input_dir", type=str, required=True, help="Directory to the loom files") - parser.add_argument("--input_sample_file", type=str, required=True, help="Path to the file containing the list of loom files") - - logging.info("Parsing arguments") + parser = argparse.ArgumentParser( + description='Output a multiplexing scheme under pool size constraint for a predefined number of samples' + ) + parser.add_argument( + '--robust', + type=bool, + required=False, + default=False, + help='Input list of loom files', + ) + parser.add_argument( + '-k', + '--maximal_pool_size', + type=int, + required=False, + help='The maximal amount of samples to be multiplexed in a pool', + ) + parser.add_argument( + '--n_samples', + type=int, + required=True, + help='The number of samples to be sequenced in the cohort', + ) + parser.add_argument( + '--output', type=str, required=True, help='Where to store the output file' + ) + parser.add_argument( + '--input_dir', type=str, required=True, help='Directory to the loom files' + ) + parser.add_argument( + '--input_sample_file', + type=str, + required=True, + help='Path to the file containing the list of loom files', + ) + + logging.info('Parsing arguments') args = parser.parse_args() - + if args.n_samples == 0: - logging.error("Number of samples must be greater than 0") + logging.error('Number of samples must be greater than 0') raise ValueError - + return args def main(args): input_samples = load_input_samples(args.input_sample_file) n_samples = len(input_samples) - demultiplexing_scheme = find_demultiplexing_scheme(args.maximal_pool_size, n_samples, args.robust) + demultiplexing_scheme = find_demultiplexing_scheme( + args.maximal_pool_size, n_samples, args.robust + ) pool_scheme = multiplexing_scheme_format2pool_format(demultiplexing_scheme) - pools_summary = select_samples_for_pooling(pool_scheme, args.input_dir, input_samples) - logging.debug(f"Output: {pools_summary.keys()}") - logging.info(f"Writing output to {args.output}") + pools_summary = select_samples_for_pooling( + pool_scheme, args.input_dir, input_samples + ) + logging.debug(f'Output: {pools_summary.keys()}') + logging.info(f'Writing output to {args.output}') write_output(pools_summary, args.output) - logging.info("Success.") + logging.info('Success.') -if __name__ == "__main__": +if __name__ == '__main__': args = parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/mixmax/match_samples.py b/mixmax/match_samples.py index da39ccc..cacb546 100644 --- a/mixmax/match_samples.py +++ b/mixmax/match_samples.py @@ -1,5 +1,4 @@ import argparse -import logging import pandas as pd import seaborn as sns @@ -9,22 +8,25 @@ from pathlib import Path - def pool_format2multiplexing_scheme(pool_scheme): demultiplexing_scheme = {} for pool, samples in pool_scheme.items(): for sample in samples: - if "_split1.loom" in sample: + if '_split1.loom' in sample: sample_name = str(Path(sample).stem) sample_name = sample_name.replace('_split1', '') - number_of_iterations =pool.split('.')[1][:-1] + number_of_iterations = pool.split('.')[1][:-1] pool_ID = pool.split('.')[0][1:] if sample_name not in demultiplexing_scheme.keys(): - demultiplexing_scheme[sample_name] = [int(pool_ID), np.inf, int(number_of_iterations)] + demultiplexing_scheme[sample_name] = [ + int(pool_ID), + np.inf, + int(number_of_iterations), + ] else: demultiplexing_scheme[sample_name][0] = int(pool_ID) demultiplexing_scheme[sample_name][2] = int(number_of_iterations) - if "_split2.loom" in sample: + if '_split2.loom' in sample: sample_name = str(Path(sample).stem) sample_name = sample_name.replace('_split2', '') pool_ID = pool.split('.')[0][1:] @@ -33,25 +35,32 @@ def pool_format2multiplexing_scheme(pool_scheme): else: demultiplexing_scheme[sample_name][1] = int(pool_ID) # Swap keys and items in the dictionary - swapped_demultiplexing_scheme = {tuple(v): k for k, v in demultiplexing_scheme.items()} - - return swapped_demultiplexing_scheme - + swapped_demultiplexing_scheme = { + tuple(v): k for k, v in demultiplexing_scheme.items() + } + return swapped_demultiplexing_scheme # Create a custom colormap -cmap = sns.color_palette("viridis", as_cmap=True) +cmap = sns.color_palette('viridis', as_cmap=True) cmap.set_bad(color='grey') def parse_args(): - parser = argparse.ArgumentParser(description="Compare distances between samples") - parser.add_argument("--tsv_files", nargs='+' , type=str, help="Path to TSV files containing the genotypes of the clusters. Must be in the same order as the pools in the pooling scheme.") - parser.add_argument("--pool_scheme", type=str, help="Path to the pool scheme file") - parser.add_argument("--output_plot", type=str, help="Output heatmap") - parser.add_argument("--output", type=str, help="sample assignment") - parser.add_argument("--robust", type=bool, required=False, default = False, help="Robust assignment") + parser = argparse.ArgumentParser(description='Compare distances between samples') + parser.add_argument( + '--tsv_files', + nargs='+', + type=str, + help='Path to TSV files containing the genotypes of the clusters. Must be in the same order as the pools in the pooling scheme.', + ) + parser.add_argument('--pool_scheme', type=str, help='Path to the pool scheme file') + parser.add_argument('--output_plot', type=str, help='Output heatmap') + parser.add_argument('--output', type=str, help='sample assignment') + parser.add_argument( + '--robust', type=bool, required=False, default=False, help='Robust assignment' + ) return parser.parse_args() @@ -72,19 +81,30 @@ def compute_ratio(matrix): def permute_tsv_files(tsv_files, pool_scheme): pools = [f"({Path(file).name.split("_")[1]})" for file in tsv_files] - + pool_scheme_keys = list(pool_scheme.keys()) pool_permutation = [pools.index(pool) for pool in pool_scheme_keys] - - if not all((pool1 == pool2 for pool1, pool2 in zip([pools[permuted_idx] for permuted_idx in pool_permutation], pool_scheme_keys))): - raise ValueError("The pools in the pool scheme do not match the pools in the TSV files") + + if not all( + ( + pool1 == pool2 + for pool1, pool2 in zip( + [pools[permuted_idx] for permuted_idx in pool_permutation], + pool_scheme_keys, + ) + ) + ): + raise ValueError( + 'The pools in the pool scheme do not match the pools in the TSV files' + ) return pool_permutation + args = parse_args() -def load_data(args): +def load_data(args): with open(args.pool_scheme, 'r') as file: pooling_scheme = yaml.safe_load(file) @@ -94,11 +114,12 @@ def load_data(args): tsv_files = [args.tsv_files] tsv_files = list(args.tsv_files) tsv_files_permuted = [tsv_files[i] for i in permutation_of_pools] - - dfs = [pd.read_csv(f, sep='\t', index_col = 0) for f in tsv_files_permuted] + + dfs = [pd.read_csv(f, sep='\t', index_col=0) for f in tsv_files_permuted] return dfs, pooling_scheme + dfs, pooling_scheme = load_data(args) # Get a set of all column names @@ -109,14 +130,14 @@ def load_data(args): print(all_columns) -for idx,df in enumerate(dfs): +for idx, df in enumerate(dfs): for col in all_columns: if col not in df.columns: dfs[idx][col] = np.nan -for idx,df in enumerate(dfs): +for idx, df in enumerate(dfs): dfs[idx] = df[sorted(all_columns)] - + # Concatenate all DataFrames concatenated_df = pd.concat(dfs, ignore_index=True)[sorted(all_columns)] @@ -135,7 +156,7 @@ def load_data(args): concatenated_df.drop(columns=cols_to_remove, inplace=True) concatenated_df.drop(columns=cols_to_remove_45_55, inplace=True) -for idx,df in enumerate(dfs): +for idx, df in enumerate(dfs): dfs[idx] = df.drop(columns=cols_to_remove, inplace=False) dfs[idx] = df.drop(columns=cols_to_remove_45_55, inplace=False) @@ -151,7 +172,7 @@ def load_data(args): # Compute the distance matrix considering only non-NaN entries and normalizing - # Compute the distance matrix for all pairs of DataFrames +# Compute the distance matrix for all pairs of DataFrames distance_matrices = np.empty((len(dfs), len(dfs)), dtype=object) for data1, df1 in enumerate(dfs): @@ -159,7 +180,9 @@ def load_data(args): distance_matrix = np.zeros((df1.shape[0], df2.shape[0])) for i in range(df1.shape[0]): for j in range(df2.shape[0]): - distance_matrix[i, j] = custom_distance(df1.iloc[i].values, df2.iloc[j].values) + distance_matrix[i, j] = custom_distance( + df1.iloc[i].values, df2.iloc[j].values + ) distance_matrices[data1, data2] = distance_matrix @@ -200,7 +223,6 @@ def load_data(args): plt.savefig(heatmap_plot) - fig, axes = plt.subplots(nrows=1, ncols=len(dfs), figsize=(15, 5)) for i, df in enumerate(dfs): @@ -209,9 +231,6 @@ def load_data(args): plt.savefig(genotype_plot) - - - ratios = [] for i in range(len(dfs)): for j in range(len(dfs)): @@ -224,7 +243,7 @@ def load_data(args): lowest_value_pairs = [] -used_samples = [[]*len(dfs) for _ in range(len(dfs))] +used_samples = [[] * len(dfs) for _ in range(len(dfs))] """ if remove is not None: for j in range(len(dfs)): if j != remove: @@ -272,27 +291,31 @@ def load_data(args): # Sort the distance matrices by the recomputed ratio, from largest to smallest sorted_ratios = sorted(ratios, key=lambda x: x[2], reverse=True) else: - print(f"skipping assignment in matrix {i},{j}") + print(f'skipping assignment in matrix {i},{j}') # Print the pairs with the lowest values for i, j, row, col in lowest_value_pairs: - print(f'Lowest value in Distance Matrix DF{i} vs DF{j}: Row = {row}, Column = {col}') + print( + f'Lowest value in Distance Matrix DF{i} vs DF{j}: Row = {row}, Column = {col}' + ) - -demultiplexing_scheme = pool_format2multiplexing_scheme(pooling_scheme) +demultiplexing_scheme = pool_format2multiplexing_scheme(pooling_scheme) for i, j, row, col in lowest_value_pairs: if i > j: i, j = j, i row, col = col, row - print(f'Sample {row} from pool {i} and sample {col} from pool {j}: {demultiplexing_scheme[(i,j,0)]}') + print( + f'Sample {row} from pool {i} and sample {col} from pool {j}: {demultiplexing_scheme[(i,j,0)]}' + ) -if args.robust == False: +if args.robust: for i in range(len(used_samples)): - unused_sample = next(sample for sample in range(len(dfs[i])) if sample not in used_samples[i]) - print(f"Sample {unused_sample} from pool {i}: {demultiplexing_scheme[(i,i,0)]}") - + unused_sample = next( + sample for sample in range(len(dfs[i])) if sample not in used_samples[i] + ) + print(f'Sample {unused_sample} from pool {i}: {demultiplexing_scheme[(i,i,0)]}') pooling_scheme_keys = list(pooling_scheme.keys()) @@ -302,21 +325,35 @@ def load_data(args): i, j = j, i row, col = col, row if pooling_scheme_keys[i] not in sample_assignment.keys(): - sample_assignment[pooling_scheme_keys[i]] = {row: demultiplexing_scheme[(i, j, 0)]} + sample_assignment[pooling_scheme_keys[i]] = { + row: demultiplexing_scheme[(i, j, 0)] + } else: - sample_assignment[pooling_scheme_keys[i]][row] = demultiplexing_scheme[(i, j, 0)] + sample_assignment[pooling_scheme_keys[i]][row] = demultiplexing_scheme[ + (i, j, 0) + ] if pooling_scheme_keys[j] not in sample_assignment.keys(): - sample_assignment[pooling_scheme_keys[j]] = {col: demultiplexing_scheme[(i, j, 0)]} + sample_assignment[pooling_scheme_keys[j]] = { + col: demultiplexing_scheme[(i, j, 0)] + } else: - sample_assignment[pooling_scheme_keys[j]][col] = demultiplexing_scheme[(i, j, 0)] + sample_assignment[pooling_scheme_keys[j]][col] = demultiplexing_scheme[ + (i, j, 0) + ] -if args.robust == False: +if args.robust: for i in range(len(used_samples)): - unused_sample = next(sample for sample in range(len(dfs[i])) if sample not in used_samples[i]) + unused_sample = next( + sample for sample in range(len(dfs[i])) if sample not in used_samples[i] + ) if pooling_scheme_keys[i] not in sample_assignment.keys(): - sample_assignment[pooling_scheme_keys[i]] = {unused_sample: demultiplexing_scheme[(i, i, 0)]} + sample_assignment[pooling_scheme_keys[i]] = { + unused_sample: demultiplexing_scheme[(i, i, 0)] + } else: - sample_assignment[pooling_scheme_keys[i]][unused_sample] = demultiplexing_scheme[(i, i, 0)] + sample_assignment[pooling_scheme_keys[i]][unused_sample] = ( + demultiplexing_scheme[(i, i, 0)] + ) with open(args.output, 'w') as yaml_file: yaml.dump(sample_assignment, yaml_file)