diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..ea51302
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,13 @@
+repos:
+  - repo: https://github.com/snakemake/snakefmt
+    rev: v0.10.2 # Replace by any tag/version ≥0.2.4 : https://github.com/snakemake/snakefmt/releases
+    hooks:
+      - id: snakefmt
+  - repo: https://github.com/astral-sh/ruff-pre-commit
+    rev: v0.4.4
+    hooks:
+    # Run the linter.
+      - id: ruff
+        args: [ --fix ]
+    # Run the formatter.
+      - id: ruff-format
diff --git a/benchmarking_pipeline/launch_pipeline.py b/benchmarking_pipeline/launch_pipeline.py
index 7732d48..7e4f2b1 100644
--- a/benchmarking_pipeline/launch_pipeline.py
+++ b/benchmarking_pipeline/launch_pipeline.py
@@ -1,33 +1,37 @@
 import argparse
 from pathlib import Path
 import subprocess
-import tempfile
+
 
 def parse_args():
-    parser = argparse.ArgumentParser(description="set a seed and run the pipeline")
-    parser.add_argument("--seed",  type=int, help="seed")
-    parser.add_argument("--config_template",type = str, help = "Provide path to template config file")
+    parser = argparse.ArgumentParser(description='set a seed and run the pipeline')
+    parser.add_argument('--seed', type=int, help='seed')
+    parser.add_argument(
+        '--config_template', type=str, help='Provide path to template config file'
+    )
     return parser.parse_args()
 
+
 def main(args):
     with open(args.config_template, 'r') as file:
         config_content = file.read()
 
+    config_content = config_content.replace('__seed__', str(args.seed))
 
-    config_content = config_content.replace("__seed__", str(args.seed))
-    
-
-    with open(Path(args.config_template).parent / f"config_{args.seed}.yaml", 'w') as file:
+    with open(
+        Path(args.config_template).parent / f'config_{args.seed}.yaml', 'w'
+    ) as file:
         file.write(config_content)
 
     command = (
-        "source $(conda info --base)/etc/profile.d/conda.sh && conda activate snakemake; "
-        f"sbatch --time=24:00:00 --wrap=\"snakemake --cores 10 --configfile config/config_{args.seed}.yaml "
-        "--software-deployment-method conda --rerun-incomplete  -p  --keep-going --profile profiles/slurm/\""
+        'source $(conda info --base)/etc/profile.d/conda.sh && conda activate snakemake; '
+        f'sbatch --time=24:00:00 --wrap="snakemake --cores 10 --configfile config/config_{args.seed}.yaml '
+        '--software-deployment-method conda --rerun-incomplete  -p  --keep-going --profile profiles/slurm/"'
     )
 
     subprocess.run(command, shell=True, check=True, executable='/bin/bash')
 
+
 if __name__ == '__main__':
     args = parse_args()
     main(args)
diff --git a/benchmarking_pipeline/tests/test_common.py b/benchmarking_pipeline/tests/test_common.py
index 66b416f..b632635 100644
--- a/benchmarking_pipeline/tests/test_common.py
+++ b/benchmarking_pipeline/tests/test_common.py
@@ -4,8 +4,8 @@
 
 script_directory = Path(__file__).resolve().parent
 
-path = (script_directory.parent / "workflow" / "rules" / "common.smk").as_posix()
-spec = importlib.util.spec_from_file_location("common", path)
+path = (script_directory.parent / 'workflow' / 'rules' / 'common.smk').as_posix()
+spec = importlib.util.spec_from_file_location('common', path)
 target = importlib.util.module_from_spec(spec)
 spec.loader.exec_module(target)
 get_split_files = target.get_split_files
@@ -15,14 +15,29 @@
 class TestCommon(unittest.TestCase):
     def __init__(self, *args, **kwargs):
         super(TestCommon, self).__init__(*args, **kwargs)
-        self.sample_list = script_directory / "sample_list_test.txt"
-        
+        self.sample_list = script_directory / 'sample_list_test.txt'
+
     def test_get_split_files(self):
         split_files = get_split_files(self.sample_list)
-        self.assertEqual(split_files, [self.sample_list.parent / "loom1_split1.loom", self.sample_list.parent / "loom1_split2.loom", self.sample_list.parent / "loom2_split1.loom",self.sample_list.parent / "loom2_split2.loom"])
+        self.assertEqual(
+            split_files,
+            [
+                self.sample_list.parent / 'loom1_split1.loom',
+                self.sample_list.parent / 'loom1_split2.loom',
+                self.sample_list.parent / 'loom2_split1.loom',
+                self.sample_list.parent / 'loom2_split2.loom',
+            ],
+        )
 
     def test_determine_number_of_different_donors(self):
-        self.assertEqual(target.determine_number_of_different_donors(script_directory / 'pools_test.txt'), "(0.0)", 4)
+        self.assertEqual(
+            target.determine_number_of_different_donors(
+                script_directory / 'pools_test.txt'
+            ),
+            '(0.0)',
+            4,
+        )
+
 
 if __name__ == '__main__':
-    unittest.main()
\ No newline at end of file
+    unittest.main()
diff --git a/benchmarking_pipeline/tests/test_create_demultiplexing_scheme.py b/benchmarking_pipeline/tests/test_create_demultiplexing_scheme.py
index 830c37f..70d4b11 100644
--- a/benchmarking_pipeline/tests/test_create_demultiplexing_scheme.py
+++ b/benchmarking_pipeline/tests/test_create_demultiplexing_scheme.py
@@ -4,32 +4,85 @@
 
 script_directory = Path(__file__).resolve().parent
 
-path = (script_directory.parent / "workflow" / "scripts" / "create_demultiplexing_scheme.py").as_posix()
-spec = importlib.util.spec_from_file_location("create_demultiplexing_scheme", path)
+path = (
+    script_directory.parent / 'workflow' / 'scripts' / 'create_demultiplexing_scheme.py'
+).as_posix()
+spec = importlib.util.spec_from_file_location('create_demultiplexing_scheme', path)
 target = importlib.util.module_from_spec(spec)
 spec.loader.exec_module(target)
 multiplexing_scheme_format2pool_format = target.multiplexing_scheme_format2pool_format
 select_samples_for_pooling = target.select_samples_for_pooling
-define_demultiplexing_scheme_optimal_case = target.define_demultiplexing_scheme_optimal_case
+define_demultiplexing_scheme_optimal_case = (
+    target.define_demultiplexing_scheme_optimal_case
+)
+
 
 class TestCreateDemultiplexingScheme(unittest.TestCase):
     def __init__(self, *args, **kwargs):
         super(TestCreateDemultiplexingScheme, self).__init__(*args, **kwargs)
-        self.multiplexing_scheme = {1: (0, 1, 0), 2: (0, 2, 0), 3: (1, 2, 0), 4: (0, 0, 0), 5: (1, 1, 0), 6: (2, 2, 0), 7: (0, 1, 1), 8: (0, 2, 1), 9: (1, 2, 1), 10: (0, 0, 1), 11: (1, 1, 1), 12: (2, 2, 1)}
-        self.multiplexing_scheme_pool_format = {(0, 0): [1,2, 4,-4], (1, 0): [-1,3,5,-5], (2, 0): [-2, -3, 6, -6], (0, 1): [7, 8, 10, -10], (1,1): [-7, 9, 11, -11], (2, 1): [-8, -9, 12, -12]}
-        self.mock_samples = ['/test/folder/loom1.loom', '/test/folder/loom2.loom', '/test/folder/loom3.loom']
+        self.multiplexing_scheme = {
+            1: (0, 1, 0),
+            2: (0, 2, 0),
+            3: (1, 2, 0),
+            4: (0, 0, 0),
+            5: (1, 1, 0),
+            6: (2, 2, 0),
+            7: (0, 1, 1),
+            8: (0, 2, 1),
+            9: (1, 2, 1),
+            10: (0, 0, 1),
+            11: (1, 1, 1),
+            12: (2, 2, 1),
+        }
+        self.multiplexing_scheme_pool_format = {
+            (0, 0): [1, 2, 4, -4],
+            (1, 0): [-1, 3, 5, -5],
+            (2, 0): [-2, -3, 6, -6],
+            (0, 1): [7, 8, 10, -10],
+            (1, 1): [-7, 9, 11, -11],
+            (2, 1): [-8, -9, 12, -12],
+        }
+        self.mock_samples = [
+            '/test/folder/loom1.loom',
+            '/test/folder/loom2.loom',
+            '/test/folder/loom3.loom',
+        ]
         self.mock_path = '/another/test/folder'
-        self.another_pool_scheme = {(0,0): [1,-1,3], (1,0): [2,-2,-3]}
+        self.another_pool_scheme = {(0, 0): [1, -1, 3], (1, 0): [2, -2, -3]}
 
     def test_multiplexing_scheme_format2pool_format(self):
-        self.assertEqual(multiplexing_scheme_format2pool_format(self.multiplexing_scheme), self.multiplexing_scheme_pool_format)
-    
+        self.assertEqual(
+            multiplexing_scheme_format2pool_format(self.multiplexing_scheme),
+            self.multiplexing_scheme_pool_format,
+        )
+
     def test_select_samples_for_pooling(self):
-        self.assertEqual(select_samples_for_pooling(self.another_pool_scheme, self.mock_path, self.mock_samples), {'(0.0)': ['/another/test/folder/loom1_split1.loom', '/another/test/folder/loom1_split2.loom', '/another/test/folder/loom3_split1.loom'],
-                                                                                                             '(1.0)': ['/another/test/folder/loom2_split1.loom', '/another/test/folder/loom2_split2.loom', '/another/test/folder/loom3_split2.loom']})
+        self.assertEqual(
+            select_samples_for_pooling(
+                self.another_pool_scheme, self.mock_path, self.mock_samples
+            ),
+            {
+                '(0.0)': [
+                    '/another/test/folder/loom1_split1.loom',
+                    '/another/test/folder/loom1_split2.loom',
+                    '/another/test/folder/loom3_split1.loom',
+                ],
+                '(1.0)': [
+                    '/another/test/folder/loom2_split1.loom',
+                    '/another/test/folder/loom2_split2.loom',
+                    '/another/test/folder/loom3_split2.loom',
+                ],
+            },
+        )
 
     def test_define_demultiplexing_scheme_optimal_case(self):
-        self.assertEqual(define_demultiplexing_scheme_optimal_case(maximal_number_of_samples = 3, maximal_pool_size = 2, n_samples = 3), {1: (0, 1, 0), 2: (0, 0, 0), 3: (1, 1, 0)})
+        self.assertEqual(
+            define_demultiplexing_scheme_optimal_case(
+                maximal_number_of_samples=3, maximal_pool_size=2, n_samples=3
+            ),
+            {1: (0, 1, 0), 2: (0, 0, 0), 3: (1, 1, 0)},
+        )
+
 
 if __name__ == '__main__':
-    unittest.main()
\ No newline at end of file
+    unittest.main()
diff --git a/benchmarking_pipeline/tests/test_split_loom_files.py b/benchmarking_pipeline/tests/test_split_loom_files.py
index afdc70a..146299d 100644
--- a/benchmarking_pipeline/tests/test_split_loom_files.py
+++ b/benchmarking_pipeline/tests/test_split_loom_files.py
@@ -6,12 +6,19 @@
 import numpy as np
 import loompy
 
-logging.basicConfig(format="{asctime} - {levelname} - {message}", style="{", datefmt="%Y-%m-%d %H:%M",level=logging.DEBUG)
+logging.basicConfig(
+    format='{asctime} - {levelname} - {message}',
+    style='{',
+    datefmt='%Y-%m-%d %H:%M',
+    level=logging.DEBUG,
+)
 
 script_directory = Path(__file__).resolve().parent
 
-split_loom_file_path = (script_directory.parent / "workflow" / "scripts" / "split_loom_files.py").as_posix()
-spec = importlib.util.spec_from_file_location("split_loom_file", split_loom_file_path)
+split_loom_file_path = (
+    script_directory.parent / 'workflow' / 'scripts' / 'split_loom_files.py'
+).as_posix()
+spec = importlib.util.spec_from_file_location('split_loom_file', split_loom_file_path)
 target = importlib.util.module_from_spec(spec)
 spec.loader.exec_module(target)
 split_loom_file = target.split_loom_file
@@ -20,36 +27,42 @@
 class TestSplitLoomFiles(unittest.TestCase):
     def __init__(self, *args, **kwargs):
         super(TestSplitLoomFiles, self).__init__(*args, **kwargs)
-        self.filename = "split_loom_data_test.loom"
-        #matrix = np.arange(4).reshape(2,2)
-        #row_attrs = { "SomeRowAttr": np.arange(2), "OtherRowAttr": ['A','B'] }
-        #col_attrs = { "SomeColAttr": np.arange(2), "OtherColAttr": ['C','D'] }
-        #other_layer = np.arange(4, 8).reshape(2,2)
-        #loompy.create((script_directory / self.filename).as_posix(), layers = {"":matrix, "other": other_layer}, row_attrs = row_attrs, col_attrs = col_attrs)
+        self.filename = 'split_loom_data_test.loom'
+        # matrix = np.arange(4).reshape(2,2)
+        # row_attrs = { "SomeRowAttr": np.arange(2), "OtherRowAttr": ['A','B'] }
+        # col_attrs = { "SomeColAttr": np.arange(2), "OtherColAttr": ['C','D'] }
+        # other_layer = np.arange(4, 8).reshape(2,2)
+        # loompy.create((script_directory / self.filename).as_posix(), layers = {"":matrix, "other": other_layer}, row_attrs = row_attrs, col_attrs = col_attrs)
 
     def test_split_loom_files(self):
         loom_file = script_directory / self.filename
-        split_loom_file(0.5, script_directory / self.filename, loom_file.parent, seed = 42)
-        with loompy.connect(script_directory / 'temp' / f'{Path(self.filename).stem}_split1.loom') as ds:
+        split_loom_file(
+            0.5, script_directory / self.filename, loom_file.parent, seed=42
+        )
+        with loompy.connect(
+            script_directory / 'temp' / f'{Path(self.filename).stem}_split1.loom'
+        ) as ds:
             self.assertEqual(ds.shape[1], 1)
             self.assertEqual(ds.shape[0], 2)
-            self.assertTrue(np.all(ds[:,:] == np.array([[0], [2]])))
-            self.assertTrue(np.all(ds.layers["other"][:,:] == np.array([[4], [6]])))
-            self.assertTrue(np.all(ds.ca["SomeColAttr"] == np.array([0])))
-            self.assertTrue(np.all(ds.ca["OtherColAttr"] == np.array(['C'])))
-            self.assertTrue(np.all(ds.ra["SomeRowAttr"] == np.array([0, 1])))
-            self.assertTrue(np.all(ds.ra["OtherRowAttr"] == np.array(['A', 'B'])))
-
-        with loompy.connect(script_directory / 'temp' / f'{Path(self.filename).stem}_split2.loom') as ds:
+            self.assertTrue(np.all(ds[:, :] == np.array([[0], [2]])))
+            self.assertTrue(np.all(ds.layers['other'][:, :] == np.array([[4], [6]])))
+            self.assertTrue(np.all(ds.ca['SomeColAttr'] == np.array([0])))
+            self.assertTrue(np.all(ds.ca['OtherColAttr'] == np.array(['C'])))
+            self.assertTrue(np.all(ds.ra['SomeRowAttr'] == np.array([0, 1])))
+            self.assertTrue(np.all(ds.ra['OtherRowAttr'] == np.array(['A', 'B'])))
+
+        with loompy.connect(
+            script_directory / 'temp' / f'{Path(self.filename).stem}_split2.loom'
+        ) as ds:
             self.assertEqual(ds.shape[1], 1)
             self.assertEqual(ds.shape[0], 2)
-            self.assertTrue(np.all(ds[:,:] == np.array([[1], [3]])))
-            self.assertTrue(np.all(ds.layers["other"][:,:] == np.array([[5], [7]])))
-            self.assertTrue(np.all(ds.ca["SomeColAttr"] == np.array([1])))
-            self.assertTrue(np.all(ds.ca["OtherColAttr"] == np.array(['D'])))
-            self.assertTrue(np.all(ds.ra["SomeRowAttr"] == np.array([0, 1])))
-            self.assertTrue(np.all(ds.ra["OtherRowAttr"] == np.array(['A', 'B'])))
+            self.assertTrue(np.all(ds[:, :] == np.array([[1], [3]])))
+            self.assertTrue(np.all(ds.layers['other'][:, :] == np.array([[5], [7]])))
+            self.assertTrue(np.all(ds.ca['SomeColAttr'] == np.array([1])))
+            self.assertTrue(np.all(ds.ca['OtherColAttr'] == np.array(['D'])))
+            self.assertTrue(np.all(ds.ra['SomeRowAttr'] == np.array([0, 1])))
+            self.assertTrue(np.all(ds.ra['OtherRowAttr'] == np.array(['A', 'B'])))
+
 
 if __name__ == '__main__':
     unittest.main()
-
diff --git a/benchmarking_pipeline/workflow/Snakefile b/benchmarking_pipeline/workflow/Snakefile
index 2c584b7..d0bc779 100644
--- a/benchmarking_pipeline/workflow/Snakefile
+++ b/benchmarking_pipeline/workflow/Snakefile
@@ -1,19 +1,22 @@
 include: Path("rules/common.smk")
 
+
 rule all:
     input:
-        output_files
-    
+        output_files,
+
 
 checkpoint get_input_samples:
     output:
-        samples_file = output_folder / "seed_{seed}" / "sample_list_{pool_size}.txt"
+        samples_file=output_folder / "seed_{seed}" / "sample_list_{pool_size}.txt",
     params:
-        number_of_samples = lambda wildcards: int(int(wildcards.pool_size)*(int(wildcards.pool_size) + 1)/2),
-        loom_file_path = loom_file_path,
-        seed = lambda w: w.seed,
+        number_of_samples=lambda wildcards: int(
+            int(wildcards.pool_size) * (int(wildcards.pool_size) + 1) / 2
+        ),
+        loom_file_path=loom_file_path,
+        seed=lambda w: w.seed,
     log:
-        output_folder / 'log' / 'generate_input_samples_{seed}_{pool_size}.log'
+        output_folder / "log" / "generate_input_samples_{seed}_{pool_size}.log",
     shell:
         """
         python {workflow.basedir}/scripts/get_input_samples.py \
@@ -24,19 +27,20 @@ checkpoint get_input_samples:
         &> {log}
         """
 
-rule split_files: ### This part of the pipeline is still not fully reproducible, because the files are split randomly. If I use the seed though, every file will be split the same way, which is not what I want.
+
+rule split_files:  ### This part of the pipeline is still not fully reproducible, because the files are split randomly. If I use the seed though, every file will be split the same way, which is not what I want.
     input:
-        loom_file_path / '{loom_file}.loom',
+        loom_file_path / "{loom_file}.loom",
     output:
-        split1 = output_folder / "seed_{seed}" / "loom" / '{loom_file}_split1.loom',
-        split2 = output_folder / "seed_{seed}" / "loom" / '{loom_file}_split2.loom',
+        split1=output_folder / "seed_{seed}" / "loom" / "{loom_file}_split1.loom",
+        split2=output_folder / "seed_{seed}" / "loom" / "{loom_file}_split2.loom",
     conda:
         "envs/loom.yml"
     resources:
         mem_mb_per_cpu=3000,
         runtime=90,
     log:
-        output_folder / "seed_{seed}" / 'log' / 'split_{loom_file}.log',
+        output_folder / "seed_{seed}" / "log" / "split_{loom_file}.log",
     shell:
         """
         python {workflow.basedir}/scripts/split_loom_files.py \
@@ -49,14 +53,19 @@ rule split_files: ### This part of the pipeline is still not fully reproducible,
 
 checkpoint create_demultiplexing_scheme:
     input:
-        sample_file = output_folder / "seed_{seed}" / 'sample_list_{pool_size}.txt'
+        sample_file=output_folder / "seed_{seed}" / "sample_list_{pool_size}.txt",
     output:
-        pools = output_folder / "seed_{seed}" / "pools_{pool_size}_robust{robust}.txt"
+        pools=output_folder / "seed_{seed}" / "pools_{pool_size}_robust{robust}.txt",
     params:
-        number_of_samples = lambda wildcards: int(int(wildcards.pool_size)*(int(wildcards.pool_size) + 1)/2),
-        input_dir = lambda w: output_folder / f"seed_{w.seed}" / "loom"
+        number_of_samples=lambda wildcards: int(
+            int(wildcards.pool_size) * (int(wildcards.pool_size) + 1) / 2
+        ),
+        input_dir=lambda w: output_folder / f"seed_{w.seed}" / "loom",
     log:
-        output_folder / "seed_{seed}" /  'log' / 'create_demultiplexing_scheme_{seed}_{pool_size}_robust{robust}.log'
+        output_folder
+        / "seed_{seed}"
+        / "log"
+        / "create_demultiplexing_scheme_{seed}_{pool_size}_robust{robust}.log",
     shell:
         """
         python {workflow.basedir}/scripts/create_demultiplexing_scheme.py \
