diff --git a/docs/source/advanced/link-prediction.rst b/docs/source/advanced/link-prediction.rst index 9f910c010e..f28f0ac542 100644 --- a/docs/source/advanced/link-prediction.rst +++ b/docs/source/advanced/link-prediction.rst @@ -236,6 +236,8 @@ impact is negligible. With DGL 1.0.4, ``fast_localuniform`` dataloader can speedup 2.4X over ``localuniform`` dataloader on training a 2 layer RGCN on MAG dataset on four g5.48x instances. +.. _hard_negative_sampling: + Hard Negative sampling ----------------------- GraphStorm provides support for users to define hard negative edges for a positive edge during Link Prediction training. @@ -272,10 +274,10 @@ In general, GraphStorm covers following cases: Preparing graph data for hard negative sampling ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -The gconstruct pipeline of GraphStorm provides support to load hard negative data from raw input. +Both single machine and distributed graph construction pipeline of GraphStorm provide support to load hard negative data from raw input. Hard destination negatives can be defined through ``edge_dst_hard_negative`` transformation. The ``feature_col`` field of ``edge_dst_hard_negative`` must stores the raw node ids of hard destination nodes. -The follwing example shows how to define a hard negative feature for edges with the relation ``(node1, relation1, node1)``: +The following example shows how to define a hard negative feature for edges with the relation ``(node1, relation1, node1)``: .. code-block:: json diff --git a/docs/source/cli/graph-construction/distributed/gsprocessing/input-configuration.rst b/docs/source/cli/graph-construction/distributed/gsprocessing/input-configuration.rst index f44b22d47e..d2074a338f 100644 --- a/docs/source/cli/graph-construction/distributed/gsprocessing/input-configuration.rst +++ b/docs/source/cli/graph-construction/distributed/gsprocessing/input-configuration.rst @@ -491,6 +491,13 @@ arguments. You can use a length greater than the dataset's longest sentence; or for a safe value choose 128. Make sure to check the model's max supported length when setting this value. +- ``edge_dst_hard_negative`` + + - Encodes a hard negative edge feature for link prediction. For detail information for hard negative support, please refer to :ref:`hard_negative_sampling`. + - ``kwargs``: + - ``separator`` (String, optional): The separator is used to + split multiple values in an input string for data in CSV files e.g. ``p0;s1``. If it is not provided, then the whole value + will be treated as a single string. .. _gsprocessing-multitask-ref: diff --git a/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py b/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py index 5129254538..33fe40f760 100644 --- a/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py +++ b/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py @@ -188,6 +188,12 @@ def _convert_feature(feats: list[Mapping[str, Any]]) -> list[dict]: "hf_model": gconstruct_transform_dict["bert_model"], "max_seq_length": gconstruct_transform_dict["max_seq_length"], } + elif gconstruct_transform_dict["name"] == "edge_dst_hard_negative": + gsp_transformation_dict["name"] = "edge_dst_hard_negative" + if "separator" in gconstruct_transform_dict: + gsp_transformation_dict["kwargs"] = { + "separator": gconstruct_transform_dict["separator"], + } else: raise ValueError( "Unsupported GConstruct transformation name: " diff --git a/graphstorm-processing/graphstorm_processing/config/config_parser.py b/graphstorm-processing/graphstorm_processing/config/config_parser.py index 38e92528d8..95f4ab3dd2 100644 --- a/graphstorm-processing/graphstorm_processing/config/config_parser.py +++ b/graphstorm-processing/graphstorm_processing/config/config_parser.py @@ -29,6 +29,7 @@ ) from .categorical_configs import MultiCategoricalFeatureConfig from .hf_configs import HFConfig +from .hard_negative_configs import HardEdgeNegativeConfig from .data_config_base import DataStorageConfig @@ -71,6 +72,8 @@ def parse_feat_config(feature_dict: Dict) -> FeatureConfig: return MultiCategoricalFeatureConfig(feature_dict) elif transformation_name == "huggingface": return HFConfig(feature_dict) + elif transformation_name == "edge_dst_hard_negative": + return HardEdgeNegativeConfig(feature_dict) else: raise RuntimeError(f"Unknown transformation name: '{transformation_name}'") diff --git a/graphstorm-processing/graphstorm_processing/config/hard_negative_configs.py b/graphstorm-processing/graphstorm_processing/config/hard_negative_configs.py new file mode 100644 index 0000000000..26e99a00a6 --- /dev/null +++ b/graphstorm-processing/graphstorm_processing/config/hard_negative_configs.py @@ -0,0 +1,35 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"). +You may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Mapping + +from .feature_config_base import FeatureConfig + + +class HardEdgeNegativeConfig(FeatureConfig): + """Feature configuration for hard negative feature. Now only support link prediction. + + Supported kwargs + ---------------- + separator: str, optional + The separator for string input value. Only required when input value type is string. + """ + + def __init__(self, config: Mapping): + super().__init__(config) + self.separator = self._transformation_kwargs.get("separator", None) + + self._sanity_check() diff --git a/graphstorm-processing/graphstorm_processing/constants.py b/graphstorm-processing/graphstorm_processing/constants.py index cbb48f4c02..a732306ab8 100644 --- a/graphstorm-processing/graphstorm_processing/constants.py +++ b/graphstorm-processing/graphstorm_processing/constants.py @@ -58,6 +58,14 @@ HUGGINGFACE_TOKENIZE = "tokenize_hf" HUGGINGFACE_EMB = "embedding_hf" +################# Hard Negative transformations ################ +ORDER_INDEX = "hard_negative_order_id" +EXPLODE_HARD_NEGATIVE_VALUE = "hard_negative_exploded_single_value" + +################# Node Mapping ################ +NODE_MAPPING_STR = "orig" +NODE_MAPPING_INT = "new" + ################# Supported execution envs ############## class ExecutionEnv(Enum): diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_feature_transformer.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_feature_transformer.py index 8ea337bef7..36a6311867 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_feature_transformer.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_feature_transformer.py @@ -28,6 +28,7 @@ DistCategoryTransformation, DistMultiCategoryTransformation, DistHFTransformation, + DistHardEdgeNegativeTransformation, ) @@ -71,6 +72,10 @@ def __init__( self.transformation = DistMultiCategoryTransformation(**default_kwargs, **args_dict) elif feat_type == "huggingface": self.transformation = DistHFTransformation(**default_kwargs, **args_dict) + elif feat_type == "edge_dst_hard_negative": + self.transformation = DistHardEdgeNegativeTransformation( + **default_kwargs, **args_dict, spark=spark + ) else: raise NotImplementedError( f"Feature {feat_name} has type: {feat_type} that is not supported" diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/__init__.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/__init__.py index 4849c53acc..959124644b 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/__init__.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/__init__.py @@ -15,3 +15,4 @@ ) from .dist_bucket_numerical_transformation import DistBucketNumericalTransformation from .dist_hf_transformation import DistHFTransformation +from .dist_hard_negative_transformation import DistHardEdgeNegativeTransformation diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_hard_negative_transformation.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_hard_negative_transformation.py new file mode 100755 index 0000000000..35dd4005cc --- /dev/null +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_hard_negative_transformation.py @@ -0,0 +1,125 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"). +You may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Sequence +from pyspark.sql.functions import split, col +from pyspark.sql.types import ArrayType, IntegerType, StringType +from pyspark.sql import DataFrame, functions as F, SparkSession + +from graphstorm_processing.constants import ( + NODE_MAPPING_STR, + NODE_MAPPING_INT, + ORDER_INDEX, + EXPLODE_HARD_NEGATIVE_VALUE, +) + +from .base_dist_transformation import DistributedTransformation + + +class DistHardEdgeNegativeTransformation(DistributedTransformation): + """Transformation to apply hard negative transformation. + + Parameters + ---------- + cols : Sequence[str] + List of column names to apply hard negative transformation to. + spark: SparkSession + The spark session. + hard_node_mapping_dict: dict + The mapping dictionary contain mapping file directory and edge type. + { + "edge_type": str + Edge type to apply hard negative transformation. + "mapping_path": str + Path to the raw node mapping. + "format_name": str + Parquet. + } + separator: str, optional + The separator for string input value. Only required when input value type is string. + """ + + def __init__( + self, + cols: Sequence[str], + spark: SparkSession, + hard_node_mapping_dict: dict, + separator: str = "", + ) -> None: + super().__init__(cols, spark) + self.cols = cols + assert len(self.cols) == 1, "Hard Negative Transformation only supports single column" + self.separator = separator + self.hard_node_mapping_dict = hard_node_mapping_dict + assert self.hard_node_mapping_dict, "edge mapping dict cannot be None for hard negative " + + def apply(self, input_df: DataFrame) -> DataFrame: + assert self.spark + input_col = self.cols[0] + column_type = input_df.schema[input_col].dataType + if isinstance(column_type, StringType): + transformed_df = input_df.withColumn(input_col, split(col(input_col), self.separator)) + else: + transformed_df = input_df + # Edge type should be (src_ntype:relation_type:dst_ntype) + # Only support hard negative for destination nodes. Get the node type of destination nodes. + # TODO: support hard negative for source nodes. + _, _, dst_type = self.hard_node_mapping_dict["edge_type"].split(":") + mapping_prefix = self.hard_node_mapping_dict["mapping_path"] + format_name = self.hard_node_mapping_dict["format_name"] + hard_negative_node_mapping = self.spark.read.parquet( + f"{mapping_prefix}{dst_type}/{format_name}/" + ) + # The maximum number of negatives in the input feature column + max_size = ( + transformed_df.select(F.size(F.col(input_col)).alias(f"{input_col}_size")) + .agg(F.max(f"{input_col}_size")) + .collect()[0][0] + ) + + # TODO: Use panda series to possibly improve the efficiency + # Explode the original list and join node id mapping dataframe + transformed_df = transformed_df.withColumn(ORDER_INDEX, F.monotonically_increasing_id()) + # Could result in extremely large DFs in num_nodes * avg(len_of_negatives) rows + transformed_df = transformed_df.withColumn( + EXPLODE_HARD_NEGATIVE_VALUE, F.explode(F.col(input_col)) + ) + transformed_df = transformed_df.join( + hard_negative_node_mapping, + transformed_df[EXPLODE_HARD_NEGATIVE_VALUE] + == hard_negative_node_mapping[NODE_MAPPING_STR], + "inner", + ).select(NODE_MAPPING_INT, ORDER_INDEX) + transformed_df = transformed_df.groupBy(ORDER_INDEX).agg( + F.collect_list(NODE_MAPPING_INT).alias(input_col) + ) + + # Extend the feature to the same length as the maximum length of the feature column + def pad_mapped_values(hard_neg_list): + if len(hard_neg_list) < max_size: + hard_neg_list.extend([-1] * (max_size - len(hard_neg_list))) + return hard_neg_list + + pad_value_udf = F.udf(pad_mapped_values, ArrayType(IntegerType())) + # Make sure it keeps the original order + transformed_df = transformed_df.orderBy(ORDER_INDEX) + transformed_df = transformed_df.select(pad_value_udf(F.col(input_col)).alias(input_col)) + + return transformed_df + + @staticmethod + def get_transformation_name() -> str: + return "DistHardEdgeNegativeTransformation" diff --git a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py index af76ab40e8..2f6e2ebe83 100644 --- a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py +++ b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py @@ -1654,6 +1654,15 @@ def _process_edge_features( .get(edge_type, {}) .get(feat_conf.feat_name, {}) ) + # Hard Negative Transformation use case, but should be able to be reused + if feat_conf.feat_type == "edge_dst_hard_negative": + hard_node_mapping_dict = { + "edge_type": edge_type, + "mapping_path": f"{self.output_prefix}/raw_id_mappings/", + "format_name": FORMAT_NAME, + } + feat_conf.transformation_kwargs["hard_node_mapping_dict"] = hard_node_mapping_dict + transformer = DistFeatureTransformer(feat_conf, self.spark, json_representation) if json_representation: diff --git a/graphstorm-processing/tests/test_converter.py b/graphstorm-processing/tests/test_converter.py index 334ef72284..a4871342d2 100644 --- a/graphstorm-processing/tests/test_converter.py +++ b/graphstorm-processing/tests/test_converter.py @@ -401,7 +401,14 @@ def test_convert_gsprocessing(converter: GConstructConfigConverter): "files": ["/tmp/acm_raw/edges/author_writing_paper.parquet"], "source_id_col": "~from", "dest_id_col": "~to", - "features": [{"feature_col": ["author"], "feature_name": "feat"}], + "features": [ + {"feature_col": ["author"], "feature_name": "feat"}, + { + "feature_col": ["author"], + "feature_name": "hard_negative", + "transform": {"name": "edge_dst_hard_negative", "separator": ";"}, + }, + ], "labels": [ { "label_col": "edge_col", @@ -505,7 +512,12 @@ def test_convert_gsprocessing(converter: GConstructConfigConverter): assert edges_output["dest"] == {"column": "~to", "type": "paper"} assert edges_output["relation"] == {"type": "writing"} assert edges_output["features"] == [ - {"column": "author", "transformation": {"name": "no-op"}, "name": "feat"} + {"column": "author", "transformation": {"name": "no-op"}, "name": "feat"}, + { + "column": "author", + "name": "hard_negative", + "transformation": {"name": "edge_dst_hard_negative", "kwargs": {"separator": ";"}}, + }, ] assert edges_output["labels"] == [ { diff --git a/graphstorm-processing/tests/test_dist_hard_negative_transformation.py b/graphstorm-processing/tests/test_dist_hard_negative_transformation.py new file mode 100755 index 0000000000..8932bd1dbe --- /dev/null +++ b/graphstorm-processing/tests/test_dist_hard_negative_transformation.py @@ -0,0 +1,108 @@ +""" +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"). +You may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import pytest +from pyspark.sql import DataFrame, SparkSession +import numpy as np +from numpy.testing import assert_array_equal + +from graphstorm_processing.constants import NODE_MAPPING_STR, NODE_MAPPING_INT +from graphstorm_processing.data_transformations.dist_transformations import ( + DistHardEdgeNegativeTransformation, +) + + +def test_hard_negative_example_list(spark: SparkSession, check_df_schema, tmp_path): + # Input Data DataFrame + data = [ + ("mark", "doctor", ["scientist"]), + ("john", "scientist", ["engineer", "nurse"]), + ("tara", "engineer", ["nurse", "doctor", "scientist"]), + ("jen", "nurse", ["doctor"]), + ] + columns = ["src_type", "dst_type", "hard_negative"] + input_df = spark.createDataFrame(data, schema=columns) + + # Mapping DataFrame + mapping_data = [ + ("doctor", 0), + ("scientist", 1), + ("engineer", 2), + ("nurse", 3), + ] + mapping_column = [NODE_MAPPING_STR, NODE_MAPPING_INT] + mapping_df = spark.createDataFrame(mapping_data, schema=mapping_column) + mapping_df.repartition(1).write.parquet(f"{tmp_path}/raw_id_mappings/dst_type/parquet") + hard_node_mapping_dict = { + "edge_type": "src_type:relation:dst_type", + "mapping_path": f"{tmp_path}/raw_id_mappings/", + "format_name": "parquet", + } + hard_negative_transformation = DistHardEdgeNegativeTransformation( + ["hard_negative"], + spark=spark, + hard_node_mapping_dict=hard_node_mapping_dict, + separator=None, + ) + output_df = hard_negative_transformation.apply(input_df) + check_df_schema(output_df) + output_data = output_df.collect() + + # All the length should be the same as the maximum array. + expected_output = [[1, -1, -1], [2, 3, -1], [3, 0, 1], [0, -1, -1]] + + for idx, row in enumerate(output_data): + np.testing.assert_equal(row[0], expected_output[idx], err_msg=f"Row {idx} is not equal") + + +def test_hard_negative_example_str(spark: SparkSession, check_df_schema, tmp_path): + # Input Data DataFrame + data = [ + ("mark", "doctor", "scientist"), + ("john", "scientist", "engineer;nurse"), + ("tara", "engineer", "nurse;doctor;scientist"), + ("jen", "nurse", "doctor"), + ] + columns = ["src_type", "dst_type", "hard_negative"] + input_df = spark.createDataFrame(data, schema=columns) + + # Mapping DataFrame + mapping_data = [ + ("doctor", 0), + ("scientist", 1), + ("engineer", 2), + ("nurse", 3), + ] + mapping_column = [NODE_MAPPING_STR, NODE_MAPPING_INT] + mapping_df = spark.createDataFrame(mapping_data, schema=mapping_column) + mapping_df.repartition(1).write.parquet(f"{tmp_path}/raw_id_mappings/dst_type/parquet") + hard_node_mapping_dict = { + "edge_type": "src_type:relation:dst_type", + "mapping_path": f"{tmp_path}/raw_id_mappings/", + "format_name": "parquet", + } + hard_negative_transformation = DistHardEdgeNegativeTransformation( + ["hard_negative"], spark=spark, hard_node_mapping_dict=hard_node_mapping_dict, separator=";" + ) + output_df = hard_negative_transformation.apply(input_df) + check_df_schema(output_df) + output_data = output_df.collect() + + # All the length should be the same as the maximum array. + expected_output = [[1, -1, -1], [2, 3, -1], [3, 0, 1], [0, -1, -1]] + + for idx, row in enumerate(output_data): + np.testing.assert_equal(row[0], expected_output[idx], err_msg=f"Row {idx} is not equal") diff --git a/python/graphstorm/gconstruct/utils.py b/python/graphstorm/gconstruct/utils.py index e921293c2d..d2077f632e 100644 --- a/python/graphstorm/gconstruct/utils.py +++ b/python/graphstorm/gconstruct/utils.py @@ -1118,6 +1118,43 @@ def get_hard_edge_negs_feats(hard_edge_neg_ops): return hard_edge_neg_feats +def get_gnid2pnid_map(ntype: str, node_mapping: dict, gnid2pnid_mapping: dict): + """ Get global nid to partitioned nid mapping. + + Parameters + ---------- + ntype: str + Node type. + node_mapping: dict + Dict of mapping. + { + ntype: 1D tensor representing the mapping from + partition node IDs (pnid) to global node IDs (gnid). + Each index corresponds to a partition node IDs, and + the value at that index is the global node IDs. + } + gnid2pnid_mapping: dict + Dict of mapping. Here are the mapping represented: + { + ntype: 1D tensor representing the mapping from + global node IDs (gnid) to partition node IDs (pnid). + Each index corresponds to a global node ID, and + the value at that index is the partition node ID. + } + + Returns + 1-D node Mapping Tensor for target node type. + """ + if ntype in gnid2pnid_mapping: + return gnid2pnid_mapping[ntype] + else: + pnid2gnid_map = node_mapping[ntype] + gnid2pnid_map = th.argsort(pnid2gnid_map) + gnid2pnid_mapping[ntype] = gnid2pnid_map + # del ntype in node_mapping to save memory + del node_mapping[ntype] + return gnid2pnid_mapping[ntype] + def shuffle_hard_nids(data_path, num_parts, hard_edge_neg_feats): """ Shuffle node ids of hard negatives from Graph node id space to Partition Node id space. @@ -1136,17 +1173,6 @@ def shuffle_hard_nids(data_path, num_parts, hard_edge_neg_feats): node_mapping = load_maps(data_path, "node_mapping") gnid2pnid_mapping = {} - def get_gnid2pnid_map(ntype): - if ntype in gnid2pnid_mapping: - return gnid2pnid_mapping[ntype] - else: - pnid2gnid_map = node_mapping[ntype] - gnid2pnid_map = th.argsort(pnid2gnid_map) - gnid2pnid_mapping[ntype] = gnid2pnid_map - # del ntype in node_mapping to save memory - del node_mapping[ntype] - return gnid2pnid_mapping[ntype] - # iterate all the partitions to convert hard negative node ids. for i in range(num_parts): part_path = os.path.join(data_path, f"part{i}") @@ -1162,7 +1188,8 @@ def get_gnid2pnid_map(ntype): efeat_name = f"{etype}/{neg_feat}" hard_nids = edge_feats[efeat_name] hard_nid_idx = hard_nids > -1 - gnid2pnid_map = get_gnid2pnid_map(neg_ntype) + gnid2pnid_map = get_gnid2pnid_map(neg_ntype, node_mapping, + gnid2pnid_mapping) hard_nids[hard_nid_idx] = gnid2pnid_map[hard_nids[hard_nid_idx]] # replace the edge_feat.dgl with the updated one. diff --git a/python/graphstorm/gpartition/__init__.py b/python/graphstorm/gpartition/__init__.py index c7957002c2..b66664f68f 100644 --- a/python/graphstorm/gpartition/__init__.py +++ b/python/graphstorm/gpartition/__init__.py @@ -19,3 +19,4 @@ from .metis_partition import (ParMetisPartitionAlgorithm) from .partition_config import (ParMETISConfig) from .partition_algo_base import LocalPartitionAlgorithm +from .post_hard_negative import shuffle_hard_negative_nids diff --git a/python/graphstorm/gpartition/dist_partition_graph.py b/python/graphstorm/gpartition/dist_partition_graph.py index da50ce8ca6..fffddbe738 100644 --- a/python/graphstorm/gpartition/dist_partition_graph.py +++ b/python/graphstorm/gpartition/dist_partition_graph.py @@ -38,6 +38,7 @@ ParMetisPartitionAlgorithm, ParMETISConfig, RandomPartitionAlgorithm, + shuffle_hard_negative_nids, ) from graphstorm.utils import get_log_level @@ -189,6 +190,22 @@ def main(): dirs_exist_ok=True, ) + # Hard Negative Mapping + # Load GSProcessing config from launch_arguments generated by GSProcessing + # Generated GSProcessing config will have _with_transformation suffix. + launch_arguments_path = os.path.join(args.input_path, "launch_arguments.json") + if os.path.exists(launch_arguments_path): + with open(launch_arguments_path, "r", encoding="utf-8") as f: + gsprocessing_launch_arguments: Dict = json.load(f) + gsprocessing_config = gsprocessing_launch_arguments["config_filename"] + gsprocessing_config = gsprocessing_config.replace(".json", "_with_transformations.json") + shuffle_hard_negative_nids(f"{args.input_path}/{gsprocessing_config}", + args.num_parts, args.output_path) + else: + logging.info("Skip the hard negative node ID mapping, " + "the processed data is not generated by GSProcessing.") + + def parse_args() -> argparse.Namespace: """Parses arguments for the script""" argparser = argparse.ArgumentParser("Partition DGL graphs for node and edge classification " diff --git a/python/graphstorm/gpartition/post_hard_negative.py b/python/graphstorm/gpartition/post_hard_negative.py new file mode 100644 index 0000000000..0c86d85fc5 --- /dev/null +++ b/python/graphstorm/gpartition/post_hard_negative.py @@ -0,0 +1,110 @@ +""" + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" + +import json +import os + +from dgl.data.utils import load_tensors, save_tensors +from graphstorm.model.utils import load_dist_nid_map +from graphstorm.gconstruct.utils import get_gnid2pnid_map + + +def load_hard_negative_config(gsprocessing_config: str): + """Load GSProcessing Config to extract hard negative config + + Parameters + ---------------- + gsprocessing_config: str + Path to the gsprocessing config. + + Returns + ------- + list of dicts + A list of dict for each hard negative feature transformation. + Each dict will look like: + { + "dst_node_type": destination node type for hard negative, + "edge_type": edge_type, + "hard_neg_feat_name": feature name + } + """ + with open(gsprocessing_config, 'r', encoding='utf-8') as file: + config = json.load(file) + + # Hard Negative only supports link prediction + edges_config = config['graph']['edges'] + hard_neg_list = [] + for single_edge_config in edges_config: + if "features" not in single_edge_config: + continue + feature_dict = single_edge_config["features"] + for single_feature in feature_dict: + if single_feature["transformation"]["name"] \ + == "edge_dst_hard_negative": + edge_type = ":".join([single_edge_config["source"]["type"], + single_edge_config["relation"]["type"], + single_edge_config["dest"]["type"]]) + hard_neg_feat_name = single_feature['name'] + hard_neg_list.append({"dst_node_type": single_edge_config["dest"]["type"], + "edge_type": edge_type, + "hard_neg_feat_name": hard_neg_feat_name}) + return hard_neg_list + + +def shuffle_hard_negative_nids(gsprocessing_config: str, + num_parts: int, graph_path: str): + """Shuffle hard negative edge feature ids with int-to-int node id mapping. + The function here align with the shuffle_hard_nids in graphstorm.gconstruct.utils. + + Parameters + ---------------- + gsprocessing_config: str + Path to the gsprocessing config. + num_parts: int + Number of parts. + graph_path: str + Path to the output DGL graph. + """ + shuffled_edge_config = load_hard_negative_config(gsprocessing_config) + + node_type_list = [] + for single_shuffled_edge_config in shuffled_edge_config: + node_type = single_shuffled_edge_config["dst_node_type"] + node_type_list.append(node_type) + node_mapping = load_dist_nid_map(f"{graph_path}/dist_graph", node_type_list) + gnid2pnid_mapping = {} + + # iterate all the partitions to convert hard negative node ids. + for i in range(num_parts): + part_path = os.path.join(f"{graph_path}/dist_graph", f"part{i}") + edge_feat_path = os.path.join(part_path, "edge_feat.dgl") + + # load edge features first + edge_feats = load_tensors(edge_feat_path) + for single_shuffled_edge_config in shuffled_edge_config: + etype = single_shuffled_edge_config["edge_type"] + neg_feat = single_shuffled_edge_config["hard_neg_feat_name"] + neg_ntype = single_shuffled_edge_config["dst_node_type"] + efeat_name = f"{etype}/{neg_feat}" + hard_nids = edge_feats[efeat_name].long() + hard_nid_idx = hard_nids > -1 + gnid2pnid_map = get_gnid2pnid_map(neg_ntype, node_mapping, + gnid2pnid_mapping) + hard_nids[hard_nid_idx] = gnid2pnid_map[hard_nids[hard_nid_idx]] + + # replace the edge_feat.dgl with the updated one. + os.remove(edge_feat_path) + save_tensors(edge_feat_path, edge_feats) diff --git a/python/graphstorm/model/utils.py b/python/graphstorm/model/utils.py index 1b6aa3731d..a37a66948c 100644 --- a/python/graphstorm/model/utils.py +++ b/python/graphstorm/model/utils.py @@ -405,8 +405,20 @@ def _exchange_node_id_mapping(rank, world_size, device, # move mapping into CPU return gather_list[0].to(th.device("cpu")) -def _load_dist_nid_map(node_id_mapping_file, ntypes): +def load_dist_nid_map(node_id_mapping_file, ntypes): """ Load id mapping files in dist partition format. + + Parameters + ---------- + node_id_mapping_file: str + Node mapping directory. + ntypes: list[str] + List of node types. + + Return + ------ + id_mappings: dict + Node mapping dictionary. """ # node_id_mapping_file it is actually a directory # /part0, /part1, ... @@ -459,7 +471,7 @@ def distribute_nid_map(embeddings, rank, world_size, else: # Homogeneous graph # node id mapping file from dgl tools/distpartitioning/convert_partition.py. - ori_node_id_mapping = _load_dist_nid_map(node_id_mapping_file, ["_N"])["_N"] + ori_node_id_mapping = load_dist_nid_map(node_id_mapping_file, ["_N"])["_N"] _, node_id_mapping = th.sort(ori_node_id_mapping) else: node_id_mapping = None @@ -474,7 +486,7 @@ def distribute_nid_map(embeddings, rank, world_size, node_id_mappings = th.load(node_id_mapping_file) else: # node id mapping file from dgl tools/distpartitioning/convert_partition.py. - node_id_mappings = _load_dist_nid_map(node_id_mapping_file, + node_id_mappings = load_dist_nid_map(node_id_mapping_file, list(embeddings.keys())) else: node_id_mappings = None @@ -1184,7 +1196,7 @@ def __init__(self, g, node_id_mapping_file, ntypes=None): id_mappings = th.load(node_id_mapping_file) if get_rank() == 0 else None else: # node id mapping file from dgl tools/distpartitioning/convert_partition.py. - id_mappings = _load_dist_nid_map(node_id_mapping_file, ntypes) \ + id_mappings = load_dist_nid_map(node_id_mapping_file, ntypes) \ if get_rank() == 0 else None self._id_mapping_info = { diff --git a/tests/unit-tests/gpartition/config/gsprocessing_hard_negative_config.json b/tests/unit-tests/gpartition/config/gsprocessing_hard_negative_config.json new file mode 100644 index 0000000000..0d56a2cf3a --- /dev/null +++ b/tests/unit-tests/gpartition/config/gsprocessing_hard_negative_config.json @@ -0,0 +1,79 @@ +{ + "graph": { + "nodes": [ + { + "data": { + "format": "parquet", + "files": [ + "./nodes/author.parquet" + ] + }, + "type": "author", + "column": "node_id" + }, + { + "data": { + "format": "parquet", + "files": [ + "./nodes/paper.parquet" + ] + }, + "type": "paper", + "column": "node_id" + } + ], + "edges": [ + { + "data": { + "format": "parquet", + "files": [ + "./edges/author_writing_paper_hard_negative.parquet" + ] + }, + "source": { + "column": "source_id", + "type": "author" + }, + "dest": { + "column": "dest_id", + "type": "paper" + }, + "relation": { + "type": "writing" + }, + "features": [ + { + "column": "hard_neg", + "name": "hard_neg_feat", + "transformation": { + "name": "edge_dst_hard_negative", + "kwargs": { + "separator": ";" + } + } + } + ] + }, + { + "data": { + "format": "parquet", + "files": [ + "./edges/paper_citing_paper.parquet" + ] + }, + "source": { + "column": "source_id", + "type": "paper" + }, + "dest": { + "column": "dest_id", + "type": "paper" + }, + "relation": { + "type": "citing" + } + } + ] + }, + "version": "gsprocessing-v1.0" +} \ No newline at end of file diff --git a/tests/unit-tests/gpartition/config/gsprocessing_non_hard_negative_config.json b/tests/unit-tests/gpartition/config/gsprocessing_non_hard_negative_config.json new file mode 100644 index 0000000000..daf5122113 --- /dev/null +++ b/tests/unit-tests/gpartition/config/gsprocessing_non_hard_negative_config.json @@ -0,0 +1,67 @@ +{ + "graph": { + "nodes": [ + { + "data": { + "format": "parquet", + "files": [ + "./nodes/author.parquet" + ] + }, + "type": "author", + "column": "node_id" + }, + { + "data": { + "format": "parquet", + "files": [ + "./nodes/paper.parquet" + ] + }, + "type": "paper", + "column": "node_id" + } + ], + "edges": [ + { + "data": { + "format": "parquet", + "files": [ + "./edges/author_writing_paper_hard_negative.parquet" + ] + }, + "source": { + "column": "source_id", + "type": "author" + }, + "dest": { + "column": "dest_id", + "type": "paper" + }, + "relation": { + "type": "writing" + } + }, + { + "data": { + "format": "parquet", + "files": [ + "./edges/paper_citing_paper.parquet" + ] + }, + "source": { + "column": "source_id", + "type": "paper" + }, + "dest": { + "column": "dest_id", + "type": "paper" + }, + "relation": { + "type": "citing" + } + } + ] + }, + "version": "gsprocessing-v1.0" +} \ No newline at end of file diff --git a/tests/unit-tests/gpartition/test_hard_negative_post_partition.py b/tests/unit-tests/gpartition/test_hard_negative_post_partition.py new file mode 100644 index 0000000000..368562eb3c --- /dev/null +++ b/tests/unit-tests/gpartition/test_hard_negative_post_partition.py @@ -0,0 +1,123 @@ +""" + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +""" +import os +import json +import torch as th +import numpy as np +from typing import Dict + +import pytest + +from numpy.testing import assert_almost_equal +from graphstorm.model.utils import load_dist_nid_map +from dgl.data.utils import load_tensors, save_tensors +from graphstorm.gpartition.post_hard_negative import (shuffle_hard_negative_nids, + load_hard_negative_config) + +_ROOT = os.path.abspath(os.path.dirname(__file__)) + + +@pytest.fixture +def setup_graph_partition(tmp_path): + partitioned_graph = f"{tmp_path}/partitioned_graph" + + # Generate ID mapping for each partition + nid_map_dict_path0 = os.path.join(partitioned_graph, "dist_graph", "part0", "orig_nids.dgl") + nid_map_dict_path1 = os.path.join(partitioned_graph, "dist_graph", "part1", "orig_nids.dgl") + os.makedirs(os.path.dirname(nid_map_dict_path0), exist_ok=True) + os.makedirs(os.path.dirname(nid_map_dict_path1), exist_ok=True) + + # Use randperm in the test otherwise there maybe no mapping necessary + nid_map0 = { + "paper": th.randperm(100), + "author": th.arange(200, 300) + } + save_tensors(nid_map_dict_path0, nid_map0) + + nid_map1 = { + "paper": th.randperm(100) + 100, + "author": th.arange(300, 400) + } + save_tensors(nid_map_dict_path1, nid_map1) + + # Create reversed map + node_mapping = load_dist_nid_map(f"{partitioned_graph}/dist_graph", ["author", "paper"]) + reverse_map_dst = {gid: i for i, gid in enumerate(node_mapping["paper"].tolist())} + reverse_map_dst[-1] = -1 + + return partitioned_graph, reverse_map_dst + + +def test_load_hard_negative_config(): + # For config with hard negative transformation + json_file_path = os.path.join(_ROOT, + "config/gsprocessing_hard_negative_config.json") + + res = load_hard_negative_config(json_file_path) + + assert res[0] == {'dst_node_type': 'paper', 'edge_type': + 'author:writing:paper', 'hard_neg_feat_name': 'hard_neg_feat'} + + # For config without hard negative transformation + json_file_path = os.path.join(_ROOT, + "config/gsprocessing_non_hard_negative_config.json") + + res = load_hard_negative_config(json_file_path) + + assert res == [] + + +def test_shuffle_hard_negative_nids(setup_graph_partition): + # Test the hard negative id shuffling process within distributed setting + + partitioned_graph, reverse_map_dst = setup_graph_partition + # For config with gsprocessing_config.json + json_file_path = os.path.join(_ROOT, + "config/gsprocessing_hard_negative_config.json") + + # generate edge features + etype = ("author", "writing", "paper") + edge_feat_path0 = os.path.join(partitioned_graph, "dist_graph", "part0", "edge_feat.dgl") + edge_feat_path1 = os.path.join(partitioned_graph, "dist_graph", "part1", "edge_feat.dgl") + os.makedirs(os.path.dirname(edge_feat_path0), exist_ok=True) + os.makedirs(os.path.dirname(edge_feat_path1), exist_ok=True) + + paper_writing_hard_neg0 = th.cat((th.randint(0, 100, (100, 100)), + th.full((100, 10), -1, dtype=th.int32)), dim=1) + paper_writing_hard_neg0_shuffled = [ + [reverse_map_dst[nid] for nid in negs] \ + for negs in paper_writing_hard_neg0.tolist()] + paper_writing_hard_neg0_shuffled = np.array(paper_writing_hard_neg0_shuffled) + paper_writing_hard_neg1 = th.cat((th.randint(100, 200, (100, 100)), + th.full((100, 10), -1, dtype=th.int32)), dim=1) + paper_writing_hard_neg1_shuffled = [ + [reverse_map_dst[nid] for nid in negs] \ + for negs in paper_writing_hard_neg1.tolist()] + paper_writing_hard_neg1_shuffled = np.array(paper_writing_hard_neg1_shuffled) + + save_tensors(edge_feat_path0, {":".join(etype)+"/hard_neg_feat": paper_writing_hard_neg0}) + save_tensors(edge_feat_path1, {":".join(etype)+"/hard_neg_feat": paper_writing_hard_neg1}) + + # Do the shuffling + shuffle_hard_negative_nids(json_file_path, 2, partitioned_graph) + + # Assert + paper_writing_hard_neg0 = load_tensors(edge_feat_path0) + assert_almost_equal(paper_writing_hard_neg0[":".join(etype) + "/hard_neg_feat"].numpy(), + paper_writing_hard_neg0_shuffled) + paper_writing_hard_neg1 = load_tensors(edge_feat_path1) + assert_almost_equal(paper_writing_hard_neg1[":".join(etype) + "/hard_neg_feat"].numpy(), + paper_writing_hard_neg1_shuffled) \ No newline at end of file