Skip to content

Commit

Permalink
Pre-trained Model and training_mode changes (#2793)
Browse files Browse the repository at this point in the history
* Updated FOBS readme to add DatumManager, added agrpcs as secure scheme

* Added support for pre-trained model

* Changed training_mode to split_mode + secure_training

* split_mode => data_split_mode

* Format error

* Fixed a format error

* Addressed PR comments

* Fixed format

* Changed all xgboost controller/executor to use new XGBoost

---------

Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <[email protected]>
  • Loading branch information
nvidianz and YuanTingHsieh authored Aug 14, 2024
1 parent a3fb1e5 commit 2d731b9
Show file tree
Hide file tree
Showing 17 changed files with 87 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, data_split_path, psi_path, id_col, label_owner, train_proport
self.label_owner = label_owner
self.train_proportion = train_proportion

def load_data(self, client_id: str, training_mode: str = ""):
def load_data(self, client_id: str, split_mode: int = 1):
client_data_split_path = self.data_split_path.replace("site-x", client_id)
client_psi_path = self.psi_path.replace("site-x", client_id)

Expand All @@ -84,7 +84,7 @@ def load_data(self, client_id: str, training_mode: str = ""):
label = ""

# for Vertical XGBoost, read from csv with label_column and set data_split_mode to 1 for column mode
dtrain = xgb.DMatrix(train_path + f"?format=csv{label}", data_split_mode=1)
dvalid = xgb.DMatrix(valid_path + f"?format=csv{label}", data_split_mode=1)
dtrain = xgb.DMatrix(train_path + f"?format=csv{label}", data_split_mode=split_mode)
dvalid = xgb.DMatrix(valid_path + f"?format=csv{label}", data_split_mode=split_mode)

return dtrain, dvalid
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, data_split_filename):
"""
self.data_split_filename = data_split_filename

def load_data(self, client_id: str, training_mode: str = ""):
def load_data(self, client_id: str, split_mode: int):
with open(self.data_split_filename, "r") as file:
data_split = json.load(file)

Expand Down
3 changes: 2 additions & 1 deletion examples/advanced/xgboost/utils/prepare_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def _update_server_config(config: dict, args):
config["num_rounds"] = args.round_num
config["workflows"][0]["args"]["xgb_params"]["nthread"] = args.nthread
config["workflows"][0]["args"]["xgb_params"]["tree_method"] = args.tree_method
config["workflows"][0]["args"]["training_mode"] = args.training_mode
config["workflows"][0]["args"]["split_mode"] = args.split_mode
config["workflows"][0]["args"]["secure_training"] = args.secure_training


def _copy_custom_files(src_job_path, src_app_name, dst_job_path, dst_app_name):
Expand Down
26 changes: 0 additions & 26 deletions nvflare/app_opt/xgboost/constant.py

This file was deleted.

6 changes: 1 addition & 5 deletions nvflare/app_opt/xgboost/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,10 @@

import xgboost as xgb

from .constant import TrainingMode


class XGBDataLoader(ABC):
@abstractmethod
def load_data(
self, client_id: str, training_mode: str = TrainingMode.HORIZONTAL
) -> Tuple[xgb.DMatrix, xgb.DMatrix]:
def load_data(self, client_id: str, split_mode: int) -> Tuple[xgb.DMatrix, xgb.DMatrix]:
"""Loads data for xgboost.
Returns:
Expand Down
3 changes: 2 additions & 1 deletion nvflare/app_opt/xgboost/histogram_based/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,10 @@ def start_controller(self, fl_ctx: FLContext):
if not self._get_certificates(fl_ctx):
self.log_error(fl_ctx, "Can't get required certificates for XGB FL server in secure mode.")
return
self.log_info(fl_ctx, "Running XGB FL server in secure mode.")
self._xgb_fl_server = multiprocessing.Process(
target=xgb_federated.run_federated_server,
args=(self._port, len(clients), self._server_key_path, self._server_cert_path, self._ca_cert_path),
args=(len(clients), self._port, self._server_key_path, self._server_cert_path, self._ca_cert_path),
)
else:
self._xgb_fl_server = multiprocessing.Process(
Expand Down
6 changes: 3 additions & 3 deletions nvflare/app_opt/xgboost/histogram_based/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,9 @@ def train(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -
if not self._get_certificates(fl_ctx):
return make_reply(ReturnCode.ERROR)

communicator_env["federated_server_cert"] = self._ca_cert_path
communicator_env["federated_client_key"] = self._client_key_path
communicator_env["federated_client_cert"] = self._client_cert_path
communicator_env["federated_server_cert_path"] = self._ca_cert_path
communicator_env["federated_client_key_path"] = self._client_key_path
communicator_env["federated_client_cert_path"] = self._client_cert_path

try:
with xgb.collective.CommunicatorContext(**communicator_env):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ class since the self object contains a sender that contains a Core Cell which ca
Constant.RUNNER_CTX_SERVER_ADDR: server_addr,
Constant.RUNNER_CTX_RANK: self.rank,
Constant.RUNNER_CTX_NUM_ROUNDS: self.num_rounds,
Constant.RUNNER_CTX_TRAINING_MODE: self.training_mode,
Constant.RUNNER_CTX_SPLIT_MODE: self.split_mode,
Constant.RUNNER_CTX_SECURE_TRAINING: self.secure_training,
Constant.RUNNER_CTX_XGB_PARAMS: self.xgb_params,
Constant.RUNNER_CTX_XGB_OPTIONS: self.xgb_options,
Constant.RUNNER_CTX_MODEL_DIR: self._run_dir,
Expand Down
16 changes: 11 additions & 5 deletions nvflare/app_opt/xgboost/histogram_based_v2/adaptors/xgb_adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def __init__(self, in_process: bool, per_msg_timeout: float, tx_timeout: float):
self.stopped = False
self.rank = None
self.num_rounds = None
self.training_mode = None
self.split_mode = None
self.secure_training = None
self.xgb_params = None
self.xgb_options = None
self.world_size = None
Expand Down Expand Up @@ -196,10 +197,15 @@ def configure(self, config: dict, fl_ctx: FLContext):
check_positive_int(Constant.CONF_KEY_NUM_ROUNDS, num_rounds)
self.num_rounds = num_rounds

self.training_mode = config.get(Constant.CONF_KEY_TRAINING_MODE)
if self.training_mode is None:
raise RuntimeError("training_mode is not configured")
fl_ctx.set_prop(key=Constant.PARAM_KEY_TRAINING_MODE, value=self.training_mode, private=True, sticky=True)
self.split_mode = config.get(Constant.CONF_KEY_SPLIT_MODE)
if self.split_mode is None:
raise RuntimeError("split_mode is not configured")
fl_ctx.set_prop(key=Constant.PARAM_KEY_SPLIT_MODE, value=self.split_mode, private=True, sticky=True)

self.secure_training = config.get(Constant.CONF_KEY_SECURE_TRAINING)
if self.secure_training is None:
raise RuntimeError("secure_training is not configured")
fl_ctx.set_prop(key=Constant.PARAM_KEY_SECURE_TRAINING, value=self.secure_training, private=True, sticky=True)

self.xgb_params = config.get(Constant.CONF_KEY_XGB_PARAMS)
if not self.xgb_params:
Expand Down
20 changes: 11 additions & 9 deletions nvflare/app_opt/xgboost/histogram_based_v2/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from nvflare.fuel.utils.validation_utils import check_number_range, check_object_type, check_positive_number, check_str
from nvflare.security.logging import secure_format_exception

from .defs import TRAINING_MODE_MAPPING, Constant
from .defs import Constant


class ClientStatus:
Expand Down Expand Up @@ -59,7 +59,8 @@ def __init__(
self,
adaptor_component_id: str,
num_rounds: int,
training_mode: str,
split_mode: int,
secure_training: bool,
xgb_params: dict,
xgb_options: Optional[dict] = None,
configure_task_name=Constant.CONFIG_TASK_NAME,
Expand All @@ -80,7 +81,8 @@ def __init__(
Args:
adaptor_component_id - the component ID of server target adaptor
num_rounds - number of rounds
training_mode - Split mode (horizontal, vertical, horizontal_secure, vertical_secure)
split_mode - 0 for horizontal/row-split, 1 for vertical/column-split
secure_training - If true, secure training is enabled
xgb_params - The params argument for train method
xgb_options - All other arguments for train method are passed through this dictionary
configure_task_name - name of the config task
Expand All @@ -100,7 +102,8 @@ def __init__(
Controller.__init__(self)
self.adaptor_component_id = adaptor_component_id
self.num_rounds = num_rounds
self.training_mode = training_mode.lower()
self.split_mode = split_mode
self.secure_training = secure_training
self.xgb_params = xgb_params
self.xgb_options = xgb_options
self.configure_task_name = configure_task_name
Expand All @@ -118,10 +121,8 @@ def __init__(
self.client_statuses = {} # client name => ClientStatus
self.abort_signal = None

check_str("training_mode", training_mode)
valid_mode = TRAINING_MODE_MAPPING.keys()
if training_mode not in valid_mode:
raise ValueError(f"training_mode must be one of following values: {valid_mode}")
if split_mode not in {0, 1}:
raise ValueError("split_mode must be either 0 or 1")

if not self.xgb_params:
raise ValueError("xgb_params can't be empty")
Expand Down Expand Up @@ -462,7 +463,8 @@ def _configure_clients(self, abort_signal: Signal, fl_ctx: FLContext):

shareable[Constant.CONF_KEY_CLIENT_RANKS] = self.client_ranks
shareable[Constant.CONF_KEY_NUM_ROUNDS] = self.num_rounds
shareable[Constant.CONF_KEY_TRAINING_MODE] = self.training_mode
shareable[Constant.CONF_KEY_SPLIT_MODE] = self.split_mode
shareable[Constant.CONF_KEY_SECURE_TRAINING] = self.secure_training
shareable[Constant.CONF_KEY_XGB_PARAMS] = self.xgb_params
shareable[Constant.CONF_KEY_XGB_OPTIONS] = self.xgb_options

Expand Down
27 changes: 7 additions & 20 deletions nvflare/app_opt/xgboost/histogram_based_v2/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nvflare.app_opt.xgboost.constant import TrainingMode
from nvflare.fuel.f3.drivers.net_utils import MAX_FRAME_SIZE


Expand All @@ -26,12 +25,13 @@ class Constant:
CONF_KEY_CLIENT_RANKS = "client_ranks"
CONF_KEY_WORLD_SIZE = "world_size"
CONF_KEY_NUM_ROUNDS = "num_rounds"
CONF_KEY_TRAINING_MODE = "training_mode"
CONF_KEY_SPLIT_MODE = "split_mode"
CONF_KEY_SECURE_TRAINING = "secure_training"
CONF_KEY_XGB_PARAMS = "xgb_params"
CONF_KEY_XGB_OPTIONS = "xgb_options"

# default component config values
CONFIG_TASK_TIMEOUT = 20
CONFIG_TASK_TIMEOUT = 60
START_TASK_TIMEOUT = 10
XGB_SERVER_READY_TIMEOUT = 10.0

Expand Down Expand Up @@ -87,14 +87,16 @@ class Constant:
PARAM_KEY_REPLY = "xgb.reply"
PARAM_KEY_REQUEST = "xgb.request"
PARAM_KEY_EVENT = "xgb.event"
PARAM_KEY_TRAINING_MODE = "xgb.training_mode"
PARAM_KEY_SPLIT_MODE = "xgb.split_mode"
PARAM_KEY_SECURE_TRAINING = "xgb.secure_training"
PARAM_KEY_CONFIG_ERROR = "xgb.config_error"

RUNNER_CTX_SERVER_ADDR = "server_addr"
RUNNER_CTX_PORT = "port"
RUNNER_CTX_CLIENT_NAME = "client_name"
RUNNER_CTX_NUM_ROUNDS = "num_rounds"
RUNNER_CTX_TRAINING_MODE = "training_mode"
RUNNER_CTX_SPLIT_MODE = "split_mode"
RUNNER_CTX_SECURE_TRAINING = "secure_training"
RUNNER_CTX_XGB_PARAMS = "xgb_params"
RUNNER_CTX_XGB_OPTIONS = "xgb_options"
RUNNER_CTX_WORLD_SIZE = "world_size"
Expand Down Expand Up @@ -127,18 +129,3 @@ class Constant:
class SplitMode:
ROW = 0
COL = 1


# Mapping of text training mode to split mode
TRAINING_MODE_MAPPING = {
TrainingMode.H: SplitMode.ROW,
TrainingMode.HORIZONTAL: SplitMode.ROW,
TrainingMode.V: SplitMode.COL,
TrainingMode.VERTICAL: SplitMode.COL,
TrainingMode.HS: SplitMode.ROW,
TrainingMode.HORIZONTAL_SECURE: SplitMode.ROW,
TrainingMode.VS: SplitMode.COL,
TrainingMode.VERTICAL_SECURE: SplitMode.COL,
}

SECURE_TRAINING_MODES = {TrainingMode.HS, TrainingMode.HORIZONTAL_SECURE, TrainingMode.VS, TrainingMode.VERTICAL_SECURE}
6 changes: 4 additions & 2 deletions nvflare/app_opt/xgboost/histogram_based_v2/fed_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ class XGBFedController(XGBController):
def __init__(
self,
num_rounds: int,
training_mode: str,
split_mode: int,
secure_training: bool,
xgb_params: dict,
xgb_options: Optional[dict] = None,
configure_task_name=Constant.CONFIG_TASK_NAME,
Expand All @@ -44,7 +45,8 @@ def __init__(
self,
adaptor_component_id="",
num_rounds=num_rounds,
training_mode=training_mode,
split_mode=split_mode,
secure_training=secure_training,
xgb_params=xgb_params,
xgb_options=xgb_options,
configure_task_name=configure_task_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def __init__(
):
XGBController.__init__(
self,
training_mode="horizontal",
split_mode=0,
secure_training=False,
xgb_params={"max_depth": 3},
adaptor_component_id="",
num_rounds=num_rounds,
Expand Down
Loading

0 comments on commit 2d731b9

Please sign in to comment.