@@ -70,16 +79,22 @@ checkpoint create_demultiplexing_scheme:
         """
 
 
-rule pooling: ##Need to introduce proper downsampling of cells
+rule pooling:  ##Need to introduce proper downsampling of cells
     input:
-        split_files = lambda w: yaml.safe_load(
+        split_files=lambda w: yaml.safe_load(
             open(
-                checkpoints.create_demultiplexing_scheme.get(seed=w.seed, pool_size = w.pool_size, robust = w.robust).output.pools
-                )
-            )[f"({w.pool_ID})"],
-        pool_data = output_folder / "seed_{seed}" / "pools_{pool_size}_robust{robust}.txt",
+                checkpoints.create_demultiplexing_scheme.get(
+                    seed=w.seed, pool_size=w.pool_size, robust=w.robust
+                ).output.pools
+            )
+        )[f"({w.pool_ID})"],
+        pool_data=output_folder / "seed_{seed}" / "pools_{pool_size}_robust{robust}.txt",
     output:
-        output_folder / "seed_{seed}" / paramspace.wildcard_pattern  / "pools" / "pool_{pool_ID}_{seed}.csv"
+        output_folder
+        / "seed_{seed}"
+        / paramspace.wildcard_pattern
+        / "pools"
+        / "pool_{pool_ID}_{seed}.csv",
     params:
         minGQ=config.get("mosaic", {}).get("minGQ", 30),
         minDP=config.get("mosaic", {}).get("minDP", 10),
@@ -91,15 +106,19 @@ rule pooling: ##Need to introduce proper downsampling of cells
         minHomVAF=config.get("mosaic", {}).get("minHomVAF", 0.95),
         minHetVAF=config.get("mosaic", {}).get("minHetVAF", 0.35),
         proximity=config.get("mosaic", {}).get("proximity", "25 50 100 200"),
-        ratios = lambda w, input: sample_mixing_ratios(w.seed, len(input.split_files)),
+        ratios=lambda w, input: sample_mixing_ratios(w.seed, len(input.split_files)),
     resources:
         mem_mb_per_cpu=10000,
         runtime=90,
     conda:
         "envs/mosaic.yml"
-    group: "demultiplexing_simulation"
+    group:
+        "demultiplexing_simulation"
     log:
-        output_folder / "seed_{seed}" / 'log' / 'pooling_{pool_ID}_{seed}_{doublet_rate}_{cell_count}_{pool_size}_robust{robust}.log'
+        output_folder
+        / "seed_{seed}"
+        / "log"
+        / "pooling_{pool_ID}_{seed}_{doublet_rate}_{cell_count}_{pool_size}_robust{robust}.log",
     shell:
         """
         python {workflow.basedir}/scripts/mosaic_processing.py \
@@ -124,29 +143,50 @@ rule pooling: ##Need to introduce proper downsampling of cells
 
 
 rule demultiplexing_demoTape:
-        input:
-            variants = output_folder /  "seed_{seed}" / paramspace.wildcard_pattern / "pools" / "pool_{pool_ID}_{seed}.csv",
-            pools = output_folder / "seed_{seed}" / "pools_{pool_size}_robust{robust}.txt",
-        output:
-            profiles = output_folder / "seed_{seed}" / paramspace.wildcard_pattern / 'demultiplexed' / 'pool_{pool_ID}_{seed}_demultiplexed.profiles.tsv',
-            assignments = output_folder / "seed_{seed}" / paramspace.wildcard_pattern / 'demultiplexed' / 'pool_{pool_ID}_{seed}_demultiplexed.assignments.tsv'
-        params:
-            cluster_number= lambda w, input: determine_number_of_different_donors(
-                checkpoints.create_demultiplexing_scheme.get(seed=w.seed, pool_size = w.pool_size, robust = w.robust).output.pools,
-                "({})".format(w.pool_ID)
-                ),
-            outbase = lambda w: output_folder / "seed_{}".format(w.seed) / paramspace.wildcard_pattern / 'demultiplexed' / "pool_{}_{}_demultiplexed".format(w.pool_ID, w.seed),
-            output_folder = output_folder
-        resources:
-            mem_mb_per_cpu=4096,
-            runtime=90,
-        conda:
-            "envs/sample_assignment.yml"
-        group: "demultiplexing_simulation"
-        log:
-            output_folder / 'log' / 'demultiplexing_{pool_ID}_{seed}_{doublet_rate}_{cell_count}_{pool_size}_robust{robust}.log',
-        shell:
-            """
+    input:
+        variants=output_folder
+        / "seed_{seed}"
+        / paramspace.wildcard_pattern
+        / "pools"
+        / "pool_{pool_ID}_{seed}.csv",
+        pools=output_folder / "seed_{seed}" / "pools_{pool_size}_robust{robust}.txt",
+    output:
+        profiles=output_folder
+        / "seed_{seed}"
+        / paramspace.wildcard_pattern
+        / "demultiplexed"
+        / "pool_{pool_ID}_{seed}_demultiplexed.profiles.tsv",
+        assignments=output_folder
+        / "seed_{seed}"
+        / paramspace.wildcard_pattern
+        / "demultiplexed"
+        / "pool_{pool_ID}_{seed}_demultiplexed.assignments.tsv",
+    params:
+        cluster_number=lambda w, input: determine_number_of_different_donors(
+            checkpoints.create_demultiplexing_scheme.get(
+                seed=w.seed, pool_size=w.pool_size, robust=w.robust
+            ).output.pools,
+            "({})".format(w.pool_ID),
+        ),
+        outbase=lambda w: output_folder
+        / "seed_{}".format(w.seed)
+        / paramspace.wildcard_pattern
+        / "demultiplexed"
+        / "pool_{}_{}_demultiplexed".format(w.pool_ID, w.seed),
+        output_folder=output_folder,
+    resources:
+        mem_mb_per_cpu=4096,
+        runtime=90,
+    conda:
+        "envs/sample_assignment.yml"
+    group:
+        "demultiplexing_simulation"
+    log:
+        output_folder
+        / "log"
+        / "demultiplexing_{pool_ID}_{seed}_{doublet_rate}_{cell_count}_{pool_size}_robust{robust}.log",
+    shell:
+        """
             python {workflow.basedir}/scripts/demultiplex_distance.py \
                 --input {input.variants} \
                 -n {params.cluster_number} \
@@ -156,20 +196,28 @@ rule demultiplexing_demoTape:
 
 
 rule label_samples:
-        input:
-            demultiplexing_assignments = lambda wildcards: get_all_demultiplexed_assignments(wildcards, paramspace.wildcard_pattern),
-        output:
-            output_folder / "seed_{seed}" / paramspace.wildcard_pattern / 'sample_identity.yaml',
-        resources:
-            mem_mb_per_cpu=4096,
-            runtime=90,
-        conda:
-            "envs/sample_assignment.yml"
-        group: "assess_simulation_results"
-        log:
-            output_folder / 'log' / 'label_samples.{seed}_{doublet_rate}_{cell_count}_{pool_size}_robust{robust}.log',
-        shell:
-            """
+    input:
+        demultiplexing_assignments=lambda wildcards: get_all_demultiplexed_assignments(
+            wildcards, paramspace.wildcard_pattern
+        ),
+    output:
+        output_folder
+        / "seed_{seed}"
+        / paramspace.wildcard_pattern
+        / "sample_identity.yaml",
+    resources:
+        mem_mb_per_cpu=4096,
+        runtime=90,
+    conda:
+        "envs/sample_assignment.yml"
+    group:
+        "assess_simulation_results"
+    log:
+        output_folder
+        / "log"
+        / "label_samples.{seed}_{doublet_rate}_{cell_count}_{pool_size}_robust{robust}.log",
+    shell:
+        """
             python {workflow.basedir}/scripts/assign_subsample_simulation.py \
                 --demultiplexing_assignment {input.demultiplexing_assignments} \
                 --output {output} \
@@ -178,17 +226,28 @@ rule label_samples:
 
 
 rule assign_samples:
-    input: 
-        demultiplexing_genotypes = lambda wildcards: get_all_demultiplexed_profiles(wildcards, paramspace.wildcard_pattern),
-        pools = output_folder / "seed_{seed}" / "pools_{pool_size}_robust{robust}.txt",
+    input:
+        demultiplexing_genotypes=lambda wildcards: get_all_demultiplexed_profiles(
+            wildcards, paramspace.wildcard_pattern
+        ),
+        pools=output_folder / "seed_{seed}" / "pools_{pool_size}_robust{robust}.txt",
     output:
-        sample_assignment = output_folder / "seed_{seed}" / paramspace.wildcard_pattern / 'sample_assignment.yaml',
-        heatmap = output_folder / "seed_{seed}" / paramspace.wildcard_pattern / 'sample_assignment_heatmap.png',
+        sample_assignment=output_folder
+        / "seed_{seed}"
+        / paramspace.wildcard_pattern
+        / "sample_assignment.yaml",
+        heatmap=output_folder
+        / "seed_{seed}"
+        / paramspace.wildcard_pattern
+        / "sample_assignment_heatmap.png",
     conda:
         "envs/sample_assignment.yml"
-    group: "assess_simulation_results"
+    group:
+        "assess_simulation_results"
     log:
-        output_folder / 'log' / 'assign_samples.{seed}_{doublet_rate}_{cell_count}_{pool_size}_robust{robust}.log',
+        output_folder
+        / "log"
+        / "assign_samples.{seed}_{doublet_rate}_{cell_count}_{pool_size}_robust{robust}.log",
     shell:
         """
         python {workflow.basedir}/scripts/compare_distances.py \
@@ -198,4 +257,4 @@ rule assign_samples:
             --output {output.sample_assignment} \
             --robust {wildcards.robust} \
             &> {log}
-        """
\ No newline at end of file
+        """
diff --git a/benchmarking_pipeline/workflow/rules/common.smk b/benchmarking_pipeline/workflow/rules/common.smk
index 685aa17..27a75e5 100644
--- a/benchmarking_pipeline/workflow/rules/common.smk
+++ b/benchmarking_pipeline/workflow/rules/common.smk
@@ -7,12 +7,14 @@ import numpy as np
 import pandas as pd
 from snakemake.utils import Paramspace
 
-paramspace = Paramspace(pd.read_csv(Path(workflow.basedir) / 'sandbox' / 'parameter_space.tsv', sep='\t'))
+paramspace = Paramspace(
+    pd.read_csv(Path(workflow.basedir) / "sandbox" / "parameter_space.tsv", sep="\t")
+)
 
 output_folder = Path(config["output_folder"])
 loom_file_path = Path(config["loom_files"])
-#number_of_samples = config.get("n_samples", 10)
-seed=config["seed"]
+# number_of_samples = config.get("n_samples", 10)
+seed = config["seed"]
 
 loom_files = []
 for p in loom_file_path.iterdir():
@@ -20,11 +22,9 @@ for p in loom_file_path.iterdir():
         loom_files.append(p)
 
 
-
-
 def get_variant_list(pool_ID):
     file = output_folder / f"pool_{pool_ID}.txt"
-    with open(file, 'r') as f:
+    with open(file, "r") as f:
         return [line.strip() for line in f]
 
 
@@ -32,7 +32,7 @@ def get_split_files(path):
     if isinstance(path, str):
         path = Path(path)
     split_files = []
-    with open(path, 'r') as f:
+    with open(path, "r") as f:
         for line in f:
             path2 = Path(line.strip())
             split_files.append(path.parent / "loom" / f"{path2.stem}_split1.loom")
@@ -41,23 +41,28 @@ def get_split_files(path):
     return split_files
 
 
-
 def sample_mixing_ratios(seed, max_pool_size):
-    concentrations = [40/max_pool_size] * max_pool_size
+    concentrations = [40 / max_pool_size] * max_pool_size
     rng = np.random.default_rng(seed=int(seed))
     ratios = rng.dirichlet(concentrations)
-    
+
     return " ".join(str(item) for item in ratios)
 
 
 def get_demultiplexed_samples(wildcards):
-    pools = yaml.safe_load(open(checkpoints.create_demultiplexing_scheme.get(seed=wildcard.seed).output, 'r'))
+    pools = yaml.safe_load(
+        open(
+            checkpoints.create_demultiplexing_scheme.get(seed=wildcard.seed).output, "r"
+        )
+    )
     pools = list(pools.keys())
-    demultiplexed_samples = [f"pool_{pool}_{wildcard.seed}_demultiplexed.assignments.tsv" for pool in pools]
+    demultiplexed_samples = [
+        f"pool_{pool}_{wildcard.seed}_demultiplexed.assignments.tsv" for pool in pools
+    ]
+
 
-        
 def determine_number_of_different_donors(pool_file, pool_ID):
-    samples = yaml.safe_load(open(pool_file, 'r'))[pool_ID]
+    samples = yaml.safe_load(open(pool_file, "r"))[pool_ID]
     samples = [Path(sample) for sample in samples]
     unique_patterns = set()
     for sample in samples:
@@ -66,35 +71,66 @@ def determine_number_of_different_donors(pool_file, pool_ID):
                 unique_patterns.add(sample.stem.replace("_split1", ""))
             elif "_split2" in sample.stem:
                 unique_patterns.add(sample.stem.replace("_split2", ""))
-                
+
     return len(unique_patterns)
 
 
 def get_all_demultiplexed_assignments(wildcards, wildcard_pattern):
-    demultiplexing_scheme = yaml.safe_load(open(checkpoints.create_demultiplexing_scheme.get(seed=wildcards.seed, pool_size = wildcards.pool_size, robust = wildcards.robust).output.pools, 'r'))
+    demultiplexing_scheme = yaml.safe_load(
+        open(
+            checkpoints.create_demultiplexing_scheme.get(
+                seed=wildcards.seed,
+                pool_size=wildcards.pool_size,
+                robust=wildcards.robust,
+            ).output.pools,
+            "r",
+        )
+    )
     pool_names = list(demultiplexing_scheme.keys())
     pool_names = [pool_name.strip("()") for pool_name in pool_names]
-    
+
     demultiplexed_files = []
-    file_path = output_folder / f"seed_{wildcards.seed}" / wildcard_pattern / 'demultiplexed' 
+    file_path = (
+        output_folder / f"seed_{wildcards.seed}" / wildcard_pattern / "demultiplexed"
+    )
     for pool in pool_names:
-        demultiplexed_files.append(file_path / f"pool_{pool}_{wildcards.seed}_demultiplexed.assignments.tsv")
+        demultiplexed_files.append(
+            file_path / f"pool_{pool}_{wildcards.seed}_demultiplexed.assignments.tsv"
+        )
 
     return demultiplexed_files
 
 
 def get_all_demultiplexed_profiles(wildcards, wildcard_pattern):
-    demultiplexing_scheme = yaml.safe_load(open(checkpoints.create_demultiplexing_scheme.get(seed=wildcards.seed, pool_size = wildcards.pool_size, robust = wildcards.robust).output.pools, 'r'))
+    demultiplexing_scheme = yaml.safe_load(
+        open(
+            checkpoints.create_demultiplexing_scheme.get(
+                seed=wildcards.seed,
+                pool_size=wildcards.pool_size,
+                robust=wildcards.robust,
+            ).output.pools,
+            "r",
+        )
+    )
     pool_names = list(demultiplexing_scheme.keys())
     pool_names = [pool_name.strip("()") for pool_name in pool_names]
-    
+
     demultiplexed_files = []
-    file_path = output_folder / f"seed_{wildcards.seed}" / wildcard_pattern / 'demultiplexed' 
+    file_path = (
+        output_folder / f"seed_{wildcards.seed}" / wildcard_pattern / "demultiplexed"
+    )
     for pool in pool_names:
-        demultiplexed_files.append(file_path / f"pool_{pool}_{wildcards.seed}_demultiplexed.profiles.tsv")
+        demultiplexed_files.append(
+            file_path / f"pool_{pool}_{wildcards.seed}_demultiplexed.profiles.tsv"
+        )
 
     return demultiplexed_files
 
 
-
-output_files = expand(output_folder / f"seed_{seed}" / "{params}" / 'sample_assignment.yaml', params=paramspace.instance_patterns) + expand(output_folder / f"seed_{seed}" / "{params}" / 'sample_identity.yaml', params=paramspace.instance_patterns)
+output_files = expand(
+    output_folder / f"seed_{seed}" / "{params}" / "sample_assignment.yaml",
+    params=paramspace.instance_patterns,
+) + expand(
+    output_folder / f"seed_{seed}" / "{params}" / "sample_identity.yaml",
+    params=paramspace.instance_patterns,
+)
diff --git a/benchmarking_pipeline/workflow/sandbox/compute_performance_score.py b/benchmarking_pipeline/workflow/sandbox/compute_performance_score.py
index 9aa4227..806a899 100644
--- a/benchmarking_pipeline/workflow/sandbox/compute_performance_score.py
+++ b/benchmarking_pipeline/workflow/sandbox/compute_performance_score.py
@@ -6,7 +6,12 @@
 from sklearn.metrics import v_measure_score
 
 
-logging.basicConfig(format="{asctime} - {levelname} - {message}", style="{", datefmt="%Y-%m-%d %H:%M",level=logging.INFO)
+logging.basicConfig(
+    format='{asctime} - {levelname} - {message}',
+    style='{',
+    datefmt='%Y-%m-%d %H:%M',
+    level=logging.INFO,
+)
 
 
 def compute_score(labels, estimate):
@@ -18,10 +23,8 @@ def compute_score(labels, estimate):
         for key2, sample_assignment in assignments:
             if sample_assignment != labels[str(key1)][str(key2)]:
                 mismatches += 1
-            
-    return mismatches / total_counts
-
 
+    return mismatches / total_counts
 
 
 def load_data(path):
@@ -32,7 +35,7 @@ def load_data(path):
 
     with open(path / 'sample_assignment.yaml', 'r') as file:
         sample_assignment = yaml.safe_load(file)
-    
+
     return sample_identity, sample_assignment
 
 
@@ -47,7 +50,9 @@ def preprocess_labels(sample_identity):
             value_dict[sub_key] = int(value)
 
         if len(values) != len(set(values)):
-            logging.error(f"Duplicate values found in sample_identity for key {key}: {values}")
+            logging.error(
+                f'Duplicate values found in sample_identity for key {key}: {values}'
+            )
             pathological_sample_identity = True
     return pathological_sample_identity, sample_identity
 
@@ -60,9 +65,11 @@ def preprocess_sample_assignment(sample_assignment):
 
 def process_simulation_run(path):
     sample_identity, sample_assignment = load_data(path)
-    sample_identity_is_pathological, sample_identity = preprocess_labels(sample_identity)
+    sample_identity_is_pathological, sample_identity = preprocess_labels(
+        sample_identity
+    )
     preprocess_sample_assignment(sample_assignment)
-    
+
     logging.debug(sample_identity_is_pathological)
     logging.debug(sample_assignment)
     logging.debug(sample_identity)
@@ -72,17 +79,22 @@ def process_simulation_run(path):
     return sample_identity_is_pathological, score
 
 
-
 def compute_clustering_performance_score(subfolder):
     demultiplexed_folder = subfolder / 'demultiplexed'
 
     v_measures = []
     for assignment_file in demultiplexed_folder.glob('*assignments.tsv'):
         df = pd.read_csv(assignment_file, sep='\t')
-        
+
         # Extract ground truth labels and cluster assignments
 
-        ground_truth_labels = df.columns[1:].map(lambda col: 'doublet' if '+' in col else col.split('_')[1].split('-')[1]).tolist()
+        ground_truth_labels = (
+            df.columns[1:]
+            .map(
+                lambda col: 'doublet' if '+' in col else col.split('_')[1].split('-')[1]
+            )
+            .tolist()
+        )
         cluster_assignments = df.iloc[0, 1:].tolist()
 
         # Compute ARI
@@ -100,31 +112,46 @@ def main(input_dir):
             for subfolder in subfolder_0.glob('*0'):
                 sample_identity_path = subfolder / 'sample_identity.yaml'
                 sample_assignment_path = subfolder / 'sample_assignment.yaml'
-                
+
                 if sample_identity_path.exists() and sample_assignment_path.exists():
-                    sample_identity_is_pathological, score = process_simulation_run(subfolder)
+                    sample_identity_is_pathological, score = process_simulation_run(
+                        subfolder
+                    )
 
                     mean_v_score = compute_clustering_performance_score(subfolder)
 
                     seed_number = int(seed_folder.name.split('_')[1])
-                    
-                    results.append({
-                        'seed': seed_number,
-                        'doublet_rate': float(subfolder_0.name),
-                        'cell_count': int(subfolder.name),
-                        'score': score,
-                        'sample_identity_is_pathological': sample_identity_is_pathological,
-                        'v_score': mean_v_score
-                    })
+
+                    results.append(
+                        {
+                            'seed': seed_number,
+                            'doublet_rate': float(subfolder_0.name),
+                            'cell_count': int(subfolder.name),
+                            'score': score,
+                            'sample_identity_is_pathological': sample_identity_is_pathological,
+                            'v_score': mean_v_score,
+                        }
+                    )
 
                     # Periodically save results to a file to avoid memory burden
                     if len(results) % 100 == 0:
                         temp_df = pd.DataFrame(results)
-                        temp_df.to_csv('more_intermediate_results.csv', mode='a', header=False, index=False)
+                        temp_df.to_csv(
+                            'more_intermediate_results.csv',
+                            mode='a',
+                            header=False,
+                            index=False,
+                        )
                         results.clear()
     if results:
         df = pd.DataFrame(results)
-        df.to_csv('more_intermediate_results.csv', mode='a', header=not Path('more_intermediate_results.csv').exists(), index=False)
+        df.to_csv(
+            'more_intermediate_results.csv',
+            mode='a',
+            header=not Path('more_intermediate_results.csv').exists(),
+            index=False,
+        )
+
 
 """ def main(input_dir):
     if not isinstance(input_dir, Path):
