diff --git a/docs/source/configuration/configuration-gconstruction.rst b/docs/source/configuration/configuration-gconstruction.rst index b04f01ff92..a4ce2cd3c6 100644 --- a/docs/source/configuration/configuration-gconstruction.rst +++ b/docs/source/configuration/configuration-gconstruction.rst @@ -19,6 +19,7 @@ Graph Construction * **-\-skip-nonexist-edges**: boolean value to decide whether skip edges whose endpoint nodes don't exist. Default is true. * **-\-ext-mem-workspace**: the directory where the tool can store data during graph construction. Suggest to use high-speed SSD as the external memory workspace. * **-\-ext-mem-feat-size**: the minimal number of feature dimensions that features can be stored in external memory. Default is 64. +* **-\-output-conf-file**: The output file with the updated configurations that records the details of data transformation, e.g., one-hot encoding maps, max-min normalization ranges. If not specified, will save the updated configuration file in the **-\-output-dir** with name `data_transform_new.json`. .. _gconstruction-json: diff --git a/docs/source/gs-processing/developer/input-configuration.rst b/docs/source/gs-processing/developer/input-configuration.rst index b667cdb8d2..b747d19986 100644 --- a/docs/source/gs-processing/developer/input-configuration.rst +++ b/docs/source/gs-processing/developer/input-configuration.rst @@ -173,6 +173,15 @@ objects: assign to the validation set [0.0, 1.0). - ``test``: The percentage of the data with available labels to assign to the test set [0.0, 1.0). + - ``custom_split_filenames`` (JSON object, optional): Specifies the customized + training/validation/test mask. Once it is defined, GSProcessing will ignore + the ``split_rate``. + - ``train``: Path of the training mask parquet file such that each line contains + the original ID for node tasks, or the pair [source_id, destination_id] for edge tasks. + - ``val``: Path of the validation mask parquet file such that each line contains + the original ID for node tasks, or the pair [source_id, destination_id] for edge tasks. + - ``test``: Path of the test mask parquet file such that each line contains + the original ID for node tasks, or the pair [source_id, destination_id] for edge tasks. - ``features`` (List of JSON objects, optional)\ **:** Describes the set of features for the current edge type. See the :ref:`features-object` section for details. diff --git a/docs/source/index.rst b/docs/source/index.rst index 0d067bfb4a..769f8d7961 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -44,6 +44,7 @@ Welcome to the GraphStorm Documentation and Tutorials notebooks/Notebook_0_Data_Prepare notebooks/Notebook_1_NC_Pipeline + notebooks/Notebook_2_LP_Pipeline .. toctree:: :maxdepth: 1 diff --git a/docs/source/notebooks/Notebook_2_LP_Pipeline.ipynb b/docs/source/notebooks/Notebook_2_LP_Pipeline.ipynb new file mode 100644 index 0000000000..1e21a4e82e --- /dev/null +++ b/docs/source/notebooks/Notebook_2_LP_Pipeline.ipynb @@ -0,0 +1,362 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Notebook 2: Use GraphStorm APIs for Building a Link Prediction Pipeline\n", + "\n", + "This notebook demonstrates how to use GraphStorm's APIs to create a graph machine learning pipeline for a link prediction task.\n", + "\n", + "In this notebook, we modify the RGCN model used in the Notebook 1 to adapt to link prediction tasks and use it to conduct link prediction on the ACM dataset created by the **Notebook_0_Data_Prepare**. \n", + "\n", + "### Prerequsites\n", + "\n", + "- GraphStorm installed using pip. Please find [more details on installation of GraphStorm](https://graphstorm.readthedocs.io/en/latest/install/env-setup.html#setup-graphstorm-with-pip-packages).\n", + "- ACM data created in the [Notebook 0: Data Prepare](https://graphstorm.readthedocs.io/en/latest/notebooks/Notebook_0_Data_Prepare.html), and is stored in the `./acm_gs_1p/` folder.\n", + "- Installation of supporting libraries, e.g., matplotlib." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Setup log level in Jupyter Notebook\n", + "import logging\n", + "logging.basicConfig(level=20)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "The major steps of creating a link prediction pipeline are same as the node classification pipeline in the Notebook 1. In this notebook, we will only highlight the different components for clarity.\n", + "\n", + "### 0. Initialize the GraphStorm Standalone Environment" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import graphstorm as gs\n", + "gs.initialize()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. Setup GraphStorm Dataset and DataLoaders\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "nfeats_4_modeling = {'author':['feat'], 'paper':['feat'],'subject':['feat']}\n", + "\n", + "# create a GraphStorm Dataset for the ACM graph data generated in the Notebook 0\n", + "acm_data = gs.dataloading.GSgnnData(part_config='./acm_gs_1p/acm.json', node_feat_field=nfeats_4_modeling)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Because link prediction needs both positive and negative edges for training, we use GraphStorm's `GSgnnLinkPredictionDataloader` which is dedicated for link prediction dataloading. This class takes some common arugments as these `NodePredictionDataloader`s, such as `dataset`, `target_idx`, `node_feats`, and `batch_size`. It also takes some link prediction-related arguments, e.g., `num_negative_edges`, `exlude_training_targets`, and etc." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# define dataloaders for training and validation\n", + "train_dataloader = gs.dataloading.GSgnnLinkPredictionDataLoader(\n", + " dataset=acm_data,\n", + " target_idx=acm_data.get_edge_train_set(etypes=[('paper', 'citing', 'paper')]),\n", + " fanout=[20, 20],\n", + " num_negative_edges=10,\n", + " node_feats=nfeats_4_modeling,\n", + " batch_size=64,\n", + " exclude_training_targets=False,\n", + " reverse_edge_types_map=[\"paper,citing,cited,paper\"],\n", + " train_task=True)\n", + "val_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(\n", + " dataset=acm_data,\n", + " target_idx=acm_data.get_edge_val_set(etypes=[('paper', 'citing', 'paper')]),\n", + " fanout=[100, 100],\n", + " num_negative_edges=100,\n", + " node_feats=nfeats_4_modeling,\n", + " batch_size=256)\n", + "test_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(\n", + " dataset=acm_data,\n", + " target_idx=acm_data.get_edge_test_set(etypes=[('paper', 'citing', 'paper')]),\n", + " fanout=[100, 100],\n", + " num_negative_edges=100,\n", + " node_feats=nfeats_4_modeling,\n", + " batch_size=256)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Create a GraphStorm-compatible RGCN Model for Link Prediction \n", + "\n", + "For the link prediction task, we modified the RGCN model used for node classification to adopt to link prediction task. Users can find the details in the `demon_models.py` file." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# import a simplified RGCN model for node classification\n", + "from demo_models import RgcnLPModel\n", + "\n", + "model = RgcnLPModel(g=acm_data.g,\n", + " num_hid_layers=2,\n", + " node_feat_field=nfeats_4_modeling,\n", + " hid_size=128)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Setup a GraphStorm Evaluator\n", + "\n", + "Here we change evaluator to a `GSgnnMrrLPEvaluator` that uses \"mrr\" as the metric dedicated for evaluation of link prediction performance." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# setup a link prediction evaluator for the trainer\n", + "evaluator = gs.eval.GSgnnMrrLPEvaluator(eval_frequency=1000)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4. Setup a Trainer and Training\n", + "\n", + "GraphStorm has the `GSgnnLinkPredictionTrainer` for link prediction training loop. The way of constructing this trainer and calling `fit()` method are same as the `GSgnnNodePredictionTrainer` used in the Notebook 1." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "# create a GraphStorm link prediction task trainer for the RGCN model\n", + "trainer = gs.trainer.GSgnnLinkPredictionTrainer(model, topk_model_to_save=1)\n", + "trainer.setup_evaluator(evaluator)\n", + "trainer.setup_device(gs.utils.get_device())" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Train the model with the trainer using fit() function\n", + "trainer.fit(train_loader=train_dataloader,\n", + " val_loader=val_dataloader,\n", + " test_loader=test_dataloader,\n", + " num_epochs=5,\n", + " save_model_path='a_save_path/',\n", + " save_model_frequency=1000,\n", + " use_mini_batch_infer=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### (Optional) 5. Visualize Model Performance History\n", + "\n", + "Same as the node classification pipeline, we can use the history stored in the evaluator." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "# extract evaluation history of metrics from the trainer's evaluator:\n", + "val_metrics, test_metrics = [], []\n", + "for val_metric, test_metric in trainer.evaluator.history:\n", + " val_metrics.append(val_metric['mrr'])\n", + " test_metrics.append(test_metric['mrr'])\n", + "\n", + "# plot the performance curves\n", + "fig, ax = plt.subplots()\n", + "ax.plot(val_metrics, label='val')\n", + "ax.plot(test_metrics, label='test')\n", + "ax.set(xlabel='Epoch', ylabel='Mrr')\n", + "ax.legend(loc='best')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6. Inference with the Trained Model\n", + "\n", + "The operations of model restore are same as those used in the Notebook 1. Users can find the best model path first, and use model's `restore_model()` to load the trained model file." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# after training, the best model is saved to disk:\n", + "best_model_path = trainer.get_best_model_path()\n", + "print('Best model path:', best_model_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "# we can restore the model from the saved path using the model's restore_model() function.\n", + "model.restore_model(best_model_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To do inference, users can either create a new dataloader as the following code does, or reuse one of the dataloaders defined in training." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# Setup dataloader for inference\n", + "infer_dataloader = gs.dataloading.GSgnnLinkPredictionTestDataLoader(\n", + " dataset=acm_data,\n", + " target_idx=acm_data.get_edge_infer_set(etypes=[('paper', 'citing', 'paper')]),\n", + " fanout=[100, 100],\n", + " num_negative_edges=100,\n", + " node_feats=nfeats_4_modeling,\n", + " batch_size=256)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can define a `GSgnnLinkPredictionInferrer` by giving the restored model and do inference by calling its `infer()` method." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "# Create an Inferrer object\n", + "infer = gs.inference.GSgnnLinkPredictionInferrer(model)\n", + "\n", + "# Run inference on the inference dataset\n", + "infer.infer(acm_data,\n", + " infer_dataloader,\n", + " save_embed_path='infer/embeddings',\n", + " use_mini_batch_infer=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For link prediction task, the inference outputs are embeddings of all nodes in the inference graph." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "# The GNN embeddings of all nodes in the inference graph are saved to the folder named after the target_ntype\n", + "!ls -lh infer/embeddings/paper" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "gsf", + "language": "python", + "name": "gsf" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/source/notebooks/demo_models.py b/docs/source/notebooks/demo_models.py index 93f53124da..03c642ea92 100644 --- a/docs/source/notebooks/demo_models.py +++ b/docs/source/notebooks/demo_models.py @@ -22,7 +22,10 @@ GSNodeEncoderInputLayer, RelationalGCNEncoder, EntityClassifier, - ClassifyLossFunc) + ClassifyLossFunc, + GSgnnLinkPredictionModel, + LinkPredictDotDecoder, + LinkPredictBCELossFunc) class RgcnNCModel(GSgnnNodeModel): @@ -89,3 +92,60 @@ def __init__(self, self.init_optimizer(lr=0.001, sparse_optimizer_lr=0.01, weight_decay=0) + + +class RgcnLPModel(GSgnnLinkPredictionModel): + """ A simple RGCN model for link prediction using Graphstorm APIs + + This RGCN model extends GraphStorm's GSgnnLinkPredictionModel, and it has the similar + model architecture as the node model, but has a different decoder layer and loss function: + 1. an input layer that converts input node features to the embeddings with hidden dimensions + 2. a GNN encoder layer that performs the message passing work + 3. a decoder layer that transfors edge representations into logits for link prediction, and + 4. a loss function that matches to link prediction tasks. + + The model also initialize its own optimizer object. + + Arguments + ---------- + g: DistGraph + a DGL DistGraph + num_hid_layers: int + the number of gnn layers + node_feat_field: dict of list of strings + The list features for each node type to be used in the model + hid_size: int + the dimension of hidden layers. + """ + def __init__(self, + g, + num_hid_layers, + node_feat_field, + hid_size): + super(RgcnLPModel, self).__init__(alpha_l2norm=0.) + + # extract feature size + feat_size = gs.get_node_feat_size(g, node_feat_field) + + # set an input layer encoder + encoder = GSNodeEncoderInputLayer(g=g, feat_size=feat_size, embed_size=hid_size) + self.set_node_input_encoder(encoder) + + # set a GNN encoder + gnn_encoder = RelationalGCNEncoder(g=g, + h_dim=hid_size, + out_dim=hid_size, + num_hidden_layers=num_hid_layers-1) + self.set_gnn_encoder(gnn_encoder) + + # set a decoder specific to link prediction task + decoder = LinkPredictDotDecoder(hid_size) + self.set_decoder(decoder) + + # link prediction loss function + self.set_loss_func(LinkPredictBCELossFunc()) + + # initialize model's optimizer + self.init_optimizer(lr=0.001, + sparse_optimizer_lr=0.01, + weight_decay=0) diff --git a/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py b/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py index 3fa46e7326..c60e70a4e6 100644 --- a/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py +++ b/graphstorm-processing/graphstorm_processing/config/config_conversion/gconstruct_converter.py @@ -14,6 +14,7 @@ limitations under the License. """ +import math from typing import Any from collections.abc import Mapping @@ -52,16 +53,25 @@ def _convert_label(labels: list[dict]) -> list[dict]: label_column = label["label_col"] if "label_col" in label else "" label_type = label["task_type"] label_dict = {"column": label_column, "type": label_type} - if "split_pct" in label: - label_splitrate = label["split_pct"] - # check if split_pct is valid - assert ( - sum(label_splitrate) <= 1.0 - ), "sum of the label split rate should be <=1.0" - label_dict["split_rate"] = { - "train": label_splitrate[0], - "val": label_splitrate[1], - "test": label_splitrate[2], + if "custom_split_filenames" not in label: + if "split_pct" in label: + label_splitrate = label["split_pct"] + # check if split_pct is valid + assert ( + math.fsum(label_splitrate) == 1.0 + ), "sum of the label split rate should be ==1.0" + label_dict["split_rate"] = { + "train": label_splitrate[0], + "val": label_splitrate[1], + "test": label_splitrate[2], + } + else: + label_custom_split_filenames = label["custom_split_filenames"] + label_dict["custom_split_filenames"] = { + "train": label_custom_split_filenames["train"], + "valid": label_custom_split_filenames["valid"], + "test": label_custom_split_filenames["test"], + "column": label_custom_split_filenames["column"], } if "separator" in label: label_sep = label["separator"] diff --git a/graphstorm-processing/graphstorm_processing/config/label_config_base.py b/graphstorm-processing/graphstorm_processing/config/label_config_base.py index fca0a82ec0..7f1251ebd4 100644 --- a/graphstorm-processing/graphstorm_processing/config/label_config_base.py +++ b/graphstorm-processing/graphstorm_processing/config/label_config_base.py @@ -30,9 +30,14 @@ def __init__(self, config_dict: Dict[str, Any]): self._label_column = "" assert config_dict["type"] == "link_prediction" self._task_type: str = config_dict["type"] - self._split: Dict[str, float] = config_dict["split_rate"] self._separator: str = config_dict["separator"] if "separator" in config_dict else None self._multilabel = self._separator is not None + if "custom_split_filenames" not in config_dict: + self._split: Dict[str, float] = config_dict["split_rate"] + self._custom_split_filenames = None + else: + self._split = None + self._custom_split_filenames: Dict[str, str] = config_dict["custom_split_filenames"] def _sanity_check(self): if self._label_column == "": @@ -40,9 +45,17 @@ def _sanity_check(self): "When no label column is specified, the task type must be link_prediction, " f"got {self._task_type}" ) - assert isinstance(self._task_type, str) - assert isinstance(self._split, dict) - assert isinstance(self._separator, str) if self._multilabel else self._separator is None + if "custom_split_filenames" not in self._config: + assert isinstance(self._task_type, str) + assert isinstance(self._split, dict) + assert isinstance(self._separator, str) if self._multilabel else self._separator is None + else: + assert isinstance(self._custom_split_filenames, dict) + assert "train" in self._custom_split_filenames + assert "valid" in self._custom_split_filenames + assert "test" in self._custom_split_filenames + assert "column" in self._custom_split_filenames + assert isinstance(self._separator, str) if self._multilabel else self._separator is None @property def label_column(self) -> str: @@ -71,6 +84,11 @@ def multilabel(self) -> bool: """Whether the task is multilabel classification.""" return self._multilabel + @property + def custom_split_filenames(self) -> Dict[str, str]: + """The config for custom split labels.""" + return self._custom_split_filenames + class EdgeLabelConfig(LabelConfig): """Holds the configuration of an edge label. diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_label_loader.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_label_loader.py index 5782a683da..1d7a7596e5 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_label_loader.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_label_loader.py @@ -55,6 +55,29 @@ def __post_init__(self) -> None: ) +@dataclass +class CustomSplit: + """ + Dataclass to hold the custom split for each of the train/val/test splits. + + Parameters + ---------- + train : str + Path of the training mask parquet file. + valid : str + Path of the validation mask parquet file. + test : str + Path of the testing mask parquet file. + mask_columns : list[str] + List of columns that contain original string ids. + """ + + train: str + valid: str + test: str + mask_columns: list[str] + + class DistLabelLoader: """Used to transform label columns to conform to downstream GraphStorm expectations. diff --git a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py index 193b668818..962a831935 100644 --- a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py +++ b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py @@ -18,6 +18,7 @@ import logging import numbers import os +import math from collections import Counter, defaultdict from dataclasses import dataclass from time import perf_counter @@ -34,6 +35,7 @@ ArrayType, ByteType, ) +from pyspark.sql.functions import col, when from numpy.random import default_rng from graphstorm_processing.constants import ( @@ -50,7 +52,7 @@ from ..config.label_config_base import LabelConfig from ..config.feature_config_base import FeatureConfig from ..data_transformations.dist_feature_transformer import DistFeatureTransformer -from ..data_transformations.dist_label_loader import DistLabelLoader, SplitRates +from ..data_transformations.dist_label_loader import DistLabelLoader, SplitRates, CustomSplit from ..data_transformations import s3_utils, spark_utils from .heterogeneous_graphloader import HeterogeneousGraphLoader @@ -1108,8 +1110,21 @@ def _process_node_labels( ) else: split_rates = None - label_split_dicts = self._create_split_files_from_rates( - nodes_df, label_conf.label_column, split_rates, split_masks_output_prefix + if label_conf.custom_split_filenames: + custom_split_filenames = CustomSplit( + train=label_conf.custom_split_filenames["train"], + valid=label_conf.custom_split_filenames["valid"], + test=label_conf.custom_split_filenames["test"], + mask_columns=label_conf.custom_split_filenames["column"], + ) + else: + custom_split_filenames = None + label_split_dicts = self._create_split_files( + nodes_df, + label_conf.label_column, + split_rates, + split_masks_output_prefix, + custom_split_filenames, ) node_type_label_metadata.update(label_split_dicts) @@ -1574,11 +1589,23 @@ def _process_edge_labels( ) else: split_rates = None - label_split_dicts = self._create_split_files_from_rates( - edges_df, label_conf.label_column, split_rates, split_masks_output_prefix + if label_conf.custom_split_filenames: + custom_split_filenames = CustomSplit( + train=label_conf.custom_split_filenames["train"], + valid=label_conf.custom_split_filenames["valid"], + test=label_conf.custom_split_filenames["test"], + mask_columns=label_conf.custom_split_filenames["column"], + ) + else: + custom_split_filenames = None + label_split_dicts = self._create_split_files( + edges_df, + label_conf.label_column, + split_rates, + split_masks_output_prefix, + custom_split_filenames, ) label_metadata_dicts.update(label_split_dicts) - # TODO: Support custom_split_filenames return label_metadata_dicts @@ -1652,17 +1679,18 @@ def _update_label_properties( else: raise RuntimeError(f"Invalid task type: {label_config.task_type}") - def _create_split_files_from_rates( + def _create_split_files( self, input_df: DataFrame, label_column: str, split_rates: Optional[SplitRates], output_path: str, + custom_split_file: Optional[CustomSplit] = None, seed: Optional[int] = None, ) -> Dict: """ - Given an input dataframe and a list of split rates creates the - split masks and writes them to S3 and returns the corresponding + Given an input dataframe and a list of split rates or a list of custom split files + creates the split masks and writes them to S3 and returns the corresponding metadata.json dict elements. Parameters @@ -1678,6 +1706,9 @@ def _create_split_files_from_rates( If None, a default split rate of 0.9:0.05:0.05 is used. output_path: str The output path under which we write the masks. + custom_split_file: Optional[CustomSplit] + A CustomSplit object including path to the custom split files for + training/validation/test. seed: int An optional random seed for reproducibility. @@ -1688,11 +1719,70 @@ def _create_split_files_from_rates( """ # If the user did not provide a split rate we use a default split_metadata = {} + if not custom_split_file: + train_mask_df, val_mask_df, test_mask_df = self._create_split_files_split_rates( + input_df, label_column, split_rates, seed + ) + else: + train_mask_df, val_mask_df, test_mask_df = self._create_split_files_custom_split( + input_df, custom_split_file + ) + + def create_metadata_entry(path_list): + return {"format": {"name": FORMAT_NAME, "delimiter": DELIMITER}, "data": path_list} + + def write_mask(kind: str, mask_df: DataFrame) -> Sequence[str]: + out_path_list = self._write_df( + mask_df.select(F.col(f"{kind}_mask").cast(ByteType()).alias(f"{kind}_mask")), + f"{output_path}-{kind}-mask", + ) + return out_path_list + + out_path_list = write_mask("train", train_mask_df) + split_metadata["train_mask"] = create_metadata_entry(out_path_list) + + out_path_list = write_mask("val", val_mask_df) + split_metadata["val_mask"] = create_metadata_entry(out_path_list) + + out_path_list = write_mask("test", test_mask_df) + split_metadata["test_mask"] = create_metadata_entry(out_path_list) + + return split_metadata + + def _create_split_files_split_rates( + self, + input_df: DataFrame, + label_column: str, + split_rates: Optional[SplitRates], + seed: Optional[int], + ) -> tuple[DataFrame, DataFrame, DataFrame]: + """ + Creates the train/val/test mask dataframe based on split rates. + + Parameters + ---------- + input_df: DataFrame + Input dataframe for which we will create split masks. + label_column: str + The name of the label column. If provided, the values in the column + need to be not null for the data point to be included in one of the masks. + If an empty string, all rows in the input_df are included in one of train/val/test sets. + split_rates: Optional[SplitRates] + A SplitRates object indicating the train/val/test split rates. + If None, a default split rate of 0.9:0.05:0.05 is used. + seed: Optional[int] + An optional random seed for reproducibility. + + Returns + ------- + tuple[DataFrame, DataFrame, DataFrame] + Train/val/test mask dataframes. + """ if split_rates is None: split_rates = SplitRates(train_rate=0.8, val_rate=0.1, test_rate=0.1) else: # TODO: add support for sums <= 1.0, useful for large-scale link prediction - if sum(split_rates.tolist()) != 1.0: + if math.fsum(split_rates.tolist()) != 1.0: raise RuntimeError(f"Provided split rates do not sum to 1: {split_rates}") split_list = split_rates.tolist() @@ -1719,30 +1809,94 @@ def multinomial_sample(label_col: str) -> Sequence[int]: # to create one-hot vector indicating train/test/val membership input_col = F.col(label_column).astype("string") if label_column else F.lit("dummy") int_group_df = input_df.select(split_group(input_col).alias(group_col_name)) + int_group_df.cache() + train_mask_df = int_group_df.select(F.col(group_col_name)[0].alias("train_mask")) + val_mask_df = int_group_df.select(F.col(group_col_name)[1].alias("val_mask")) + test_mask_df = int_group_df.select(F.col(group_col_name)[2].alias("test_mask")) - def create_metadata_entry(path_list): - return {"format": {"name": FORMAT_NAME, "delimiter": DELIMITER}, "data": path_list} + return train_mask_df, val_mask_df, test_mask_df - def write_mask(kind: str, mask_df: DataFrame) -> Sequence[str]: - out_path_list = self._write_df( - mask_df.select(F.col(f"{kind}_mask").cast(ByteType()).alias(f"{kind}_mask")), - f"{output_path}-{kind}-mask", - ) - return out_path_list + def _create_split_files_custom_split( + self, input_df: DataFrame, custom_split_file: str + ) -> tuple[DataFrame, DataFrame, DataFrame]: + """ + Creates the train/val/test mask dataframe based on custom split files. - train_mask_df = int_group_df.select(F.col(group_col_name)[0].alias("train_mask")) - out_path_list = write_mask("train", train_mask_df) - split_metadata["train_mask"] = create_metadata_entry(out_path_list) + Parameters + ---------- + input_df: DataFrame + Input dataframe for which we will create split masks. + custom_split_file: Optional[CustomSplit] + A CustomSplit object including path to the custom split files for + training/validation/test. - val_mask_df = int_group_df.select(F.col(group_col_name)[1].alias("val_mask")) - out_path_list = write_mask("val", val_mask_df) - split_metadata["val_mask"] = create_metadata_entry(out_path_list) + Returns + ------- + tuple[DataFrame, DataFrame, DataFrame] + Train/val/test mask dataframes. + """ - test_mask_df = int_group_df.select(F.col(group_col_name)[2].alias("test_mask")) - out_path_list = write_mask("test", test_mask_df) - split_metadata["test_mask"] = create_metadata_entry(out_path_list) + # custom node/edge label + # create custom mask dataframe for one of the types: train, val, test + def process_custom_mask_df(input_df, split_file, mask_type): + if mask_type == "train": + file_path = split_file.train + elif mask_type == "val": + file_path = split_file.valid + elif mask_type == "test": + file_path = split_file.test + else: + raise ValueError("Unknown mask type") + + if len(split_file.mask_columns) == 1: + # custom split on node original id + custom_mask_df = self.spark.read.parquet( + os.path.join(self.input_prefix, file_path) + ).select(col(split_file.mask_columns[0]).alias(f"custom_{mask_type}_mask")) + mask_df = input_df.join( + custom_mask_df, + input_df[NODE_MAPPING_STR] == custom_mask_df[f"custom_{mask_type}_mask"], + "left_outer", + ) + mask_df = mask_df.select( + "*", + when(mask_df[f"custom_{mask_type}_mask"].isNotNull(), 1) + .otherwise(0) + .alias(f"{mask_type}_mask"), + ).select(f"{mask_type}_mask") + elif len(split_file.mask_columns) == 2: + # custom split on edge (srd, dst) original ids + custom_mask_df = self.spark.read.parquet( + os.path.join(self.input_prefix, file_path) + ).select( + col(split_file.mask_columns[0]).alias(f"custom_{mask_type}_mask_src"), + col(split_file.mask_columns[1]).alias(f"custom_{mask_type}_mask_dst"), + ) + join_condition = ( + input_df["src_str_id"] == custom_mask_df[f"custom_{mask_type}_mask_src"] + ) & (input_df["dst_str_id"] == custom_mask_df[f"custom_{mask_type}_mask_dst"]) + mask_df = input_df.join(custom_mask_df, join_condition, "left_outer") + mask_df = mask_df.select( + "*", + when( + (mask_df[f"custom_{mask_type}_mask_src"].isNotNull()) + & (mask_df[f"custom_{mask_type}_mask_dst"].isNotNull()), + 1, + ) + .otherwise(0) + .alias(f"{mask_type}_mask"), + ).select(f"{mask_type}_mask") + else: + raise ValueError("The number of column should be only 1 or 2.") - return split_metadata + return mask_df + + train_mask_df, val_mask_df, test_mask_df = ( + process_custom_mask_df(input_df, custom_split_file, "train"), + process_custom_mask_df(input_df, custom_split_file, "val"), + process_custom_mask_df(input_df, custom_split_file, "test"), + ) + return train_mask_df, val_mask_df, test_mask_df def load(self) -> Dict: return self.process_and_write_graph_data(self._data_configs) diff --git a/graphstorm-processing/tests/test_converter.py b/graphstorm-processing/tests/test_converter.py index 54078e4700..84a56ce6fd 100644 --- a/graphstorm-processing/tests/test_converter.py +++ b/graphstorm-processing/tests/test_converter.py @@ -166,6 +166,50 @@ def test_read_node_gconstruct(converter: GConstructConfigConverter, node_dict: d } ] + node_dict["nodes"].append( + { + "node_type": "paper_custom", + "format": {"name": "parquet"}, + "files": ["/tmp/acm_raw/nodes/paper_custom.parquet"], + "node_id_col": "node_id", + "labels": [ + { + "label_col": "label", + "task_type": "classification", + "custom_split_filenames": { + "train": "customized_label/node_train_idx.parquet", + "valid": "customized_label/node_val_idx.parquet", + "test": "customized_label/node_test_idx.parquet", + "column": ["ID"], + }, + "label_stats_type": "frequency_cnt", + } + ], + } + ) + + # nodes with all elements + # [self.type, self.format, self.files, self.separator, self.column, self.features, self.labels] + node_config = converter.convert_nodes(node_dict["nodes"])[2] + assert len(converter.convert_nodes(node_dict["nodes"])) == 3 + assert node_config.node_type == "paper_custom" + assert node_config.file_format == "parquet" + assert node_config.files == ["/tmp/acm_raw/nodes/paper_custom.parquet"] + assert node_config.separator is None + assert node_config.column == "node_id" + assert node_config.labels == [ + { + "column": "label", + "type": "classification", + "custom_split_filenames": { + "train": "customized_label/node_train_idx.parquet", + "valid": "customized_label/node_val_idx.parquet", + "test": "customized_label/node_test_idx.parquet", + "column": ["ID"], + }, + } + ] + @pytest.mark.parametrize("col_name", ["author", ["author"]]) def test_read_edge_gconstruct(converter: GConstructConfigConverter, col_name): diff --git a/graphstorm-processing/tests/test_dist_heterogenous_loader.py b/graphstorm-processing/tests/test_dist_heterogenous_loader.py index 200889f65f..5cfa4988c4 100644 --- a/graphstorm-processing/tests/test_dist_heterogenous_loader.py +++ b/graphstorm-processing/tests/test_dist_heterogenous_loader.py @@ -554,8 +554,8 @@ def test_create_split_files_from_rates( non_missing_data_points = NUM_DATAPOINTS - missing_data_points edges_df = create_edges_df(spark, missing_data_points) - output_dicts = dghl_loader._create_split_files_from_rates( - edges_df, LABEL_COL, split_rates, os.path.join(tempdir, "sample_masks"), seed=42 + output_dicts = dghl_loader._create_split_files( + edges_df, LABEL_COL, split_rates, os.path.join(tempdir, "sample_masks"), None, seed=42 ) train_mask_df, test_mask_df, val_mask_df = read_masks_from_disk( @@ -585,8 +585,8 @@ def test_create_split_files_from_rates_empty_col( edges_df = create_edges_df(spark, 0).drop(LABEL_COL) split_rates = SplitRates(0.8, 0.1, 0.1) - output_dicts = dghl_loader._create_split_files_from_rates( - edges_df, "", split_rates, os.path.join(tempdir, "sample_masks"), seed=42 + output_dicts = dghl_loader._create_split_files( + edges_df, "", split_rates, os.path.join(tempdir, "sample_masks"), None, seed=42 ) train_mask_df, test_mask_df, val_mask_df = read_masks_from_disk( @@ -706,3 +706,93 @@ def test_update_label_properties_multilabel( assert user_properties[COLUMN_NAME] == "multi" assert user_properties[VALUE_COUNTS] == {str(i): 1 for i in range(1, 11)} + + +def test_node_custom_label(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp_path): + data = [(i,) for i in range(1, 11)] + + # Create DataFrame + nodes_df = spark.createDataFrame(data, ["orig"]) + + train_df = spark.createDataFrame([(i,) for i in range(1, 6)], ["mask_id"]) + val_df = spark.createDataFrame([(i,) for i in range(6, 9)], ["mask_id"]) + test_df = spark.createDataFrame([(i,) for i in range(9, 11)], ["mask_id"]) + + train_df.repartition(1).write.parquet(f"{tmp_path}/train.parquet") + val_df.repartition(1).write.parquet(f"{tmp_path}/val.parquet") + test_df.repartition(1).write.parquet(f"{tmp_path}/test.parquet") + config_dict = { + "column": "orig", + "type": "classification", + "split_rate": {"train": 0.8, "val": 0.1, "test": 0.1}, + "custom_split_filenames": { + "train": f"{tmp_path}/train.parquet", + "valid": f"{tmp_path}/val.parquet", + "test": f"{tmp_path}/test.parquet", + "column": ["mask_id"], + }, + } + dghl_loader.input_prefix = "" + label_configs = [NodeLabelConfig(config_dict)] + label_metadata_dicts = dghl_loader._process_node_labels(label_configs, nodes_df, "orig") + + assert label_metadata_dicts.keys() == {"train_mask", "test_mask", "val_mask", "orig"} + + train_mask_df, test_mask_df, val_mask_df = read_masks_from_disk( + spark, dghl_loader, label_metadata_dicts + ) + + train_total_ones = train_mask_df.agg(F.sum("train_mask")).collect()[0][0] + val_total_ones = val_mask_df.agg(F.sum("val_mask")).collect()[0][0] + test_total_ones = test_mask_df.agg(F.sum("test_mask")).collect()[0][0] + assert train_total_ones == 5 + assert val_total_ones == 3 + assert test_total_ones == 2 + + +def test_edge_custom_label(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp_path): + data = [(i, j) for i in range(1, 4) for j in range(11, 14)] + # Create DataFrame + edges_df = spark.createDataFrame(data, ["src_str_id", "dst_str_id"]) + + train_df = spark.createDataFrame( + [(i, j) for i in range(1, 2) for j in range(11, 14)], ["mask_src_id", "mask_dst_id"] + ) + val_df = spark.createDataFrame( + [(i, j) for i in range(2, 3) for j in range(11, 14)], ["mask_src_id", "mask_dst_id"] + ) + test_df = spark.createDataFrame( + [(i, j) for i in range(3, 4) for j in range(11, 14)], ["mask_src_id", "mask_dst_id"] + ) + + train_df.repartition(1).write.parquet(f"{tmp_path}/train.parquet") + val_df.repartition(1).write.parquet(f"{tmp_path}/val.parquet") + test_df.repartition(1).write.parquet(f"{tmp_path}/test.parquet") + config_dict = { + "column": "", + "type": "link_prediction", + "custom_split_filenames": { + "train": f"{tmp_path}/train.parquet", + "valid": f"{tmp_path}/val.parquet", + "test": f"{tmp_path}/test.parquet", + "column": ["mask_src_id", "mask_dst_id"], + }, + } + dghl_loader.input_prefix = "" + label_configs = [EdgeLabelConfig(config_dict)] + label_metadata_dicts = dghl_loader._process_edge_labels( + label_configs, edges_df, "src_str_id:to:dst_str_id", "" + ) + + assert label_metadata_dicts.keys() == {"train_mask", "test_mask", "val_mask"} + + train_mask_df, test_mask_df, val_mask_df = read_masks_from_disk( + spark, dghl_loader, label_metadata_dicts + ) + + train_total_ones = train_mask_df.agg(F.sum("train_mask")).collect()[0][0] + val_total_ones = val_mask_df.agg(F.sum("val_mask")).collect()[0][0] + test_total_ones = test_mask_df.agg(F.sum("test_mask")).collect()[0][0] + assert train_total_ones == 3 + assert val_total_ones == 3 + assert test_total_ones == 3 diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index 7955815d39..6e982c8fc8 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -1695,9 +1695,16 @@ def target_etype(self): def remove_target_edge_type(self): """ Whether to remove the training target edge type for message passing. - Will set the fanout of training target edge type as zero + Will set the fanout of training target edge type as zero. Only used + with edge classification. - Only used with edge classification + If the edge classification is to predict the existence of an edge between + two nodes, we should remove the target edge in the message passing to + avoid information leak. + If it's to predict some attributes associated with an edge, we may not need + to remove the target edge. + Since we don't know what to predict, to be safe, we should remove the target + edge in message passing by default. """ # pylint: disable=no-member if hasattr(self, "_remove_target_edge_type"): @@ -1706,6 +1713,10 @@ def remove_target_edge_type(self): # By default, remove training target etype during # message passing to avoid information leakage + logging.warning("remove_target_edge_type is set to True by default. " + "If your edge classification task is not predicting " + "the existence of the target edge, we suggest you to " + "set it to False.") return True @property diff --git a/python/graphstorm/gconstruct/construct_graph.py b/python/graphstorm/gconstruct/construct_graph.py index a433240fe6..9ed376baeb 100644 --- a/python/graphstorm/gconstruct/construct_graph.py +++ b/python/graphstorm/gconstruct/construct_graph.py @@ -770,10 +770,25 @@ def process_graph(args): skip_nonexist_edges=args.skip_nonexist_edges) sys_tracker.check('Process the edge data') num_nodes = {ntype: len(raw_node_id_maps[ntype]) for ntype in raw_node_id_maps} + + os.makedirs(args.output_dir, exist_ok=True) + if args.output_conf_file is not None: - # Save the new config file. - with open(args.output_conf_file, "w", encoding="utf8") as outfile: - json.dump(process_confs, outfile, indent=4) + outfile_path = args.output_conf_file + else: + new_file_name = 'data_transform_new.json' + outfile_path = os.path.join(args.output_dir,new_file_name ) + + # check if the output configuration file exists. Overwrite it with a warning. + if os.path.exists(outfile_path): + logging.warning('Overwrote the existing %s file, which was generated in ' + \ + 'the previous graph construction command. Use the --output-conf-file ' + \ + 'argument to specify a different location if not want to overwrite the ' + \ + 'existing configuration file.', outfile_path) + + # Save the new config file. + with open(outfile_path, "w", encoding="utf8") as outfile: + json.dump(process_confs, outfile, indent=4) if args.add_reverse_edges: edges1 = {} @@ -815,7 +830,6 @@ def process_graph(args): g = dgl.heterograph(edges, num_nodes_dict=num_nodes) print_graph_info(g, node_data, edge_data, node_label_stats, edge_label_stats, node_label_masks, edge_label_masks) - os.makedirs(args.output_dir, exist_ok=True) sys_tracker.check('Construct DGL graph') # reshape customized mask diff --git a/python/graphstorm/model/node_gnn.py b/python/graphstorm/model/node_gnn.py index adbb2027c4..fca05ada24 100644 --- a/python/graphstorm/model/node_gnn.py +++ b/python/graphstorm/model/node_gnn.py @@ -323,17 +323,17 @@ def node_mini_batch_predict(model, emb, loader, return_proba=True, return_label= # TODO(zhengda) I need to check if the data loader only returns target nodes. model.eval() with th.no_grad(): - for input_nodes, seeds, _ in loader: - for ntype, in_nodes in input_nodes.items(): + for _, seeds, _ in loader: # seeds are target nodes + for ntype, seed_nodes in seeds.items(): if isinstance(model.decoder, th.nn.ModuleDict): assert ntype in model.decoder, f"Node type {ntype} not in decoder" decoder = model.decoder[ntype] else: decoder = model.decoder if return_proba: - pred = decoder.predict_proba(emb[ntype][in_nodes].to(device)) + pred = decoder.predict_proba(emb[ntype][seed_nodes].to(device)) else: - pred = decoder.predict(emb[ntype][in_nodes].to(device)) + pred = decoder.predict(emb[ntype][seed_nodes].to(device)) if ntype in preds: preds[ntype].append(pred.cpu()) else: diff --git a/tests/unit-tests/test_gnn.py b/tests/unit-tests/test_gnn.py index f5c5771085..399015a2ae 100644 --- a/tests/unit-tests/test_gnn.py +++ b/tests/unit-tests/test_gnn.py @@ -275,15 +275,27 @@ def require_cache_embed(self): dataloader2 = GSgnnNodeDataLoader(data, target_nidx, fanout=[-1, -1], batch_size=10, label_field='label', node_feats='feat', train_task=False) - pred2, _, labels2 = node_mini_batch_gnn_predict(model, dataloader2, return_label=True) + # Call GNN mini-batch inference + pred2_gnn_pred, _, labels2_gnn_pred, = node_mini_batch_gnn_predict(model, dataloader2, return_label=True) + # Call last layer mini-batch inference with the GNN dataloader + pred2_pred, labels2_pred = node_mini_batch_predict(model, embs, dataloader2, return_label=True) if isinstance(pred1,dict): - assert len(pred1) == len(pred2) and len(labels1) == len(labels2) + assert len(pred1) == len(pred2_gnn_pred) and len(labels1) == len(labels2_gnn_pred) + assert len(pred1) == len(pred2_pred) and len(labels1) == len(labels2_pred) for ntype in pred1: - assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(), pred2[ntype][0:len(pred2)].numpy(), decimal=5) - assert_equal(labels1[ntype].numpy(), labels2[ntype].numpy()) + assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(), + pred2_gnn_pred[ntype][0:len(pred2_gnn_pred)].numpy(), decimal=5) + assert_equal(labels1[ntype].numpy(), labels2_gnn_pred[ntype].numpy()) + assert_almost_equal(pred1[ntype][0:len(pred1)].numpy(), + pred2_pred[ntype][0:len(pred2_pred)].numpy(), decimal=5) + assert_equal(labels1[ntype].numpy(), labels2_pred[ntype].numpy()) else: - assert_almost_equal(pred1[0:len(pred1)].numpy(), pred2[0:len(pred2)].numpy(), decimal=5) - assert_equal(labels1.numpy(), labels2.numpy()) + assert_almost_equal(pred1[0:len(pred1)].numpy(), + pred2_gnn_pred[0:len(pred2_gnn_pred)].numpy(), decimal=5) + assert_equal(labels1.numpy(), labels2_gnn_pred.numpy()) + assert_almost_equal(pred1[0:len(pred1)].numpy(), + pred2_pred[0:len(pred2_pred)].numpy(), decimal=5) + assert_equal(labels1.numpy(), labels2_pred.numpy()) # Test the return_proba argument. pred3, labels3 = node_mini_batch_predict(model, embs, dataloader1, return_proba=True, return_label=True)