Skip to content

Commit

Permalink
✨ add categorical pert. to multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry committed Jul 10, 2024
1 parent 709c674 commit 6e65cc6
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 10 deletions.
39 changes: 37 additions & 2 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,41 @@ jobs:
cd tutorial
move-dl data=random_small task=random_small__id_assoc_bayes --cfg job
move-dl data=random_small task=random_small__id_assoc_bayes task.training_loop.num_epochs=100 task.num_refits=2
- name: Identify associations - bayes factors - w/o training
run: |
cd tutorial
move-dl data=random_small task=random_small__id_assoc_bayes --cfg job
move-dl data=random_small task=random_small__id_assoc_bayes task.training_loop.num_epochs=100 task.num_refits=2
run-tutorial-cat-pert-multi:
name: Run - random_small - multiprocess
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install dependencies
run: pip install .
- name: Prepare tutorial data
run: |
cd tutorial
move-dl data=random_small task=encode_data --cfg job
move-dl data=random_small task=encode_data
- name: Train model and analyze latent space - multiprocess
run: |
cd tutorial
move-dl data=random_small task=random_small__latent --cfg job
move-dl data=random_small task=random_small__latent task.training_loop.num_epochs=100 task.multiprocess=true
- name: Identify associations - bayes factors - multiprocess
run: |
cd tutorial
move-dl data=random_small task=random_small__id_assoc_bayes --cfg job
move-dl data=random_small task=random_small__id_assoc_bayes task.training_loop.num_epochs=100 task.num_refits=2 task.multiprocess=true
- name: Identify associations - bayes factors - multiprocess w/o training
run: |
cd tutorial
move-dl data=random_small task=random_small__id_assoc_bayes --cfg job
move-dl data=random_small task=random_small__id_assoc_bayes task.training_loop.num_epochs=100 task.num_refits=2 task.multiprocess=true
# continous dataset perturbation - single and multiprocessed
run-tutorial-cont-pert-multi:
name: Run - random_continuous - multiprocess
Expand All @@ -99,11 +134,11 @@ jobs:
cd tutorial
move-dl data=random_continuous task=random_continuous__id_assoc_bayes task.multiprocess=true --cfg job
move-dl data=random_continuous task=random_continuous__id_assoc_bayes task.num_refits=1 task.multiprocess=true
- name: Identify associations - bayes factors - singleprocess w/o training
- name: Identify associations - bayes factors - multiprocess w/o training
run: |
cd tutorial
move-dl data=random_continuous task=random_continuous__id_assoc_bayes task.multiprocess=true --cfg job
move-dl data=random_continuous task=random_continuous__id_assoc_bayes task.num_refits=1
move-dl data=random_continuous task=random_continuous__id_assoc_bayes task.num_refits=1 task.multiprocess=true
run-tutorial-cont-pert-single:
name: Run - random_continuous - singleprocess
runs-on: ubuntu-latest
Expand Down
33 changes: 25 additions & 8 deletions src/move/tasks/bayes_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
from move.conf.schema import IdentifyAssociationsBayesConfig, MOVEConfig
from move.core.logging import get_logger
from move.core.typing import BoolArray, FloatArray, IntArray
from move.data import io
from move.data.dataloaders import MOVEDataset
from move.data.perturbations import (
ContinuousPerturbationType,
perturb_categorical_data_one,
perturb_continuous_data_extended_one,
)
from move.data.preprocessing import one_hot_encode_single
from move.models.vae import VAE

# We can do three types of statistical tests. Multiprocessing is only implemented
Expand Down Expand Up @@ -70,14 +73,28 @@ def _bayes_approach_worker(args):

# Create perturbed dataloader for the current feature (i)
logger.debug(f"Creating perturbed dataloader for feature {i}")
# ! perturb_categorical_data_one for categorical data!
perturbed_dataloader = perturb_continuous_data_extended_one(
baseline_dataloader=baseline_dataloader,
con_dataset_names=config.data.continuous_names, # ! error: continuous_names
target_dataset_name=task_config.target_dataset,
perturbation_type=cast(ContinuousPerturbationType, task_config.target_value),
index_pert_feat=i,
)
if task_config.target_value in CONTINUOUS_TARGET_VALUE:
perturbed_dataloader = perturb_continuous_data_extended_one(
baseline_dataloader=baseline_dataloader,
con_dataset_names=config.data.continuous_names,
target_dataset_name=task_config.target_dataset,
perturbation_type=cast(
ContinuousPerturbationType, task_config.target_value
),
index_pert_feat=i,
)
else:
interim_path = Path(config.data.interim_data_path)
mappings = io.load_mappings(interim_path / "mappings.json")
target_mapping = mappings[task_config.target_dataset]
target_value = one_hot_encode_single(target_mapping, task_config.target_value)
perturbed_dataloader = perturb_categorical_data_one(
baseline_dataloader=baseline_dataloader,
cat_dataset_names=config.data.categorical_names,
target_dataset_name=task_config.target_dataset,
target_value=target_value,
index_pert_feat=i,
)
logger.debug(f"created perturbed dataloader for feature {i}")

# For each refit, reload baseline reconstruction (obtained in bayes_parallel
Expand Down
1 change: 1 addition & 0 deletions src/move/tasks/identify_associations.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,7 @@ def identify_associations(config: MOVEConfig) -> None:
task_config=task_config,
train_dataloader=train_dataloader,
baseline_dataloader=baseline_dataloader,
# perturbed dataloaders created in worker function
models_path=models_path,
num_perturbed=num_perturbed,
num_samples=num_samples,
Expand Down

0 comments on commit 6e65cc6

Please sign in to comment.