@@ -168,7 +195,9 @@ def main(input_dir):
  """
 
 if __name__ == '__main__':
-    input_dir = Path('/cluster/work/bewi/members/jgawron/projects/Demultiplexing/AML_data/output')# = args.input
-    
-    #process_simulation_run('/cluster/work/bewi/members/jgawron/projects/Demultiplexing/AML_data/output/seed_4/robust~True/pool_size~4/doublet_rate~0.05/cell_count~1000')
+    input_dir = Path(
+        '/cluster/work/bewi/members/jgawron/projects/Demultiplexing/AML_data/output'
+    )  # = args.input
+
+    # process_simulation_run('/cluster/work/bewi/members/jgawron/projects/Demultiplexing/AML_data/output/seed_4/robust~True/pool_size~4/doublet_rate~0.05/cell_count~1000')
     main(input_dir)
diff --git a/benchmarking_pipeline/workflow/sandbox/figure3.py b/benchmarking_pipeline/workflow/sandbox/figure3.py
index e10c36a..cad4bf6 100644
--- a/benchmarking_pipeline/workflow/sandbox/figure3.py
+++ b/benchmarking_pipeline/workflow/sandbox/figure3.py
@@ -4,48 +4,80 @@
 import matplotlib.pyplot as plt
 import seaborn as sns
 
-logging.basicConfig(format="{asctime} - {levelname} - {message}", style="{", datefmt="%Y-%m-%d %H:%M",level=logging.INFO)
-
-data = pd.read_csv('intermediate_results.csv', header = None)
-data.columns = ['seed', 'doublet_rate', 'cell_count', 'score', 'sample_identity_is_pathological', 'v_score']
-sns.set_theme(style="whitegrid")
+logging.basicConfig(
+    format='{asctime} - {levelname} - {message}',
+    style='{',
+    datefmt='%Y-%m-%d %H:%M',
+    level=logging.INFO,
+)
+
+data = pd.read_csv('intermediate_results.csv', header=None)
+data.columns = [
+    'seed',
+    'doublet_rate',
+    'cell_count',
+    'score',
+    'sample_identity_is_pathological',
+    'v_score',
+]
+sns.set_theme(style='whitegrid')
 
 data = data.loc[data['cell_count'].isin([1000, 2000, 3000, 4000])]
 
-data['score'] = 1-data['score']
+data['score'] = 1 - data['score']
 # Customize the appearance of the plots
-plt.rcParams.update({
-    'axes.titlesize': 26,
-    'axes.labelsize': 26,
-    'xtick.labelsize': 20,
-    'ytick.labelsize': 26,
-    'legend.fontsize': 16,
-    'figure.titlesize': 18
-})
+plt.rcParams.update(
+    {
+        'axes.titlesize': 26,
+        'axes.labelsize': 26,
+        'xtick.labelsize': 20,
+        'ytick.labelsize': 26,
+        'legend.fontsize': 16,
+        'figure.titlesize': 18,
+    }
+)
 fig, axes = plt.subplots(1, 5, figsize=(30, 6), sharey=True)
 
 doublet_rates = [0, 0.02, 0.04, 0.06, 0.08]
 
 for ax, rate in zip(axes, doublet_rates):
-    sns.boxplot(x='cell_count', y='score', data=data.loc[(~data['sample_identity_is_pathological']) & (data['doublet_rate'] == rate)], ax=ax, color='#C97B84',flierprops=dict(marker='o', color='red', markersize=5, markerfacecolor='black'))
+    sns.boxplot(
+        x='cell_count',
+        y='score',
+        data=data.loc[
+            (~data['sample_identity_is_pathological']) & (data['doublet_rate'] == rate)
+        ],
+        ax=ax,
+        color='#C97B84',
+        flierprops=dict(marker='o', color='red', markersize=5, markerfacecolor='black'),
+    )
     ax.set_title(f'Doublet Rate: {rate}')
     ax.set_xlabel('Cell Count')
 
 axes[0].set_ylabel('Sample assignment score')
 plt.tight_layout()
 
-logging.info("Saving figure4a.png")
-plt.savefig("figure4a.png", dpi = 300)
+logging.info('Saving figure4a.png')
+plt.savefig('figure4a.png', dpi=300)
 
 plt.close()
 
 
-data = pd.read_csv('results_for_robust.csv', header = 0)
-data.columns = ['seed', 'doublet_rate', 'cell_count', 'score', 'sample_identity_is_pathological', 'v_score']
-data['score'] = 1-data['score']
-data = data[data['score'] != 1] ### This outlier datapoint goes back to a mistake in the simulation! 
+data = pd.read_csv('results_for_robust.csv', header=0)
+data.columns = [
+    'seed',
+    'doublet_rate',
+    'cell_count',
+    'score',
+    'sample_identity_is_pathological',
+    'v_score',
+]
+data['score'] = 1 - data['score']
+data = data[
+    data['score'] != 1
+]  ### This outlier datapoint goes back to a mistake in the simulation!
 plt.figure(figsize=(10, 8))
-plt.scatter(x='v_score', y='score', data = data, color = '#C97B84')
+plt.scatter(x='v_score', y='score', data=data, color='#C97B84')
 
 # Customize the plot
 plt.grid(True)
@@ -60,13 +92,19 @@
 plt.gca().tick_params(axis='x', colors='#444444')
 plt.gca().tick_params(axis='y', colors='#444444')
 
-logging.info("Saving figure4b.png")
-plt.savefig('figure4b.png', dpi = 300)
+logging.info('Saving figure4b.png')
+plt.savefig('figure4b.png', dpi=300)
 plt.close()
 
 
 plt.figure(figsize=(10, 8))
-sns.boxplot(x='sample_identity_is_pathological', y='score', data=data, color='#C97B84', flierprops=dict(marker='o', color='red', markersize=5, markerfacecolor='black'))
+sns.boxplot(
+    x='sample_identity_is_pathological',
+    y='score',
+    data=data,
+    color='#C97B84',
+    flierprops=dict(marker='o', color='red', markersize=5, markerfacecolor='black'),
+)
 
 # Customize the plot
 plt.grid(True)
@@ -81,8 +119,7 @@
 plt.gca().tick_params(axis='x', colors='#444444')
 plt.gca().tick_params(axis='y', colors='#444444')
 
-logging.info("Saving figure4c.png")
-plt.savefig('figure4c.png', dpi = 300)
+logging.info('Saving figure4c.png')
+plt.savefig('figure4c.png', dpi=300)
 plt.close()
 # Show the plot
-
diff --git a/benchmarking_pipeline/workflow/sandbox/figure4.py b/benchmarking_pipeline/workflow/sandbox/figure4.py
index e258ed8..eee9d30 100644
--- a/benchmarking_pipeline/workflow/sandbox/figure4.py
+++ b/benchmarking_pipeline/workflow/sandbox/figure4.py
@@ -2,27 +2,45 @@
 import matplotlib.pyplot as plt
 import seaborn as sns
 
-data = pd.read_csv('intermediate_results.csv', header = None)
-data.columns = ['seed', 'doublet_rate', 'cell_count', 'score', 'sample_identity_is_pathological', 'v_score']
-sns.set_theme(style="whitegrid")
+data = pd.read_csv('intermediate_results.csv', header=None)
+data.columns = [
+    'seed',
+    'doublet_rate',
+    'cell_count',
+    'score',
+    'sample_identity_is_pathological',
+    'v_score',
+]
+sns.set_theme(style='whitegrid')
 
 # Customize the appearance of the plots
-plt.rcParams.update({
-    'axes.titlesize': 16,
-    'axes.labelsize': 14,
-    'xtick.labelsize': 12,
-    'ytick.labelsize': 12,
-    'legend.fontsize': 12,
-    'figure.titlesize': 18
-})
+plt.rcParams.update(
+    {
+        'axes.titlesize': 16,
+        'axes.labelsize': 14,
+        'xtick.labelsize': 12,
+        'ytick.labelsize': 12,
+        'legend.fontsize': 12,
+        'figure.titlesize': 18,
+    }
+)
 fig, axes = plt.subplots(1, 6, figsize=(30, 6), sharey=True)
 
 doublet_rates = [0, 0.02, 0.04, 0.06, 0.08]
 
 for ax, rate in zip(axes, doublet_rates):
-    sns.boxplot(x='cell_count', y='score', data=data.loc[(~data['sample_identity_is_pathological']) & (data['doublet_rate'] == rate)], ax=ax, color='#C97B84',flierprops=dict(marker='o', color='red', markersize=5, markerfacecolor='black'))
+    sns.boxplot(
+        x='cell_count',
+        y='score',
+        data=data.loc[
+            (~data['sample_identity_is_pathological']) & (data['doublet_rate'] == rate)
+        ],
+        ax=ax,
+        color='#C97B84',
+        flierprops=dict(marker='o', color='red', markersize=5, markerfacecolor='black'),
+    )
     ax.set_title(f'Doublet Rate: {rate}')
     ax.set_xlabel('Cell Count')
 axes[0].set_ylabel('Score')
 plt.tight_layout()
-plt.savefig("figure3.png")
+plt.savefig('figure3.png')
diff --git a/benchmarking_pipeline/workflow/sandbox/generate_parameter_space.py b/benchmarking_pipeline/workflow/sandbox/generate_parameter_space.py
index aa59ff1..7b1e6da 100644
--- a/benchmarking_pipeline/workflow/sandbox/generate_parameter_space.py
+++ b/benchmarking_pipeline/workflow/sandbox/generate_parameter_space.py
@@ -8,7 +8,9 @@
 robust_values = [True, False]
 
 # Generate all combinations of parameters
-combinations = list(itertools.product(robust_values, pool_sizes, doublet_rates, cell_counts))
+combinations = list(
+    itertools.product(robust_values, pool_sizes, doublet_rates, cell_counts)
+)
 
 # Define the output file path
 output_file = 'parameter_space.tsv'
@@ -22,4 +24,4 @@
     for combination in combinations:
         writer.writerow(combination)
 
-print(f"Parameter space written to {output_file}")
\ No newline at end of file
+print(f'Parameter space written to {output_file}')
diff --git a/benchmarking_pipeline/workflow/sandbox/get_read_count_distribution_of_samples.py b/benchmarking_pipeline/workflow/sandbox/get_read_count_distribution_of_samples.py
deleted file mode 100644
index f02abea..0000000
--- a/benchmarking_pipeline/workflow/sandbox/get_read_count_distribution_of_samples.py
+++ /dev/null
@@ -1,38 +0,0 @@
-from pathlib import Path
-
-import loompy
-import pandas as pd
-import numpy as np
-
-def process_loom_files(folder_path):
-    folder = Path(folder_path)
-    files = []
-    cell_counts = []
-    for filename in Path(folder).rglob("*.loom"):
-        with loompy.connect(str(filename)) as ds:
-            cell_counts.append(ds.shape[1])
-            files.append(filename)
-    data = pd.DataFrame({'cell_count': cell_counts}, index = files)
-    import matplotlib.pyplot as plt
-
-    # Plot the histogram of cell counts
-    plt.hist(cell_counts, bins=30, density=True, alpha=0.6, color='g', label='Cell count histogram')
-
-    # Plot the Poisson distribution density
-
-    # Fit a Normal distribution to the cell counts
-    mean_count = np.mean(cell_counts)
-    std_dev = np.std(cell_counts)
- 
-    x = np.arange(0, max(cell_counts) + 1)
-    from scipy.stats import norm
-    plt.plot(x, norm.pdf(x, mean_count, std_dev), 'r-', label='Normal fit')
-    
-    plt.xlabel('Cell count')
-    plt.ylabel('Density')
-    plt.title('Histogram of Cell Counts with Poisson Distribution Fit')
-    plt.legend()
-    plt.savefig(folder / 'cell_count_distribution.png')
-
-folder_path = '/cluster/work/bewi/members/jgawron/projects/Demultiplexing/AML_data/loom_files_complete'
-process_loom_files(folder_path)
\ No newline at end of file
diff --git a/benchmarking_pipeline/workflow/scripts/assign_subsample_simulation.py b/benchmarking_pipeline/workflow/scripts/assign_subsample_simulation.py
index 211a878..f9f3e29 100644
--- a/benchmarking_pipeline/workflow/scripts/assign_subsample_simulation.py
+++ b/benchmarking_pipeline/workflow/scripts/assign_subsample_simulation.py
@@ -8,27 +8,44 @@
 
 logging.basicConfig(level=logging.INFO)
 
+
 def parse_args():
     parser = argparse.ArgumentParser(description='Assign subsample simulation')
-    parser.add_argument('--demultiplexing_assignment', nargs = '+', type=str, help='A tsv file that specifies which cell is assigned to which cluster')
-    parser.add_argument('--output', type=str, help='Output file for the sample assignment')
+    parser.add_argument(
+        '--demultiplexing_assignment',
+        nargs='+',
+        type=str,
+        help='A tsv file that specifies which cell is assigned to which cluster',
+    )
+    parser.add_argument(
+        '--output', type=str, help='Output file for the sample assignment'
+    )
 
     return parser.parse_args()
 
 
 def analyse_demultiplexing_assignment(assignment_file, pool_name):
-    data = pd.read_csv(assignment_file, sep='\t', index_col = 0, header = None)
+    data = pd.read_csv(assignment_file, sep='\t', index_col=0, header=None)
 
-    data.iloc[0, :] = data.iloc[0, :].map(lambda x: "+".join([part.split('_')[1].split('.')[0].split("-")[1] for part in x.split('+')]) if '+' in x else x.split('_')[1].split('.')[0].split("-")[1])
+    data.iloc[0, :] = data.iloc[0, :].map(
+        lambda x: '+'.join(
+            [part.split('_')[1].split('.')[0].split('-')[1] for part in x.split('+')]
+        )
+        if '+' in x
+        else x.split('_')[1].split('.')[0].split('-')[1]
+    )
 
-    unique_values = data.iloc[1,:].unique()
+    unique_values = data.iloc[1, :].unique()
     colors = plt.cm.tab20.colors
-    color_map = {label: colors[i % len(colors)] for i, label in enumerate(data.iloc[0, :].unique())}    
-    
+    color_map = {
+        label: colors[i % len(colors)]
+        for i, label in enumerate(data.iloc[0, :].unique())
+    }
+
     fig, axes = plt.subplots(1, len(unique_values), figsize=(15, 5))
 
     sample_assignment = {}
-    
+
     for ax, value in zip(axes, unique_values):
         subset = data.loc[:, data.iloc[1, :] == value]
         counts = subset.iloc[0, :].value_counts()
@@ -36,15 +53,23 @@ def analyse_demultiplexing_assignment(assignment_file, pool_name):
         major_label = counts.idxmax()
         sample_assignment[value] = major_label
         if not any(counts > 0.75 * counts.sum()):
-            logging.warning(f'Demultiplexing results in contaminated cell clusters.')
-        counts.plot.pie(autopct='%1.1f%%', startangle=90, ax=ax, colors=[color_map[label] for label in counts.index])
+            logging.warning('Demultiplexing results in contaminated cell clusters.')
+        counts.plot.pie(
+            autopct='%1.1f%%',
+            startangle=90,
+            ax=ax,
+            colors=[color_map[label] for label in counts.index],
+        )
         ax.set_title(f'Abundancies for {value}')
         ax.set_ylabel('')
 
     plt.tight_layout()
-    plt.savefig(Path(assignment_file).parent / f'pool_{pool_name}_demultiplexing_assignment.png')
+    plt.savefig(
+        Path(assignment_file).parent / f'pool_{pool_name}_demultiplexing_assignment.png'
+    )
     return sample_assignment
 
+
 def main(args):
     sample_assignments = {}
     if isinstance(args.demultiplexing_assignment, str):
@@ -53,13 +78,16 @@ def main(args):
         assignment_files = args.demultiplexing_assignment
     for assignment_file in assignment_files:
         pool_name = str(Path(assignment_file).stem).split('_')[1]
-        sample_assignment = analyse_demultiplexing_assignment(assignment_file, pool_name)
-        sample_assignments[f"({pool_name})"] = sample_assignment
+        sample_assignment = analyse_demultiplexing_assignment(
+            assignment_file, pool_name
+        )
+        sample_assignments[f'({pool_name})'] = sample_assignment
 
     with open(args.output, 'w') as outfile:
         yaml.dump(sample_assignments, outfile)
     logging.info('Success.')
 
+
 if __name__ == '__main__':
     args = parse_args()
-    main(args)
\ No newline at end of file
+    main(args)
diff --git a/benchmarking_pipeline/workflow/scripts/compare_distances.py b/benchmarking_pipeline/workflow/scripts/compare_distances.py
index da39ccc..cacb546 100644
--- a/benchmarking_pipeline/workflow/scripts/compare_distances.py
+++ b/benchmarking_pipeline/workflow/scripts/compare_distances.py
@@ -1,5 +1,4 @@
 import argparse
-import logging
 
 import pandas as pd
 import seaborn as sns
@@ -9,22 +8,25 @@
 from pathlib import Path
 
 
-
 def pool_format2multiplexing_scheme(pool_scheme):
     demultiplexing_scheme = {}
     for pool, samples in pool_scheme.items():
         for sample in samples:
-            if "_split1.loom" in sample:
+            if '_split1.loom' in sample:
                 sample_name = str(Path(sample).stem)
                 sample_name = sample_name.replace('_split1', '')
-                number_of_iterations =pool.split('.')[1][:-1]
+                number_of_iterations = pool.split('.')[1][:-1]
                 pool_ID = pool.split('.')[0][1:]
                 if sample_name not in demultiplexing_scheme.keys():
-                    demultiplexing_scheme[sample_name] = [int(pool_ID), np.inf, int(number_of_iterations)]
+                    demultiplexing_scheme[sample_name] = [
+                        int(pool_ID),
+                        np.inf,
+                        int(number_of_iterations),
+                    ]
                 else:
                     demultiplexing_scheme[sample_name][0] = int(pool_ID)
                     demultiplexing_scheme[sample_name][2] = int(number_of_iterations)
-            if "_split2.loom" in sample:
+            if '_split2.loom' in sample:
                 sample_name = str(Path(sample).stem)
                 sample_name = sample_name.replace('_split2', '')
                 pool_ID = pool.split('.')[0][1:]
@@ -33,25 +35,32 @@ def pool_format2multiplexing_scheme(pool_scheme):
                 else:
                     demultiplexing_scheme[sample_name][1] = int(pool_ID)
             # Swap keys and items in the dictionary
-    swapped_demultiplexing_scheme = {tuple(v): k for k, v in demultiplexing_scheme.items()}
-    
-    return swapped_demultiplexing_scheme
-
+    swapped_demultiplexing_scheme = {
+        tuple(v): k for k, v in demultiplexing_scheme.items()
+    }
 
+    return swapped_demultiplexing_scheme
 
 
 # Create a custom colormap
-cmap = sns.color_palette("viridis", as_cmap=True)
+cmap = sns.color_palette('viridis', as_cmap=True)
 cmap.set_bad(color='grey')
 
 
 def parse_args():
-    parser = argparse.ArgumentParser(description="Compare distances between samples")
-    parser.add_argument("--tsv_files", nargs='+' , type=str, help="Path to TSV files containing the genotypes of the clusters. Must be in the same order as the pools in the pooling scheme.")
-    parser.add_argument("--pool_scheme", type=str, help="Path to the pool scheme file")
-    parser.add_argument("--output_plot", type=str, help="Output heatmap")
-    parser.add_argument("--output", type=str, help="sample assignment")
-    parser.add_argument("--robust", type=bool, required=False, default = False, help="Robust assignment")
+    parser = argparse.ArgumentParser(description='Compare distances between samples')
+    parser.add_argument(
+        '--tsv_files',
+        nargs='+',
+        type=str,
+        help='Path to TSV files containing the genotypes of the clusters. Must be in the same order as the pools in the pooling scheme.',
+    )
+    parser.add_argument('--pool_scheme', type=str, help='Path to the pool scheme file')
+    parser.add_argument('--output_plot', type=str, help='Output heatmap')
+    parser.add_argument('--output', type=str, help='sample assignment')
+    parser.add_argument(
+        '--robust', type=bool, required=False, default=False, help='Robust assignment'
+    )
     return parser.parse_args()
 
 
@@ -72,19 +81,30 @@ def compute_ratio(matrix):
 
 def permute_tsv_files(tsv_files, pool_scheme):
     pools = [f"({Path(file).name.split("_")[1]})" for file in tsv_files]
-    
+
     pool_scheme_keys = list(pool_scheme.keys())
     pool_permutation = [pools.index(pool) for pool in pool_scheme_keys]
-    
-    if not all((pool1 == pool2 for pool1, pool2 in zip([pools[permuted_idx] for permuted_idx in pool_permutation], pool_scheme_keys))):
-        raise ValueError("The pools in the pool scheme do not match the pools in the TSV files")
+
+    if not all(
+        (
+            pool1 == pool2
+            for pool1, pool2 in zip(
+                [pools[permuted_idx] for permuted_idx in pool_permutation],
+                pool_scheme_keys,
+            )
+        )
+    ):
+        raise ValueError(
+            'The pools in the pool scheme do not match the pools in the TSV files'
+        )
 
     return pool_permutation
 
+
 args = parse_args()
 
-def load_data(args):
 
+def load_data(args):
     with open(args.pool_scheme, 'r') as file:
         pooling_scheme = yaml.safe_load(file)
 
@@ -94,11 +114,12 @@ def load_data(args):
         tsv_files = [args.tsv_files]
     tsv_files = list(args.tsv_files)
     tsv_files_permuted = [tsv_files[i] for i in permutation_of_pools]
-    
-    dfs = [pd.read_csv(f, sep='\t', index_col = 0) for f in tsv_files_permuted]
+
+    dfs = [pd.read_csv(f, sep='\t', index_col=0) for f in tsv_files_permuted]
 
     return dfs, pooling_scheme
 
+
 dfs, pooling_scheme = load_data(args)
 
 # Get a set of all column names
@@ -109,14 +130,14 @@ def load_data(args):
 print(all_columns)
 
 
-for idx,df in enumerate(dfs):
+for idx, df in enumerate(dfs):
     for col in all_columns:
         if col not in df.columns:
             dfs[idx][col] = np.nan
 
-for idx,df in enumerate(dfs):
+for idx, df in enumerate(dfs):
     dfs[idx] = df[sorted(all_columns)]
-    
+
 # Concatenate all DataFrames
 concatenated_df = pd.concat(dfs, ignore_index=True)[sorted(all_columns)]
 
@@ -135,7 +156,7 @@ def load_data(args):
 concatenated_df.drop(columns=cols_to_remove, inplace=True)
 concatenated_df.drop(columns=cols_to_remove_45_55, inplace=True)
 
-for idx,df in enumerate(dfs):
+for idx, df in enumerate(dfs):
     dfs[idx] = df.drop(columns=cols_to_remove, inplace=False)
     dfs[idx] = df.drop(columns=cols_to_remove_45_55, inplace=False)
 
@@ -151,7 +172,7 @@ def load_data(args):
 # Compute the distance matrix considering only non-NaN entries and normalizing
 
 
-    # Compute the distance matrix for all pairs of DataFrames
+# Compute the distance matrix for all pairs of DataFrames
 
 distance_matrices = np.empty((len(dfs), len(dfs)), dtype=object)
 for data1, df1 in enumerate(dfs):
@@ -159,7 +180,9 @@ def load_data(args):
         distance_matrix = np.zeros((df1.shape[0], df2.shape[0]))
         for i in range(df1.shape[0]):
             for j in range(df2.shape[0]):
-                distance_matrix[i, j] = custom_distance(df1.iloc[i].values, df2.iloc[j].values)
+                distance_matrix[i, j] = custom_distance(
+                    df1.iloc[i].values, df2.iloc[j].values
+                )
         distance_matrices[data1, data2] = distance_matrix
 
 
@@ -200,7 +223,6 @@ def load_data(args):
 plt.savefig(heatmap_plot)
 
 
-
 fig, axes = plt.subplots(nrows=1, ncols=len(dfs), figsize=(15, 5))
 
 for i, df in enumerate(dfs):
@@ -209,9 +231,6 @@ def load_data(args):
 plt.savefig(genotype_plot)
 
 
-
-
-
 ratios = []
 for i in range(len(dfs)):
     for j in range(len(dfs)):
@@ -224,7 +243,7 @@ def load_data(args):
 lowest_value_pairs = []
 
 
-used_samples = [[]*len(dfs) for _ in range(len(dfs))]
+used_samples = [[] * len(dfs) for _ in range(len(dfs))]
 """ if remove is not None:
     for j in range(len(dfs)):
         if j != remove:
