From e7208d93054360edc42e9416590e19f52915e2ea Mon Sep 17 00:00:00 2001 From: Theodore Vasiloudis Date: Wed, 24 Jul 2024 03:23:51 +0300 Subject: [PATCH] [GSProcessing] Add option to truncate vectors with no-op transformation. (#922) *Issue #, if available:* *Description of changes:* * We add a keyword argument to the GSProcessing no-op transformation, named `truncate_dim`. When a user provides this parameter, we will try to truncate the input vectors to the dimension specified. This allows users to easily make use of input Matryoshka embeddings. * (Revision) Added the same transformation for GConstruct. ## Testing Unit tests passing, unit and integration test added for the new transformation option. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: jalencato --- .../configuration-gconstruction.rst | 3 ++ .../gs-processing/input-configuration.rst | 10 +++- .../config/feature_config_base.py | 19 ++++++-- .../dist_noop_transformation.py | 48 +++++++++++++++---- .../distributed_executor.py | 6 +-- .../gsprocessing-config.json | 11 +++++ .../tests/test_dist_heterogenous_loader.py | 2 + .../tests/test_dist_noop_transformation.py | 24 ++++++++++ python/graphstorm/gconstruct/transform.py | 23 +++++++-- tests/unit-tests/gconstruct/test_transform.py | 9 ++++ 10 files changed, 133 insertions(+), 22 deletions(-) diff --git a/docs/source/configuration/configuration-gconstruction.rst b/docs/source/configuration/configuration-gconstruction.rst index b2100b507c..429830ab34 100644 --- a/docs/source/configuration/configuration-gconstruction.rst +++ b/docs/source/configuration/configuration-gconstruction.rst @@ -93,6 +93,9 @@ Currently, the graph construction pipeline supports the following feature transf * **Numerical Rank Gauss transformation** normalizes numerical input features with rank gauss normalization. It maps the numeric feature values to gaussian distribution based on ranking. The method follows https://www.kaggle.com/c/porto-seguro-safe-driver-prediction/discussion/44629#250927. The ``name`` field in the feature transformation dictionary is ``rank_gauss``. The dict can contains one optional field, i.e., ``epsilon`` which is used to avoid INF float during computation and ``uniquify`` which controls whether deduplicating input features before computing rank gauss norm. * **Convert to categorical values** converts text data to categorial values. The ``name`` field is ``to_categorical``, and ``separator`` specifies how to split the string into multiple categorical values (this is only used to define multiple categorical values). If ``separator`` is not specified, the entire string is a categorical value. ``mapping`` (**optional**) is a dict that specifies how to map a string to an integer value that defines a categorical value. * **Numerical Bucket transformation** normalizes numerical input features with buckets. The input features are divided into one or multiple buckets. Each bucket stands for a range of floats. An input value can fall into one or more buckets depending on the transformation configuration. The ``name`` field in the feature transformation dictionary is ``bucket_numerical``. Users need to provide ``range`` and ``bucket_cnt`` field, which ``range`` defines a numerical range, and ``bucket_cnt`` defines number of buckets among the range. All buckets will have same length, and each of them is left included. e.g, bucket ``(a, b)`` will include a, but not b. All input feature column data are categorized into respective buckets using this method. Any input data lower than the minimum value will be assigned to the first bucket, and any input data exceeding the maximum value will be assigned to the last bucket. For example, with range=`[10,30]` and bucket_cnt=`2`, input data `1` will fall into the bucket `[10, 20]`, input data `11` will be mapped to `[10, 20]`, input data `21` will be mapped to `[20, 30]`, input data `31` will be mapped to `[20, 30]`. Finally we use one-hot-encoding to encode the feature for each numerical bucket. If a user wants to make numeric values fall into more than one bucket, it is preferred to use the `slide_window_size`: `"slide_window_size": s` , where `s` is a number. Then each value `v` will be transformed into a range from `v - s/2` through `v + s/2` , and assigns the value `v` to every bucket that the range covers. +* **No-op vector truncation** truncates feature vectors to the length requested. The ``name`` field can be empty, + and an integer ``truncate_dim`` value will determine the length of the output vector. + This can be useful when experimenting with input features that were trained using Matryoshka Representation Learning. .. _output-format: diff --git a/docs/source/graph-construction/gs-processing/input-configuration.rst b/docs/source/graph-construction/gs-processing/input-configuration.rst index bc0a85a19f..f84cb9520f 100644 --- a/docs/source/graph-construction/gs-processing/input-configuration.rst +++ b/docs/source/graph-construction/gs-processing/input-configuration.rst @@ -23,7 +23,7 @@ The GSProcessing input data configuration has two top-level objects: .. code-block:: json { - "version": "gsprocessing-v1.0", + "version": "gsprocessing-v0.3.1", "graph": {} } @@ -380,6 +380,12 @@ arguments. split the values in the column and create a vector column output. Example: for a separator ``'|'`` the CSV value ``1|2|3`` would be transformed to a vector, ``[1, 2, 3]``. + - ``truncate_dim`` (Integer, Optional): Relevant for vector inputs. + Allows you to truncate the input vector to the first ``truncate_dim`` + values, which can be useful when your inputs are `Matryoshka representation + learning embeddings `_. + - ``out_dtype`` (String, Optional): Specify the data type of the transformed feature. + Currently we only support ``float32`` and ``float64`` . - ``numerical`` - Transforms a numerical column using a missing data imputer and an @@ -400,7 +406,7 @@ arguments. - ``rank-gauss``: Normalize each value using Rank-Gauss normalization. Rank-gauss first ranks all values, converts the ranks to the -1/1 range, and applies the `inverse of the error function `_ to make the values conform to a Gaussian distribution shape. This transformation only supports a single column as input. - - ``out_dtype`` (Optional): Specify the data type of the transformed feature. + - ``out_dtype`` (String, Optional): Specify the data type of the transformed feature. Currently we only support ``float32`` and ``float64`` . - ``epsilon``: Only relevant for ``rank-gauss``, this epsilon value is added to the denominator to avoid infinite values during normalization. diff --git a/graphstorm-processing/graphstorm_processing/config/feature_config_base.py b/graphstorm-processing/graphstorm_processing/config/feature_config_base.py index f169d2b718..177a55e610 100644 --- a/graphstorm-processing/graphstorm_processing/config/feature_config_base.py +++ b/graphstorm-processing/graphstorm_processing/config/feature_config_base.py @@ -15,7 +15,7 @@ """ import abc -from typing import Any, Mapping, Sequence +from typing import Any, Mapping, Optional, Sequence from graphstorm_processing.constants import VALID_OUTDTYPE, TYPE_FLOAT32 @@ -89,18 +89,24 @@ class NoopFeatureConfig(FeatureConfig): Supported kwargs ---------------- + out_dtype: str + Output feature dtype. Currently, we support ``float32`` and ``float64``. + Default is ``float32`` separator: str When provided will treat the input as strings, split each value in the string using the separator, and convert the resulting list of floats into a float-vector feature. + truncate_dim: int + When provided, will truncate the output float-vector feature to the specified dimension. + This is useful when the feature is a multi-dimensional vector and we only need + a subset of the dimensions, e.g. for Matryoshka Representation Learning embeddings. """ def __init__(self, config: Mapping): super().__init__(config) - self.value_separator = None - self.out_dtype = self._transformation_kwargs.get("out_dtype", TYPE_FLOAT32) - if self._transformation_kwargs: - self.value_separator = self._transformation_kwargs.get("separator") + self.out_dtype: str = self._transformation_kwargs.get("out_dtype", TYPE_FLOAT32) + self.value_separator: Optional[str] = self._transformation_kwargs.get("separator", None) + self.truncate_dim: Optional[int] = self._transformation_kwargs.get("truncate_dim", None) self._sanity_check() @@ -111,3 +117,6 @@ def _sanity_check(self) -> None: assert ( self.out_dtype in VALID_OUTDTYPE ), f"Unsupported output dtype, expected one of {VALID_OUTDTYPE}, got {self.out_dtype}" + assert self.truncate_dim is None or isinstance( + self.truncate_dim, int + ), f"truncate_dim should be an int or None, got {type(self.truncate_dim)}" diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_noop_transformation.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_noop_transformation.py index 563f75a5bd..afc279e6c5 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_noop_transformation.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_noop_transformation.py @@ -14,7 +14,9 @@ limitations under the License. """ -from typing import List, Optional +import warnings +from typing import Optional + from pyspark.sql import DataFrame from pyspark.sql import functions as F @@ -35,16 +37,24 @@ class NoopTransformation(DistributedTransformation): Parameters ---------- - cols : List[str] + cols : list[str] The list of columns to parse as floats or lists of float separator : Optional[str], optional Optional separator to use to split the string, by default None out_dtype: str The output feature dtype + truncate_dim: int + When provided, will truncate the output float-vector feature to the specified dimension. + This is useful when the feature is a multi-dimensional vector and we only need + a subset of the dimensions, e.g. for Matryoshka Representation Learning embeddings. """ def __init__( - self, cols: List[str], out_dtype: str = TYPE_FLOAT32, separator: Optional[str] = None + self, + cols: list[str], + out_dtype: str = TYPE_FLOAT32, + separator: Optional[str] = None, + truncate_dim: Optional[int] = None, ) -> None: super().__init__(cols) # TODO: Support multiple cols? @@ -55,6 +65,18 @@ def __init__( # escape special chars to be used as separators if self.separator in SPECIAL_CHARACTERS: self.separator = f"\\{self.separator}" + self.truncate_dim = truncate_dim + + def _truncate_vector_df(self, input_df: DataFrame) -> DataFrame: + """Truncates every vector in the input DF to the specified dimension.""" + assert self.truncate_dim is not None + return input_df.select( + [ + # SQL array indexes start at 1 + F.slice(F.col(column), 1, self.truncate_dim).alias(column) + for column in self.cols + ] + ) def apply(self, input_df: DataFrame) -> DataFrame: """ @@ -72,13 +94,17 @@ def apply(self, input_df: DataFrame) -> DataFrame: f"Unsupported array type {col_datatype.elementType} " f"for column {self.cols[0]}" ) - return input_df + if self.truncate_dim: + return self._truncate_vector_df(input_df) + else: + return input_df elif isinstance(col_datatype, NumericType): + if self.truncate_dim is not None: + warnings.warn(f"Trying use {self.truncate_dim=} on a DataFrame of scalars!") return input_df # Otherwise we'll try to convert the values from list of strings to list of Doubles - - def str_list_to_float_vec(string_list: Optional[List[str]]) -> Optional[List[float]]: + def str_list_to_float_vec(string_list: Optional[list[str]]) -> Optional[list[float]]: if string_list: return [float(x) for x in string_list] return None @@ -89,21 +115,25 @@ def str_list_to_float_vec(string_list: Optional[List[str]]) -> Optional[List[flo if self.separator: # Split up string into vector of floats - input_df = input_df.select( + vector_df = input_df.select( [ strvec_to_float_vec_udf(F.split(F.col(column), self.separator)).alias(column) for column in self.cols ] ) - return input_df else: - return input_df.select( + vector_df = input_df.select( [ F.col(column).cast(DTYPE_MAP[self.out_dtype]).alias(column) for column in self.cols ] ) + if self.truncate_dim: + return self._truncate_vector_df(vector_df) + else: + return vector_df + @staticmethod def get_transformation_name() -> str: """ diff --git a/graphstorm-processing/graphstorm_processing/distributed_executor.py b/graphstorm-processing/graphstorm_processing/distributed_executor.py index 8f3236f54a..0b2e6e5b21 100644 --- a/graphstorm-processing/graphstorm_processing/distributed_executor.py +++ b/graphstorm-processing/graphstorm_processing/distributed_executor.py @@ -209,11 +209,11 @@ def __init__( self.precomputed_transformations = {} if "version" in dataset_config_dict: - config_version = dataset_config_dict["version"] - if config_version == "gsprocessing-v1.0": + config_version: str = dataset_config_dict["version"] + if config_version.startswith("gsprocessing"): logging.info("Parsing config file as GSProcessing config") self.gsp_config_dict = dataset_config_dict["graph"] - elif config_version == "gconstruct-v1.0": + elif config_version.startswith("gconstruct"): logging.info("Parsing config file as GConstruct config") converter = GConstructConfigConverter() self.gsp_config_dict = converter.convert_to_gsprocessing(dataset_config_dict)[ diff --git a/graphstorm-processing/tests/resources/small_heterogeneous_graph/gsprocessing-config.json b/graphstorm-processing/tests/resources/small_heterogeneous_graph/gsprocessing-config.json index 56212b540b..d06cfd283e 100644 --- a/graphstorm-processing/tests/resources/small_heterogeneous_graph/gsprocessing-config.json +++ b/graphstorm-processing/tests/resources/small_heterogeneous_graph/gsprocessing-config.json @@ -57,6 +57,17 @@ } } }, + { + "column": "multi", + "name": "no-op-truncated", + "transformation": { + "name": "no-op", + "kwargs": { + "separator": "|", + "truncate_dim": 1 + } + } + }, { "column": "occupation", "transformation": { diff --git a/graphstorm-processing/tests/test_dist_heterogenous_loader.py b/graphstorm-processing/tests/test_dist_heterogenous_loader.py index f63756e790..4bb9d36b93 100644 --- a/graphstorm-processing/tests/test_dist_heterogenous_loader.py +++ b/graphstorm-processing/tests/test_dist_heterogenous_loader.py @@ -63,6 +63,7 @@ "input_ids": 16, "token_type_ids": 16, "multi": 2, + "no-op-truncated": 1, "state": 3, } }, @@ -296,6 +297,7 @@ def test_load_dist_heterogen_node_class(dghl_loader: DistHeterogeneousGraphLoade "test_mask", "age", "multi", + "no-op-truncated", "state", "input_ids", "attention_mask", diff --git a/graphstorm-processing/tests/test_dist_noop_transformation.py b/graphstorm-processing/tests/test_dist_noop_transformation.py index 1f55417626..eeb3999998 100644 --- a/graphstorm-processing/tests/test_dist_noop_transformation.py +++ b/graphstorm-processing/tests/test_dist_noop_transformation.py @@ -77,6 +77,30 @@ def test_noop_floatvector_transformation(spark: SparkSession, check_df_schema): assert_array_equal(expected_values, transformed_values) +def test_noop_floatvector_truncation(spark: SparkSession, check_df_schema): + """No-op transformation for numerical vector columns with truncation""" + data = [([[10, 20]]), ([[30, 40]]), ([[50, 60]]), ([[70, 80]]), ([[90, 100]])] + + col_name = "feat" + schema = StructType([StructField("feat", ArrayType(IntegerType(), True), True)]) + vec_df = spark.createDataFrame(data, schema=schema) + + noop_transfomer = NoopTransformation( + [col_name], + truncate_dim=1, + ) + + transformed_df = noop_transfomer.apply(vec_df) + + expected_values = [[10], [30], [50], [70], [90]] + + check_df_schema(transformed_df) + + transformed_values = [row[col_name] for row in transformed_df.collect()] + + assert_array_equal(expected_values, transformed_values) + + def test_noop_largegint_transformation(spark: SparkSession, check_df_schema): """No-op transformation for long numerical columns""" large_int = 4 * 10**18 diff --git a/python/graphstorm/gconstruct/transform.py b/python/graphstorm/gconstruct/transform.py index 60b2b3f8df..40e1da2fa3 100644 --- a/python/graphstorm/gconstruct/transform.py +++ b/python/graphstorm/gconstruct/transform.py @@ -916,13 +916,18 @@ class Noop(FeatTransform): The name of the column that contains the feature. feat_name : str The feature name used in the constructed graph. - out_dtype: + out_dtype : str The dtype of the transformed feature. Default: None, we will not do data type casting. + truncate_dim : int + When provided, will truncate the output float-vector feature to the specified dimension. + This is useful when the feature is a multi-dimensional vector and we only need + a subset of the dimensions, e.g. for Matryoshka Representation Learning embeddings. """ - def __init__(self, col_name, feat_name, out_dtype=None): + def __init__(self, col_name, feat_name, out_dtype=None, truncate_dim=None): out_dtype = np.float32 if out_dtype is None else out_dtype super(Noop, self).__init__(col_name, feat_name, out_dtype) + self.truncate_dim = truncate_dim def call(self, feats): """ This transforms the features. @@ -941,6 +946,13 @@ def call(self, feats): assert np.issubdtype(feats.dtype, np.integer) \ or np.issubdtype(feats.dtype, np.floating), \ f"The feature {self.feat_name} has to be integers or floats." + if self.truncate_dim is not None: + if isinstance(feats, np.ndarray): + feats = feats[:, :self.truncate_dim] + else: + assert isinstance(feats, ExtMemArrayWrapper) + # Need to convert to in-memory array to make truncation possible + feats = feats.to_numpy()[:, :self.truncate_dim] return {self.feat_name: feats} class HardEdgeNegativeTransform(TwoPhaseFeatTransform): @@ -1148,7 +1160,12 @@ def parse_feat_ops(confs, input_data_format=None): out_dtype = _get_output_dtype(feat['out_dtype']) if 'out_dtype' in feat else None if 'transform' not in feat: - transform = Noop(feat['feature_col'], feat_name, out_dtype=out_dtype) + transform = Noop( + feat['feature_col'], + feat_name, + out_dtype=out_dtype, + truncate_dim=feat.get('truncate_dim', None) + ) else: conf = feat['transform'] assert 'name' in conf, "'name' must be defined in the transformation field." diff --git a/tests/unit-tests/gconstruct/test_transform.py b/tests/unit-tests/gconstruct/test_transform.py index 87d5988f87..ae5f353288 100644 --- a/tests/unit-tests/gconstruct/test_transform.py +++ b/tests/unit-tests/gconstruct/test_transform.py @@ -531,6 +531,14 @@ def test_noop_transform(out_dtype): else: assert norm_feats["test"].dtype == np.float32 +def test_noop_truncate(): + transform = Noop("test", "test", truncate_dim=16) + feats = np.random.randn(100, 32).astype(np.float32) + trunc_feats = transform(feats) + + assert trunc_feats["test"].shape[1] == 16 + + @pytest.mark.parametrize("input_dtype", [np.cfloat, np.float32]) @pytest.mark.parametrize("out_dtype", [None, np.float16]) def test_rank_gauss_transform(input_dtype, out_dtype): @@ -1157,6 +1165,7 @@ def test_hard_edge_dst_negative_transform(id_dtype): test_noop_transform(None) test_noop_transform(np.float16) test_noop_transform(np.float64) + test_noop_truncate() test_bucket_transform(None) test_bucket_transform(np.float16)