From 52b4ba3095c90514447ff5e3f3f58edf8f930bc3 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv <168006707+dnandakumar-nv@users.noreply.github.com> Date: Wed, 17 Jul 2024 11:29:27 -0400 Subject: [PATCH 1/5] Add AE Feature Selector Class and Test --- .../models/dfencoder/ae_feature_selector.py | 538 ++++++++++++++++++ tests/dfencoder/test_ae_feature_selector.py | 64 +++ 2 files changed, 602 insertions(+) create mode 100644 morpheus/models/dfencoder/ae_feature_selector.py create mode 100644 tests/dfencoder/test_ae_feature_selector.py diff --git a/morpheus/models/dfencoder/ae_feature_selector.py b/morpheus/models/dfencoder/ae_feature_selector.py new file mode 100644 index 0000000000..528c41cc3f --- /dev/null +++ b/morpheus/models/dfencoder/ae_feature_selector.py @@ -0,0 +1,538 @@ +# Copyright (c) 2021-2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging + +import numpy as np +import pandas as pd +import torch +from sklearn.compose import ColumnTransformer +from sklearn.feature_selection import VarianceThreshold +from sklearn.preprocessing import StandardScaler, OneHotEncoder +from tqdm.notebook import tqdm + +from morpheus.models.dfencoder.autoencoder import AutoEncoder + +LOG = logging.getLogger('AE Feature Selector') + + +class AutoencoderFeatureSelector: + """ + A class to select features using an autoencoder, handling categorical and numerical data. + Supports handling ambiguities in data types and ensures selected feature count constraints. + + Attributes: + input_json (dict): List of dictionary objects to normalize into a dataframe + id_column (str) : Column name that contains ID for morpheus AE pipeline. Default None. + timestamp_column (str) : Column name that contains the + timestamp for morpheus AE pipeline. Default None. + encoding_dim (int): Dimension of the encoding layer, defaults to half of input + dimensions if not set. + batch_size (int): Batch size for training the autoencoder. + variance_threshold (float) : Minimum variance a column must contain + to remain in consideration. Default 0. + null_threshold (float): Maximum proportion of null values a column can contain. Default 0.3. + cardinality_threshold_high (float): Maximum proportion + of cardinality to length of data allowable. Default 0.99. + cardinality_threshold_low_n (int): Minimum cardinalty for a feature to be considered numerical + during type infernce. Default 10. + categorical_features (list[str]): List of features in the data to be considered categorical. Default []. + numeric_features (list[str]): List of features in the data to be considered numeric. Default []. + ablation_epochs (int): Number of epochs to train the autoencoder. + device (str): Device to run the model on, defaults to 'cuda' if available. + + Methods: + train_autoencoder(data_loader): Trains the autoencoder using the provided data loader. + calculate_loss(model, data_loader): Calculates reconstruction loss using the trained model. + preprocess_data(): Preprocesses the DataFrame to handle numerical and categorical data appropriately. + remove_low_variance(processed_data): Removes features with low variance. + remove_high_correlation(data): Removes highly correlated features based on a threshold. + feature_importance_evaluation(processed_data): Evaluates feature importance using the autoencoder. + select_features(k_min, k_max): Selects features based on importance, adhering to min/max constraints. + """ + + def __init__( + self, + input_json, + id_column=None, + timestamp_column=None, + encoding_dim=None, + batch_size=256, + variance_threshold=0, + null_threshold=0.3, + cardinality_threshold_high=0.999, + cardinality_threshold_low_n=10, + categorical_features=[], + numeric_features=[], + ablation_epochs=20, + device="cuda" if torch.cuda.is_available() else "cpu", + ): + self.df = pd.json_normalize(input_json) + self.df_orig = pd.json_normalize(input_json) + self.encoding_dim = encoding_dim or self.df.shape[1] // 2 + self.batch_size = batch_size + self.device = device + self.preprocessor = None # To store the preprocessor for transforming data + self.variance_threshold = variance_threshold + self.null_threshold = null_threshold + self.cardinality_threshold_high = cardinality_threshold_high + self.cardinality_threshold_low = cardinality_threshold_low_n + self.categorical_features = categorical_features + self.numeric_features = numeric_features + self.ablation_epochs = ablation_epochs + self.id_col = id_column + self.ts_col = timestamp_column + self.final_features = self.df.columns.tolist() # Initialize with all features + + self._model_kwargs = { + "encoder_layers": [512, 500], # layers of the encoding part + "decoder_layers": [512], # layers of the decoding part + "activation": 'relu', # activation function + "swap_probability": 0.2, # noise parameter + "learning_rate": 0.001, # learning rate + "learning_rate_decay": .99, # learning decay + "batch_size": 512, + "verbose": False, + "optimizer": 'sgd', # SGD optimizer is selected(Stochastic gradient descent) + "scaler": 'standard', # feature scaling method + "min_cats": 1, # cut off for minority categories + "progress_bar": False, + "device": "cuda", + "patience": -1, + } + + if self.id_col is not None: + self.df.drop([self.id_col], axis=1, inplace=True) + self.df_orig.drop([self.id_col], axis=1, inplace=True) + + if self.ts_col is not None: + self.df.drop([self.ts_col], axis=1, inplace=True) + self.df_orig.drop([self.ts_col], axis=1, inplace=True) + + def train_autoencoder(self, dataframe): + """ + Trains the autoencoder model on the data provided by the DataLoader. + + Parameters + __________ + dataframe: pd.DataFrame + DataFrame containing the dataset for training. + + Returns + ________ + The trained autoencoder model. + """ + model = AutoEncoder(**self._model_kwargs) + model.fit(dataframe, epochs=self.ablation_epochs) + return model + + def calculate_loss(self, model: AutoEncoder, data: pd.DataFrame): + """ + Calculates the reconstruction loss of the autoencoder model using the provided DataLoader. + + Parameters + ___________ + model: Autoencoder + The trained autoencoder model. + + data: pd.DataFrame + DataLoader for evaluating the model. + + Returns + ________ + + Mean reconstruction loss over the data in the DataLoader. + """ + dataset = model._data_to_dataset(data) + mse, _, _ = model.get_anomaly_score_losses(data) + return mse.mean().item() + + def preprocess_data(self): + """ + Preprocesses the DataFrame by scaling numeric features and encoding categorical features. + Handles ambiguities in data types by attempting to convert object types to numeric where feasible. + Removes columns with high null values, high cardinality, or low cardinality. + + Returns + _______ + + The preprocessed data ready for feature selection. + """ + # Parameters to define what constitutes high null, high cardinality, and low cardinality + high_null_threshold = ( + self.null_threshold + ) # Columns with more than 30% missing values + high_cardinality_threshold = ( + self.cardinality_threshold_high + ) # Columns with unique values > 50% of the total rows + low_cardinality_threshold = ( + self.cardinality_threshold_low + ) # Columns with fewer than 10 unique values + + LOG.info("\n##########\nPreprocessing Data") + # Remove columns with high percentage of NULL values + null_counts = self.df.isnull().mean() + self.df = self.df.loc[:, null_counts <= high_null_threshold] + col_uniqueness = {} + # Remove columns with unhashable types + for col in self.df.columns: + + try: + unique_count = self.df[col].nunique() + col_uniqueness[col] = unique_count + total_count = self.df.shape[0] + except TypeError: + # Unhashable types + self.df.drop(columns=[col], inplace=True) + LOG.info(f"\t*Dropped unhashable column: {col}") + continue + + dataframe_columns = list(self.df.columns) + self.final_features = dataframe_columns # Update final_features + + # Perform type inferencing if needed + + if self.categorical_features == [] or self.numeric_features == []: + LOG.warning( + "Categorical or numeric features not provided. Performing type inference which could be inaccurate." + ) + + if self.categorical_features == [] and self.numeric_features != []: + self.categorical_features = [ + ft for ft in dataframe_columns if ft not in self.numeric_features + ] + + elif self.categorical_features != [] and self.numeric_features == []: + self.numeric_features = [ + ft for ft in dataframe_columns if ft not in self.categorical_features + ] + + else: + + for col in self.df.columns: + unique_count = col_uniqueness[col] + total_count = self.df.shape[0] + + if unique_count < low_cardinality_threshold: + # Considered as categorical due to low cardinality + self.categorical_features.append(col) + else: + # Try to convert 'object' columns to numeric if feasible + try: + self.df[col] = pd.to_numeric(self.df[col]) + self.numeric_features.append(col) + except ValueError: + # Default to categorical if conversion fails + # Check cardinality + if unique_count / total_count > high_cardinality_threshold: + # Exclude from dataset due to high cardinality + LOG.info(f"\t*Dropped high cardinality column: {col}") + self.df.drop(columns=[col], inplace=True) + else: + self.categorical_features.append(col) + + self.preprocessor = ColumnTransformer( + transformers=[ + ("num", StandardScaler(), self.numeric_features), + ("cat", OneHotEncoder(), self.categorical_features), + ] + ) + + processed_data = self.preprocessor.fit_transform(self.df) + + # Convert to dense if the output is sparse + if hasattr(processed_data, "toarray"): + processed_data = processed_data.toarray() + LOG.warning( + "Found sparse arrays when one-hot encoding. Consider using fewer categorical variables." + ) + + self.preprocessed_df = pd.DataFrame( + processed_data, columns=self.preprocessor.get_feature_names_out() + ) + self.final_features = self.preprocessed_df.columns.tolist() # Update final_features + + return processed_data + + def infer_column_types(self, df, sample_size=100, seed=None): + """ + Infers the data types of all columns in a pandas DataFrame by sampling values from each column. + + Parameters + __________ + df: pd.DataFrame + The DataFrame whose column types are to be inferred. + sample_size: int + The number of samples to take from each column for type inference. Defaults to 100. + seed: int + An integer seed for the random number generator to ensure reproducibility of the sampling. + + Returns + ________ + dict + A dictionary mapping each column name to its inferred data type ('int', 'float', 'bool', or 'string'). + """ + + type_dict = {} + df = df.copy().infer_objects() + + for column in df.columns: + col_type = str(df[column].dtype) + if col_type.startswith("int"): + type_dict[column] = "int" + elif col_type.startswith("float"): + type_dict[column] = "float" + elif col_type.startswith("bool"): + type_dict[column] = "bool" + else: + type_dict[column] = "string" + + return type_dict + + def prepare_schema(self, df, path=None): + """ + Creates dictionary schema definition for use with Morpheus JSONSchemaBuilder. + + Dumps to json path if not None. + + Parameters + __________ + df: pd.DataFrame + Dataframe to generate schema for + + Returns + _______ + scehma: dict + Dictionary schema definition to dump to JSON + """ + + datatypes = self.infer_column_types(df) + json_columns = list( + set([col.split(".")[0] for col in df.columns if "." in col]) + ) + schema = {"JSON_COLUMNS": json_columns, "SCHEMA_COLUMNS": []} + + for column in datatypes.keys(): + + if datatypes[column] != "bool": + schema["SCHEMA_COLUMNS"].append( + { + "type": "ColumnInfo", + "dtype": datatypes[column], + "data_column": column, + } + ) + else: + schema["SCHEMA_COLUMNS"].append( + { + "type": "BoolColumn", + "dtype": "bool", + "data_column": column, + "name": column, + } + ) + + if self.id_col is not None: + schema["SCHEMA_COLUMNS"].append( + {"type": "ColumnInfo", "dtype": "string", "data_column": self.id_col} + ) + + if self.ts_col is not None: + schema["SCHEMA_COLUMNS"].append( + { + "type": "DateTimeColumn", + "dtype": "datetime", + "name": self.ts_col, + "data_column": self.ts_col, + } + ) + + if path is not None: + with open(path, "w") as f: + json.dump(schema, f) + + return schema + + def remove_low_variance(self, processed_data): + """ + Removes features with low variance from the dataset. + + Parameters + ___________ + processed_data: np.array) + Preprocessed data from which to remove low variance features. + + Returns + ________ + tuple + A tuple containing the reduced dataset and the mask of features retained (boolean array). + """ + selector = VarianceThreshold(threshold=self.variance_threshold) + reduced_data = selector.fit_transform(processed_data) + LOG.info("\n##########\nDropped Features with Low Variance.") + self.final_features = [f for f, s in zip(self.final_features, selector.get_support()) if + s] # Update final_features + return reduced_data, selector.get_support() + + def remove_high_correlation(self, data, threshold=0.85): + """ + Removes highly correlated features from the dataset based on the specified threshold. + + Parameters + __________ + data: np.array + The dataset from which to remove highly correlated features. + + threshold: float + Correlation coefficient above which features are considered highly correlated and one is removed. + + Returns + ________ + tuple + A tuple containing the reduced dataset and a list of dropped columns. + """ + corr_matrix = pd.DataFrame(data).corr().abs() + upper = corr_matrix.where( + np.triu(np.ones(corr_matrix.shape), k=1).astype(np.bool_) + ) + to_drop = [column for column in upper.columns if any(upper[column] > threshold)] + LOG.info("\n##########\nDropped Features with High Correlation.") + LOG.info(to_drop) + self.final_features = [f for i, f in enumerate(self.final_features) if + i not in to_drop] # Update final_features + return np.delete(data, to_drop, axis=1), to_drop + + def feature_importance_evaluation(self, processed_data: pd.DataFrame): + """ + Evaluates the importance of features by measuring the increase in the reconstruction loss of an + autoencoder when each feature is individually omitted. + + Parameters + ___________ + processed_data: pd.DataFrame + The dataset to be used for assessing feature importance. + + Returns + ________ + dict + A dictionary where keys are feature indices and values are the calculated importance scores based on loss increase. + """ + num_columns = processed_data.shape[1] + baseline_columns = processed_data.columns + + full_model = self.train_autoencoder(processed_data) + base_loss = self.calculate_loss(full_model, processed_data) + + LOG.info("\n##########\nPerforming Autoencoder Ablation Study.") + importance_scores = {} + for i in tqdm(range(num_columns), desc="AE Ablation Study"): + reduced_data = processed_data.drop(columns=baseline_columns[i]) + reduced_model = self.train_autoencoder(reduced_data) + loss = self.calculate_loss(reduced_model, reduced_data) + importance_scores[i] = base_loss - loss + + LOG.info("\nPerformed Autoencoder Ablation Study.") + + return importance_scores + + def print_report(self, feature_names): + """ + Prints summary information on feature selection. + + Parameters + __________ + feature_names: list[str] + List of feature names that were downselected as important + + Returns + _______ + list[str] + List of feature names cleaned and processed for use. + + """ + + cats = [] + nums = [] + + for s in feature_names: + if s.startswith("cat__"): + cats.append(s[5:].split("_")[0]) # Strip 'cat__' and append + elif s.startswith("num__"): + nums.append(s[5:]) # Strip 'num__' and append + + cats = list(set(cats)) + LOG.info( + f"\n##########\nThe following numeric features were found to be effective: " + ) + LOG.info(nums) + LOG.info( + f"\n##########\nThe following categorical features were found to be effective: " + ) + LOG.info(cats) + + return cats + nums + + def select_features( + self, k_min=5, k_max=10, raw_schema_path=None, preprocess_schema_path=None, perform_ablation=True + ): + """ + Selects features based on autoencoder performance, adhering to specified minimum and maximum feature count. + + Parameters + ___________ + k_min: int + Minimum number of features to retain. + + k_max: int + Maximum number of features to retain. + + raw_schema_path: str + Path to dump raw data schema file for Morpheus pipeline. + + preprocess_schema_path: str + Path to dump preprocessed data schema file for Morpheus pipeline. + + perform_ablation: bool + Flag to indicate whether to perform the autoencoder ablation study. + + Returns + ________ + list + selected features based on importance scores. + """ + raw_schema = self.prepare_schema(self.df_orig, raw_schema_path) + processed_data = self.preprocess_data() + processed_data, _ = self.remove_low_variance(processed_data) + processed_data, _ = self.remove_high_correlation(processed_data) + + if perform_ablation: + feature_scores = self.feature_importance_evaluation(self.preprocessed_df[self.final_features]) + sorted_features = sorted(feature_scores, key=feature_scores.get, reverse=True) + selected_features = sorted_features[: min(k_max, len(sorted_features))] + final_features = selected_features[: max(k_min, len(selected_features))] + final_feature_names = [ + f + for i, f in enumerate(self.preprocessor.get_feature_names_out()) + if i in final_features + ] + final_feature_names = self.print_report(final_feature_names) + else: + final_feature_names = self.final_features[: max(k_min, len(self.final_features))] + + preproc_schema = self.prepare_schema( + self.df[final_feature_names], preprocess_schema_path + ) + + return raw_schema, preproc_schema diff --git a/tests/dfencoder/test_ae_feature_selector.py b/tests/dfencoder/test_ae_feature_selector.py new file mode 100644 index 0000000000..e6eb199456 --- /dev/null +++ b/tests/dfencoder/test_ae_feature_selector.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pandas as pd +import pytest +from sklearn.datasets import load_iris + +from morpheus.models.dfencoder.ae_feature_selector import AutoencoderFeatureSelector + + +@pytest.fixture +def sample_data(): + """Fixture to provide the Iris dataset for testing.""" + iris = load_iris() + data = pd.DataFrame(data=iris.data, columns=iris.feature_names) + return data.to_dict(orient='records') + + +def test_preprocess_data(sample_data): + """Test the preprocess_data method.""" + selector = AutoencoderFeatureSelector(sample_data) + processed_data = selector.preprocess_data() + assert isinstance(processed_data, np.ndarray) + assert processed_data.shape[1] > 0 # Ensure columns exist after preprocessing + + +def test_remove_low_variance(sample_data): + """Test removing low variance features.""" + selector = AutoencoderFeatureSelector(sample_data) + processed_data = selector.preprocess_data() + reduced_data, mask = selector.remove_low_variance(processed_data) + assert reduced_data.shape[1] <= processed_data.shape[1] + assert isinstance(mask, np.ndarray) + + +def test_remove_high_correlation(sample_data): + """Test removing highly correlated features.""" + selector = AutoencoderFeatureSelector(sample_data) + processed_data = selector.preprocess_data() + reduced_data, dropped_cols = selector.remove_high_correlation(processed_data) + assert reduced_data.shape[1] <= processed_data.shape[1] + assert isinstance(dropped_cols, list) + + +def test_select_features(sample_data): + """Test the select_features method.""" + selector = AutoencoderFeatureSelector(sample_data) + raw_schema, preproc_schema = selector.select_features(k_min=3, k_max=5) + assert isinstance(raw_schema, dict) + assert isinstance(preproc_schema, dict) From e775c36528f7c37afdd01ff09571f30999045e38 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv <168006707+dnandakumar-nv@users.noreply.github.com> Date: Wed, 17 Jul 2024 11:36:38 -0400 Subject: [PATCH 2/5] Add demo notebook --- .../ae_automated_feature_selection.ipynb | 500 ++++++++++++++++++ 1 file changed, 500 insertions(+) create mode 100644 models/training-tuning-scripts/dfp-models/ae_automated_feature_selection.ipynb diff --git a/models/training-tuning-scripts/dfp-models/ae_automated_feature_selection.ipynb b/models/training-tuning-scripts/dfp-models/ae_automated_feature_selection.ipynb new file mode 100644 index 0000000000..11a4a45ba3 --- /dev/null +++ b/models/training-tuning-scripts/dfp-models/ae_automated_feature_selection.ipynb @@ -0,0 +1,500 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "22706b86-d366-4390-bedf-d0833435bc4b", + "metadata": { + "tags": [] + }, + "source": [ + "# 1.0 Loading a Training Data Sample into Memory" + ] + }, + { + "cell_type": "markdown", + "id": "3deed77f-3f61-45bb-abe0-ef8ea6e09fe1", + "metadata": {}, + "source": [ + "In this example, we'll walk through a heuristics-based automated feature selection process for an autoencoder based digital fingerprinting workflow. The selection process can be a compute and time intensive process on large datasets. Consequently, it's recommended that you select datasets that meet the following criteria for an initial exploratory analysis.\n", + "\n", + "1. Tractable Size: Chose a sample of your dataset that can easily fit into GPU and host memory. As an example, limiting your sample size to ~10,000 rows will result in a runtime of approximately 2 minutes. We recommend experimenting with a few different dataset sizes to find one that works best for you. \n", + "2. Capturing Data Variance: Chose a sample of data that sufficiently captures the general behavior you're trying to model. For example, if you have a lot of categorical columns in your dataset, try and ensure your sample contains most, if not all, values your categorical variables can take. Also aim to capture as much variance (diversity) in your data as possible to get the best results out of this process. \n", + "3. Don't attempt to optimize on too many features: While the automated process is capable of handling an arbitrary number of features (within system limits), try to optimize your process on only those features that may have cyber-specific relevance to the behavior you're trying to capture. Including a very large number of features to this analysis will noticably increase runtime and could also lead to suboptimal feature seleciton. \n", + "\n", + "With that being said, let's explore what a sample dataset could look like. Please note that your data should be **JSON formatted** in JSON files for this to work. If you're reading instead from something like `parquet` or CSV, we suggest reading it into a DataFrame with `pandas` and then turning that into a dictionary using `df.to_dict(orient='records')`. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "ce603a32-99f7-4e62-beb0-8013948fa8e3", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The loaded data contains 3239 rows. An exampled data entry looks like: \n", + "\n", + "{'Level': 4,\n", + " 'callerIpAddress': '13.113.40.157',\n", + " 'category': 'NonInteractiveUserSignInLogs',\n", + " 'correlationId': '36764e36-d379-45f4-914a-96a69bd59ae5',\n", + " 'durationMs': 0,\n", + " 'identity': 'Attack Target',\n", + " 'location': 'XR',\n", + " 'operationName': 'Sign-in activity',\n", + " 'operationVersion': '1.0',\n", + " 'properties': {'appDisplayName': 'Articulate 360',\n", + " 'appId': '9c5b7fe3-0ad2-4ea6-94e5-9e0001f367e3',\n", + " 'appServicePrincipalId': None,\n", + " 'appliedConditionalAccessPolicies': [],\n", + " 'authenticationContextClassReferences': [],\n", + " 'authenticationDetails': [],\n", + " 'authenticationProcessingDetails': [],\n", + " 'authenticationProtocol': 'none',\n", + " 'authenticationRequirement': 'singleFactorAuthentication',\n", + " 'authenticationRequirementPolicies': [],\n", + " 'autonomousSystemNumber': 34974,\n", + " 'clientAppUsed': 'Mobile Apps and Desktop clients',\n", + " 'clientCredentialType': 'none',\n", + " 'conditionalAccessStatus': 'success',\n", + " 'correlationId': '2a7afd02-0af7-4919-b42e-e96c358011e4',\n", + " 'createdDateTime': '2022-08-05T18:07:31.755011Z',\n", + " 'crossTenantAccessType': 'none',\n", + " 'deviceDetail': {'browser': 'Edge 99.14477',\n", + " 'deviceId': 'a44625dc-6f81-449a-9799-8005f7209b42',\n", + " 'displayName': 'ATTACKTARGET-LT',\n", + " 'operatingSystem': 'Windows 10',\n", + " 'trustType': 'Azure AD registered'},\n", + " 'flaggedForReview': False,\n", + " 'homeTenantId': 'd3e5a967-5657-4a42-afcc-6106b6c3c299',\n", + " 'id': '528a72ae-c612-474f-a22e-2f69e7ca7700',\n", + " 'incomingTokenType': 'primaryRefreshToken',\n", + " 'ipAddress': '13.113.40.157',\n", + " 'isInteractive': False,\n", + " 'isTenantRestricted': False,\n", + " 'location': {'city': 'Smithfort',\n", + " 'countryOrRegion': 'XR',\n", + " 'geoCoordinates': {'latitude': 3.7564095,\n", + " 'longitude': -121.574606},\n", + " 'state': 'Smithfort'},\n", + " 'networkLocationDetails': [],\n", + " 'originalRequestId': '1adc0f22-2e25-4831-a411-315dcf995f07',\n", + " 'privateLinkDetails': {},\n", + " 'processingTimeInMilliseconds': 313,\n", + " 'resourceDisplayName': 'Articulate 360 Online',\n", + " 'resourceId': '436260d5-3791-4aea-a4da-aa698ec35f2f',\n", + " 'resourceServicePrincipalId': '94b9f3ba-5f4b-4edc-99dc-a420b76960f9',\n", + " 'resourceTenantId': 'd3e5a967-5657-4a42-afcc-6106b6c3c299',\n", + " 'riskDetail': 'none',\n", + " 'riskEventTypes': [],\n", + " 'riskEventTypes_v2': [],\n", + " 'riskLevelAggregated': 'none',\n", + " 'riskLevelDuringSignIn': 'none',\n", + " 'riskState': 'none',\n", + " 'rngcStatus': 0,\n", + " 'servicePrincipalId': '',\n", + " 'ssoExtensionVersion': '',\n", + " 'status': {'errorCode': 0},\n", + " 'tokenIssuerName': '',\n", + " 'tokenIssuerType': 'AzureAD',\n", + " 'uniqueTokenIdentifier': 'KZ4AsNDKZSYNt3xpOb8Wvlfn7uWHiUcyz0JxdB4i6K7h0Br8',\n", + " 'userAgent': 'Mozilla/5.0 (iPad; CPU iPad OS 9_3_6 like Mac OS '\n", + " 'X) AppleWebKit/536.1 (KHTML, like Gecko) '\n", + " 'CriOS/42.0.865.0 Mobile/77O124 Safari/536.1 '\n", + " 'Edge/99.14477',\n", + " 'userDisplayName': 'Attack Target',\n", + " 'userId': 'd735e84b-dcca-404d-9f7d-700f360f41a6',\n", + " 'userPrincipalName': 'attacktarget@domain.com',\n", + " 'userType': 'Member'},\n", + " 'resourceId': '/tenants/d3e5a967-5657-4a42-afcc-6106b6c3c299/providers/Microsoft.aadiam',\n", + " 'resultDescription': None,\n", + " 'resultSignature': 'None',\n", + " 'resultType': '0',\n", + " 'tenantId': 'd3e5a967-5657-4a42-afcc-6106b6c3c299',\n", + " 'time': '2022-08-05T18:07:31.442011Z'}\n" + ] + } + ], + "source": [ + "import os\n", + "import json\n", + "from pprint import pprint\n", + "\n", + "def concatenate_json_lists(directory_path):\n", + " \"\"\"\n", + " Reads all JSON files in the specified directory, assuming each JSON file contains a list,\n", + " and concatenates these lists into a single list.\n", + "\n", + " Parameters:\n", + " - directory_path (str): The path to the directory containing the JSON files.\n", + "\n", + " Returns:\n", + " - list: A concatenated list containing all elements from the lists in the JSON files.\n", + " \"\"\"\n", + " concatenated_list = []\n", + " for filename in os.listdir(directory_path):\n", + " if filename.endswith(\".json\"):\n", + " file_path = os.path.join(directory_path, filename)\n", + " try:\n", + " with open(file_path, 'r') as json_file:\n", + " data = json.load(json_file)\n", + " if isinstance(data, list):\n", + " concatenated_list.extend(data)\n", + " else:\n", + " print(f\"File {filename} does not contain a list.\")\n", + " except Exception as e:\n", + " print(f\"An error occurred while processing file {filename}: {e}\")\n", + " return concatenated_list\n", + "\n", + "concatenated_data = concatenate_json_lists(\"/workspace/examples/data/dfp/azure-training-data/\") #Change the path to your data directory here. \n", + "\n", + "print(f\"The loaded data contains {len(concatenated_data)} rows. An exampled data entry looks like: \\n\")\n", + "pprint(concatenated_data[0])" + ] + }, + { + "cell_type": "markdown", + "id": "56ab34bc-3cf9-4ce4-96e3-415dad563581", + "metadata": { + "tags": [] + }, + "source": [ + "# 2 Performing Automated Feature Selection" + ] + }, + { + "cell_type": "markdown", + "id": "2fdcf1f0-15bc-4080-906c-c3e5cdccb85c", + "metadata": { + "tags": [] + }, + "source": [ + "## 2.1 Known Limitations and Considerations" + ] + }, + { + "cell_type": "markdown", + "id": "0f6fd5c2-1d6b-4133-bb6d-761a648532d8", + "metadata": {}, + "source": [ + "Once we have the data oriented as a list of dictionary objects (JSON), we're ready to run the automated feature selection process. \n", + "\n", + "__The tool is domain agnostic__, which means that it doesn't select features that work well for your specific cyber workflow. Instead, it selects features that work well for an autoencoder model from a statistical perspective. This necessarily introduces a few considerations we recommend taking into account.\n", + "\n", + "1. **Consult with cyber experts**: We recommend analyizing the output of the feature selection process with cybersecurity experts in your domain area to verify if they believe the features contain enough 'signal' from a use-case persepctive. Domain experts can also help give you an idea of how they would solve the problem, which can inform your feature selection. \n", + "2. **Limit the possibility of overfitting**: Avoid selecting feature outputs from the model that could lead to undesrible properties like model overfitting. For example, avoid using features such as individually identifiable IP addresses, usernames, or MAC addresses outside of the actual user attribute in the DFP pipeline. \n", + "3. **Derived features**: This tool does not suggest or create any derived features. Such features are often helpful, if not critical, in the successful functioning of advanced ML workflows. Consider adding some derived features that could be helpful in capturing the behavior of interest into the data. This can often be done in collaboration with cyber domain experts who can help inform how they would solve the problem, which can be tranlated into derived features. \n", + "4. **Alternate encodings**: By default, any non-numeric features are one-hot encoded as categorical variables by the tool. Often, there may be better ways of representing the data that lend themselves well to your use case. For example, only encoding variables of interest, embedding text instead of one-hot encoding them, etc. We recommend exploring alternatives once you've established a baseline\n", + "\n", + "Refer to the notebook demo [here](https://github.com/nv-morpheus/Morpheus/blob/branch-24.06/models/training-tuning-scripts/dfp-models/dfp-feature-selection-demo.ipynb) for a more detailed analysis of datasets and ideas on derived features. The automated method here is intended to be used purely as a starting point in your development. " + ] + }, + { + "cell_type": "markdown", + "id": "afd2d69e-8cc1-407f-875f-f575be15c946", + "metadata": {}, + "source": [ + "## 2.2. Running Automated Feature Selection" + ] + }, + { + "cell_type": "markdown", + "id": "9b498cf3-8f94-436d-861a-fd8a17186a5e", + "metadata": {}, + "source": [ + "We can run automated feature selection by using Morpheus' `AutoencoderFeatureSelector` class and configuring some parameters. Let's see how. The AutoencoderFeatureSelector class provides the following configurable parameters. \n", + "\n", + "```python\n", + "\"\"\"\n", + " A class to select features using an autoencoder, handling categorical and numerical data.\n", + " Supports handling ambiguities in data types and ensures selected feature count constraints.\n", + "\n", + " Attributes:\n", + " input_json (dict): List of dictionary objects to normalize into a dataframe\n", + " id_column (str) : Column name that contains ID for morpheus AE pipeline. Default None.\n", + " timestamp_column (str) : Column name that contains the\n", + " timestamp for morpheus AE pipeline. Default None.\n", + " encoding_dim (int): Dimension of the encoding layer, defaults to half of input\n", + " dimensions if not set.\n", + " batch_size (int): Batch size for training the autoencoder.\n", + " variance_threshold (float) : Minimum variance a column must contain\n", + " to remain in consideration. Default 0.\n", + " null_threshold (float): Maximum proportion of null values a column can contain. Default 0.3.\n", + " cardinality_threshold_high (float): Maximum proportion\n", + " of cardinality to length of data allowable. Default 0.99.\n", + " cardinality_threshold_low_n (int): Minimum cardinalty for a feature to be considered numerical\n", + " during type infernce. Default 10.\n", + " categorical_features (list[str]): List of features in the data to be considered categorical. Default [].\n", + " numeric_features (list[str]): List of features in the data to be considered numeric. Default [].\n", + " ablation_epochs (int): Number of epochs to train the autoencoder.\n", + " device (str): Device to run the model on, defaults to 'cuda' if available.\n", + " \n", + "\"\"\"\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b071d622-cbcc-429d-bc49-2a7fd1b6fb99", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import sys \n", + "\n", + "sys.path.insert(0, '/workspace/morpheus')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "fd671cda-73ca-49dd-8cf0-02aa63adfe1b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from models.dfencoder.ae_feature_selector import AutoencoderFeatureSelector \n", + "\n", + "selector = AutoencoderFeatureSelector(\n", + " input_json = concatenated_data, \n", + " id_column = 'identity', #This is the entity you want to 'fingerprint'\n", + " timestamp_column = 'time', #This is your log timestamp\n", + " variance_threshold=0.1, #Removes cols. with variance lower than this\n", + " null_threshold=0.3, #Removes cols will null proportion great than this\n", + " cardinality_threshold_high=0.9, #Remove columns with high cardinality\n", + " cardinality_threshold_low_n=10, #Cols. with this cardinality with considered categorica\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "03995194-23b1-453a-8714-bfc6a3cb37e3", + "metadata": {}, + "source": [ + "Once we've instantiated the class, we can run the feature selection as a one-line command. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fbe23de-e3c0-43f7-83de-771de6919619", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Categorical or numeric features not provided. Performing type inference which could be inaccurate.\n", + "Found sparse arrays when one-hot encoding. Consider using fewer categorical variables.\n", + "/opt/conda/envs/morpheus/lib/python3.10/site-packages/torch/nn/init.py:452: UserWarning: Initializing zero-element tensors is a no-op\n", + " warnings.warn(\"Initializing zero-element tensors is a no-op\")\n", + "/opt/conda/envs/morpheus/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n", + " _torch_pytree._register_pytree_node(\n", + "Not going to perform early-stopping. self.patience(=-1) is provided for early-stopping but validation is not enabled. Please set `run_validation` to True and provide a `validation_dataset` to enable early-stopping.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "37bc78c7b9d34cd5baebdc4c52960447", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "AE Ablation Study: 0%| | 0/28 [00:00 Date: Thu, 25 Jul 2024 13:50:17 -0400 Subject: [PATCH 3/5] Update ae_feature_selector.py --- .../models/dfencoder/ae_feature_selector.py | 107 ++++++++++-------- 1 file changed, 62 insertions(+), 45 deletions(-) diff --git a/morpheus/models/dfencoder/ae_feature_selector.py b/morpheus/models/dfencoder/ae_feature_selector.py index 528c41cc3f..54ea6c238d 100644 --- a/morpheus/models/dfencoder/ae_feature_selector.py +++ b/morpheus/models/dfencoder/ae_feature_selector.py @@ -33,34 +33,51 @@ class AutoencoderFeatureSelector: A class to select features using an autoencoder, handling categorical and numerical data. Supports handling ambiguities in data types and ensures selected feature count constraints. - Attributes: - input_json (dict): List of dictionary objects to normalize into a dataframe - id_column (str) : Column name that contains ID for morpheus AE pipeline. Default None. - timestamp_column (str) : Column name that contains the - timestamp for morpheus AE pipeline. Default None. - encoding_dim (int): Dimension of the encoding layer, defaults to half of input - dimensions if not set. - batch_size (int): Batch size for training the autoencoder. - variance_threshold (float) : Minimum variance a column must contain - to remain in consideration. Default 0. - null_threshold (float): Maximum proportion of null values a column can contain. Default 0.3. - cardinality_threshold_high (float): Maximum proportion - of cardinality to length of data allowable. Default 0.99. - cardinality_threshold_low_n (int): Minimum cardinalty for a feature to be considered numerical - during type infernce. Default 10. - categorical_features (list[str]): List of features in the data to be considered categorical. Default []. - numeric_features (list[str]): List of features in the data to be considered numeric. Default []. - ablation_epochs (int): Number of epochs to train the autoencoder. - device (str): Device to run the model on, defaults to 'cuda' if available. - - Methods: - train_autoencoder(data_loader): Trains the autoencoder using the provided data loader. - calculate_loss(model, data_loader): Calculates reconstruction loss using the trained model. - preprocess_data(): Preprocesses the DataFrame to handle numerical and categorical data appropriately. - remove_low_variance(processed_data): Removes features with low variance. - remove_high_correlation(data): Removes highly correlated features based on a threshold. - feature_importance_evaluation(processed_data): Evaluates feature importance using the autoencoder. - select_features(k_min, k_max): Selects features based on importance, adhering to min/max constraints. + Attributes + ---------- + input_json : dict + List of dictionary objects to normalize into a dataframe. + id_column : str, optional + Column name that contains ID for morpheus AE pipeline. Default is None. + timestamp_column : str, optional + Column name that contains the timestamp for morpheus AE pipeline. Default is None. + encoding_dim : int, optional + Dimension of the encoding layer, defaults to half of input dimensions if not set. + batch_size : int + Batch size for training the autoencoder. + variance_threshold : float, optional + Minimum variance a column must contain to remain in consideration. Default is 0. + null_threshold : float, optional + Maximum proportion of null values a column can contain. Default is 0.3. + cardinality_threshold_high : float, optional + Maximum proportion of cardinality to length of data allowable. Default is 0.99. + cardinality_threshold_low_n : int, optional + Minimum cardinality for a feature to be considered numerical during type inference. Default is 10. + categorical_features : list of str, optional + List of features in the data to be considered categorical. Default is []. + numeric_features : list of str, optional + List of features in the data to be considered numeric. Default is []. + ablation_epochs : int + Number of epochs to train the autoencoder. + device : str + Device to run the model on, defaults to 'cuda' if available. + + Methods + ------- + train_autoencoder(data_loader) + Trains the autoencoder using the provided data loader. + calculate_loss(model, data_loader) + Calculates reconstruction loss using the trained model. + preprocess_data() + Preprocesses the DataFrame to handle numerical and categorical data appropriately. + remove_low_variance(processed_data) + Removes features with low variance. + remove_high_correlation(data) + Removes highly correlated features based on a threshold. + feature_importance_evaluation(processed_data) + Evaluates feature importance using the autoencoder. + select_features(k_min, k_max) + Selects features based on importance, adhering to min/max constraints. """ def __init__( @@ -80,11 +97,11 @@ def __init__( device="cuda" if torch.cuda.is_available() else "cpu", ): self.df = pd.json_normalize(input_json) - self.df_orig = pd.json_normalize(input_json) + self._df_orig = pd.json_normalize(input_json) self.encoding_dim = encoding_dim or self.df.shape[1] // 2 self.batch_size = batch_size self.device = device - self.preprocessor = None # To store the preprocessor for transforming data + self._preprocessor = None # To store the preprocessor for transforming data self.variance_threshold = variance_threshold self.null_threshold = null_threshold self.cardinality_threshold_high = cardinality_threshold_high @@ -94,7 +111,7 @@ def __init__( self.ablation_epochs = ablation_epochs self.id_col = id_column self.ts_col = timestamp_column - self.final_features = self.df.columns.tolist() # Initialize with all features + self._final_features = self.df.columns.tolist() # Initialize with all features self._model_kwargs = { "encoder_layers": [512, 500], # layers of the encoding part @@ -115,13 +132,13 @@ def __init__( if self.id_col is not None: self.df.drop([self.id_col], axis=1, inplace=True) - self.df_orig.drop([self.id_col], axis=1, inplace=True) + self._df_orig.drop([self.id_col], axis=1, inplace=True) if self.ts_col is not None: self.df.drop([self.ts_col], axis=1, inplace=True) - self.df_orig.drop([self.ts_col], axis=1, inplace=True) + self._df_orig.drop([self.ts_col], axis=1, inplace=True) - def train_autoencoder(self, dataframe): + def train_autoencoder(self, dataframe:pd.DataFrame): """ Trains the autoencoder model on the data provided by the DataLoader. @@ -200,7 +217,7 @@ def preprocess_data(self): continue dataframe_columns = list(self.df.columns) - self.final_features = dataframe_columns # Update final_features + self._final_features = dataframe_columns # Update final_features # Perform type inferencing if needed @@ -243,14 +260,14 @@ def preprocess_data(self): else: self.categorical_features.append(col) - self.preprocessor = ColumnTransformer( + self._preprocessor = ColumnTransformer( transformers=[ ("num", StandardScaler(), self.numeric_features), ("cat", OneHotEncoder(), self.categorical_features), ] ) - processed_data = self.preprocessor.fit_transform(self.df) + processed_data = self._preprocessor.fit_transform(self.df) # Convert to dense if the output is sparse if hasattr(processed_data, "toarray"): @@ -260,9 +277,9 @@ def preprocess_data(self): ) self.preprocessed_df = pd.DataFrame( - processed_data, columns=self.preprocessor.get_feature_names_out() + processed_data, columns=self._preprocessor.get_feature_names_out() ) - self.final_features = self.preprocessed_df.columns.tolist() # Update final_features + self._final_features = self.preprocessed_df.columns.tolist() # Update final_features return processed_data @@ -382,7 +399,7 @@ def remove_low_variance(self, processed_data): selector = VarianceThreshold(threshold=self.variance_threshold) reduced_data = selector.fit_transform(processed_data) LOG.info("\n##########\nDropped Features with Low Variance.") - self.final_features = [f for f, s in zip(self.final_features, selector.get_support()) if + self._final_features = [f for f, s in zip(self._final_features, selector.get_support()) if s] # Update final_features return reduced_data, selector.get_support() @@ -410,7 +427,7 @@ def remove_high_correlation(self, data, threshold=0.85): to_drop = [column for column in upper.columns if any(upper[column] > threshold)] LOG.info("\n##########\nDropped Features with High Correlation.") LOG.info(to_drop) - self.final_features = [f for i, f in enumerate(self.final_features) if + self._final_features = [f for i, f in enumerate(self._final_features) if i not in to_drop] # Update final_features return np.delete(data, to_drop, axis=1), to_drop @@ -512,24 +529,24 @@ def select_features( list selected features based on importance scores. """ - raw_schema = self.prepare_schema(self.df_orig, raw_schema_path) + raw_schema = self.prepare_schema(self._df_orig, raw_schema_path) processed_data = self.preprocess_data() processed_data, _ = self.remove_low_variance(processed_data) processed_data, _ = self.remove_high_correlation(processed_data) if perform_ablation: - feature_scores = self.feature_importance_evaluation(self.preprocessed_df[self.final_features]) + feature_scores = self.feature_importance_evaluation(self.preprocessed_df[self._final_features]) sorted_features = sorted(feature_scores, key=feature_scores.get, reverse=True) selected_features = sorted_features[: min(k_max, len(sorted_features))] final_features = selected_features[: max(k_min, len(selected_features))] final_feature_names = [ f - for i, f in enumerate(self.preprocessor.get_feature_names_out()) + for i, f in enumerate(self._preprocessor.get_feature_names_out()) if i in final_features ] final_feature_names = self.print_report(final_feature_names) else: - final_feature_names = self.final_features[: max(k_min, len(self.final_features))] + final_feature_names = self._final_features[: max(k_min, len(self._final_features))] preproc_schema = self.prepare_schema( self.df[final_feature_names], preprocess_schema_path From 330377623c5acfe46dcdee2283259d37131e344a Mon Sep 17 00:00:00 2001 From: dnandakumar-nv <168006707+dnandakumar-nv@users.noreply.github.com> Date: Fri, 26 Jul 2024 09:39:22 -0400 Subject: [PATCH 4/5] Update test_ae_feature_selector.py --- tests/dfencoder/test_ae_feature_selector.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/dfencoder/test_ae_feature_selector.py b/tests/dfencoder/test_ae_feature_selector.py index e6eb199456..ded81e720a 100644 --- a/tests/dfencoder/test_ae_feature_selector.py +++ b/tests/dfencoder/test_ae_feature_selector.py @@ -40,20 +40,20 @@ def test_preprocess_data(sample_data): def test_remove_low_variance(sample_data): """Test removing low variance features.""" - selector = AutoencoderFeatureSelector(sample_data) - processed_data = selector.preprocess_data() - reduced_data, mask = selector.remove_low_variance(processed_data) - assert reduced_data.shape[1] <= processed_data.shape[1] - assert isinstance(mask, np.ndarray) + sample_data['low_variance'] = 0 + selector = AutoencoderFeatureSelector(sample_data, variance_threshold=0.1) + reduced_data, mask = selector.remove_low_variance(sample_data.values) + assert reduced_data.shape == (150,4) def test_remove_high_correlation(sample_data): """Test removing highly correlated features.""" - selector = AutoencoderFeatureSelector(sample_data) - processed_data = selector.preprocess_data() - reduced_data, dropped_cols = selector.remove_high_correlation(processed_data) - assert reduced_data.shape[1] <= processed_data.shape[1] - assert isinstance(dropped_cols, list) + sample_data['high_corr_1'] = sample_data['sepal length (cm)'] + 1 + sample_data['high_corr_2'] = 2 * sample_data['sepal length (cm)'] + 1 + selector = AutoencoderFeatureSelector(sample_data, variance_threshold=0.1) + reduced_data, mask = selector.remove_high_correlation(sample_data.values, threshold=0.99) + assert reduced_data.shape == (150,4) + assert mask == [4,5] def test_select_features(sample_data): From 5da0340454b843c382f057c7bbd473fbdf45bf87 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv <168006707+dnandakumar-nv@users.noreply.github.com> Date: Fri, 26 Jul 2024 10:00:31 -0400 Subject: [PATCH 5/5] Update test_ae_feature_selector.py --- tests/dfencoder/test_ae_feature_selector.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/dfencoder/test_ae_feature_selector.py b/tests/dfencoder/test_ae_feature_selector.py index ded81e720a..ccfbd4cb52 100644 --- a/tests/dfencoder/test_ae_feature_selector.py +++ b/tests/dfencoder/test_ae_feature_selector.py @@ -27,12 +27,12 @@ def sample_data(): """Fixture to provide the Iris dataset for testing.""" iris = load_iris() data = pd.DataFrame(data=iris.data, columns=iris.feature_names) - return data.to_dict(orient='records') + return data def test_preprocess_data(sample_data): """Test the preprocess_data method.""" - selector = AutoencoderFeatureSelector(sample_data) + selector = AutoencoderFeatureSelector(sample_data.to_dict(orient='records')) processed_data = selector.preprocess_data() assert isinstance(processed_data, np.ndarray) assert processed_data.shape[1] > 0 # Ensure columns exist after preprocessing @@ -41,7 +41,7 @@ def test_preprocess_data(sample_data): def test_remove_low_variance(sample_data): """Test removing low variance features.""" sample_data['low_variance'] = 0 - selector = AutoencoderFeatureSelector(sample_data, variance_threshold=0.1) + selector = AutoencoderFeatureSelector(sample_data.to_dict(orient='records'), variance_threshold=0.1) reduced_data, mask = selector.remove_low_variance(sample_data.values) assert reduced_data.shape == (150,4) @@ -50,7 +50,7 @@ def test_remove_high_correlation(sample_data): """Test removing highly correlated features.""" sample_data['high_corr_1'] = sample_data['sepal length (cm)'] + 1 sample_data['high_corr_2'] = 2 * sample_data['sepal length (cm)'] + 1 - selector = AutoencoderFeatureSelector(sample_data, variance_threshold=0.1) + selector = AutoencoderFeatureSelector(sample_data.to_dict(orient='records'), variance_threshold=0.1) reduced_data, mask = selector.remove_high_correlation(sample_data.values, threshold=0.99) assert reduced_data.shape == (150,4) assert mask == [4,5] @@ -58,7 +58,7 @@ def test_remove_high_correlation(sample_data): def test_select_features(sample_data): """Test the select_features method.""" - selector = AutoencoderFeatureSelector(sample_data) + selector = AutoencoderFeatureSelector(sample_data.to_dict(orient='records')) raw_schema, preproc_schema = selector.select_features(k_min=3, k_max=5) assert isinstance(raw_schema, dict) assert isinstance(preproc_schema, dict)