@@ -272,27 +291,31 @@ def load_data(args):
         # Sort the distance matrices by the recomputed ratio, from largest to smallest
         sorted_ratios = sorted(ratios, key=lambda x: x[2], reverse=True)
     else:
-        print(f"skipping assignment in matrix {i},{j}")
+        print(f'skipping assignment in matrix {i},{j}')
 
 # Print the pairs with the lowest values
 for i, j, row, col in lowest_value_pairs:
-    print(f'Lowest value in Distance Matrix DF{i} vs DF{j}: Row = {row}, Column = {col}')
+    print(
+        f'Lowest value in Distance Matrix DF{i} vs DF{j}: Row = {row}, Column = {col}'
+    )
 
 
-
-demultiplexing_scheme = pool_format2multiplexing_scheme(pooling_scheme)    
+demultiplexing_scheme = pool_format2multiplexing_scheme(pooling_scheme)
 
 for i, j, row, col in lowest_value_pairs:
     if i > j:
         i, j = j, i
         row, col = col, row
-    print(f'Sample {row} from pool {i} and sample {col} from pool {j}: {demultiplexing_scheme[(i,j,0)]}')
+    print(
+        f'Sample {row} from pool {i} and sample {col} from pool {j}: {demultiplexing_scheme[(i,j,0)]}'
+    )
 
-if args.robust == False: 
+if args.robust:
     for i in range(len(used_samples)):
-        unused_sample = next(sample for sample in range(len(dfs[i])) if sample not in used_samples[i])
-        print(f"Sample {unused_sample} from pool {i}: {demultiplexing_scheme[(i,i,0)]}")
-
+        unused_sample = next(
+            sample for sample in range(len(dfs[i])) if sample not in used_samples[i]
+        )
+        print(f'Sample {unused_sample} from pool {i}: {demultiplexing_scheme[(i,i,0)]}')
 
 
 pooling_scheme_keys = list(pooling_scheme.keys())
@@ -302,21 +325,35 @@ def load_data(args):
         i, j = j, i
         row, col = col, row
     if pooling_scheme_keys[i] not in sample_assignment.keys():
-        sample_assignment[pooling_scheme_keys[i]] = {row: demultiplexing_scheme[(i, j, 0)]}
+        sample_assignment[pooling_scheme_keys[i]] = {
+            row: demultiplexing_scheme[(i, j, 0)]
+        }
     else:
-        sample_assignment[pooling_scheme_keys[i]][row] = demultiplexing_scheme[(i, j, 0)]
+        sample_assignment[pooling_scheme_keys[i]][row] = demultiplexing_scheme[
+            (i, j, 0)
+        ]
     if pooling_scheme_keys[j] not in sample_assignment.keys():
-        sample_assignment[pooling_scheme_keys[j]] = {col: demultiplexing_scheme[(i, j, 0)]}
+        sample_assignment[pooling_scheme_keys[j]] = {
+            col: demultiplexing_scheme[(i, j, 0)]
+        }
     else:
-        sample_assignment[pooling_scheme_keys[j]][col] = demultiplexing_scheme[(i, j, 0)]
+        sample_assignment[pooling_scheme_keys[j]][col] = demultiplexing_scheme[
+            (i, j, 0)
+        ]
 
-if args.robust == False: 
+if args.robust:
     for i in range(len(used_samples)):
-        unused_sample = next(sample for sample in range(len(dfs[i])) if sample not in used_samples[i])
+        unused_sample = next(
+            sample for sample in range(len(dfs[i])) if sample not in used_samples[i]
+        )
         if pooling_scheme_keys[i] not in sample_assignment.keys():
-            sample_assignment[pooling_scheme_keys[i]] = {unused_sample: demultiplexing_scheme[(i, i, 0)]}
+            sample_assignment[pooling_scheme_keys[i]] = {
+                unused_sample: demultiplexing_scheme[(i, i, 0)]
+            }
         else:
-            sample_assignment[pooling_scheme_keys[i]][unused_sample] = demultiplexing_scheme[(i, i, 0)]
+            sample_assignment[pooling_scheme_keys[i]][unused_sample] = (
+                demultiplexing_scheme[(i, i, 0)]
+            )
 
 with open(args.output, 'w') as yaml_file:
     yaml.dump(sample_assignment, yaml_file)
diff --git a/benchmarking_pipeline/workflow/scripts/create_demultiplexing_scheme.py b/benchmarking_pipeline/workflow/scripts/create_demultiplexing_scheme.py
index 331816d..5e16c61 100644
--- a/benchmarking_pipeline/workflow/scripts/create_demultiplexing_scheme.py
+++ b/benchmarking_pipeline/workflow/scripts/create_demultiplexing_scheme.py
@@ -6,39 +6,56 @@
 
 import numpy as np
 
-logging.basicConfig(format="{asctime} - {levelname} - {message}", style="{", datefmt="%Y-%m-%d %H:%M",level=logging.INFO)
+logging.basicConfig(
+    format='{asctime} - {levelname} - {message}',
+    style='{',
+    datefmt='%Y-%m-%d %H:%M',
+    level=logging.INFO,
+)
 
 
-def define_demultiplexing_scheme_optimal_case(maximal_number_of_samples, maximal_pool_size, n_samples, robust):
+def define_demultiplexing_scheme_optimal_case(
+    maximal_number_of_samples, maximal_pool_size, n_samples, robust
+):
     if n_samples % maximal_number_of_samples != 0:
-        raise ValueError("Number of samples must be a multiple of maximal number of samples to run this function!")
+        raise ValueError(
+            'Number of samples must be a multiple of maximal number of samples to run this function!'
+        )
     demultiplexing_scheme = {}
     number_of_iterations = int(n_samples / maximal_number_of_samples)
     if not robust:
-        unordered_unique_pairs = list(itertools.combinations(range(maximal_pool_size), 2))
+        unordered_unique_pairs = list(
+            itertools.combinations(range(maximal_pool_size), 2)
+        )
         diagonal = list(zip(range(maximal_pool_size), range(maximal_pool_size)))
         unordered_pairs = unordered_unique_pairs + diagonal
     else:
-        unordered_pairs = list(itertools.combinations(range(maximal_pool_size+1), 2))                           
+        unordered_pairs = list(itertools.combinations(range(maximal_pool_size + 1), 2))
 
     for idx1 in range(number_of_iterations):
         for idx2, pair in enumerate(unordered_pairs):
-            demultiplexing_scheme[int(idx1 * maximal_number_of_samples + idx2 + 1)] = pair + (idx1,) # naming of samples starts with 1
+            demultiplexing_scheme[int(idx1 * maximal_number_of_samples + idx2 + 1)] = (
+                pair + (idx1,)
+            )  # naming of samples starts with 1
 
     return demultiplexing_scheme
 
 
-
 def find_demultiplexing_scheme(maximal_pool_size, n_samples, robust):
-    maximal_number_of_samples = (maximal_pool_size * (maximal_pool_size+1))/2
+    maximal_number_of_samples = (maximal_pool_size * (maximal_pool_size + 1)) / 2
 
     if n_samples % maximal_number_of_samples != 0:
-        logging.error("So far, only defined for certain cohort sizes")
+        logging.error('So far, only defined for certain cohort sizes')
         raise NotImplementedError
     else:
-        logging.info("Creating multiplexing scheme")
-        demultiplexing_scheme = define_demultiplexing_scheme_optimal_case(maximal_number_of_samples = maximal_number_of_samples, maximal_pool_size = maximal_pool_size, n_samples = n_samples, robust = robust)
-        
+        logging.info('Creating multiplexing scheme')
+        demultiplexing_scheme = define_demultiplexing_scheme_optimal_case(
+            maximal_number_of_samples=maximal_number_of_samples,
+            maximal_pool_size=maximal_pool_size,
+            n_samples=n_samples,
+            robust=robust,
+        )
+
         logging.info(f'Demultiplexing scheme: {demultiplexing_scheme}')
 
         return demultiplexing_scheme
@@ -46,91 +63,135 @@ def find_demultiplexing_scheme(maximal_pool_size, n_samples, robust):
 
 def multiplexing_scheme_format2pool_format(demultiplexing_scheme):
     pool_scheme = {}
-    total_no_pools_per_repetition = np.max([pool[:-1] for pool in list(demultiplexing_scheme.values())])
-    total_no_of_repetitions = np.max([pool[-1] for pool in list(demultiplexing_scheme.values())])
-
-    pools = list(itertools.product(range(total_no_pools_per_repetition + 1), range(total_no_of_repetitions + 1)))
+    total_no_pools_per_repetition = np.max(
+        [pool[:-1] for pool in list(demultiplexing_scheme.values())]
+    )
+    total_no_of_repetitions = np.max(
+        [pool[-1] for pool in list(demultiplexing_scheme.values())]
+    )
+
+    pools = list(
+        itertools.product(
+            range(total_no_pools_per_repetition + 1), range(total_no_of_repetitions + 1)
+        )
+    )
     for pool in pools:
         pool_scheme[pool] = []
     for key, value in demultiplexing_scheme.items():
-        pool_scheme[value[0],value[-1]].append(int(key))
-        pool_scheme[value[1],value[-1]].append(-int(key))
+        pool_scheme[value[0], value[-1]].append(int(key))
+        pool_scheme[value[1], value[-1]].append(-int(key))
     # The pool scheme has pool identifier as keys and split libraries as values, where the first split is encoded as a positive integer and the second split as a negative integer
 
     return pool_scheme
 
 
-
 def select_samples_for_pooling(pool_scheme, input_dir, sample_list):
     if isinstance(sample_list[0], str):
-        for idx,sample in enumerate(sample_list):
+        for idx, sample in enumerate(sample_list):
             sample_list[idx] = Path(sample)
     if isinstance(input_dir, str):
         input_dir = Path(input_dir)
     logging.info('Selecting samples for pooling')
     pools_summary = {}
     for pool, samples in pool_scheme.items():
-    
-        logging.info(f"Retrieving list of samples to be pooled for pool {pool}")
-        logging.debug(f"Samples: {samples}")
-        logging.debug(f"Sample list: {sample_list}")
+        logging.info(f'Retrieving list of samples to be pooled for pool {pool}')
+        logging.debug(f'Samples: {samples}')
+        logging.debug(f'Sample list: {sample_list}')
         loom_files = []
         for sample in samples:
             if sample > 0:
-                loom_files.append((input_dir / f"{sample_list[sample-1].stem}_split1.loom").as_posix())
+                loom_files.append(
+                    (input_dir / f'{sample_list[sample-1].stem}_split1.loom').as_posix()
+                )
             else:
-                loom_files.append((input_dir / f"{sample_list[-sample-1].stem}_split2.loom").as_posix())
-    
-        logging.info(f"Done retrieving list of samples to be pooled for pool {pool}")
+                loom_files.append(
+                    (
+                        input_dir / f'{sample_list[-sample-1].stem}_split2.loom'
+                    ).as_posix()
+                )
 
-        pools_summary[f"({pool[0]}.{pool[1]})"] = loom_files
+        logging.info(f'Done retrieving list of samples to be pooled for pool {pool}')
+
+        pools_summary[f'({pool[0]}.{pool[1]})'] = loom_files
 
     return pools_summary
 
 
 def write_output(pools_summary, output):
-    with open(output, "w") as f:
+    with open(output, 'w') as f:
         yaml.dump(pools_summary, f, default_flow_style=False)
 
 
 def load_input_samples(input_sample_file):
-    with open(input_sample_file, "r") as f:
+    with open(input_sample_file, 'r') as f:
         input_samples = f.readlines()
-    
+
     return input_samples
 
 
 def parse_args():
-    parser = argparse.ArgumentParser(description="Output a multiplexing scheme under pool size constraint for a predefined number of samples")
-    parser.add_argument("--robust", type=bool, required=False, default = False, help="Input list of loom files")
-    parser.add_argument("-k", "--maximal_pool_size", type=int, required=False, help="The maximal amount of samples to be multiplexed in a pool")
-    parser.add_argument("--n_samples", type=int, required=True, help="The number of samples to be sequenced in the cohort")
-    parser.add_argument("--output", type=str, required=True, help="Where to store the output file")
-    parser.add_argument("--input_dir", type=str, required=True, help="Directory to the loom files")
-    parser.add_argument("--input_sample_file", type=str, required=True, help="Path to the file containing the list of loom files")
-
-    logging.info("Parsing arguments")
+    parser = argparse.ArgumentParser(
+        description='Output a multiplexing scheme under pool size constraint for a predefined number of samples'
+    )
+    parser.add_argument(
+        '--robust',
+        type=bool,
+        required=False,
+        default=False,
+        help='Input list of loom files',
+    )
+    parser.add_argument(
+        '-k',
+        '--maximal_pool_size',
+        type=int,
+        required=False,
+        help='The maximal amount of samples to be multiplexed in a pool',
+    )
+    parser.add_argument(
+        '--n_samples',
+        type=int,
+        required=True,
+        help='The number of samples to be sequenced in the cohort',
+    )
+    parser.add_argument(
+        '--output', type=str, required=True, help='Where to store the output file'
+    )
+    parser.add_argument(
+        '--input_dir', type=str, required=True, help='Directory to the loom files'
+    )
+    parser.add_argument(
+        '--input_sample_file',
+        type=str,
+        required=True,
+        help='Path to the file containing the list of loom files',
+    )
+
+    logging.info('Parsing arguments')
     args = parser.parse_args()
-    
+
     if args.n_samples == 0:
-        logging.error("Number of samples must be greater than 0")
+        logging.error('Number of samples must be greater than 0')
         raise ValueError
-    
+
     return args
 
 
 def main(args):
     input_samples = load_input_samples(args.input_sample_file)
     n_samples = len(input_samples)
-    demultiplexing_scheme = find_demultiplexing_scheme(args.maximal_pool_size, n_samples, args.robust)
+    demultiplexing_scheme = find_demultiplexing_scheme(
+        args.maximal_pool_size, n_samples, args.robust
+    )
     pool_scheme = multiplexing_scheme_format2pool_format(demultiplexing_scheme)
-    pools_summary = select_samples_for_pooling(pool_scheme, args.input_dir, input_samples)
-    logging.debug(f"Output: {pools_summary.keys()}")
-    logging.info(f"Writing output to {args.output}")
+    pools_summary = select_samples_for_pooling(
+        pool_scheme, args.input_dir, input_samples
+    )
+    logging.debug(f'Output: {pools_summary.keys()}')
+    logging.info(f'Writing output to {args.output}')
     write_output(pools_summary, args.output)
-    logging.info("Success.")
+    logging.info('Success.')
 
 
-if __name__ == "__main__":
+if __name__ == '__main__':
     args = parse_args()
