From 2d731b94a316fa4aff04d383bd2d40ea45ef0742 Mon Sep 17 00:00:00 2001 From: Zhihong Zhang <100308595+nvidianz@users.noreply.github.com> Date: Tue, 13 Aug 2024 20:36:04 -0400 Subject: [PATCH] Pre-trained Model and training_mode changes (#2793) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Updated FOBS readme to add DatumManager, added agrpcs as secure scheme * Added support for pre-trained model * Changed training_mode to split_mode + secure_training * split_mode => data_split_mode * Format error * Fixed a format error * Addressed PR comments * Fixed format * Changed all xgboost controller/executor to use new XGBoost --------- Co-authored-by: Yuan-Ting Hsieh (謝沅廷) --- .../code/vertical_xgb/vertical_data_loader.py | 6 ++-- .../base_v2/app/custom/higgs_data_loader.py | 2 +- .../xgboost/utils/prepare_job_config.py | 3 +- nvflare/app_opt/xgboost/constant.py | 26 -------------- nvflare/app_opt/xgboost/data_loader.py | 6 +--- .../xgboost/histogram_based/controller.py | 3 +- .../xgboost/histogram_based/executor.py | 6 ++-- .../adaptors/grpc_client_adaptor.py | 3 +- .../adaptors/xgb_adaptor.py | 16 ++++++--- .../xgboost/histogram_based_v2/controller.py | 20 ++++++----- .../xgboost/histogram_based_v2/defs.py | 27 ++++---------- .../histogram_based_v2/fed_controller.py | 6 ++-- .../mock/mock_controller.py | 3 +- .../runners/xgb_client_runner.py | 35 ++++++++++++++----- .../histogram_based_v2/sec/client_handler.py | 9 ++--- .../histogram_based_v2/secure_data_loader.py | 15 +++----- .../adaptors/xgb_adaptor_test.py | 3 +- 17 files changed, 87 insertions(+), 102 deletions(-) delete mode 100644 nvflare/app_opt/xgboost/constant.py diff --git a/examples/advanced/vertical_xgboost/code/vertical_xgb/vertical_data_loader.py b/examples/advanced/vertical_xgboost/code/vertical_xgb/vertical_data_loader.py index 096d428d2d..246824d819 100644 --- a/examples/advanced/vertical_xgboost/code/vertical_xgb/vertical_data_loader.py +++ b/examples/advanced/vertical_xgboost/code/vertical_xgb/vertical_data_loader.py @@ -62,7 +62,7 @@ def __init__(self, data_split_path, psi_path, id_col, label_owner, train_proport self.label_owner = label_owner self.train_proportion = train_proportion - def load_data(self, client_id: str, training_mode: str = ""): + def load_data(self, client_id: str, split_mode: int = 1): client_data_split_path = self.data_split_path.replace("site-x", client_id) client_psi_path = self.psi_path.replace("site-x", client_id) @@ -84,7 +84,7 @@ def load_data(self, client_id: str, training_mode: str = ""): label = "" # for Vertical XGBoost, read from csv with label_column and set data_split_mode to 1 for column mode - dtrain = xgb.DMatrix(train_path + f"?format=csv{label}", data_split_mode=1) - dvalid = xgb.DMatrix(valid_path + f"?format=csv{label}", data_split_mode=1) + dtrain = xgb.DMatrix(train_path + f"?format=csv{label}", data_split_mode=split_mode) + dvalid = xgb.DMatrix(valid_path + f"?format=csv{label}", data_split_mode=split_mode) return dtrain, dvalid diff --git a/examples/advanced/xgboost/histogram-based/jobs/base_v2/app/custom/higgs_data_loader.py b/examples/advanced/xgboost/histogram-based/jobs/base_v2/app/custom/higgs_data_loader.py index d97f459600..3edb2d7408 100644 --- a/examples/advanced/xgboost/histogram-based/jobs/base_v2/app/custom/higgs_data_loader.py +++ b/examples/advanced/xgboost/histogram-based/jobs/base_v2/app/custom/higgs_data_loader.py @@ -41,7 +41,7 @@ def __init__(self, data_split_filename): """ self.data_split_filename = data_split_filename - def load_data(self, client_id: str, training_mode: str = ""): + def load_data(self, client_id: str, split_mode: int): with open(self.data_split_filename, "r") as file: data_split = json.load(file) diff --git a/examples/advanced/xgboost/utils/prepare_job_config.py b/examples/advanced/xgboost/utils/prepare_job_config.py index e38c88eec8..71b6b650ba 100644 --- a/examples/advanced/xgboost/utils/prepare_job_config.py +++ b/examples/advanced/xgboost/utils/prepare_job_config.py @@ -152,7 +152,8 @@ def _update_server_config(config: dict, args): config["num_rounds"] = args.round_num config["workflows"][0]["args"]["xgb_params"]["nthread"] = args.nthread config["workflows"][0]["args"]["xgb_params"]["tree_method"] = args.tree_method - config["workflows"][0]["args"]["training_mode"] = args.training_mode + config["workflows"][0]["args"]["split_mode"] = args.split_mode + config["workflows"][0]["args"]["secure_training"] = args.secure_training def _copy_custom_files(src_job_path, src_app_name, dst_job_path, dst_app_name): diff --git a/nvflare/app_opt/xgboost/constant.py b/nvflare/app_opt/xgboost/constant.py deleted file mode 100644 index 826e311418..0000000000 --- a/nvflare/app_opt/xgboost/constant.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class TrainingMode: - # Non-secure mode - H = "h" - HORIZONTAL = "horizontal" - V = "v" - VERTICAL = "vertical" - # Secure mode - HS = "hs" - HORIZONTAL_SECURE = "horizontal_secure" - VS = "VS" - VERTICAL_SECURE = "vertical_secure" diff --git a/nvflare/app_opt/xgboost/data_loader.py b/nvflare/app_opt/xgboost/data_loader.py index d59a36c4de..f49d6dc796 100644 --- a/nvflare/app_opt/xgboost/data_loader.py +++ b/nvflare/app_opt/xgboost/data_loader.py @@ -18,14 +18,10 @@ import xgboost as xgb -from .constant import TrainingMode - class XGBDataLoader(ABC): @abstractmethod - def load_data( - self, client_id: str, training_mode: str = TrainingMode.HORIZONTAL - ) -> Tuple[xgb.DMatrix, xgb.DMatrix]: + def load_data(self, client_id: str, split_mode: int) -> Tuple[xgb.DMatrix, xgb.DMatrix]: """Loads data for xgboost. Returns: diff --git a/nvflare/app_opt/xgboost/histogram_based/controller.py b/nvflare/app_opt/xgboost/histogram_based/controller.py index 9ebf8680ae..67be563613 100644 --- a/nvflare/app_opt/xgboost/histogram_based/controller.py +++ b/nvflare/app_opt/xgboost/histogram_based/controller.py @@ -107,9 +107,10 @@ def start_controller(self, fl_ctx: FLContext): if not self._get_certificates(fl_ctx): self.log_error(fl_ctx, "Can't get required certificates for XGB FL server in secure mode.") return + self.log_info(fl_ctx, "Running XGB FL server in secure mode.") self._xgb_fl_server = multiprocessing.Process( target=xgb_federated.run_federated_server, - args=(self._port, len(clients), self._server_key_path, self._server_cert_path, self._ca_cert_path), + args=(len(clients), self._port, self._server_key_path, self._server_cert_path, self._ca_cert_path), ) else: self._xgb_fl_server = multiprocessing.Process( diff --git a/nvflare/app_opt/xgboost/histogram_based/executor.py b/nvflare/app_opt/xgboost/histogram_based/executor.py index f48b1775ca..8336d31aba 100644 --- a/nvflare/app_opt/xgboost/histogram_based/executor.py +++ b/nvflare/app_opt/xgboost/histogram_based/executor.py @@ -269,9 +269,9 @@ def train(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) - if not self._get_certificates(fl_ctx): return make_reply(ReturnCode.ERROR) - communicator_env["federated_server_cert"] = self._ca_cert_path - communicator_env["federated_client_key"] = self._client_key_path - communicator_env["federated_client_cert"] = self._client_cert_path + communicator_env["federated_server_cert_path"] = self._ca_cert_path + communicator_env["federated_client_key_path"] = self._client_key_path + communicator_env["federated_client_cert_path"] = self._client_cert_path try: with xgb.collective.CommunicatorContext(**communicator_env): diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_client_adaptor.py b/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_client_adaptor.py index acf7850da2..c4819fea1b 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_client_adaptor.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/grpc_client_adaptor.py @@ -98,7 +98,8 @@ class since the self object contains a sender that contains a Core Cell which ca Constant.RUNNER_CTX_SERVER_ADDR: server_addr, Constant.RUNNER_CTX_RANK: self.rank, Constant.RUNNER_CTX_NUM_ROUNDS: self.num_rounds, - Constant.RUNNER_CTX_TRAINING_MODE: self.training_mode, + Constant.RUNNER_CTX_SPLIT_MODE: self.split_mode, + Constant.RUNNER_CTX_SECURE_TRAINING: self.secure_training, Constant.RUNNER_CTX_XGB_PARAMS: self.xgb_params, Constant.RUNNER_CTX_XGB_OPTIONS: self.xgb_options, Constant.RUNNER_CTX_MODEL_DIR: self._run_dir, diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/xgb_adaptor.py b/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/xgb_adaptor.py index 3cada9ae89..c77827c472 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/xgb_adaptor.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/adaptors/xgb_adaptor.py @@ -150,7 +150,8 @@ def __init__(self, in_process: bool, per_msg_timeout: float, tx_timeout: float): self.stopped = False self.rank = None self.num_rounds = None - self.training_mode = None + self.split_mode = None + self.secure_training = None self.xgb_params = None self.xgb_options = None self.world_size = None @@ -196,10 +197,15 @@ def configure(self, config: dict, fl_ctx: FLContext): check_positive_int(Constant.CONF_KEY_NUM_ROUNDS, num_rounds) self.num_rounds = num_rounds - self.training_mode = config.get(Constant.CONF_KEY_TRAINING_MODE) - if self.training_mode is None: - raise RuntimeError("training_mode is not configured") - fl_ctx.set_prop(key=Constant.PARAM_KEY_TRAINING_MODE, value=self.training_mode, private=True, sticky=True) + self.split_mode = config.get(Constant.CONF_KEY_SPLIT_MODE) + if self.split_mode is None: + raise RuntimeError("split_mode is not configured") + fl_ctx.set_prop(key=Constant.PARAM_KEY_SPLIT_MODE, value=self.split_mode, private=True, sticky=True) + + self.secure_training = config.get(Constant.CONF_KEY_SECURE_TRAINING) + if self.secure_training is None: + raise RuntimeError("secure_training is not configured") + fl_ctx.set_prop(key=Constant.PARAM_KEY_SECURE_TRAINING, value=self.secure_training, private=True, sticky=True) self.xgb_params = config.get(Constant.CONF_KEY_XGB_PARAMS) if not self.xgb_params: diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/controller.py b/nvflare/app_opt/xgboost/histogram_based_v2/controller.py index 6303148335..048c1573cf 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/controller.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/controller.py @@ -27,7 +27,7 @@ from nvflare.fuel.utils.validation_utils import check_number_range, check_object_type, check_positive_number, check_str from nvflare.security.logging import secure_format_exception -from .defs import TRAINING_MODE_MAPPING, Constant +from .defs import Constant class ClientStatus: @@ -59,7 +59,8 @@ def __init__( self, adaptor_component_id: str, num_rounds: int, - training_mode: str, + split_mode: int, + secure_training: bool, xgb_params: dict, xgb_options: Optional[dict] = None, configure_task_name=Constant.CONFIG_TASK_NAME, @@ -80,7 +81,8 @@ def __init__( Args: adaptor_component_id - the component ID of server target adaptor num_rounds - number of rounds - training_mode - Split mode (horizontal, vertical, horizontal_secure, vertical_secure) + split_mode - 0 for horizontal/row-split, 1 for vertical/column-split + secure_training - If true, secure training is enabled xgb_params - The params argument for train method xgb_options - All other arguments for train method are passed through this dictionary configure_task_name - name of the config task @@ -100,7 +102,8 @@ def __init__( Controller.__init__(self) self.adaptor_component_id = adaptor_component_id self.num_rounds = num_rounds - self.training_mode = training_mode.lower() + self.split_mode = split_mode + self.secure_training = secure_training self.xgb_params = xgb_params self.xgb_options = xgb_options self.configure_task_name = configure_task_name @@ -118,10 +121,8 @@ def __init__( self.client_statuses = {} # client name => ClientStatus self.abort_signal = None - check_str("training_mode", training_mode) - valid_mode = TRAINING_MODE_MAPPING.keys() - if training_mode not in valid_mode: - raise ValueError(f"training_mode must be one of following values: {valid_mode}") + if split_mode not in {0, 1}: + raise ValueError("split_mode must be either 0 or 1") if not self.xgb_params: raise ValueError("xgb_params can't be empty") @@ -462,7 +463,8 @@ def _configure_clients(self, abort_signal: Signal, fl_ctx: FLContext): shareable[Constant.CONF_KEY_CLIENT_RANKS] = self.client_ranks shareable[Constant.CONF_KEY_NUM_ROUNDS] = self.num_rounds - shareable[Constant.CONF_KEY_TRAINING_MODE] = self.training_mode + shareable[Constant.CONF_KEY_SPLIT_MODE] = self.split_mode + shareable[Constant.CONF_KEY_SECURE_TRAINING] = self.secure_training shareable[Constant.CONF_KEY_XGB_PARAMS] = self.xgb_params shareable[Constant.CONF_KEY_XGB_OPTIONS] = self.xgb_options diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/defs.py b/nvflare/app_opt/xgboost/histogram_based_v2/defs.py index 3b71d59ffb..a08b40e28a 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/defs.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/defs.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nvflare.app_opt.xgboost.constant import TrainingMode from nvflare.fuel.f3.drivers.net_utils import MAX_FRAME_SIZE @@ -26,12 +25,13 @@ class Constant: CONF_KEY_CLIENT_RANKS = "client_ranks" CONF_KEY_WORLD_SIZE = "world_size" CONF_KEY_NUM_ROUNDS = "num_rounds" - CONF_KEY_TRAINING_MODE = "training_mode" + CONF_KEY_SPLIT_MODE = "split_mode" + CONF_KEY_SECURE_TRAINING = "secure_training" CONF_KEY_XGB_PARAMS = "xgb_params" CONF_KEY_XGB_OPTIONS = "xgb_options" # default component config values - CONFIG_TASK_TIMEOUT = 20 + CONFIG_TASK_TIMEOUT = 60 START_TASK_TIMEOUT = 10 XGB_SERVER_READY_TIMEOUT = 10.0 @@ -87,14 +87,16 @@ class Constant: PARAM_KEY_REPLY = "xgb.reply" PARAM_KEY_REQUEST = "xgb.request" PARAM_KEY_EVENT = "xgb.event" - PARAM_KEY_TRAINING_MODE = "xgb.training_mode" + PARAM_KEY_SPLIT_MODE = "xgb.split_mode" + PARAM_KEY_SECURE_TRAINING = "xgb.secure_training" PARAM_KEY_CONFIG_ERROR = "xgb.config_error" RUNNER_CTX_SERVER_ADDR = "server_addr" RUNNER_CTX_PORT = "port" RUNNER_CTX_CLIENT_NAME = "client_name" RUNNER_CTX_NUM_ROUNDS = "num_rounds" - RUNNER_CTX_TRAINING_MODE = "training_mode" + RUNNER_CTX_SPLIT_MODE = "split_mode" + RUNNER_CTX_SECURE_TRAINING = "secure_training" RUNNER_CTX_XGB_PARAMS = "xgb_params" RUNNER_CTX_XGB_OPTIONS = "xgb_options" RUNNER_CTX_WORLD_SIZE = "world_size" @@ -127,18 +129,3 @@ class Constant: class SplitMode: ROW = 0 COL = 1 - - -# Mapping of text training mode to split mode -TRAINING_MODE_MAPPING = { - TrainingMode.H: SplitMode.ROW, - TrainingMode.HORIZONTAL: SplitMode.ROW, - TrainingMode.V: SplitMode.COL, - TrainingMode.VERTICAL: SplitMode.COL, - TrainingMode.HS: SplitMode.ROW, - TrainingMode.HORIZONTAL_SECURE: SplitMode.ROW, - TrainingMode.VS: SplitMode.COL, - TrainingMode.VERTICAL_SECURE: SplitMode.COL, -} - -SECURE_TRAINING_MODES = {TrainingMode.HS, TrainingMode.HORIZONTAL_SECURE, TrainingMode.VS, TrainingMode.VERTICAL_SECURE} diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/fed_controller.py b/nvflare/app_opt/xgboost/histogram_based_v2/fed_controller.py index b0610567d1..2d6a8cf875 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/fed_controller.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/fed_controller.py @@ -27,7 +27,8 @@ class XGBFedController(XGBController): def __init__( self, num_rounds: int, - training_mode: str, + split_mode: int, + secure_training: bool, xgb_params: dict, xgb_options: Optional[dict] = None, configure_task_name=Constant.CONFIG_TASK_NAME, @@ -44,7 +45,8 @@ def __init__( self, adaptor_component_id="", num_rounds=num_rounds, - training_mode=training_mode, + split_mode=split_mode, + secure_training=secure_training, xgb_params=xgb_params, xgb_options=xgb_options, configure_task_name=configure_task_name, diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/mock/mock_controller.py b/nvflare/app_opt/xgboost/histogram_based_v2/mock/mock_controller.py index 8e9e32a9eb..ea81a4a1ee 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/mock/mock_controller.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/mock/mock_controller.py @@ -37,7 +37,8 @@ def __init__( ): XGBController.__init__( self, - training_mode="horizontal", + split_mode=0, + secure_training=False, xgb_params={"max_depth": 3}, adaptor_component_id="", num_rounds=num_rounds, diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py index 768f2152e6..b9490467c5 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/runners/xgb_client_runner.py @@ -19,11 +19,11 @@ from xgboost import callback from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import SystemConfigs +from nvflare.apis.fl_constant import FLContextKey, SystemConfigs from nvflare.apis.fl_context import FLContext from nvflare.app_common.tracking.log_writer import LogWriter from nvflare.app_opt.xgboost.data_loader import XGBDataLoader -from nvflare.app_opt.xgboost.histogram_based_v2.defs import SECURE_TRAINING_MODES, Constant +from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant from nvflare.app_opt.xgboost.histogram_based_v2.runners.xgb_runner import AppRunner from nvflare.app_opt.xgboost.metrics_cb import MetricsCallback from nvflare.fuel.utils.config_service import ConfigService @@ -33,6 +33,7 @@ PLUGIN_PARAM_KEY = "federated_plugin" PLUGIN_KEY_NAME = "name" PLUGIN_KEY_PATH = "path" +MODEL_FILE_NAME = "model.json" class XGBClientRunner(AppRunner, FLComponent): @@ -46,12 +47,14 @@ def __init__( self.model_file_name = model_file_name self.data_loader_id = data_loader_id self.logger = get_logger(self) + self.fl_ctx = None self._client_name = None self._rank = None self._world_size = None self._num_rounds = None - self._training_mode = None + self._split_mode = None + self._secure_training = None self._xgb_params = None self._xgb_options = None self._server_addr = None @@ -62,6 +65,7 @@ def __init__( self._metrics_writer = None def initialize(self, fl_ctx: FLContext): + self.fl_ctx = fl_ctx engine = fl_ctx.get_engine() self._data_loader = engine.get_component(self.data_loader_id) if not isinstance(self._data_loader, XGBDataLoader): @@ -95,6 +99,17 @@ def _xgb_train(self, num_rounds, xgb_params: dict, xgb_options: dict, train_data early_stopping_rounds = xgb_options.get("early_stopping_rounds", 0) verbose_eval = xgb_options.get("verbose_eval", False) + # Check for pre-trained model + job_id = self.fl_ctx.get_prop(FLContextKey.CURRENT_JOB_ID) + workspace = self.fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) + custom_dir = workspace.get_app_custom_dir(job_id) + model_file = os.path.join(custom_dir, MODEL_FILE_NAME) + if os.path.isfile(model_file): + self.logger.info(f"Pre-trained model is used: {model_file}") + xgb_model = model_file + else: + xgb_model = None + # Run training, all the features in training API is available. bst = xgb.train( xgb_params, @@ -104,6 +119,7 @@ def _xgb_train(self, num_rounds, xgb_params: dict, xgb_options: dict, train_data early_stopping_rounds=early_stopping_rounds, verbose_eval=verbose_eval, callbacks=callbacks, + xgb_model=xgb_model, ) return bst @@ -112,7 +128,8 @@ def run(self, ctx: dict): self._rank = ctx.get(Constant.RUNNER_CTX_RANK) self._world_size = ctx.get(Constant.RUNNER_CTX_WORLD_SIZE) self._num_rounds = ctx.get(Constant.RUNNER_CTX_NUM_ROUNDS) - self._training_mode = ctx.get(Constant.RUNNER_CTX_TRAINING_MODE) + self._split_mode = ctx.get(Constant.RUNNER_CTX_SPLIT_MODE) + self._secure_training = ctx.get(Constant.RUNNER_CTX_SECURE_TRAINING) self._xgb_params = ctx.get(Constant.RUNNER_CTX_XGB_PARAMS) self._xgb_options = ctx.get(Constant.RUNNER_CTX_XGB_OPTIONS) self._server_addr = ctx.get(Constant.RUNNER_CTX_SERVER_ADDR) @@ -125,8 +142,10 @@ def run(self, ctx: dict): self._xgb_params["device"] = f"cuda:{self._rank}" self.logger.info( - f"XGB training_mode: {self._training_mode} " f"params: {self._xgb_params} XGB options: {self._xgb_options}" + f"XGB split_mode: {self._split_mode} secure_training: {self._secure_training} " + f"params: {self._xgb_params} XGB options: {self._xgb_options}" ) + self.logger.info(f"server address is {self._server_addr}") communicator_env = { @@ -136,7 +155,7 @@ def run(self, ctx: dict): "federated_rank": self._rank, } - if self._training_mode not in SECURE_TRAINING_MODES: + if not self._secure_training: self.logger.info("XGBoost non-secure training") else: xgb_plugin_name = ConfigService.get_str_var( @@ -166,13 +185,11 @@ def run(self, ctx: dict): lib_name = f"lib{xgb_plugin_params[PLUGIN_KEY_NAME]}.{lib_ext}" xgb_plugin_params[PLUGIN_KEY_PATH] = str(get_package_root() / "libs" / lib_name) - self.logger.info(f"XGBoost secure training: {self._training_mode} Params: {xgb_plugin_params}") - communicator_env[PLUGIN_PARAM_KEY] = xgb_plugin_params with xgb.collective.CommunicatorContext(**communicator_env): # Load the data. Dmatrix must be created with column split mode in CommunicatorContext for vertical FL - train_data, val_data = self._data_loader.load_data(self._client_name, self._training_mode) + train_data, val_data = self._data_loader.load_data(self._client_name, self._split_mode) bst = self._xgb_train(self._num_rounds, self._xgb_params, self._xgb_options, train_data, val_data) diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py b/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py index 0f90c9b22c..56396f434a 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/sec/client_handler.py @@ -20,7 +20,7 @@ from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable from nvflare.app_opt.xgboost.histogram_based_v2.aggr import Aggregator -from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant, TrainingMode +from nvflare.app_opt.xgboost.histogram_based_v2.defs import Constant, SplitMode from nvflare.app_opt.xgboost.histogram_based_v2.sec.dam import DamDecoder from nvflare.app_opt.xgboost.histogram_based_v2.sec.data_converter import FeatureAggregationResult from nvflare.app_opt.xgboost.histogram_based_v2.sec.partial_he.adder import Adder @@ -408,13 +408,14 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): global tenseal_error if event_type == Constant.EVENT_XGB_JOB_CONFIGURED: task_data = fl_ctx.get_prop(FLContextKey.TASK_DATA) - training_mode = task_data.get(Constant.CONF_KEY_TRAINING_MODE) - if training_mode in {TrainingMode.VS, TrainingMode.VERTICAL_SECURE} and ipcl_imported: + split_mode = task_data.get(Constant.CONF_KEY_SPLIT_MODE) + secure_training = task_data.get(Constant.CONF_KEY_SECURE_TRAINING) + if secure_training and split_mode == SplitMode.COL and ipcl_imported: self.public_key, self.private_key = generate_keys(self.key_length) self.encryptor = Encryptor(self.public_key, self.num_workers) self.decrypter = Decrypter(self.private_key, self.num_workers) self.adder = Adder(self.num_workers) - elif training_mode in {TrainingMode.HS, TrainingMode.HORIZONTAL_SECURE}: + elif secure_training and split_mode == SplitMode.ROW: if not tenseal_imported: fl_ctx.set_prop(Constant.PARAM_KEY_CONFIG_ERROR, tenseal_error, private=True, sticky=False) return diff --git a/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py b/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py index 3939bbd41e..f5514e950f 100644 --- a/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py +++ b/nvflare/app_opt/xgboost/histogram_based_v2/secure_data_loader.py @@ -15,7 +15,7 @@ import xgboost as xgb from nvflare.app_opt.xgboost.data_loader import XGBDataLoader -from nvflare.app_opt.xgboost.histogram_based_v2.defs import TRAINING_MODE_MAPPING, SplitMode +from nvflare.app_opt.xgboost.histogram_based_v2.defs import SplitMode class SecureDataLoader(XGBDataLoader): @@ -29,22 +29,17 @@ def __init__(self, rank: int, folder: str): self.rank = rank self.folder = folder - def load_data(self, client_id: str, training_mode: str): + def load_data(self, client_id: str, split_mode: int): train_path = f"{self.folder}/{client_id}/train.csv" valid_path = f"{self.folder}/{client_id}/valid.csv" - if training_mode not in TRAINING_MODE_MAPPING: - raise ValueError(f"Invalid training_mode: {training_mode}") - - data_split_mode = TRAINING_MODE_MAPPING[training_mode] - - if self.rank == 0 or data_split_mode == SplitMode.ROW: + if self.rank == 0 or split_mode == SplitMode.ROW: label = "&label_column=0" else: label = "" - train_data = xgb.DMatrix(train_path + f"?format=csv{label}", data_split_mode=data_split_mode) - valid_data = xgb.DMatrix(valid_path + f"?format=csv{label}", data_split_mode=data_split_mode) + train_data = xgb.DMatrix(train_path + f"?format=csv{label}", data_split_mode=split_mode) + valid_data = xgb.DMatrix(valid_path + f"?format=csv{label}", data_split_mode=split_mode) return train_data, valid_data diff --git a/tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptors/xgb_adaptor_test.py b/tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptors/xgb_adaptor_test.py index ec37a48b01..6de75c5052 100644 --- a/tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptors/xgb_adaptor_test.py +++ b/tests/unit_test/app_opt/xgboost/histrogram_based_v2/adaptors/xgb_adaptor_test.py @@ -37,7 +37,8 @@ def test_configure(self): config = { Constant.CONF_KEY_CLIENT_RANKS: {"site-test": 1}, Constant.CONF_KEY_NUM_ROUNDS: 100, - Constant.CONF_KEY_TRAINING_MODE: "horizontal", + Constant.CONF_KEY_SPLIT_MODE: 0, + Constant.CONF_KEY_SECURE_TRAINING: False, Constant.CONF_KEY_XGB_PARAMS: {"depth": 1}, } ctx = FLContext()