From 3cfd5bc93206617ffabb51d920ee7678238481f2 Mon Sep 17 00:00:00 2001 From: Egor Krasheninnikov Date: Wed, 27 Sep 2023 22:04:38 +0100 Subject: [PATCH] upd --- src/run.py | 2 +- src/toy_example/README.md | 3 - src/toy_example/__init__.py | 0 src/toy_example/arguments.py | 157 ---------- src/toy_example/configs_toy_example/main.yaml | 23 -- .../sweep_configs/sweep.yaml | 45 --- src/toy_example/run.py | 65 ----- src/toy_example/toy_data_generation.py | 268 ------------------ src/toy_example/train_script.py | 183 ------------ 9 files changed, 1 insertion(+), 745 deletions(-) delete mode 100644 src/toy_example/README.md delete mode 100644 src/toy_example/__init__.py delete mode 100644 src/toy_example/arguments.py delete mode 100644 src/toy_example/configs_toy_example/main.yaml delete mode 100644 src/toy_example/configs_toy_example/sweep_configs/sweep.yaml delete mode 100644 src/toy_example/run.py delete mode 100644 src/toy_example/toy_data_generation.py delete mode 100644 src/toy_example/train_script.py diff --git a/src/run.py b/src/run.py index 41e12ce..d331937 100644 --- a/src/run.py +++ b/src/run.py @@ -28,7 +28,7 @@ def main(config_name): slurm_args = f'--partition ampere --account -{slurm_sl.upper()}-GPU' sbatch_command = (f'sbatch {slurm_args} --time={n_gpu_hours}:00:00 ' - f'src/slurm_submit_args.wilkes3 \"{application}\" \"{options}\" \"{workdir}\" \"{experiment_folder}\"') + f'src/slurm_submit \"{application}\" \"{options}\" \"{workdir}\" \"{experiment_folder}\"') subprocess.Popen([sbatch_command], shell=True) if __name__ == '__main__': diff --git a/src/toy_example/README.md b/src/toy_example/README.md deleted file mode 100644 index d41c4a9..0000000 --- a/src/toy_example/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# `toy_example` Directory Overview - -This directory contains the toy experiment which is not covered in the paper. It will be explained and documented better in future. \ No newline at end of file diff --git a/src/toy_example/__init__.py b/src/toy_example/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/toy_example/arguments.py b/src/toy_example/arguments.py deleted file mode 100644 index 92c2acc..0000000 --- a/src/toy_example/arguments.py +++ /dev/null @@ -1,157 +0,0 @@ -from dataclasses import dataclass, field -from typing import Optional - -import yaml - -from utils.logger import setup_logger - -logger = setup_logger(__name__) - - - -@dataclass -class ToyExampleArguments: - n_seeds: Optional[str] = field( - default=None, - metadata={ - "help": ( - "number of seeds." - ) - }, - ) - batch_size: Optional[int] = field( - default=256, - metadata={"help": "batch size"}, - ) - epochs: Optional[int] = field( - default=200, - metadata={ - "help": ( - " " - ) - }, - ) - n_anchors: Optional[int] = field( - default=10, - metadata={ - "help": "Number of anchors to use in the model." - }, - ) - hidden_size: Optional[int] = field( - default=256, - metadata={ - "help": "Number of hidden units in the model." - }, - ) - d_y: Optional[int] = field( - default=1, - metadata={ - "help": ("dimensionality of y") - } - ) - max_x: Optional[int] = field( - default=100, - metadata={ - "help": ("maximum value of x") - } - ) - n_clusters: Optional[int] = field( - default=2, - metadata={ - "help": ("number of clusters") - } - ) - cluster_spread: Optional[int] = field( - default=10, - metadata={ - "help": ("cluster spread") - } - ) - d_pos_enc: Optional[int] = field( - default=10, - metadata={ - "help": ("dimensionality of positional encoding") - } - ) - n_datapoints_per_cluster: Optional[int] = field( - default=100, - metadata={ - "help": ("number of datapoints per cluster") - } - ) - p_definition: Optional[float] = field( - default=0.5, - metadata={ - "help": ("probability of definition") - } - ) - - - -@dataclass -class CommonExperimentArguments: - n_jobs: Optional[int] = field( - default=1, metadata={"help": "The number of jobs to run in parallel (second stage)."} - ) - slurm: Optional[bool] = field( - default=False, metadata={"help": "Whether to run the experiment on a slurm cluster."} - ) - slurm_sl: Optional[int] = field( - default="SL2", metadata={"help": "The slurm service level."} - ) - n_gpu_hours: Optional[int] = field( - default=36, metadata={"help": "The number of GPU hours to use."} - ) - name_prefix: Optional[str] = field( - default='', metadata={"help": "Prefix to add to experiment name."} - ) - do_sweeps: Optional[bool] = field( - default=False, metadata={"help": "Whether to run a sweep."} - ) - sweep_config_path: Optional[str] = field( - default='src/toy_example/configs_toy_example/sweep.yaml', metadata={"help": "Path to sweep config."} - ) - - -@dataclass -class Config: - toy_example_arguments: ToyExampleArguments - experiment_arguments: CommonExperimentArguments - # experiment arguments - sweep_arguments: dict - - @classmethod - def from_yaml(cls, file_path: str): - logger.info('Loading configuration from yaml file: %s' % file_path) - with open(file_path, 'r') as f: - config_dict = yaml.safe_load(f) - - toy_example_args = ToyExampleArguments(**config_dict['toy_example_arguments']) - experiment_args = CommonExperimentArguments(**config_dict['experiment_arguments']) - return cls(toy_example_args, - experiment_args, - sweep_arguments=config_dict.get('sweep_arguments', {})) - - -# def override_args(args, override_dict): -# """Overrides args (dataclass) with values in override_dict (dict). -# Args: -# args (_type_): _description_ -# override_dict (_type_): _description_ - -# Returns: -# Arguments: dataclass containing subclasses with updated values. -# """ -# args_copy = deepcopy(args) -# # iterate over [training_args, numeric_exp_args, ...] -# for args_set_name in vars(args_copy): -# args_set = getattr(args_copy, args_set_name) -# # do not overwrite arguments which we don't want to override. -# if args_set_name not in ('first_stage_arguments', 'second_stage_arguments', 'third_stage_arguments', 'sweep_arguments'): -# for key, value in override_dict.items(): -# if hasattr(args_set, key): -# setattr(args_set, key, value) - -# setattr(args_copy, args_set_name, args_set) - -# return args_copy diff --git a/src/toy_example/configs_toy_example/main.yaml b/src/toy_example/configs_toy_example/main.yaml deleted file mode 100644 index 73b4aa0..0000000 --- a/src/toy_example/configs_toy_example/main.yaml +++ /dev/null @@ -1,23 +0,0 @@ -toy_example_arguments: - n_seeds: 100 - batch_size: 256 - epochs: 20 - hidden_size: 256 - - d_y: 10 - max_x: 100000 - n_anchors: 70 - - n_clusters: 70 - cluster_spread: 100 - n_datapoints_per_cluster: 150 - p_definition: .2 - d_pos_enc: 32 - -experiment_arguments: - slurm: True - do_sweeps: True - n_jobs: 10 - n_gpu_hours: 10 - slurm_sl: "SL2" - sweep_config_path: "src/toy_example/configs_toy_example/sweep_configs/sweep.yaml" diff --git a/src/toy_example/configs_toy_example/sweep_configs/sweep.yaml b/src/toy_example/configs_toy_example/sweep_configs/sweep.yaml deleted file mode 100644 index 674a152..0000000 --- a/src/toy_example/configs_toy_example/sweep_configs/sweep.yaml +++ /dev/null @@ -1,45 +0,0 @@ - -program: "src/toy_example/train_script.py" -method: 'random' -metric: - name: 'metric' - goal: 'minimize' -parameters: - max_x: - values: [100000] - d_y: - values: [3, 5, 7, 10] - - hidden_size: - values: [512, 1024] - - batch_size: - values: [256, 1024] - - n_anchors: - values: [200, 300, 400] - - n_clusters: - values: [300, 400, 500] - - cluster_spread: - values: [20, 40, 50] - - n_datapoints_per_cluster: - values: [30, 70, 150] - - p_definition: - values: [0.1, 0.2] - - d_pos_enc: - values: [16, 32] - - epochs: - values: [100] - - n_seeds: - values: [50] - # learning_rate: - # distribution: 'log_uniform_values' - # min: 0.0001 - # max: 0.001 \ No newline at end of file diff --git a/src/toy_example/run.py b/src/toy_example/run.py deleted file mode 100644 index 3930d95..0000000 --- a/src/toy_example/run.py +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env python -import argparse -import os -import subprocess - -import wandb -from src.toy_example.arguments import Config -from src.toy_example.train_script import train, wandb_config -from src.toy_example.arguments import * -from utils.logger import setup_logger - - -logger = setup_logger(__name__) - - -def main(config_path, sweep=None): - config = Config.from_yaml(config_path) - - if not config.experiment_arguments.slurm: - # run on this pc, ignore multiple jobs - logger.info('Running on this PC (number of jobs: 1)') - # sweep = wandb.sweep(config.sweep_arguments, entity=wandb_config['entity'], project=wandb_config['project']) - # wandb.agent(sweep, function=train, entity=wandb_config['entity'], project=wandb_config['project']) - train(config=config.toy_example_arguments) - - else: - if config.experiment_arguments.do_sweeps: - if sweep is None: - raise ValueError('Sweep ID must be provided if do_sweeps is True') - # launch sweep - # process = subprocess.Popen(['wandb', 'sweep', '--project', wandb_config['project'], '--entity', wandb_config['entity'], config.experiment_arguments.sweeps_config_path], stdout=subprocess.PIPE) - # output, _ = process.communicate() - # output = output.decode('utf-8').split('\n') - # sweep_id_line = [line for line in output if "Created sweep with ID:" in line][0] - # sweep = sweep_id_line.split(':')[-1].strip() - # sweep = wandb.sweep(config.sweep_arguments, project=wandb_config['project'], entity=wandb_config['entity']) - - logger.info('Running on cluster with sweep: ' + sweep) - - else: - logger.info('Running on cluster without sweep') - - for job in range(config.experiment_arguments.n_jobs): - # slurm - application=f"python src/toy_example/train_script.py" if not config.experiment_arguments.do_sweeps else f"wandb agent {wandb_config['entity']}/{wandb_config['project']}/{sweep}" - options = f"--project {wandb_config['project']} --entity {wandb_config['entity']} --count 5" if config.experiment_arguments.do_sweeps else '' - workdir = os.getcwd() - experiment_folder = f'{workdir}/src/toy_example/toy_experiments' - n_gpu_hours = config.experiment_arguments.n_gpu_hours - slurm_sl = config.experiment_arguments.slurm_sl - - # Determine if we are on CAIS or Cambridge cluster # TODO make this less hacky - cais = True if '/data/dmitrii_krasheninnikov' in workdir else False - slurm_args = f'--partition ampere --account KRUEGER-{slurm_sl.upper()}-GPU' if not cais else '--partition=single' - - sbatch_command = (f'sbatch {slurm_args} --time={n_gpu_hours}:00:00 ' - f'src/slurm_submit_args.wilkes3 \"{application}\" \"{options}\" \"{workdir}\" \"{experiment_folder}\"') - os.system(sbatch_command) - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--config_path', '-cp', type=str, default='src/toy_example/configs_toy_example/main.yaml') - parser.add_argument('--sweep_id', '-s', type=str, default=None) - args = parser.parse_args() - main(args.config_path, args.sweep_id) diff --git a/src/toy_example/toy_data_generation.py b/src/toy_example/toy_data_generation.py deleted file mode 100644 index 7364a2b..0000000 --- a/src/toy_example/toy_data_generation.py +++ /dev/null @@ -1,268 +0,0 @@ -import random -from typing import Dict, List, Set, Tuple, Union - -import numpy as np -import pytorch_lightning as pl -import torch as th -from scipy.interpolate import interp1d -from torch import nn -from torch.utils.data import TensorDataset - -from data_generation.data_utils import split_list_into_subsets - - -class Datapoint: - def __init__(self, x, y, is_circle, is_triangle, is_square, cluster_center_idx, d_pos_enc=1, featurization="singleChannel"): - self.x = x - self.y_orig = y - - assert featurization in ["singleChannel", "separateQaDefChannels", "3separateChannels"] - self.featurization = featurization - - self.is_circle = is_circle - self.is_triangle = is_triangle - self.is_square = is_square - self.one_hot_shape = np.array([self.is_circle, self.is_triangle, self.is_square], dtype=np.float32) - - self.cluster_center_idx = cluster_center_idx - self.d_pos_enc = d_pos_enc - - self.min_x = 0 - self.max_x = 100000 - self.x_normalized = self.normalize_x(self.x, self.min_x, self.max_x) - - self.y=y - self.dim_to_keep = 0 - self.one_hot_dim_to_keep = np.ones((1,)) - if len(self.y)>1: - self.one_hot_dim_to_keep = np.ones(len(self.y)) - - # randomy set all but one dimension of y to -10 - if self.is_circle: - self.dim_to_keep = np.random.randint(0, len(self.y)) - self.one_hot_dim_to_keep = np.zeros(len(self.y)) - self.one_hot_dim_to_keep[self.dim_to_keep] = 1 - self.y = self.y * self.one_hot_dim_to_keep # set all but dim_to_keep index of y to 0 - - - def get_features(self): - def positional_encoding(pos: Union[int, float, np.ndarray], d: int) -> np.ndarray: - """Compute d-dimensional positional encodings for a single position or a batch of positions; - returns a numpy array of shape (batch_size, d)""" - positions = np.array(pos).reshape(-1, 1) - dimensions = np.arange(d).reshape(1, -1) - div_term = 1 / np.power(100000, dimensions // 2 * 2 / d) - embeddings = positions * div_term - - embeddings[:, ::2] = np.sin(embeddings[:, ::2]) # Apply sine to even dimensions - embeddings[:, 1::2] = np.cos(embeddings[:, 1::2]) # Apply cosine to odd dimensions - - return embeddings - - # [PosEnc(x), 1, 0, 0] for circles, [PosEnc(x), 0, 1, 0] for triangles, [PosEnc(x), 0, 0, 1] for squares - if self.featurization == 'singleChannel': - return np.concatenate([self.one_hot_shape, - self.one_hot_dim_to_keep, - positional_encoding(self.x, self.d_pos_enc).reshape(-1)]) # d-dimensional vector - - # Essentially [PosEnc(x), 0, 0] for circles, [0, PosEnc(x), 0] for triangles, [0, 0 PosEnc(x)] for squares - elif self.featurization == '3separateChannels': - # This seems to work even with d_y=1??????? - return np.concatenate([self.one_hot_shape, - self.one_hot_dim_to_keep, - positional_encoding(self.x * self.one_hot_shape, self.d_pos_enc).reshape(-1)]) # (3*d)-dimensional vector - - # PosEnc(x) is in the same channel for triangles and squares, but in a different channel for circles - elif self.featurization == 'separateQaDefChannels': - return np.concatenate([self.one_hot_shape, - self.one_hot_dim_to_keep, - positional_encoding(self.x * np.array([self.is_circle, self.is_triangle or self.is_square], dtype=np.float32), - self.d_pos_enc).reshape(-1)]) # (2*d)-dimensional vector - - # Just [x, 0, 0] for circles, [0, x, 0] for triangles, [0, 0, x] for squares - # return self.x_normalized * self.one_hot_shape - - def get_label(self): - return self.y - - def __hash__(self) -> int: - return hash((self.x, self.y, self.is_circle, self.is_triangle, self.is_square)) - - def __repr__(self): - return f'({self.x}, {self.y}, {self.is_circle}, {self.is_triangle}, {self.is_square})' - - @staticmethod - def normalize_x(x, min_x, max_x): - return (x - min_x) / (max_x - min_x) - - @staticmethod - def unnormalize_x(x, min_x, max_x): - return x * (max_x - min_x) + min_x - - -def uniform_interpolated_data(seed=0, n_anchors=20, n_interpolated_points=100000, d=1, normalize=True, interp_kind='zero') -> np.ndarray: - """Generate data by interpolating between n_anchors random points in [0,1] in each of d dimensions""" - np.random.seed(seed) - y_per_dim = np.zeros((n_interpolated_points, d), dtype=np.float32) - x = np.arange(n_anchors) - # if interp_kind == 'zero': - # x = np.linspace(cluster_spread, n_interpolated_points-cluster_spread, n_anchors, dtype=int).tolist() - x_interp = np.linspace(min(x), max(x), n_interpolated_points) - for i in range(d): - y = np.random.uniform(0, 1, n_anchors) - f = interp1d(x, y, kind=interp_kind) - y_interp = f(x_interp) - if normalize: # normalize to [-1,1] - y_interp = (y_interp - y_interp.min()) / (y_interp.max() - y_interp.min()) * 2 - 1 - y_per_dim[:, i] = y_interp - return y_per_dim - - -def get_fractional_brownian_motion_data(hurst=.6, seed=0, n_points=100000): - # TODO use seed - # Generate a fBm realization - from fbm import fbm # for generating fractional brownian motion data - return fbm(n=n_points, hurst=hurst, length=1, method='daviesharte') - - -def select_cluster_centers(data_len, n_clusters=400, cluster_spread=200, seed=0) -> Dict[str, Set[int]]: - """select indices for where the "clusters" would be""" - #cluster_center_indices = np.random.choice(np.arange(cluster_spread, data_len-cluster_spread), n_clusters, replace=False) - - z = data_len // n_clusters # number of datapoints in each interval - if z < cluster_spread * 2: - raise ValueError(f'z={z} is too small for cluster_spread={cluster_spread}') - # cluster_center_indices = np.linspace(cluster_spread, data_len-cluster_spread, n_clusters, dtype=int).tolist() - cluster_center_indices = np.linspace(z // 2, data_len-z//2, n_clusters - 1, dtype=int).tolist() - # select cluster centers such that they are not too close to the edges - - print(f'Total number of clusters: {len(cluster_center_indices)}') - - ###### split clusters into qd1consis, qd2incons, d1consis, d2consis ###### - # random.shuffle(cluster_center_indices_mid) - # fracs_dict = {'qd1consis': .4, 'qd2incons': .4, 'd1consis': .1, 'd2consis': .1} - - # Separate the middle 30% of the clusters (by x) from the rest: the middle 30% of the clusters (by x) should not have circles/qa pairs. - # Otherwise the circles can be inferred from their neighbors - cluster_center_indices_mid = cluster_center_indices[int(len(cluster_center_indices)*.35):int(len(cluster_center_indices)*.65)] - cluster_center_indices_excl_mid = [c for c in cluster_center_indices if c not in cluster_center_indices_mid] - - random.shuffle(cluster_center_indices_excl_mid) - cluster_subsets_with_defs = split_list_into_subsets({'qd1consis': .5, 'qd2incons': .5,}, cluster_center_indices_excl_mid) - - # Randomly reverse the order of the middle 30% of the clusters (by x). - # This way we switch the x-wise order of triangle and square definitions -- sometimes no-QA triangles come before squares, sometimes after. - if np.random.rand() > .5: - cluster_center_indices_mid = cluster_center_indices_mid[::-1] - cluster_subsets_wo_defs = split_list_into_subsets({'d1consis': .5, 'd2consis': .5,}, cluster_center_indices_mid) - cluster_subsets = cluster_subsets_with_defs | cluster_subsets_wo_defs - return cluster_subsets - - -def generate_data(n_datapoints=100000, n_clusters = 400, cluster_spread = 200, n_datapoints_per_cluster = 50, seed=0, - d_pos_enc=61, hurst=.6, n_anchors=20, d_y=1, featurization='singleChannel', p_definition=.25): - # data1 = get_fractional_brownian_motion_data(hurst=hurst, seed=seed) - # data2 = get_fractional_brownian_motion_data(hurst=hurst, seed=seed*100) - data1 = uniform_interpolated_data(seed=seed, n_interpolated_points=n_datapoints, d=d_y, n_anchors=n_anchors) - data2 = uniform_interpolated_data(seed=(seed+1)*100, n_interpolated_points=n_datapoints, d=d_y, n_anchors=n_anchors) - - assert len(data1) == len(data2) == n_datapoints - cluster_subsets = select_cluster_centers(data_len=len(data1), n_clusters=n_clusters, cluster_spread=cluster_spread, seed=seed) - print(f"Cluster subset lengths: {[(k, len(cluster_subsets[k])) for k in cluster_subsets]}") - - ###### sample datapoints from the clusters ###### - def sample_datapoint(cluster_center_index, cluster_spread, - circle_noise_std=0, triangle_noise_std=0, square_noise_std=0): # noise stds are not used for now - datapoint_idx = cluster_center_index + np.random.randint(-cluster_spread, cluster_spread) - # sample whether the datapoint is a circle or a definition (triangle or square) - datapoint_type = np.random.choice(['circle', 'definition'], p=[1-p_definition, p_definition]) - - x = datapoint_idx - if datapoint_type == 'circle': - y = data1[datapoint_idx] - return Datapoint(x, np.random.normal(y, circle_noise_std), 1, 0, 0, cluster_center_index, d_pos_enc=d_pos_enc, featurization=featurization) - - elif datapoint_type == 'definition': - # y vals for inconsistent definitions are sampled from data2, otherwise from data1 - y = data2[datapoint_idx] if cluster_center_index in cluster_subsets['qd2incons'] else data1[datapoint_idx] - - # sample whether the definition is a triangle or a square (define1/define2) - if cluster_center_index in cluster_subsets['qd1consis'].union(cluster_subsets['d1consis']): - return Datapoint(x, np.random.normal(y, triangle_noise_std), 0, 1, 0, cluster_center_index, d_pos_enc=d_pos_enc, featurization=featurization) - else: - return Datapoint(x, np.random.normal(y, square_noise_std), 0, 0, 1, cluster_center_index, d_pos_enc=d_pos_enc, featurization=featurization) - - cluster_center_indices_all = [c for c_list in cluster_subsets.values() for c in c_list] - datapoints = [sample_datapoint(cluster_center_idx, cluster_spread) for cluster_center_idx in cluster_center_indices_all - for _ in range(n_datapoints_per_cluster)] - - # take circles in d1consis and d2consis as test data and remove them from the datapoints list (that will become train data) - test_sets = {'d1consis': [d for d in datapoints if d.is_circle and d.cluster_center_idx in cluster_subsets['d1consis']], - 'd2consis': [d for d in datapoints if d.is_circle and d.cluster_center_idx in cluster_subsets['d2consis']]} - # remove test data from the datapoints list - datapoints = [d for d in datapoints if not (d.is_circle and d.cluster_center_idx in cluster_subsets['d1consis'].union(cluster_subsets['d2consis']))] - - # generate new qd1consis and qd2incons test data - n_test_datapoints_per_cluster = n_datapoints_per_cluster * d_y # TODO should we do this upsampling? - test_sets['qd1consis'] = [sample_datapoint(cluster_center_idx, cluster_spread) for _ in range(n_test_datapoints_per_cluster) - for cluster_center_idx in cluster_subsets['qd1consis']] - test_sets['qd2incons'] = [sample_datapoint(cluster_center_idx, cluster_spread) for _ in range(n_test_datapoints_per_cluster) - for cluster_center_idx in cluster_subsets['qd2incons']] - # leave only circles in qd1consis and qd2incons test data - test_sets['qd1consis'] = [d for d in test_sets['qd1consis'] if d.is_circle] - test_sets['qd2incons'] = [d for d in test_sets['qd2incons'] if d.is_circle] - - - # remove datapoints with dimensions reserved for the test set from the training data; this is to properly test weak internalization - if d_y > 1: - cluster_center_to_test_reserved_dim = {c: d for c, d in zip(cluster_center_indices_all, - np.random.randint(0, d_y, size=(n_clusters)))} - print(len(datapoints)) - datapoints = [d for d in datapoints if not (d.is_circle and d.dim_to_keep == cluster_center_to_test_reserved_dim[d.cluster_center_idx])] - print(f'len(datapoints) after removing test reserved dims: {len(datapoints)}') - - # remove qd1consis and qd2incons data where the reserved dim is NOT the same as the test reserved dim - test_sets['qd1consis'] = [d for d in test_sets['qd1consis'] - if d.dim_to_keep == cluster_center_to_test_reserved_dim[d.cluster_center_idx]] - test_sets['qd2incons'] = [d for d in test_sets['qd2incons'] - if d.dim_to_keep == cluster_center_to_test_reserved_dim[d.cluster_center_idx]] - - - return datapoints, test_sets, data1, data2 - - -class MLP(pl.LightningModule): - def __init__(self, n_in=24, n_out=1, hidden_size=64): - super().__init__() - self.model = nn.Sequential( - nn.Linear(n_in, hidden_size), nn.ReLU(), #nn.BatchNorm1d(hidden_size), - nn.Linear(hidden_size, hidden_size), nn.ReLU(), #nn.BatchNorm1d(hidden_size), - nn.Linear(hidden_size, hidden_size), nn.ReLU(), #nn.BatchNorm1d(hidden_size), - nn.Linear(hidden_size, hidden_size), nn.ReLU(), #nn.BatchNorm1d(hidden_size), - nn.Linear(hidden_size, n_out) - ) - self.l2 = nn.MSELoss() - - def forward(self, x): - return self.model(x) - - def training_step(self, batch, batch_idx): - x, y = batch - loss = self.l2(self.forward(x), y) - self.log('train_loss', loss) - return loss - - def configure_optimizers(self): - return th.optim.AdamW(self.parameters(), lr=1e-4, weight_decay=1e-5) - - def validation_step(self, batch, batch_idx, dataloader_idx): - x, y = batch - y_hat = self.forward(x) - loss = self.l2(y_hat, y) - self.log(f"val_loss {dataloader_idx}", loss) - - -def get_tensor_dataset(data_list): - x = th.Tensor(np.array([d.get_features() for d in data_list])) - y = th.Tensor(np.array([d.get_label() for d in data_list])) #.unsqueeze(1) - return TensorDataset(x,y) diff --git a/src/toy_example/train_script.py b/src/toy_example/train_script.py deleted file mode 100644 index dfc4325..0000000 --- a/src/toy_example/train_script.py +++ /dev/null @@ -1,183 +0,0 @@ -from datetime import datetime -from src.toy_example.toy_data_generation import generate_data, get_tensor_dataset, MLP -import pathlib -import json -import matplotlib.pyplot as plt -import numpy as np -import seaborn as sns -import pandas as pd -from scipy.stats import ttest_ind - -import torch as th -from torch.utils.data import DataLoader - -import pytorch_lightning as pl -import argparse -import wandb -import os -from utils.logger import setup_logger - -logger = setup_logger(__name__) - - -wandb_config = {'project': 'internalization', - 'entity': 'assistance-llms', - 'notes': os.environ.get('SLURM_JOB_ID', 'local')} - - - -def train(config=None): - run = wandb.init(config=config, **wandb_config) - args = run.config - - n_anchors = args.n_clusters#args.n_anchors - batch_size = args.batch_size - epochs = args.epochs - hidden_size = args.hidden_size - n_seeds = args.n_seeds - d_y = args.d_y - max_x = args.max_x - n_clusters = args.n_clusters - cluster_spread = args.cluster_spread - d_pos_enc = args.d_pos_enc - n_datapoints_per_cluster = args.n_datapoints_per_cluster - p_definition = args.p_definition - - logger.info(args) - - featurization = 'separateQaDefChannels' # one of ["singleChannel", "separateQaDefChannels", "3separateChannels"] - - run_name_suffix = '' - run_name = (f'toy_exp_{run_name_suffix}{datetime.now().strftime("%Y%m%d-%H%M%S")}' - f'_{featurization}_dy{d_y}_nAnchors{n_anchors}_bs{batch_size}_epochs{epochs}_nnWidth{hidden_size}') - exp_folder = f'./toy_experiments/{run_name}' - pathlib.Path(exp_folder).mkdir(parents=True, exist_ok=True) - - config_dict = {'n_seeds': n_seeds, 'batch_size': batch_size, 'epochs': epochs, 'd_y': d_y, 'max_x': max_x, 'n_anchors': n_anchors, - 'featurization': featurization, 'n_clusters': n_clusters, 'cluster_spread': cluster_spread, - 'n_datapoints_per_cluster': n_datapoints_per_cluster, 'p_definition': p_definition, 'd_pos_enc': d_pos_enc,} - json.dump(config_dict, open(f'{exp_folder}/config.json', 'w')) - - test_losses = {} - for seed in range(n_seeds): - train_datapoints, test_sets, data1, data2 = generate_data(seed=seed+400, n_anchors=n_anchors, n_datapoints=max_x, d_y=d_y, featurization=featurization, - n_clusters=n_clusters, cluster_spread=cluster_spread, n_datapoints_per_cluster=n_datapoints_per_cluster, - p_definition=p_definition, d_pos_enc=d_pos_enc) - - print(f'total train datapoints: {len(train_datapoints)}') - - ####### plot the test/train datapoints and save to file ####### - # plot the train data - # TODO use different markers for circles/triangles/squares instead of colors - plt.figure(figsize=(15, 5)) - plt.scatter([d.x_normalized for d in train_datapoints], [d.get_label()[0] for d in train_datapoints], - c=['gray' if d.is_circle else 'green' if d.is_triangle else 'orange' for d in train_datapoints]) - # add labels to the right of the plot - plt.text(1.03, 0.9, 'circles', color='gray', transform=plt.gca().transAxes) - plt.text(1.03, 0.85, 'triangles', color='green', transform=plt.gca().transAxes) - plt.text(1.03, 0.8, 'squares', color='orange', transform=plt.gca().transAxes) - plt.title(f'train data, seed {seed}') - plt.plot(np.arange(len(data1))/max_x, data1[:, 0], c = 'k') - plt.plot(np.arange(len(data2))/max_x, data2[:, 0], c = 'brown') - plt.savefig(f'{exp_folder}/train_data_s{seed}.png') - plt.clf() - # plot the test data with the same color palette as in QA experiments - color2order = {'blue': 0, 'orange': 1, 'green': 2, 'red': 3, 'purple': 4, 'brown': 5, 'pink': 6, 'gray': 7, 'olive': 8, 'cyan': 9} - name2color = {'d1consis': 'blue', 'q': 'brown', 'qd2incons': 'pink', 'd2consis': 'red', 'qd1consis': 'purple', - 'no_qd_baseline': 'orange', 'q_no_replacement_baseline': 'green', 'qd1incons': 'cyan', 'qd2consis': 'olive', 'd3consis': 'gray'} - palette = sns.color_palette() # default palette, muted version of tab10 - plt.figure(figsize=(15, 5)) - plt.plot(np.arange(len(data1))/max_x, data1[:, 0], c = 'k') - plt.plot(np.arange(len(data2))/max_x, data2[:, 0], c = 'brown') - for subset_name, data in test_sets.items(): - plt.scatter(np.array([d.x_normalized for d in data]), np.array([d.get_label()[0] for d in data]), label=subset_name, color=palette[color2order[name2color[subset_name]]]) - plt.legend() - plt.title(f'test data, seed {seed}') - plt.savefig(f'{exp_folder}/test_data_s{seed}.png') - - wandb.log({'plot_train_data': [wandb.Image(f'{exp_folder}/train_data_s{seed}.png')], - 'plot_test_data': [wandb.Image(f'{exp_folder}/test_data_s{seed}.png')]}) - - - ####### train the model ####### - th.set_float32_matmul_precision('high') - pl.seed_everything(seed) - mlp = MLP(n_in=len(train_datapoints[0].get_features()), n_out=len(train_datapoints[0].get_label()), hidden_size=hidden_size) - trainer = pl.Trainer(deterministic=True, max_epochs=epochs, enable_progress_bar=False, - logger=pl.loggers.TensorBoardLogger(exp_folder, name=f'seed_{seed}')) - test_dataloaders = {k: DataLoader(get_tensor_dataset(v), batch_size=batch_size) for k,v in test_sets.items()} - - trainer.fit(mlp, DataLoader(get_tensor_dataset(train_datapoints), batch_size=batch_size), val_dataloaders=test_dataloaders) - - # plot the model predictions as well as the underlying data - plt.figure(figsize=(15, 5)) - plt.plot(np.arange(len(data2))/max_x, data1[:, 0], c = 'k') - plt.plot(np.arange(len(data2))/max_x, data2[:, 0], c = 'brown') - - mlp.eval() - with th.no_grad(): - test_losses[seed] = {} - for subset_name, data in test_sets.items(): - x = th.Tensor(np.array([d.get_features() for d in data])) - y = th.Tensor(np.array([d.get_label() for d in data])) #.unsqueeze(1) - y_hat = mlp(x) - - dim_to_keep_matrix = th.Tensor(np.array([d.one_hot_dim_to_keep for d in data])) - loss = mlp.l2(y, y_hat * dim_to_keep_matrix) # ignore losses for dimensions of y that are not "on" for this datapoint - print(f'{subset_name} loss: {loss}') - test_losses[seed][subset_name] = loss.detach().numpy() - # plot predictions; NOTE that we don't plot those where d.dim_to_keep != 0 - dim_to_keep_is_0_idx = [i for i, d in enumerate(data) if d.dim_to_keep==0] - plt.scatter(np.array([d.x_normalized for d in data])[dim_to_keep_is_0_idx], y_hat.detach().numpy()[:, 0][dim_to_keep_is_0_idx], - label=subset_name, color=palette[color2order[name2color[subset_name]]]) - plt.legend() - plt.savefig(f'{exp_folder}/model_predictions_s{seed}.png') - plt.clf() - - wandb.log({'plot_model_predictions': [wandb.Image(f'{exp_folder}/model_predictions_s{seed}.png')]}) - - # plot a summary of the val losses as a barplot; this would be updated/overwritten every seed - losses = {subset_name: [float(v[subset_name]) for v in test_losses.values()] for subset_name in test_sets.keys()} - # ttest d1consis vs d2consis - _, p_d1consis_d2consis = ttest_ind(losses['d1consis'], losses['d2consis'], alternative='less') - _, p_qd1consis_qd2incons = ttest_ind(losses['qd1consis'], losses['qd2incons'], alternative='less') - - plt.clf() # clear the plot - plt.figure(figsize=(15, 5)) - sns.barplot(data=pd.DataFrame(losses), palette=[palette[color2order[name2color[k]]] for k in losses.keys()]) - plt.title(f'p(qd1consis < qd2incons) = {p_qd1consis_qd2incons:.4f}, p(d1consis < d2consis) = {p_d1consis_d2consis:.4f}, n_seeds = {len(losses["d1consis"])}') - plt.ylabel('MSE') - plt.savefig(f'{exp_folder}/results.png') - - # save means, stds, n_seeds, p-values, etc in a results.json file - result_dict = {'n_seeds': len(losses['d1consis']), - 'd1consis': {'mean': np.mean(losses['d1consis']), 'std': np.std(losses['d1consis'])}, - 'd2consis': {'mean': np.mean(losses['d2consis']), 'std': np.std(losses['d2consis'])}, - 'qd1consis': {'mean': np.mean(losses['qd1consis']), 'std': np.std(losses['qd1consis'])}, - 'qd2incons': {'mean': np.mean(losses['qd2incons']), 'std': np.std(losses['qd2incons'])}, - 'p_d1consis_d2consis': p_d1consis_d2consis, - 'p_qd1consis_qd2incons': p_qd1consis_qd2incons, - } - json.dump(result_dict, open(f'{exp_folder}/results.json', 'w')) - - # metric = -np.mean(losses['qd1consis']) + np.mean(losses['qd2incons']) - np.mean(losses['d1consis']) + np.mean(losses['d2consis']) # maximize this - # p values based metric - metric = p_d1consis_d2consis + p_qd1consis_qd2incons # minimize this - wandb.log({'metric': metric, 'd1consis': np.mean(losses['d1consis']), 'd2consis': np.mean(losses['d2consis']), 'qd1consis': np.mean(losses['qd1consis']), - 'qd2incons': np.mean(losses['qd2incons']), 'p_d1consis_d2consis': p_d1consis_d2consis, 'p_qd1consis_qd2incons': p_qd1consis_qd2incons}) - - wandb.log( - {'plot_MSE': [wandb.Image(f'{exp_folder}/results.png')]} - ) - run.finish() - return metric - - -if __name__ == '__main__': - # parser = argparse.ArgumentParser() - # parser.add_argument("--sweep_id", type=str, help="Sweep ID for wandb", required=True) - # args = parser.parse_args() - - # sweep_id = args.sweep_id - # wandb.agent(sweep_id, function=train, entity=wandb_config['entity'], project=wandb_config['project']) - train() \ No newline at end of file