-    main(args)
\ No newline at end of file
+    main(args)
diff --git a/benchmarking_pipeline/workflow/scripts/create_heatmap.py b/benchmarking_pipeline/workflow/scripts/create_heatmap.py
deleted file mode 100644
index 15c857c..0000000
--- a/benchmarking_pipeline/workflow/scripts/create_heatmap.py
+++ /dev/null
@@ -1,293 +0,0 @@
-#!/usr/bin/env python3
-
-import argparse
-import copy
-import os
-import re
-import tkinter as tk
-
-from matplotlib import pyplot as plt
-import numpy as np
-import pandas as pd
-from scipy.spatial.distance import pdist, cityblock
-from scipy.cluster.hierarchy import linkage
-import seaborn as sns
-
-VCF_COLS = [
-    'CHROM',
-    'POS',
-    'ID',
-    'REF',
-    'ALT',
-    'QUAL',
-    'FILTER',
-    'INFO',
-    'FORMAT',
-    'sample',
-]
-GT_MAP = {'0/0': 0, '0/1': 1, '1/1': 2, './.': np.nan}
-DEF_DIST = 'Euclidean'
-PID2SID = {
-    'PIDasdf': 'SIDjklö'
-}  ### These can be adapted manually to have more control on how the sampled should be named in the output
-SID2PAPER = {'SIDjklö': 'S1'}
-
-FILE_EXT = '.png'
-FONTSIZE = 30
-DPI = 300
-sns.set_style('whitegrid')  # darkgrid, whitegrid, dark, white, ticks
-sns.set_context(
-    'paper',
-    rc={
-        'lines.linewidth': 2,
-        'axes.axisbelow': True,
-        'font.size': FONTSIZE,
-        'axes.labelsize': 'medium',
-        'axes.titlesize': 'large',
-        'xtick.labelsize': 'medium',
-        'ytick.labelsize': 'medium',
-        'legend.fontsize': 'medium',
-        'legend.title_fontsize': 'large',
-        'axes.labelticksize': 50,
-    },
-)
-plt.rcParams['xtick.major.size'] = 10
-plt.rcParams['xtick.major.width'] = 3
-plt.rcParams['xtick.bottom'] = True
-
-
-def dist_nan(u, v):
-    valid = ~np.isnan(u) & ~np.isnan(v)
-    # return euclidean(u[valid], v[valid]) / np.sqrt(np.sum(2**2 * valid.sum())) # Euclidean
-    # return (u[valid].round() != v[valid].round()).mean() # Hamming
-    return cityblock(u[valid], v[valid]) / (2 * valid.sum())  # Manhattan
-
-
-def get_clone_profiles(in_file):
-    df = pd.read_csv(in_file)
-    df['REF'].fillna('*', inplace=True)
-    df['ALT'].fillna('*', inplace=True)
-    idx_str = (
-        df['CHR'].astype(str)
-        + '_'
-        + df['POS'].astype(str)
-        + '_'
-        + df['REF']
-        + '_'
-        + df['ALT']
-    )
-    idx_str = idx_str.str.replace('chr', '')
-    gt = df.iloc[:, 7:].map(lambda x: int(x.split(':')[-1]))
-    gt.replace(3, np.nan, inplace=True)
-
-    df = pd.DataFrame(gt.mean(axis=1).values, index=idx_str, columns=['Cluster'])
-    return df
-
-
-def plot_heatmap(data, out_file):
-    cmap = plt.get_cmap('YlOrRd', 100)
-    dist_row = pdist(data.values, dist_nan)
-    dist_col = pdist(data.values.T, dist_nan)
-    Z_row = linkage(np.nan_to_num(dist_row, 10), 'ward')
-    Z_col = linkage(np.nan_to_num(dist_col, 10), 'ward')
-    try:
-        cm = sns.clustermap(
-            data,
-            row_linkage=Z_row,
-            col_linkage=Z_col,
-            row_cluster=True,
-            col_cluster=True,
-            col_colors=None,
-            row_colors=None,
-            vmin=0,
-            vmax=2,
-            cmap=cmap,
-            dendrogram_ratio=(0.1, 0.1),
-            figsize=(25, 15),
-            cbar_kws={
-                'ticks': [0, 1, 2],
-            },
-            # cbar_pos=(0, 0, 0.0, 0.0),
-            tree_kws=dict(linewidths=2),
-        )
-    except tk.TclError:
-        import matplotlib
-
-        matplotlib.use('Agg')
-        cm = sns.clustermap(
-            data,
-            row_linkage=Z_row,
-            col_linkage=Z_col,
-            row_cluster=True,
-            col_cluster=True,
-            col_colors=None,
-            row_colors=None,
-            vmin=0,
-            vmax=2,
-            cmap=cmap,
-            dendrogram_ratio=(0.1, 0.1),
-            figsize=(25, 15),
-            cbar_kws={
-                'ticks': [0, 1, 2],
-            },
-            # cbar_pos=(0, 0, 0.0, 0.0),
-            tree_kws=dict(linewidths=2),
-        )
-
-    cm.ax_heatmap.set_facecolor('#EAEAEA')
-    cm.ax_heatmap.set_ylabel('\nProfiles')
-    cm.ax_heatmap.set_xlabel('SNPs')
-    labels_pretty = [
-        'chr{}:{} {}>{}'.format(*i.get_text().split('_'))
-        for i in cm.ax_heatmap.get_xticklabels()
-    ]
-    cm.ax_heatmap.set_xticklabels(labels_pretty, rotation=45, ha='right', va='top')
-
-    cm.ax_cbar.set_title('Genotype')
-    cm.ax_cbar.set_yticklabels(['0|0', '0|1', '1|1'])
-
-    cm.fig.tight_layout()
-    if out_file:
-        cm.fig.savefig(out_file, dpi=DPI)
-    else:
-        plt.show()
-
-
-def plot_distance(df, out_file):
-    col_no = 1
-    row_no = 1
-    fig, ax = plt.subplots(
-        nrows=row_no, ncols=col_no, figsize=(col_no * 12, row_no * 12)
-    )
-
-    sns.set(font_scale=2.5)
-    font_size = 30
-
-    df.sort_index(ascending=False, inplace=True)
-    df = df[sorted(df.columns)]
-
-    df.columns = [i.split('_')[0] for i in df.columns]
-    df.index = [i.split('_')[0] for i in df.index]
-
-    sns.heatmap(
-        df,
-        annot=True,
-        square=True,
-        cmap='viridis_r',
-        cbar_kws={'shrink': 0.5, 'label': 'Distance'},
-        linewidths=0,
-        ax=ax,
-    )
-    ax.set_xticklabels(ax.get_xmajorticklabels(), fontsize=font_size)
-    ax.set_yticklabels(ax.get_ymajorticklabels(), fontsize=font_size)
-    ax.set_xlabel(df.index.name, fontsize=font_size + 10)
-    ax.set_ylabel(df.columns.name, fontsize=font_size + 10)
-
-    fig.tight_layout()
-    if out_file:
-        fig.savefig(out_file, dpi=300)
-    else:
-        plt.show()
-
-
-def assign_clusters(df_in, assignment_output):
-    df = copy.deepcopy(df_in)
-    # import pdb;pdb.set_trace()
-    for i in range(df_in.shape[0]):
-        cl_idx, other_idx = np.where(df == df.min().min())
-        cl = df.index[cl_idx[0]]
-        other = df.columns[other_idx[0]]
-        print(f'Assigning: {cl} -> {other}')
-        with open(assignment_output, 'a') as file:
-            file.write(f'{cl}: {other}\n')
-
-        df.drop(cl, axis=0, inplace=True)
-        df.drop(other, axis=1, inplace=True)
-
-
-def main(args):
-    if not args.outfile:
-        out_base = os.path.splitext(args.input[0])[0]
-        out_end = FILE_EXT
-    else:
-        out_base, out_end = os.path.splitext(args.outfile)
-
-    names = []
-    for i, in_file in enumerate(args.input):
-
-        cluster = re.compile(r'([0-9]*)_variants.csv').search(in_file).group(1)
-        name = f'{cluster}_DNA'
-
-        names.append(name)
-        df_new = get_clone_profiles(in_file)
-        df_new.rename({'Cluster': name}, inplace=True, axis=1)
-
-        match_df = df_new.merge(SNP_df, left_index=True, right_index=True, how='inner')
-
-        if match_df.size == 0:
-            print('No overlap in RNA and DNA data.')
-            exit()
-        try:
-            df = df.merge(
-                match_df[name], left_index=True, right_index=True, how='outer'
-            )
-        except NameError:
-            df = match_df
-
-    df = df[sorted(df.columns)]
-    # remove SNPs that are similar in all DNA clusters
-    df = df[(df[names] > 1.95).sum(axis=1) != len(names)]
-    df = df[(df[names] < 0.05).sum(axis=1) != len(names)]
-    df = df[
-        ((df[names] > 0.95).sum(axis=1) != len(names))
-        | ((df[names] < 1.05).sum(axis=1) != len(names))
-    ]
-    # Remove SNPS not called in the scDNA-seq
-    df = df[(df[names] >= 0).sum(axis=1) > 0]
-
-    dists = []
-    for name in names:
-        dists.append(
-            np.apply_along_axis(
-                dist_nan, 1, df[SNP_df.columns].values.T, df[name].values.T
-            )
-        )
-
-    dist_df = pd.DataFrame(dists, index=names, columns=SNP_df.columns)
-
-    assign_clusters(dist_df, args.assignment_output)
-    print(f'\n{dist_df}\n')
-
-    dist_df.index.name = 'scDNA-seq'
-    dist_df.columns.name = label
-
-    dist_out = f'{out_base}_distance{out_end}'
-    print(f'Generating plot: {dist_out}')
-    plot_distance(dist_df, dist_out)
-
-    hm_out = f'{out_base}_heatmap{out_end}'
-    print(f'Generating plot: {hm_out}')
-    plot_heatmap(df.T, hm_out)
-
-
-def parse_args():
-    parser = argparse.ArgumentParser()
-    parser.add_argument(
-        '-i', '--input', type=str, nargs='+', help='mutation_profiles of the demultiplexed samples'
-    )
-    parser.add_argument(
-        '-o', '--outfile', type=str, default='', help='Output file for heatmap.'
-    )
-    parser.add_argument(
-        '-ao',
-        '--assignment_output',
-        type=str,
-        help='Output yaml-file in which the automated sample matching is saved.',
-    )
-
-    return parser.parse_args()
-
-
-if __name__ == '__main__':
-    args = parse_args()
-    main(args)
\ No newline at end of file
diff --git a/benchmarking_pipeline/workflow/scripts/demo_tape.py b/benchmarking_pipeline/workflow/scripts/demo_tape.py
index 7b83023..e051d29 100644
--- a/benchmarking_pipeline/workflow/scripts/demo_tape.py
+++ b/benchmarking_pipeline/workflow/scripts/demo_tape.py
@@ -17,20 +17,57 @@
 
 EPSILON = np.finfo(np.float64).resolution
 
-COLORS = ['#e41a1c', '#377eb8', '#4daf4a', '#E4761A' , '#2FBF85'] # '#A4CA55'
-COLORS = ['#a6cee3', '#1f78b4', '#b2df8a', '#33a02c', '#fb9a99', '#e31a1c',
-    '#fdbf6f', '#ff7f00', '#cab2d6', '#6a3d9a', '#ffff99', '#b15928']
-
-COLORS = ['#e6194b', '#3cb44b', '#ffe119', '#4363d8', '#f58231', '#911eb4', '#46f0f0', '#f032e6', '#bcf60c', '#fabebe', '#008080', '#e6beff', '#9a6324', '#fffac8', '#800000', '#aaffc3', '#808000', '#ffd8b1', '#000075', '#808080', '#ffffff', '#000000']
+COLORS = ['#e41a1c', '#377eb8', '#4daf4a', '#E4761A', '#2FBF85']  # '#A4CA55'
+COLORS = [
+    '#a6cee3',
+    '#1f78b4',
+    '#b2df8a',
+    '#33a02c',
+    '#fb9a99',
+    '#e31a1c',
+    '#fdbf6f',
+    '#ff7f00',
+    '#cab2d6',
+    '#6a3d9a',
+    '#ffff99',
+    '#b15928',
+]
+
+COLORS = [
+    '#e6194b',
+    '#3cb44b',
+    '#ffe119',
+    '#4363d8',
+    '#f58231',
+    '#911eb4',
+    '#46f0f0',
+    '#f032e6',
+    '#bcf60c',
+    '#fabebe',
+    '#008080',
+    '#e6beff',
+    '#9a6324',
+    '#fffac8',
+    '#800000',
+    '#aaffc3',
+    '#808000',
+    '#ffd8b1',
+    '#000075',
+    '#808080',
+    '#ffffff',
+    '#000000',
+]
 COLORS_STR = ['red', 'blue', 'green', 'orange', 'mint']
 
 
 FILE_EXT = 'png'
 FONTSIZE = 30
 DPI = 300
-sns.set_style('whitegrid') #darkgrid, whitegrid, dark, white, ticks
-sns.set_context('paper',
-    rc={'lines.linewidth': 2,
+sns.set_style('whitegrid')  # darkgrid, whitegrid, dark, white, ticks
+sns.set_context(
+    'paper',
+    rc={
+        'lines.linewidth': 2,
         'axes.axisbelow': True,
         'font.size': FONTSIZE,
         'axes.labelsize': 'medium',
@@ -39,8 +76,9 @@
         'ytick.labelsize': 'medium',
         'legend.fontsize': 'medium',
         'legend.title_fontsize': 'large',
-        'axes.labelticksize': 50
-})
+        'axes.labelticksize': 50,
+    },
+)
 plt.rcParams['xtick.major.size'] = 5
 plt.rcParams['xtick.major.width'] = 1.5
 plt.rcParams['xtick.bottom'] = True
@@ -54,8 +92,8 @@ def __init__(self, in_file, cl_no):
         # Get full data
         df = pd.read_csv(in_file, index_col=[0, 1], dtype={'CHR': str})
         self.SNPs = df.apply(
-            lambda x: f'chr{x.name[0]}:{x.name[1]} {x["REF"]}>{x["ALT"]}',
-            axis=1)
+            lambda x: f'chr{x.name[0]}:{x.name[1]} {x["REF"]}>{x["ALT"]}', axis=1
+        )
         df.drop(['REF', 'ALT', 'REGION', 'NAME', 'FREQ'], axis=1, inplace=True)
         self.cells = df.columns.values
 
@@ -78,11 +116,13 @@ def __init__(self, in_file, cl_no):
         self.ref = df.applymap(lambda x: int(x.split(':')[0])).values.T
         self.alt = df.applymap(lambda x: int(x.split(':')[1])).values.T
         self.dp = self.ref + self.alt
-        self.VAF = np.clip(np.where(self.dp > 0, self.alt / self.dp, np.nan),
-            EPSILON, 1 - EPSILON)
+        self.VAF = np.clip(
+            np.where(self.dp > 0, self.alt / self.dp, np.nan), EPSILON, 1 - EPSILON
+        )
         self.RAF = 1 - self.VAF
-        self.norm_const = np.arange(self.dp.max() * 2 + 1) \
-            * np.log(np.arange(self.dp.max() * 2 + 1))
+        self.norm_const = np.arange(self.dp.max() * 2 + 1) * np.log(
+            np.arange(self.dp.max() * 2 + 1)
+        )
         self.reads = np.hstack([self.ref, self.alt, self.dp])
 
         self.metric = self.dist_reads
@@ -95,25 +135,25 @@ def __init__(self, in_file, cl_no):
         self.dbt_ids = np.array([])
         self.dbt_map = {}
 
-
     def __str__(self):
-        out_str = '\ndemoTape:\n' \
-            f'\tFile: {self.in_file}\n' \
-            f'\t# Samples: {self.cl_no}\n' \
-            f'\tCells: {self.cells}:\n' \
-            f'\tSNPs: {self.SNPs}\n' \
+        out_str = (
+            '\ndemoTape:\n'
+            f'\tFile: {self.in_file}\n'
+            f'\t# Samples: {self.cl_no}\n'
+            f'\tCells: {self.cells}:\n'
+            f'\tSNPs: {self.SNPs}\n'
             f'\tDistance metric: reads'
+        )
         return out_str
 
-
     @staticmethod
     def dist_reads(c1, c2):
         r1 = np.reshape(c1, (3, -1))
         r2 = np.reshape(c2, (3, -1))
 
         valid = (r1[2] > 0) & (r2[2] > 0)
-        r1 = r1[:,valid]
-        r2 = r2[:,valid]
+        r1 = r1[:, valid]
+        r2 = r2[:, valid]
 
         p1 = np.clip(r1[0] / r1[2], EPSILON, 1 - EPSILON)
         p2 = np.clip(r2[0] / r2[2], EPSILON, 1 - EPSILON)
@@ -121,49 +161,59 @@ def dist_reads(c1, c2):
         p12 = np.clip((r1[0] + r2[0]) / (dp_total), EPSILON, 1 - EPSILON)
         p12_inv = 1 - p12
 
-        logl = r1[0] * np.log(p1 / p12) + r1[1] * np.log((1 - p1) / p12_inv) \
-            + r2[0] * np.log(p2 / p12) + r2[1] * np.log((1 - p2) / p12_inv)
+        logl = (
+            r1[0] * np.log(p1 / p12)
+            + r1[1] * np.log((1 - p1) / p12_inv)
+            + r2[0] * np.log(p2 / p12)
+            + r2[1] * np.log((1 - p2) / p12_inv)
+        )
 
-        norm = np.log(dp_total) * (dp_total) \
-            - r1[2] * np.log(r1[2]) - r2[2] * np.log(r2[2])
+        norm = (
+            np.log(dp_total) * (dp_total)
+            - r1[2] * np.log(r1[2])
+            - r2[2] * np.log(r2[2])
+        )
 
         return np.sum(logl / norm) / valid.sum()
 
-
     def get_pairwise_dists(self):
         dist = []
         for i in np.arange(self.cells.size - 1):
-            valid = (self.dp[i] > 0) & (self.dp[i+1:] > 0)
-            dp_total = self.dp[i] + self.dp[i+1:]
-            p12 = np.clip((self.alt[i] + self.alt[i+1:]) / dp_total, EPSILON, 1 - EPSILON)
+            valid = (self.dp[i] > 0) & (self.dp[i + 1 :] > 0)
+            dp_total = self.dp[i] + self.dp[i + 1 :]
+            p12 = np.clip(
+                (self.alt[i] + self.alt[i + 1 :]) / dp_total, EPSILON, 1 - EPSILON
+            )
             p12_inv = 1 - p12
 
-            logl = self.alt[i] * np.log(self.VAF[i] / p12) \
-                + self.ref[i] * np.log(self.RAF[i] / p12_inv) \
-                + self.alt[i+1:] * np.log(self.VAF[i+1:] / p12) \
-                + self.ref[i+1:] * np.log(self.RAF[i+1:] / p12_inv)
+            logl = (
+                self.alt[i] * np.log(self.VAF[i] / p12)
+                + self.ref[i] * np.log(self.RAF[i] / p12_inv)
+                + self.alt[i + 1 :] * np.log(self.VAF[i + 1 :] / p12)
+                + self.ref[i + 1 :] * np.log(self.RAF[i + 1 :] / p12_inv)
+            )
 
-            norm = self.norm_const[dp_total] \
-                - self.norm_const[self.dp[i]] - self.norm_const[self.dp[i+1:]]
+            norm = (
+                self.norm_const[dp_total]
+                - self.norm_const[self.dp[i]]
+                - self.norm_const[self.dp[i + 1 :]]
+            )
 
             dist.append(np.nansum(logl / norm, axis=1) / valid.sum(axis=1))
         return np.concatenate(dist)
 
-
     def demultiplex(self):
         self.init_dendrogram()
         self.identify_doublets()
         self.merge_surplus_singlets()
 
-
-# -------------------------------- DENDROGRAM ----------------------------------
+    # -------------------------------- DENDROGRAM ----------------------------------
 
     def init_dendrogram(self):
         dist = self.get_pairwise_dists()
         self.Z = linkage(dist, method='ward')
 
-
-# ---------------------------- IDENTIFY DOUBLETS -------------------------------
+    # ---------------------------- IDENTIFY DOUBLETS -------------------------------
 
     def identify_doublets(self):
         cl_no = self.sgt_cl_no + self.dbt_cl_no
@@ -175,14 +225,13 @@ def identify_doublets(self):
             if self.dbt_ids.size == self.dbt_cl_no:
                 break
             elif cl_no == (self.sgt_cl_no + self.dbt_cl_no) * 3:
-                print(f'Could not identify all doublets.')
+                print('Could not identify all doublets.')
                 self.set_assigment(self.sgt_cl_no + self.dbt_cl_no)
                 self.set_dbt_ids()
                 break
             cl_no += 1
             print(f'Increasing cuttree clusters to {cl_no}')
 
-
     def set_assigment(self, cl_no):
         self.assignment = cut_tree(self.Z, n_clusters=cl_no).flatten()
         clusters = np.unique(self.assignment)
@@ -190,19 +239,16 @@ def set_assigment(self, cl_no):
         for cl_id, cl in enumerate(clusters):
             self.profiles[cl_id] = self.get_profile(cl)
 
-
     def get_profile(self, cl):
         cells = np.isin(self.assignment, cl)
-        p = np.average(np.nan_to_num(self.VAF[cells]), weights=self.dp[cells],
-            axis=0)
+        p = np.average(np.nan_to_num(self.VAF[cells]), weights=self.dp[cells], axis=0)
         dp = np.mean(self.dp[cells], axis=0).round()
         alt = (dp * p).round()
         ref = dp - alt
         return np.hstack([alt, ref, dp])
 
-
     def set_dbt_ids(self):
-        cl_size  = np.unique(self.assignment, return_counts=True)[1] / self.cells.size
+        cl_size = np.unique(self.assignment, return_counts=True)[1] / self.cells.size
         cl_no = cl_size.size
 
         dbt_map = {}
@@ -217,8 +263,9 @@ def set_dbt_ids(self):
 
         df_dbt = pd.DataFrame([], columns=dbt_combs, index=range(cl_no))
         for i in range(cl_no):
-            df_dbt.loc[i] = np.apply_along_axis(self.metric, 1, dbt_profiles,
-                self.profiles[i])
+            df_dbt.loc[i] = np.apply_along_axis(
+                self.metric, 1, dbt_profiles, self.profiles[i]
+            )
             for j in dbt_combs:
                 # set doublet combo dist to np.nan if:
                 # 1. Singlet is included in Dbt cluster
@@ -254,26 +301,28 @@ def set_dbt_ids(self):
             df_dbt.drop(rm_rows, axis=0, inplace=True, errors='ignore')
             # Remove doublet rows including cluster
             rm_cols = [dbt_id] + [i for i in dbt_combs if cl_id in i]
-            df_dbt.drop(rm_cols, axis=1, inplace=True, errors='ignore') # Ignore error on already dropped labels
+            df_dbt.drop(
+                rm_cols, axis=1, inplace=True, errors='ignore'
+            )  # Ignore error on already dropped labels
 
-        self.sgt_ids = np.array([i for i in np.unique(self.assignment) \
-            if not i in dbt_ids])
+        self.sgt_ids = np.array(
+            [i for i in np.unique(self.assignment) if i not in dbt_ids]
+        )
         self.dbt_ids = np.array(dbt_ids)
         self.dbt_map = dbt_map
 
         # self.plot_heatmap()
         # import pdb; pdb.set_trace()
 
-
     def check_hom_match(self, scl, dcl):
-        rs =  np.reshape(self.profiles[scl], (3, -1))
+        rs = np.reshape(self.profiles[scl], (3, -1))
         rd1 = np.reshape(self.profiles[dcl[0]], (3, -1))
         rd2 = np.reshape(self.profiles[dcl[1]], (3, -1))
 
         valid = (rs[2] > 0) & (rd1[2] > 0) & (rd2[2] > 0)
-        rs = rs[:,valid]
-        rd1 = rd1[:,valid]
-        rd2 = rd2[:,valid]
+        rs = rs[:, valid]
+        rd1 = rd1[:, valid]
+        rd2 = rd2[:, valid]
 
         ps = rs[0] / rs[2]
         pd1 = rd1[0] / rd1[2]
@@ -318,7 +367,6 @@ def check_hom_match(self, scl, dcl):
 
         return (~hom_match).sum()
 
-
     def get_cl_map(self):
         if self.sgt_ids.size == 0:
             cl_map = {i: str(i) for i in np.unique(self.assignment)}
@@ -334,8 +382,7 @@ def get_cl_map(self):
                 assignment_str[i] = cl_map[j]
         return cl_map, assignment_str
 
-
-# ------------------------- MERGE SURPLUS SINGLETS -----------------------------
+    # ------------------------- MERGE SURPLUS SINGLETS -----------------------------
 
     def merge_surplus_singlets(self):
         # Merge surplus single clusters
@@ -391,8 +438,9 @@ def merge_surplus_singlets(self):
 
                 self.assignment[self.assignment == rem_dbt_cl] = cl1
                 self.assignment[self.assignment > rem_dbt_cl] -= 1
-                self.dbt_ids = np.delete(self.dbt_ids,
-                    np.argwhere(self.dbt_ids == rem_dbt_cl))
+                self.dbt_ids = np.delete(
+                    self.dbt_ids, np.argwhere(self.dbt_ids == rem_dbt_cl)
+                )
                 self.sgt_ids[self.sgt_ids > rem_dbt_cl] -= 1
                 self.dbt_ids[self.dbt_ids > rem_dbt_cl] -= 1
 
@@ -400,13 +448,17 @@ def merge_surplus_singlets(self):
 
             if len(self.dbt_map.values()) != len(set(self.dbt_map.values())):
                 print('Removing 1 equal doublet cluster')
-                eq_dbt = [i for i,j in self.dbt_map.items() \
-                    if list(self.dbt_map.values()).count(j) > 1]
+                eq_dbt = [
+                    i
+                    for i, j in self.dbt_map.items()
+                    if list(self.dbt_map.values()).count(j) > 1
+                ]
 
                 self.assignment[self.assignment == eq_dbt[1]] = eq_dbt[0]
                 self.assignment[self.assignment > eq_dbt[1]] -= 1
-                self.dbt_ids = np.delete(self.dbt_ids,
-                    np.argwhere(self.dbt_ids == eq_dbt[1]))
+                self.dbt_ids = np.delete(
+                    self.dbt_ids, np.argwhere(self.dbt_ids == eq_dbt[1])
+                )
                 self.sgt_ids[self.sgt_ids > eq_dbt[1]] -= 1
                 self.dbt_ids[self.dbt_ids > eq_dbt[1]] -= 1
 
@@ -418,16 +470,15 @@ def merge_surplus_singlets(self):
             # Update profile
             self.profiles[cl1] = self.get_profile(cl1)
 
-
     def get_sgt_dist_matrix(self):
         mat = np.zeros((self.sgt_ids.size, self.sgt_ids.size))
         for i, j in enumerate(self.sgt_ids):
-            mat[i] = np.apply_along_axis(self.metric, 1,
-                self.profiles[self.sgt_ids], self.profiles[j])
+            mat[i] = np.apply_along_axis(
+                self.metric, 1, self.profiles[self.sgt_ids], self.profiles[j]
+            )
         mat[np.tril_indices(self.sgt_ids.size)] = np.nan
         return mat
 
-
     def update_dbt_map(self, new_cl, del_cl):
         dbt_map_new = {}
         consistent = True
@@ -454,42 +505,39 @@ def update_dbt_map(self, new_cl, del_cl):
 
         return dbt_map_new, consistent
 
-
     # ------------------------ HEATMAP GENERATION ------------------------------
 
     @staticmethod
     def get_cmap():
         # Edit this gradient at https://eltos.github.io/gradient/
-        cmap = LinearSegmentedColormap.from_list('my_gradient', (
-            (0.000, (1.000, 1.000, 1.000)),
-            (0.167, (1.000, 0.882, 0.710)),
-            (0.333, (1.000, 0.812, 0.525)),
-            (0.500, (1.000, 0.616, 0.000)),
-            (0.667, (1.000, 0.765, 0.518)),
-            (0.833, (1.000, 0.525, 0.494)),
-            (1.000, (1.000, 0.000, 0.000)))
+        cmap = LinearSegmentedColormap.from_list(
+            'my_gradient',
+            (
+                (0.000, (1.000, 1.000, 1.000)),
+                (0.167, (1.000, 0.882, 0.710)),
+                (0.333, (1.000, 0.812, 0.525)),
+                (0.500, (1.000, 0.616, 0.000)),
+                (0.667, (1.000, 0.765, 0.518)),
+                (0.833, (1.000, 0.525, 0.494)),
+                (1.000, (1.000, 0.000, 0.000)),
+            ),
         )
         return cmap
 
-
     def get_hm_data(self):
-        df =  np.nan_to_num(self.VAF, nan=-1)
+        df = np.nan_to_num(self.VAF, nan=-1)
         mask = np.zeros(self.VAF.shape, dtype=bool)
         mask[self.dp == 0] = True
         return df, mask
 
-
     @staticmethod
     def get_hm_specifics():
-        return {'vmin': 0, 'vmax': 1,
-            'cbar_kws': {'ticks': [0, 1], 'label': 'VAF'}}
-
+        return {'vmin': 0, 'vmax': 1, 'cbar_kws': {'ticks': [0, 1], 'label': 'VAF'}}
 
     @staticmethod
     def apply_cm_specifics(cm):
         pass
 
-
     def plot_heatmap(self, out_file='', cluster=True):
         cmap = self.get_cmap()
 
@@ -536,10 +584,11 @@ def get_row_cols(assignment):
                 cmap=cmap,
                 figsize=(25, 10),
                 xticklabels=self.SNPs,
-                cbar_kws={'ticks': [0, 1], 'label': 'VAF'}
+                cbar_kws={'ticks': [0, 1], 'label': 'VAF'},
             )
         except tk.TclError:
             import matplotlib
+
             matplotlib.use('Agg')
             cm = clustermap(
                 df_plot,
@@ -547,12 +596,13 @@ def get_row_cols(assignment):
                 row_cluster=cluster,
                 col_cluster=cluster,
                 mask=mask,
-                vmin=0, vmax=1,
+                vmin=0,
+                vmax=1,
                 row_colors=r_colors,
                 cmap=cmap,
                 figsize=(25, 10),
                 xticklabels=self.SNPs,
-                cbar_kws={'ticks': [0, 1], 'label': 'VAF'}
+                cbar_kws={'ticks': [0, 1], 'label': 'VAF'},
             )
 
         cm.ax_heatmap.set_facecolor('#5B566C')
@@ -561,8 +611,13 @@ def get_row_cols(assignment):
             cm.ax_heatmap.set_yticks([])
         cm.ax_heatmap.set_xlabel('SNPs')
 
-        cm.ax_heatmap.set_xticklabels(cm.ax_heatmap.get_xticklabels(),
-            rotation=45, fontsize=5, ha='right', va='top')
+        cm.ax_heatmap.set_xticklabels(
+            cm.ax_heatmap.get_xticklabels(),
+            rotation=45,
+            fontsize=5,
+            ha='right',
+            va='top',
+        )
 
         cm.ax_col_dendrogram.set_visible(False)
         self.apply_cm_specifics(cm)
@@ -574,7 +629,6 @@ def get_row_cols(assignment):
         else:
             plt.show()
 
-
     # -------------------------- GENERATE OUTPUT -------------------------------
 
     def safe_results(self, output):
@@ -590,11 +644,11 @@ def safe_results(self, output):
 
         # Safe SNV profiles to identy patients
         VAF_df = self.get_VAF_profile(self.sgt_ids)
-        VAF_df.index = [f'{cl_map[i]} ({COLORS_STR[int(cl_map[i])]})' \
-            for i in self.sgt_ids]
+        VAF_df.index = [
+            f'{cl_map[i]} ({COLORS_STR[int(cl_map[i])]})' for i in self.sgt_ids
+        ]
         VAF_df.round(2).to_csv(f'{output}.profiles.tsv', sep='\t')
 
-
     def get_VAF_profile(self, cl):
         reads_all = np.reshape(self.profiles[cl], (cl.size, 3, self.SNPs.size))
         VAF_raw = []
@@ -602,7 +656,6 @@ def get_VAF_profile(self, cl):
             VAF_raw.append(reads_cl[0] / reads_cl[2])
         return pd.DataFrame(VAF_raw, index=cl, columns=self.SNPs)
 
-
     def print_summary(self, cl_map):
         GTs = {'WT': (0, 0.35), 'HET': (0.35, 0.95), 'HOM': (0.95, 1)}
         for cl_id, cl_size in zip(*np.unique(self.assignment, return_counts=True)):
@@ -612,8 +665,10 @@ def print_summary(self, cl_map):
                 cl_color = f'{COLORS_STR[cl1]}+{COLORS_STR[cl2]}'
             else:
                 cl_color = COLORS_STR[int(cl_name)]
-            print(f'Cluster {cl_name} ({cl_color}): {cl_size: >4} cells ' \
-                f'({cl_size / self.cells.size * 100: >2.0f}%)')
+            print(
+                f'Cluster {cl_name} ({cl_color}): {cl_size: >4} cells '
+                f'({cl_size / self.cells.size * 100: >2.0f}%)'
+            )
 
             VAF_cl = self.VAF[self.assignment == cl_id]
             VAF_cl_called = (VAF_cl >= EPSILON).sum(axis=0)
@@ -651,22 +706,37 @@ def main(args):
 
 def parse_args():
     parser = argparse.ArgumentParser()
-    parser.add_argument('-i', '--input', type=str, nargs='+',
-        help='Input _variants.csv file(s).')
-    parser.add_argument('-o', '--output', type=str, default='',
-        help='Output base name: <DIR>/<> .')
-    parser.add_argument('-n', '--clusters', type=int, default=1,
-        help='Number of clusters to define. Default = 1.')
+    parser.add_argument(
+        '-i', '--input', type=str, nargs='+', help='Input _variants.csv file(s).'
+    )
+    parser.add_argument(
+        '-o', '--output', type=str, default='', help='Output base name: <DIR>/<> .'
+    )
+    parser.add_argument(
+        '-n',
+        '--clusters',
+        type=int,
+        default=1,
+        help='Number of clusters to define. Default = 1.',
+    )
 
     plotting = parser.add_argument_group('plotting')
-    plotting.add_argument('-op', '--output_plot', action='store_true',
-        help='Output file for heatmap with dendrogram to "<INPUT>.hm.png".')
-    plotting.add_argument('-sp', '--show_plot', action='store_true',
-        help='Show heatmap with dendrogram at stdout.')
+    plotting.add_argument(
+        '-op',
+        '--output_plot',
+        action='store_true',
+        help='Output file for heatmap with dendrogram to "<INPUT>.hm.png".',
+    )
+    plotting.add_argument(
+        '-sp',
+        '--show_plot',
+        action='store_true',
+        help='Show heatmap with dendrogram at stdout.',
+    )
 
     return parser.parse_args()
 
 
 if __name__ == '__main__':
     args = parse_args()
-    main(args)
\ No newline at end of file
+    main(args)
diff --git a/benchmarking_pipeline/workflow/scripts/demultiplex_distance.py b/benchmarking_pipeline/workflow/scripts/demultiplex_distance.py
index d4fcd27..6b45ed7 100644
--- a/benchmarking_pipeline/workflow/scripts/demultiplex_distance.py
+++ b/benchmarking_pipeline/workflow/scripts/demultiplex_distance.py
@@ -3,7 +3,6 @@
 import argparse
 from itertools import combinations
 import os
-import re
 
 import numpy as np
 from matplotlib import pyplot as plt
@@ -104,30 +103,30 @@ def __init__(self, in_file, cl_no, distance='manhattan'):
         partial_order = {}
         for cell in df.columns:
             if '+' in cell:
-                s1 = cell.split('+')[0].split("_")[1]
-                s2 = cell.split('+')[1].split("_")[1]
+                s1 = cell.split('+')[0].split('_')[1]
+                s2 = cell.split('+')[1].split('_')[1]
                 # Create a dictionary to store partial orders
                 if s2 in partial_order.keys():
-                    if s1 in partial_order[s2]: ## This means s2<s1
-                        s2,s1 = s1,s2
-                elif not s1 in partial_order.keys():
-                        partial_order[s1] = [s2]
-                elif not s2 in partial_order[s1]:
+                    if s1 in partial_order[s2]:  ## This means s2<s1
+                        s2, s1 = s1, s2
+                elif s1 not in partial_order.keys():
+                    partial_order[s1] = [s2]
+                elif s2 not in partial_order[s1]:
                     partial_order[s1].append(s2)
 
                 true_cl.append('+'.join([s1, s2]))
             else:
-                s1 = cell.split("_")[1]
+                s1 = cell.split('_')[1]
                 true_cl.append(s1)
-            #if 'pat' in cell:
+            # if 'pat' in cell:
             #    s1 = int(re.split(r'[\.,]', cell.split('+')[0])[-1][3:])
-            #else:
+            # else:
             #    s1 = 0
 
-            #if '+' in cell:
+            # if '+' in cell:
             #   s2 = int(re.split(r'[\.,]', cell.split('+')[1])[-1][3:])
             #    true_cl.append('+'.join([str(j) for j in sorted([s1, s2])]))
-            #else:
+            # else:
             #    true_cl.append(s1)
         self.true_cl = np.array(true_cl)
 
@@ -169,13 +168,11 @@ def check_hom_match(self, scl, dcl):
 
     def demultiplex(self):
         self.init_dendrogram()
-        #self.identify_doublets()
+        # self.identify_doublets()
         self.set_assignment(self.sgt_cl_no)
-        self.sgt_ids = np.array(
-            [i for i in np.unique(self.assignment)]
-        )
+        self.sgt_ids = np.array([i for i in np.unique(self.assignment)])
         self.delete_garbage_cluster()
-        #self.merge_surplus_singlets()
+        # self.merge_surplus_singlets()
 
     # -------------------------------- DENDROGRAM ----------------------------------
 
@@ -970,7 +967,7 @@ def main(args):
 
     for in_file in in_files:
         if args.metric == 'reads':
-            dt = demoTape_reads(in_file, args.clusters+1)
+            dt = demoTape_reads(in_file, args.clusters + 1)
         else:
             dt = demoTape_gt(in_file, args.clusters, args.metric)
 
@@ -1028,4 +1025,4 @@ def parse_args():
 
 if __name__ == '__main__':
     args = parse_args()
-    main(args)
\ No newline at end of file
+    main(args)
diff --git a/benchmarking_pipeline/workflow/scripts/get_input_samples.py b/benchmarking_pipeline/workflow/scripts/get_input_samples.py
index be79f2b..5fa89a5 100644
--- a/benchmarking_pipeline/workflow/scripts/get_input_samples.py
+++ b/benchmarking_pipeline/workflow/scripts/get_input_samples.py
@@ -3,20 +3,44 @@
 from pathlib import Path
 import logging
 
-logging.basicConfig(format="{asctime} - {levelname} - {message}", style="{", datefmt="%Y-%m-%d %H:%M",level=logging.DEBUG)
+logging.basicConfig(
+    format='{asctime} - {levelname} - {message}',
+    style='{',
+    datefmt='%Y-%m-%d %H:%M',
+    level=logging.DEBUG,
+)
+
 
 def generate_input_samples(loom_files, number_of_samples):
     return random.sample(loom_files, number_of_samples)
 
 
-
-
 def parse_args():
-    parser = argparse.ArgumentParser(description="Get a random list of samples to be pooled")
-    parser.add_argument("--loom_file_path", type=str, help="Path to a folder containing loom files")
-    parser.add_argument("--output", type=str, default=None, help="path to a directory where the sample list will be stored")
-    parser.add_argument("--seed", type=int, required=True, default=None, help="random seed for reproducibility of sampling")
-    parser.add_argument("--number_of_samples", type=int, default=1, help="number of samples to be selected")
+    parser = argparse.ArgumentParser(
+        description='Get a random list of samples to be pooled'
+    )
+    parser.add_argument(
+        '--loom_file_path', type=str, help='Path to a folder containing loom files'
+    )
+    parser.add_argument(
+        '--output',
+        type=str,
+        default=None,
+        help='path to a directory where the sample list will be stored',
+    )
+    parser.add_argument(
+        '--seed',
+        type=int,
+        required=True,
+        default=None,
+        help='random seed for reproducibility of sampling',
+    )
+    parser.add_argument(
+        '--number_of_samples',
+        type=int,
+        default=1,
+        help='number of samples to be selected',
+    )
 
     return parser.parse_args()
 
@@ -24,23 +48,24 @@ def parse_args():
 def main(args):
     loom_files = []
     for p in Path(args.loom_file_path).iterdir():
-        if p.is_file() and p.suffix == ".loom":
+        if p.is_file() and p.suffix == '.loom':
             loom_files.append(p)
-            
+    number_of_samples = args.number_of_samples
     if args.number_of_samples > len(loom_files):
-        logging.error(f"Number of samples ({number_of_samples}) is greater than the number of loom files ({len(loom_files)}). Continuing with maximal number of samples.")
+        logging.error(
+            f'Number of samples ({args.number_of_samples}) is greater than the number of loom files ({len(loom_files)}). Continuing with maximal number of samples.'
+        )
         input_samples = loom_files
         number_of_samples = len(loom_files)
-    else:
-        random.seed(args.seed)
-        input_samples = random.sample(loom_files, args.number_of_samples)
 
-    
-    with open(args.output, "w") as f:
+    random.seed(args.seed)
+    input_samples = random.sample(loom_files, number_of_samples)
+
+    with open(args.output, 'w') as f:
         for sample in input_samples:
-            f.write(f"{sample}\n")
+            f.write(f'{sample}\n')
 
 
-if __name__ == "__main__":
+if __name__ == '__main__':
     args = parse_args()
-    main(args)
\ No newline at end of file
+    main(args)
diff --git a/benchmarking_pipeline/workflow/scripts/mosaic_processing.py b/benchmarking_pipeline/workflow/scripts/mosaic_processing.py
index 9db9a73..afd6d3c 100644
--- a/benchmarking_pipeline/workflow/scripts/mosaic_processing.py
+++ b/benchmarking_pipeline/workflow/scripts/mosaic_processing.py
@@ -1,4 +1,4 @@
-#Authors:
+# Authors:
 # 1. Nico Borgsmüller
 # 2. Johannes Gawron
 
@@ -19,7 +19,12 @@
 import pandas as pd
 
 
-logging.basicConfig(format="{asctime} - {levelname} - {message}", style="{", datefmt="%Y-%m-%d %H:%M",level=logging.DEBUG)
+logging.basicConfig(
+    format='{asctime} - {levelname} - {message}',
+    style='{',
+    datefmt='%Y-%m-%d %H:%M',
+    level=logging.DEBUG,
+)
 
 EPSILON = np.finfo(np.float64).resolution
 
@@ -99,7 +104,7 @@ def main(args):
         logging.warning(
             "Computation of the VAFs based on all cells' genotype call at a position. There is no correction of the VAFs for tumor purity and ploidy."
         )
-        
+
         """ 
         logging.error(
             'Unexpected behaviour in the computation of variant allele frequencies. Values may exceed 1.'
@@ -395,15 +400,15 @@ def merge_gt(gt_in):
         return 1
 
 
-def subsample_cells(no_cells, ratios, total_cell_count = np.inf):
+def subsample_cells(no_cells, ratios, total_cell_count=np.inf):
     if not isinstance(ratios, np.ndarray):
-        raise TypeError("Ratios need to be a numpy array!")
+        raise TypeError('Ratios need to be a numpy array!')
     smallest_sample = np.argmin(no_cells)
     if total_cell_count < np.inf:
         samples_total = (args.cell_no * ratios).astype(int)
     else:
-        samples_total = (no_cells.min() * ratios/ratios[smallest_sample]).astype(int)
-    for i,size in enumerate(samples_total):
+        samples_total = (no_cells.min() * ratios / ratios[smallest_sample]).astype(int)
+    for i, size in enumerate(samples_total):
         samples_total[i] = min(size, no_cells[i])
 
     return samples_total
@@ -415,7 +420,7 @@ def multiplex_looms(args):
         'No. input files has to be the same as no. ratios. '
         f'({no_samples} != {len(args.ratio)})'
     )
-    #assert np.sum(args.ratio) <= 1, 'Ratios cannot sum up to >1.'
+    # assert np.sum(args.ratio) <= 1, 'Ratios cannot sum up to >1.'
 
     no_cells = np.zeros(no_samples, dtype=int)
     with tempfile.TemporaryDirectory() as temp_dir:
@@ -424,10 +429,10 @@ def multiplex_looms(args):
             shutil.copy2(in_file, temp_in_file)
             with loompy.connect(temp_in_file) as ds:
                 no_cells[i] = ds.shape[1]
-    
-        ##To obtain an unbiased cell count and doublet rate, we initially start with  the number of cells and add (number of cells) * doublet_rate many samples. Later, pairs of cells will be merged to a doublet, which will reduce the total cell count again accordingly. 
+
+        ##To obtain an unbiased cell count and doublet rate, we initially start with  the number of cells and add (number of cells) * doublet_rate many samples. Later, pairs of cells will be merged to a doublet, which will reduce the total cell count again accordingly.
         samples_size = subsample_cells(no_cells, np.array(args.ratio), args.cell_no)
-    
+
         samples = {}
         for i, size in enumerate(samples_size):
             size = min(size, no_cells[i])
@@ -439,10 +444,10 @@ def multiplex_looms(args):
             print(f'Reading: {in_file}')
             # Open loom file and read data
             with loompy.connect(temp_in_file) as ds:
-                index = concat_str_arrays(    
+                index = concat_str_arrays(
                     [ds.ra['CHROM'], ds.ra['POS'], ds.ra['REF'], ds.ra['ALT']]
                 )
-                
+
                 cols = samples[i]['idx'] + ((i + 1) / 10)
                 df_new = pd.DataFrame(
                     ds[:, samples[i]['idx']], index=index, columns=cols
@@ -480,7 +485,7 @@ def multiplex_looms(args):
                         AD_new = AD_new.join(pd.DataFrame(view.layers['AD'][:, :], index=index, columns=cols))
                         RO_new = RO_new.join(pd.DataFrame(view.layers['RO'][:, :], index=index, columns=cols))
                 """
-                
+
                 try:
                     barcodes = np.char.add(
                         ds.col_attrs['barcode'][samples[i]['idx']], f'.pat{i:.0f}'
@@ -489,22 +494,25 @@ def multiplex_looms(args):
                     samples[i]['name'] = barcodes.astype(f'<U{2*barcode_length+1}')
                 except (AttributeError, TypeError) as error:
                     logging.error(error)
-                    logging.error("No barcode information found in loom file. Continuing with generic cell names.")
+                    logging.error(
+                        'No barcode information found in loom file. Continuing with generic cell names.'
+                    )
 
                     in_file_stripped = str(Path(in_file).stem)
                     match = re.search(r'_split([12])$', in_file_stripped)
                     in_file_stripped = re.sub(r'_split[12]$', '', in_file_stripped)
                     if not match:
-                        raise ValueError("The file names must end with '_split1' or '_split2'.")
+                        raise ValueError(
+                            "The file names must end with '_split1' or '_split2'."
+                        )
                     split_number = match.group(1)
                     matched_pattern = f'split{split_number}'
-                    base_name = f"_{in_file_stripped}.{matched_pattern}"
+                    base_name = f'_{in_file_stripped}.{matched_pattern}'
                     counts = np.arange(samples[i]['idx'].size) + 1
                     counts_string = np.char.mod('%d', counts)
                     samples[i]['name'] = np.char.add(counts_string, base_name)
 
-
-        # First sample, nothing to merge
+            # First sample, nothing to merge
             if i == 0:
                 df = df_new
                 ampl = ampl_new
@@ -514,7 +522,7 @@ def multiplex_looms(args):
                 RO = RO_new
                 continue
 
-        # Merge amplicons (keep all)
+            # Merge amplicons (keep all)
             ampl = ampl.combine_first(ampl_new)
             df = df.merge(df_new, left_index=True, right_index=True)
             DP = DP.merge(DP_new, left_index=True, right_index=True)
@@ -530,23 +538,22 @@ def multiplex_looms(args):
     del RO_new
     gc.collect()
 
-
     # Add doublets (if doublets arg specified)
     print('Generating doublets')
     samples['dbt'] = {'name': []}
 
-    s_probs = samples_size/np.sum(samples_size)
+    s_probs = samples_size / np.sum(samples_size)
     drop_cells = []
     dbt_gt = {}
     dbt_DP = {}
     dbt_GQ = {}
     dbt_AD = {}
     dbt_RO = {}
-    
-    #We have sample_size + doublet_rate * sample_size many cells in total.
-    #The number of doublets is doublet_rate* sample_size = doublet_rate/(1+doublet_rate) * total_cell_count
-    
-    dbt_total = int(args.doublets/(1 + args.doublets) * np.sum(samples_size))
+
+    # We have sample_size + doublet_rate * sample_size many cells in total.
+    # The number of doublets is doublet_rate* sample_size = doublet_rate/(1+doublet_rate) * total_cell_count
+
+    dbt_total = int(args.doublets / (1 + args.doublets) * np.sum(samples_size))
     for i in range(dbt_total):
         s1, s2 = np.random.choice(no_samples, size=2, replace=False, p=s_probs)
         c1_idx = np.random.choice(samples[s1]['idx'].size)
@@ -556,10 +563,10 @@ def multiplex_looms(args):
         c2 = samples[s2]['idx'][c2_idx] + ((s2 + 1) / 10)
 
         new_id = f'{c1}+{c2}'
-        new_name_sample1 = samples[s1]["name"][c1_idx]
-        new_name_sample2 = samples[s2]["name"][c2_idx]
+        new_name_sample1 = samples[s1]['name'][c1_idx]
+        new_name_sample2 = samples[s2]['name'][c2_idx]
         new_name = f'{new_name_sample1}+{new_name_sample2}'
-        
+
         dbt_gt[new_id] = np.apply_along_axis(merge_gt, axis=1, arr=df[[c1, c2]])
         dbt_GQ[new_id] = GQ[[c1, c2]].mean(axis=1)
         dbt_DP[new_id] = (DP[[c1, c2]].mean(axis=1).round()).astype(int)
@@ -664,17 +671,16 @@ def multiplex_looms(args):
     }
 
     cell_data = np.empty((len(variants_info['CHR']), len(cells)), dtype='<U11')
-    cell_data = np.char.add(np.char.add(RO.astype(str), ':'), np.char.add(AD.astype(str), ':'))
+    cell_data = np.char.add(
+        np.char.add(RO.astype(str), ':'), np.char.add(AD.astype(str), ':')
+    )
     cell_data = np.char.add(cell_data, gt.astype(str))
 
     for col, cell in enumerate(cells):
         name = cell_names[col]
         variants_info[name] = cell_data[:, col].tolist()
-    
-    return pd.DataFrame(variants_info), gt, VAF
-
-
 
+    return pd.DataFrame(variants_info), gt, VAF
 
 
 def compute_pseudobulk_VAFs(df1, gt1):
@@ -888,9 +894,9 @@ def parse_args():
 
 
 if __name__ == '__main__':
-    logging.info("Parsing arguments.")
+    logging.info('Parsing arguments.')
     args = parse_args()
     if args.whitelist:
         update_whitelist(args.whitelist)
-    logging.info("Running main program.")
-    main(args)
\ No newline at end of file
+    logging.info('Running main program.')
+    main(args)
diff --git a/benchmarking_pipeline/workflow/scripts/split_loom_files.py b/benchmarking_pipeline/workflow/scripts/split_loom_files.py
index d7f5dca..3c2ace1 100644
--- a/benchmarking_pipeline/workflow/scripts/split_loom_files.py
+++ b/benchmarking_pipeline/workflow/scripts/split_loom_files.py
@@ -13,7 +13,12 @@
 import loompy
 
 
-logging.basicConfig(format="{asctime} - {levelname} - {message}", style="{", datefmt="%Y-%m-%d %H:%M",level=logging.DEBUG)
+logging.basicConfig(
+    format='{asctime} - {levelname} - {message}',
+    style='{',
+    datefmt='%Y-%m-%d %H:%M',
+    level=logging.DEBUG,
+)
 
 
 def batched(iterable, chunk_size):
@@ -26,77 +31,83 @@ class FancyLoomConnection(loompy.LoomConnection):
     def __init__(self, filename: str, mode: str = 'r+', *, validate: bool = True):
         super().__init__(filename, mode, validate=validate)
 
-    def add_columns_batched(self, loom_view_manager, col_indices, batch_size = 30):
-        logging.info("subselecting cells in batched mode")
+    def add_columns_batched(self, loom_view_manager, col_indices, batch_size=30):
+        logging.info('subselecting cells in batched mode')
         batches = list(batched(col_indices, batch_size))
         for idx, batch in enumerate(batches):
-            logging.info(f"Adding batch {idx+1} of {len(batches)}")
+            logging.info(f'Adding batch {idx+1} of {len(batches)}')
             ds_batch = loom_view_manager[:, np.array(batch)]
-            self.add_columns(ds_batch.layers, col_attrs = ds_batch.ca, row_attrs = ds_batch.ra)
+            self.add_columns(
+                ds_batch.layers, col_attrs=ds_batch.ca, row_attrs=ds_batch.ra
+            )
 
 
-def fancy_loompy_connect(filename: str, mode: str = 'r+', *, validate: bool = True) -> FancyLoomConnection:
+def fancy_loompy_connect(
+    filename: str, mode: str = 'r+', *, validate: bool = True
+) -> FancyLoomConnection:
     """
-	Establish a connection to a .loom file.
+    Establish a connection to a .loom file.
 
-	Args:
-		filename:		Path to the Loom file to open
-		mode:			Read/write mode, 'r+' (read/write) or 'r' (read-only), defaults to 'r+'
-		validate:		Validate the file structure against the Loom file format specification
-	Returns:
-		A LoomConnection instance.
+    Args:
+            filename:		Path to the Loom file to open
+            mode:			Read/write mode, 'r+' (read/write) or 'r' (read-only), defaults to 'r+'
+            validate:		Validate the file structure against the Loom file format specification
+    Returns:
+            A LoomConnection instance.
 
-	Remarks:
-		This function should typically be used as a context manager (i.e. inside a ``with``-block):
+    Remarks:
+            This function should typically be used as a context manager (i.e. inside a ``with``-block):
 
-		.. highlight:: python
-		.. code-block:: python
+            .. highlight:: python
+            .. code-block:: python
 
-			import loompy
-			with loompy.connect("mydata.loom") as ds:
-				print(ds.ca.keys())
+                    import loompy
+                    with loompy.connect("mydata.loom") as ds:
+                            print(ds.ca.keys())
 
-		This ensures that the file will be closed automatically when the context block ends
+            This ensures that the file will be closed automatically when the context block ends
 
-		Note: if validation is requested, an exception is raised if validation fails.
-	"""
+            Note: if validation is requested, an exception is raised if validation fails.
+    """
     return FancyLoomConnection(filename, mode, validate=validate)
 
 
-def fancy_loompy_new(filename: str, *, file_attrs: Optional[Dict[str, str]] = None) -> FancyLoomConnection:
-	"""
-	Create an empty Loom file, and return it as a context manager.
-	"""
-	if filename.startswith("~/"):
-		filename = os.path.expanduser(filename)
-	if file_attrs is None:
-		file_attrs = {}
-
-	# Create the file (empty).
-	# Yes, this might cause an exception, which we prefer to send to the caller
-	f = h5py.File(name=filename, mode='w')
-	f.create_group('/attrs')  # v3.0.0
-	f.create_group('/layers')
-	f.create_group('/row_attrs')
-	f.create_group('/col_attrs')
-	f.create_group('/row_graphs')
-	f.create_group('/col_graphs')
-	f.flush()
-	f.close()
-
-	ds = fancy_loompy_connect(filename, validate=False)
-	for vals in file_attrs:
-		if file_attrs[vals] is None:
-			ds.attrs[vals] = "None"
-		else:
-			ds.attrs[vals] = file_attrs[vals]
-	# store creation date
-	ds.attrs['CreationDate'] = loompy.timestamp()
-	ds.attrs["LOOM_SPEC_VERSION"] = loompy.loom_spec_version
-	return ds
-
-
-def sample_split(alpha, seed = None):
+def fancy_loompy_new(
+    filename: str, *, file_attrs: Optional[Dict[str, str]] = None
+) -> FancyLoomConnection:
+    """
+    Create an empty Loom file, and return it as a context manager.
+    """
+    if filename.startswith('~/'):
+        filename = os.path.expanduser(filename)
+    if file_attrs is None:
+        file_attrs = {}
+
+    # Create the file (empty).
+    # Yes, this might cause an exception, which we prefer to send to the caller
+    f = h5py.File(name=filename, mode='w')
+    f.create_group('/attrs')  # v3.0.0
+    f.create_group('/layers')
+    f.create_group('/row_attrs')
+    f.create_group('/col_attrs')
+    f.create_group('/row_graphs')
+    f.create_group('/col_graphs')
+    f.flush()
+    f.close()
+
+    ds = fancy_loompy_connect(filename, validate=False)
+    for vals in file_attrs:
+        if file_attrs[vals] is None:
+            ds.attrs[vals] = 'None'
+        else:
+            ds.attrs[vals] = file_attrs[vals]
+    # store creation date
+    ds.attrs['CreationDate'] = loompy.timestamp()
+    ds.attrs['LOOM_SPEC_VERSION'] = loompy.loom_spec_version
+    return ds
+
+
+def sample_split(alpha, seed=None):
     if seed:
         rng = np.random.default_rng(seed=seed)
     else:
@@ -106,8 +117,7 @@ def sample_split(alpha, seed = None):
     return splitting_ratio
 
 
-
-def split_loom_file(splitting_ratio, loom_file, output_dir, seed = None):
+def split_loom_file(splitting_ratio, loom_file, output_dir, seed=None):
     if seed:
         rng = np.random.default_rng(seed=seed)
     else:
@@ -116,12 +126,13 @@ def split_loom_file(splitting_ratio, loom_file, output_dir, seed = None):
     with tempfile.TemporaryDirectory() as temp_dir:
         temp_loom_file = Path(temp_dir) / Path(loom_file).name
         shutil.copy2(loom_file, temp_loom_file)
-            
 
         with loompy.connect(temp_loom_file) as ds:
             n_cells = ds.shape[1]
-            split_vector = rng.choice([0,1], size=n_cells, p=[1-splitting_ratio, splitting_ratio])
-            
+            split_vector = rng.choice(
+                [0, 1], size=n_cells, p=[1 - splitting_ratio, splitting_ratio]
+            )
+
             indices_first_sample = np.where(split_vector == 1)[0]
             indices_second_sample = np.where(split_vector == 0)[0]
 
@@ -130,45 +141,62 @@ def split_loom_file(splitting_ratio, loom_file, output_dir, seed = None):
             output_1 = (output_dir / f'{loom_file.stem}_split1.loom').as_posix()
 
             with fancy_loompy_new(output_1) as ds1:
-                logging.info(f"Connection to {output_1} established")
-                ds1.add_columns_batched(ds_view, indices_first_sample, batch_size = 30)
+                logging.info(f'Connection to {output_1} established')
+                ds1.add_columns_batched(ds_view, indices_first_sample, batch_size=30)
                 number_of_cells1 = ds1.shape[1]
-            logging.info("Writing file 1 was successful!")
+            logging.info('Writing file 1 was successful!')
 
             output_2 = (output_dir / f'{loom_file.stem}_split2.loom').as_posix()
             with fancy_loompy_new(output_2) as ds2:
-                logging.info(f"Connection to {output_2} established")
-                ds2.add_columns_batched(ds_view, indices_second_sample, batch_size = 30)
+                logging.info(f'Connection to {output_2} established')
+                ds2.add_columns_batched(ds_view, indices_second_sample, batch_size=30)
                 number_of_cells2 = ds2.shape[1]
-            logging.info("Writing file 2 was successful!")
+            logging.info('Writing file 2 was successful!')
             if number_of_cells1 + number_of_cells2 == n_cells:
-                logging.info("Splitting was successful!")
+                logging.info('Splitting was successful!')
             else:
-                logging.error("The number of cells in the split files does not match the number of cells in the original file")
-                raise ValueError(f"Split 1 has {number_of_cells1} cells and split 2 has {number_of_cells2} cells, but the original file has {n_cells} cells")
+                logging.error(
+                    'The number of cells in the split files does not match the number of cells in the original file'
+                )
+                raise ValueError(
+                    f'Split 1 has {number_of_cells1} cells and split 2 has {number_of_cells2} cells, but the original file has {n_cells} cells'
+                )
     return None
 
 
 def parse_args():
-    parser = argparse.ArgumentParser(description="Simulate multiplexed single-cell (or multi-sample) mutation data")
-    parser.add_argument("--input", type=str, required=True, help="Input list of loom files")
-    parser.add_argument("--alpha", type=float, required=True, help="Dirichlet distribution parameter")
-    parser.add_argument("--output", type=str, required=True, help="One of the output files. The other output file is saves to the same directory")
-
-    logging.info("Parsing arguments")
+    parser = argparse.ArgumentParser(
+        description='Simulate multiplexed single-cell (or multi-sample) mutation data'
+    )
+    parser.add_argument(
+        '--input', type=str, required=True, help='Input list of loom files'
+    )
+    parser.add_argument(
+        '--alpha', type=float, required=True, help='Dirichlet distribution parameter'
+    )
+    parser.add_argument(
+        '--output',
+        type=str,
+        required=True,
+        help='One of the output files. The other output file is saves to the same directory',
+    )
+
+    logging.info('Parsing arguments')
     args = parser.parse_args()
     return args
 
 
 def main(args):
-        logging.info(f"Splitting file {args.input}")
-        loom_file = args.input
-        splitting_ratio = sample_split(alpha=args.alpha)
-        output_dir = Path(args.output).parent
-        split_loom_file(splitting_ratio = splitting_ratio, loom_file =  loom_file, output_dir = output_dir)
-        logging.info(f"Done splitting file {args.input}")
+    logging.info(f'Splitting file {args.input}')
+    loom_file = args.input
+    splitting_ratio = sample_split(alpha=args.alpha)
+    output_dir = Path(args.output).parent
+    split_loom_file(
+        splitting_ratio=splitting_ratio, loom_file=loom_file, output_dir=output_dir
+    )
+    logging.info(f'Done splitting file {args.input}')
 
 
-if __name__ == "__main__":
+if __name__ == '__main__':
     args = parse_args()
-    main(args)
\ No newline at end of file
+    main(args)
diff --git a/mixmax/create_demultiplexing_scheme.py b/mixmax/create_demultiplexing_scheme.py
index 331816d..5e16c61 100644
--- a/mixmax/create_demultiplexing_scheme.py
+++ b/mixmax/create_demultiplexing_scheme.py
@@ -6,39 +6,56 @@
 
 import numpy as np
 
-logging.basicConfig(format="{asctime} - {levelname} - {message}", style="{", datefmt="%Y-%m-%d %H:%M",level=logging.INFO)
+logging.basicConfig(
+    format='{asctime} - {levelname} - {message}',
+    style='{',
+    datefmt='%Y-%m-%d %H:%M',
+    level=logging.INFO,
+)
 
 
-def define_demultiplexing_scheme_optimal_case(maximal_number_of_samples, maximal_pool_size, n_samples, robust):
+def define_demultiplexing_scheme_optimal_case(
+    maximal_number_of_samples, maximal_pool_size, n_samples, robust
+):
     if n_samples % maximal_number_of_samples != 0:
-        raise ValueError("Number of samples must be a multiple of maximal number of samples to run this function!")
+        raise ValueError(
+            'Number of samples must be a multiple of maximal number of samples to run this function!'
+        )
     demultiplexing_scheme = {}
     number_of_iterations = int(n_samples / maximal_number_of_samples)
     if not robust:
-        unordered_unique_pairs = list(itertools.combinations(range(maximal_pool_size), 2))
+        unordered_unique_pairs = list(
+            itertools.combinations(range(maximal_pool_size), 2)
+        )
         diagonal = list(zip(range(maximal_pool_size), range(maximal_pool_size)))
         unordered_pairs = unordered_unique_pairs + diagonal
     else:
-        unordered_pairs = list(itertools.combinations(range(maximal_pool_size+1), 2))                           
+        unordered_pairs = list(itertools.combinations(range(maximal_pool_size + 1), 2))
 
     for idx1 in range(number_of_iterations):
         for idx2, pair in enumerate(unordered_pairs):
-            demultiplexing_scheme[int(idx1 * maximal_number_of_samples + idx2 + 1)] = pair + (idx1,) # naming of samples starts with 1
+            demultiplexing_scheme[int(idx1 * maximal_number_of_samples + idx2 + 1)] = (
+                pair + (idx1,)
+            )  # naming of samples starts with 1
 
     return demultiplexing_scheme
 
 
-
 def find_demultiplexing_scheme(maximal_pool_size, n_samples, robust):
-    maximal_number_of_samples = (maximal_pool_size * (maximal_pool_size+1))/2
+    maximal_number_of_samples = (maximal_pool_size * (maximal_pool_size + 1)) / 2
 
     if n_samples % maximal_number_of_samples != 0:
-        logging.error("So far, only defined for certain cohort sizes")
+        logging.error('So far, only defined for certain cohort sizes')
         raise NotImplementedError
     else:
-        logging.info("Creating multiplexing scheme")
-        demultiplexing_scheme = define_demultiplexing_scheme_optimal_case(maximal_number_of_samples = maximal_number_of_samples, maximal_pool_size = maximal_pool_size, n_samples = n_samples, robust = robust)
-        
+        logging.info('Creating multiplexing scheme')
+        demultiplexing_scheme = define_demultiplexing_scheme_optimal_case(
+            maximal_number_of_samples=maximal_number_of_samples,
+            maximal_pool_size=maximal_pool_size,
+            n_samples=n_samples,
+            robust=robust,
+        )
+
         logging.info(f'Demultiplexing scheme: {demultiplexing_scheme}')
 
         return demultiplexing_scheme
@@ -46,91 +63,135 @@ def find_demultiplexing_scheme(maximal_pool_size, n_samples, robust):
 
 def multiplexing_scheme_format2pool_format(demultiplexing_scheme):
     pool_scheme = {}
-    total_no_pools_per_repetition = np.max([pool[:-1] for pool in list(demultiplexing_scheme.values())])
-    total_no_of_repetitions = np.max([pool[-1] for pool in list(demultiplexing_scheme.values())])
-
-    pools = list(itertools.product(range(total_no_pools_per_repetition + 1), range(total_no_of_repetitions + 1)))
+    total_no_pools_per_repetition = np.max(
+        [pool[:-1] for pool in list(demultiplexing_scheme.values())]
+    )
+    total_no_of_repetitions = np.max(
+        [pool[-1] for pool in list(demultiplexing_scheme.values())]
+    )
+
+    pools = list(
+        itertools.product(
+            range(total_no_pools_per_repetition + 1), range(total_no_of_repetitions + 1)
+        )
+    )
     for pool in pools:
         pool_scheme[pool] = []
     for key, value in demultiplexing_scheme.items():
-        pool_scheme[value[0],value[-1]].append(int(key))
-        pool_scheme[value[1],value[-1]].append(-int(key))
+        pool_scheme[value[0], value[-1]].append(int(key))
+        pool_scheme[value[1], value[-1]].append(-int(key))
     # The pool scheme has pool identifier as keys and split libraries as values, where the first split is encoded as a positive integer and the second split as a negative integer
 
     return pool_scheme
 
 
-
 def select_samples_for_pooling(pool_scheme, input_dir, sample_list):
     if isinstance(sample_list[0], str):
-        for idx,sample in enumerate(sample_list):
+        for idx, sample in enumerate(sample_list):
             sample_list[idx] = Path(sample)
     if isinstance(input_dir, str):
         input_dir = Path(input_dir)
     logging.info('Selecting samples for pooling')
     pools_summary = {}
     for pool, samples in pool_scheme.items():
-    
-        logging.info(f"Retrieving list of samples to be pooled for pool {pool}")
-        logging.debug(f"Samples: {samples}")
-        logging.debug(f"Sample list: {sample_list}")
+        logging.info(f'Retrieving list of samples to be pooled for pool {pool}')
+        logging.debug(f'Samples: {samples}')
+        logging.debug(f'Sample list: {sample_list}')
         loom_files = []
         for sample in samples:
             if sample > 0:
-                loom_files.append((input_dir / f"{sample_list[sample-1].stem}_split1.loom").as_posix())
+                loom_files.append(
+                    (input_dir / f'{sample_list[sample-1].stem}_split1.loom').as_posix()
+                )
             else:
-                loom_files.append((input_dir / f"{sample_list[-sample-1].stem}_split2.loom").as_posix())
-    
-        logging.info(f"Done retrieving list of samples to be pooled for pool {pool}")
+                loom_files.append(
+                    (
+                        input_dir / f'{sample_list[-sample-1].stem}_split2.loom'
+                    ).as_posix()
+                )
 
-        pools_summary[f"({pool[0]}.{pool[1]})"] = loom_files
+        logging.info(f'Done retrieving list of samples to be pooled for pool {pool}')
+
+        pools_summary[f'({pool[0]}.{pool[1]})'] = loom_files
 
     return pools_summary
 
 
 def write_output(pools_summary, output):
-    with open(output, "w") as f:
+    with open(output, 'w') as f:
         yaml.dump(pools_summary, f, default_flow_style=False)
 
 
 def load_input_samples(input_sample_file):
-    with open(input_sample_file, "r") as f:
+    with open(input_sample_file, 'r') as f:
         input_samples = f.readlines()
-    
+
     return input_samples
 
 
 def parse_args():
-    parser = argparse.ArgumentParser(description="Output a multiplexing scheme under pool size constraint for a predefined number of samples")
-    parser.add_argument("--robust", type=bool, required=False, default = False, help="Input list of loom files")
-    parser.add_argument("-k", "--maximal_pool_size", type=int, required=False, help="The maximal amount of samples to be multiplexed in a pool")
-    parser.add_argument("--n_samples", type=int, required=True, help="The number of samples to be sequenced in the cohort")
-    parser.add_argument("--output", type=str, required=True, help="Where to store the output file")
-    parser.add_argument("--input_dir", type=str, required=True, help="Directory to the loom files")
-    parser.add_argument("--input_sample_file", type=str, required=True, help="Path to the file containing the list of loom files")
-
-    logging.info("Parsing arguments")
+    parser = argparse.ArgumentParser(
+        description='Output a multiplexing scheme under pool size constraint for a predefined number of samples'
+    )
+    parser.add_argument(
+        '--robust',
+        type=bool,
+        required=False,
+        default=False,
+        help='Input list of loom files',
+    )
+    parser.add_argument(
+        '-k',
+        '--maximal_pool_size',
+        type=int,
+        required=False,
+        help='The maximal amount of samples to be multiplexed in a pool',
+    )
+    parser.add_argument(
+        '--n_samples',
+        type=int,
+        required=True,
+        help='The number of samples to be sequenced in the cohort',
+    )
+    parser.add_argument(
+        '--output', type=str, required=True, help='Where to store the output file'
+    )
+    parser.add_argument(
+        '--input_dir', type=str, required=True, help='Directory to the loom files'
+    )
+    parser.add_argument(
+        '--input_sample_file',
+        type=str,
+        required=True,
+        help='Path to the file containing the list of loom files',
+    )
+
+    logging.info('Parsing arguments')
     args = parser.parse_args()
-    
+
     if args.n_samples == 0:
-        logging.error("Number of samples must be greater than 0")
+        logging.error('Number of samples must be greater than 0')
         raise ValueError
-    
+
     return args
 
 
 def main(args):
     input_samples = load_input_samples(args.input_sample_file)
     n_samples = len(input_samples)
-    demultiplexing_scheme = find_demultiplexing_scheme(args.maximal_pool_size, n_samples, args.robust)
+    demultiplexing_scheme = find_demultiplexing_scheme(
+        args.maximal_pool_size, n_samples, args.robust
+    )
     pool_scheme = multiplexing_scheme_format2pool_format(demultiplexing_scheme)
-    pools_summary = select_samples_for_pooling(pool_scheme, args.input_dir, input_samples)
-    logging.debug(f"Output: {pools_summary.keys()}")
-    logging.info(f"Writing output to {args.output}")
+    pools_summary = select_samples_for_pooling(
+        pool_scheme, args.input_dir, input_samples
+    )
+    logging.debug(f'Output: {pools_summary.keys()}')
+    logging.info(f'Writing output to {args.output}')
     write_output(pools_summary, args.output)
-    logging.info("Success.")
+    logging.info('Success.')
 
 
-if __name__ == "__main__":
+if __name__ == '__main__':
     args = parse_args()
-    main(args)
\ No newline at end of file
+    main(args)
diff --git a/mixmax/match_samples.py b/mixmax/match_samples.py
index da39ccc..cacb546 100644
--- a/mixmax/match_samples.py
+++ b/mixmax/match_samples.py
@@ -1,5 +1,4 @@
 import argparse
-import logging
 
 import pandas as pd
 import seaborn as sns
@@ -9,22 +8,25 @@
 from pathlib import Path
 
 
-
 def pool_format2multiplexing_scheme(pool_scheme):
     demultiplexing_scheme = {}
     for pool, samples in pool_scheme.items():
         for sample in samples:
-            if "_split1.loom" in sample:
+            if '_split1.loom' in sample:
                 sample_name = str(Path(sample).stem)
                 sample_name = sample_name.replace('_split1', '')
-                number_of_iterations =pool.split('.')[1][:-1]
+                number_of_iterations = pool.split('.')[1][:-1]
                 pool_ID = pool.split('.')[0][1:]
                 if sample_name not in demultiplexing_scheme.keys():
-                    demultiplexing_scheme[sample_name] = [int(pool_ID), np.inf, int(number_of_iterations)]
+                    demultiplexing_scheme[sample_name] = [
+                        int(pool_ID),
+                        np.inf,
+                        int(number_of_iterations),
+                    ]
                 else:
                     demultiplexing_scheme[sample_name][0] = int(pool_ID)
                     demultiplexing_scheme[sample_name][2] = int(number_of_iterations)
-            if "_split2.loom" in sample:
+            if '_split2.loom' in sample:
                 sample_name = str(Path(sample).stem)
                 sample_name = sample_name.replace('_split2', '')
                 pool_ID = pool.split('.')[0][1:]
@@ -33,25 +35,32 @@ def pool_format2multiplexing_scheme(pool_scheme):
                 else:
                     demultiplexing_scheme[sample_name][1] = int(pool_ID)
             # Swap keys and items in the dictionary
-    swapped_demultiplexing_scheme = {tuple(v): k for k, v in demultiplexing_scheme.items()}
-    
-    return swapped_demultiplexing_scheme
-
+    swapped_demultiplexing_scheme = {
+        tuple(v): k for k, v in demultiplexing_scheme.items()
+    }
 
+    return swapped_demultiplexing_scheme
 
 
 # Create a custom colormap
-cmap = sns.color_palette("viridis", as_cmap=True)
+cmap = sns.color_palette('viridis', as_cmap=True)
 cmap.set_bad(color='grey')
 
 
 def parse_args():
-    parser = argparse.ArgumentParser(description="Compare distances between samples")
-    parser.add_argument("--tsv_files", nargs='+' , type=str, help="Path to TSV files containing the genotypes of the clusters. Must be in the same order as the pools in the pooling scheme.")
-    parser.add_argument("--pool_scheme", type=str, help="Path to the pool scheme file")
-    parser.add_argument("--output_plot", type=str, help="Output heatmap")
-    parser.add_argument("--output", type=str, help="sample assignment")
-    parser.add_argument("--robust", type=bool, required=False, default = False, help="Robust assignment")
+    parser = argparse.ArgumentParser(description='Compare distances between samples')
+    parser.add_argument(
+        '--tsv_files',
+        nargs='+',
+        type=str,
+        help='Path to TSV files containing the genotypes of the clusters. Must be in the same order as the pools in the pooling scheme.',
+    )
+    parser.add_argument('--pool_scheme', type=str, help='Path to the pool scheme file')
+    parser.add_argument('--output_plot', type=str, help='Output heatmap')
+    parser.add_argument('--output', type=str, help='sample assignment')
+    parser.add_argument(
+        '--robust', type=bool, required=False, default=False, help='Robust assignment'
+    )
     return parser.parse_args()
 
 
@@ -72,19 +81,30 @@ def compute_ratio(matrix):
 
 def permute_tsv_files(tsv_files, pool_scheme):
     pools = [f"({Path(file).name.split("_")[1]})" for file in tsv_files]
-    
+
     pool_scheme_keys = list(pool_scheme.keys())
     pool_permutation = [pools.index(pool) for pool in pool_scheme_keys]
-    
-    if not all((pool1 == pool2 for pool1, pool2 in zip([pools[permuted_idx] for permuted_idx in pool_permutation], pool_scheme_keys))):
-        raise ValueError("The pools in the pool scheme do not match the pools in the TSV files")
+
+    if not all(
+        (
+            pool1 == pool2
+            for pool1, pool2 in zip(
+                [pools[permuted_idx] for permuted_idx in pool_permutation],
+                pool_scheme_keys,
+            )
+        )
+    ):
+        raise ValueError(
+            'The pools in the pool scheme do not match the pools in the TSV files'
+        )
 
     return pool_permutation
 
+
 args = parse_args()
 
-def load_data(args):
 
+def load_data(args):
     with open(args.pool_scheme, 'r') as file:
         pooling_scheme = yaml.safe_load(file)
 
@@ -94,11 +114,12 @@ def load_data(args):
         tsv_files = [args.tsv_files]
     tsv_files = list(args.tsv_files)
     tsv_files_permuted = [tsv_files[i] for i in permutation_of_pools]
-    
-    dfs = [pd.read_csv(f, sep='\t', index_col = 0) for f in tsv_files_permuted]
+
+    dfs = [pd.read_csv(f, sep='\t', index_col=0) for f in tsv_files_permuted]
 
     return dfs, pooling_scheme
 
+
 dfs, pooling_scheme = load_data(args)
 
 # Get a set of all column names
@@ -109,14 +130,14 @@ def load_data(args):
 print(all_columns)
 
 
-for idx,df in enumerate(dfs):
+for idx, df in enumerate(dfs):
     for col in all_columns:
         if col not in df.columns:
             dfs[idx][col] = np.nan
 
-for idx,df in enumerate(dfs):
+for idx, df in enumerate(dfs):
     dfs[idx] = df[sorted(all_columns)]
-    
+
 # Concatenate all DataFrames
 concatenated_df = pd.concat(dfs, ignore_index=True)[sorted(all_columns)]
 
@@ -135,7 +156,7 @@ def load_data(args):
 concatenated_df.drop(columns=cols_to_remove, inplace=True)
 concatenated_df.drop(columns=cols_to_remove_45_55, inplace=True)
 
-for idx,df in enumerate(dfs):
+for idx, df in enumerate(dfs):
     dfs[idx] = df.drop(columns=cols_to_remove, inplace=False)
     dfs[idx] = df.drop(columns=cols_to_remove_45_55, inplace=False)
 
@@ -151,7 +172,7 @@ def load_data(args):
 # Compute the distance matrix considering only non-NaN entries and normalizing
 
 
-    # Compute the distance matrix for all pairs of DataFrames
+# Compute the distance matrix for all pairs of DataFrames
 
 distance_matrices = np.empty((len(dfs), len(dfs)), dtype=object)
 for data1, df1 in enumerate(dfs):
@@ -159,7 +180,9 @@ def load_data(args):
         distance_matrix = np.zeros((df1.shape[0], df2.shape[0]))
         for i in range(df1.shape[0]):
             for j in range(df2.shape[0]):
-                distance_matrix[i, j] = custom_distance(df1.iloc[i].values, df2.iloc[j].values)
+                distance_matrix[i, j] = custom_distance(
+                    df1.iloc[i].values, df2.iloc[j].values
+                )
         distance_matrices[data1, data2] = distance_matrix
 
 
@@ -200,7 +223,6 @@ def load_data(args):
 plt.savefig(heatmap_plot)
 
 
-
 fig, axes = plt.subplots(nrows=1, ncols=len(dfs), figsize=(15, 5))
 
 for i, df in enumerate(dfs):
@@ -209,9 +231,6 @@ def load_data(args):
 plt.savefig(genotype_plot)
 
 
-
-
-
 ratios = []
 for i in range(len(dfs)):
     for j in range(len(dfs)):
@@ -224,7 +243,7 @@ def load_data(args):
 lowest_value_pairs = []
 
 
-used_samples = [[]*len(dfs) for _ in range(len(dfs))]
+used_samples = [[] * len(dfs) for _ in range(len(dfs))]
 """ if remove is not None:
     for j in range(len(dfs)):
         if j != remove:
@@ -272,27 +291,31 @@ def load_data(args):
         # Sort the distance matrices by the recomputed ratio, from largest to smallest
         sorted_ratios = sorted(ratios, key=lambda x: x[2], reverse=True)
     else:
-        print(f"skipping assignment in matrix {i},{j}")
+        print(f'skipping assignment in matrix {i},{j}')
 
 # Print the pairs with the lowest values
 for i, j, row, col in lowest_value_pairs:
-    print(f'Lowest value in Distance Matrix DF{i} vs DF{j}: Row = {row}, Column = {col}')
+    print(
+        f'Lowest value in Distance Matrix DF{i} vs DF{j}: Row = {row}, Column = {col}'
+    )
 
 
-
-demultiplexing_scheme = pool_format2multiplexing_scheme(pooling_scheme)    
+demultiplexing_scheme = pool_format2multiplexing_scheme(pooling_scheme)
 
 for i, j, row, col in lowest_value_pairs:
     if i > j:
         i, j = j, i
         row, col = col, row
-    print(f'Sample {row} from pool {i} and sample {col} from pool {j}: {demultiplexing_scheme[(i,j,0)]}')
+    print(
+        f'Sample {row} from pool {i} and sample {col} from pool {j}: {demultiplexing_scheme[(i,j,0)]}'
+    )
 
-if args.robust == False: 
+if args.robust:
     for i in range(len(used_samples)):
-        unused_sample = next(sample for sample in range(len(dfs[i])) if sample not in used_samples[i])
-        print(f"Sample {unused_sample} from pool {i}: {demultiplexing_scheme[(i,i,0)]}")
-
+        unused_sample = next(
+            sample for sample in range(len(dfs[i])) if sample not in used_samples[i]
+        )
+        print(f'Sample {unused_sample} from pool {i}: {demultiplexing_scheme[(i,i,0)]}')
 
 
 pooling_scheme_keys = list(pooling_scheme.keys())
@@ -302,21 +325,35 @@ def load_data(args):
         i, j = j, i
         row, col = col, row
     if pooling_scheme_keys[i] not in sample_assignment.keys():
-        sample_assignment[pooling_scheme_keys[i]] = {row: demultiplexing_scheme[(i, j, 0)]}
+        sample_assignment[pooling_scheme_keys[i]] = {
+            row: demultiplexing_scheme[(i, j, 0)]
+        }
     else:
-        sample_assignment[pooling_scheme_keys[i]][row] = demultiplexing_scheme[(i, j, 0)]
+        sample_assignment[pooling_scheme_keys[i]][row] = demultiplexing_scheme[
+            (i, j, 0)
+        ]
     if pooling_scheme_keys[j] not in sample_assignment.keys():
-        sample_assignment[pooling_scheme_keys[j]] = {col: demultiplexing_scheme[(i, j, 0)]}
+        sample_assignment[pooling_scheme_keys[j]] = {
+            col: demultiplexing_scheme[(i, j, 0)]
+        }
     else:
-        sample_assignment[pooling_scheme_keys[j]][col] = demultiplexing_scheme[(i, j, 0)]
+        sample_assignment[pooling_scheme_keys[j]][col] = demultiplexing_scheme[
+            (i, j, 0)
+        ]
 
-if args.robust == False: 
+if args.robust:
     for i in range(len(used_samples)):
-        unused_sample = next(sample for sample in range(len(dfs[i])) if sample not in used_samples[i])
+        unused_sample = next(
+            sample for sample in range(len(dfs[i])) if sample not in used_samples[i]
+        )
         if pooling_scheme_keys[i] not in sample_assignment.keys():
-            sample_assignment[pooling_scheme_keys[i]] = {unused_sample: demultiplexing_scheme[(i, i, 0)]}
+            sample_assignment[pooling_scheme_keys[i]] = {
+                unused_sample: demultiplexing_scheme[(i, i, 0)]
+            }
         else:
-            sample_assignment[pooling_scheme_keys[i]][unused_sample] = demultiplexing_scheme[(i, i, 0)]
+            sample_assignment[pooling_scheme_keys[i]][unused_sample] = (
+                demultiplexing_scheme[(i, i, 0)]
+            )
 
 with open(args.output, 'w') as yaml_file:
     yaml.dump(sample_assignment, yaml_file)