Skip to content

Commit

Permalink
Attempt to get PyTorch example to work on Windows.
Browse files Browse the repository at this point in the history
  • Loading branch information
shellander committed Oct 20, 2023
1 parent e4731f7 commit 14c8ef3
Show file tree
Hide file tree
Showing 9 changed files with 244 additions and 96 deletions.
4 changes: 2 additions & 2 deletions examples/mnist-pytorch/client/fedn.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
entry_points:
train:
command: /venv/bin/python entrypoint train $ENTRYPOINT_OPTS
command: python entrypoint train $ENTRYPOINT_OPTS
validate:
command: /venv/bin/python entrypoint validate $ENTRYPOINT_OPTS
command: python entrypoint validate $ENTRYPOINT_OPTS
9 changes: 7 additions & 2 deletions fedn/cli/run_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,12 @@ def run_cmd(ctx):
help='Set logfile for client log to file.')
@click.option('--heartbeat-interval', required=False, default=2)
@click.option('--reconnect-after-missed-heartbeat', required=False, default=30)
@click.option('--verbosity', required=False, default='INFO', type=click.Choice(['CRITICAL', 'ERROR', 'WARNING', 'INFO', 'DEBUG'], case_sensitive=False))
@click.option('--theme', required=False, default='default', type=click.Choice(['dark', 'light', 'vibrant', 'default'], case_sensitive=False))
@click.pass_context
def client_cmd(ctx, discoverhost, discoverport, token, name, client_id, local_package, force_ssl, dry_run, secure, preshared_cert,
verify, preferred_combiner, validator, trainer, init, logfile, heartbeat_interval, reconnect_after_missed_heartbeat):
verify, preferred_combiner, validator, trainer, init, logfile, heartbeat_interval, reconnect_after_missed_heartbeat,
verbosity, theme):
"""
:param ctx:
Expand All @@ -128,14 +131,16 @@ def client_cmd(ctx, discoverhost, discoverport, token, name, client_id, local_pa
:param logfile:
:param hearbeat_interval
:param reconnect_after_missed_heartbeat
:param verbosity
:param theme
:return:
"""
remote = False if local_package else True
config = {'discover_host': discoverhost, 'discover_port': discoverport, 'token': token, 'name': name,
'client_id': client_id, 'remote_compute_context': remote, 'force_ssl': force_ssl, 'dry_run': dry_run, 'secure': secure,
'preshared_cert': preshared_cert, 'verify': verify, 'preferred_combiner': preferred_combiner,
'validator': validator, 'trainer': trainer, 'init': init, 'logfile': logfile, 'heartbeat_interval': heartbeat_interval,
'reconnect_after_missed_heartbeat': reconnect_after_missed_heartbeat}
'reconnect_after_missed_heartbeat': reconnect_after_missed_heartbeat, 'verbosity': verbosity, 'theme': theme}

if init:
parse_client_config(config)
Expand Down
173 changes: 105 additions & 68 deletions fedn/fedn/client.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion fedn/fedn/clients/reducer/restservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def check_configured_response(self):
if not self.control.idle():
return jsonify({'status': 'retry',
'package': self.package,
'msg': "Conroller is not in idle state, try again later. "})
'msg': "Controller is not in idle state, try again later. "})
return None

def check_configured(self):
Expand Down
58 changes: 58 additions & 0 deletions fedn/fedn/common/color_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import logging
from termcolor import colored


class ColorizingStreamHandler(logging.StreamHandler):
dark_theme = {
'DEBUG': 'white',
'INFO': 'green',
'WARNING': 'yellow',
'ERROR': 'red',
'CRITICAL': 'red',
}

light_theme = {
'DEBUG': 'black',
'INFO': 'blue',
'WARNING': 'magenta',
'ERROR': 'red',
'CRITICAL': 'red',
}

vibrant_theme = {
'DEBUG': 'cyan',
'INFO': 'green',
'WARNING': 'yellow',
'ERROR': 'red',
'CRITICAL': 'red',
}

def __init__(self, theme='dark'):
super().__init__()
self.set_theme(theme)

def set_theme(self, theme):
if theme == 'dark':
self.color_map = self.dark_theme
elif theme == 'light':
self.color_map = self.light_theme
elif theme == 'vibrant':
self.color_map = self.vibrant_theme
elif theme == 'default':
self.color_map = {} # No color applied
else:
self.color_map = {} # No color applied

def emit(self, record):
try:
# Separate the log level from the message
level = '[{}]'.format(record.levelname)
color = self.color_map.get(record.levelname, 'white')
colored_level = colored(level, color)

# Combine the colored log level with the rest of the message
message = self.format(record).replace(level, colored_level)
self.stream.write(message + "\n")
self.flush()
except Exception:
self.handleError(record)
27 changes: 12 additions & 15 deletions fedn/fedn/common/control/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from fedn.utils.checksum import sha
from fedn.utils.dispatcher import Dispatcher

from fedn.common.log_config import logger

class Package:
"""
Expand Down Expand Up @@ -76,7 +76,7 @@ def upload(self):
# print("going to send {}".format(data),flush=True)
f = open(os.path.join(os.path.dirname(
self.file_path), self.package_file), 'rb')
print("Sending the following file {}".format(f.read()), flush=True)
logger.info("Sending the following file {}".format(f.read()))
f.seek(0, 0)
files = {'file': f}
try:
Expand All @@ -85,13 +85,10 @@ def upload(self):
# data=data,
headers={'Authorization': 'Token {}'.format(self.reducer_token)})
except Exception as e:
print("failed to put execution context to reducer. {}".format(
e), flush=True)
logger.error("Failed to put execution context to reducer. {}".format(e))
finally:
f.close()

print("Upload 4 ", flush=True)


class PackageRuntime:
"""
Expand Down Expand Up @@ -141,7 +138,7 @@ def download(self, host, port, token, force_ssl=False, secure=False, name=None):
try:
self.pkg_name = params['filename']
except KeyError:
print("No package returned!", flush=True)
logger.error("No package returned.")
return None
r.raise_for_status()
with open(os.path.join(self.pkg_path, self.pkg_name), 'wb') as f:
Expand All @@ -161,7 +158,7 @@ def download(self, host, port, token, force_ssl=False, secure=False, name=None):
try:
self.checksum = data['checksum']
except Exception:
print("Could not extract checksum!")
logger.error("Could not extract checksum.")

return True

Expand All @@ -183,7 +180,7 @@ def validate(self, expected_checksum):
# return True

if self.checksum == self.expected_checksum == file_checksum:
print("Package validated {}".format(self.checksum))
logger.info("Package validated {}".format(self.checksum))
return True
else:
return False
Expand All @@ -204,7 +201,7 @@ def unpack(self):
f = tarfile.open(os.path.join(
self.pkg_path, self.pkg_name), 'r:bz2')
else:
print(
logger.warning(
"Failed to unpack compute package, no pkg_name set. Has the reducer been configured with a compute package?")

os.getcwd()
Expand All @@ -213,10 +210,10 @@ def unpack(self):

if f:
f.extractall()
print("Successfully extracted compute package content in {}".format(
self.dir), flush=True)
logger.info("Successfully extracted compute package content in {}".format(
self.dir))
except Exception:
print("Error extracting files!")
logger.errro("Error extracting files!")

def dispatcher(self, run_path):
"""
Expand All @@ -237,8 +234,8 @@ def dispatcher(self, run_path):
self.dispatch_config = cfg

except Exception:
print(
"Error trying to load and unpack dispatcher config - trying default", flush=True)
logger.error(
"Error trying to load and unpack dispatcher config - trying default")

dispatcher = Dispatcher(self.dispatch_config, run_path)

Expand Down
49 changes: 49 additions & 0 deletions fedn/fedn/common/log_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import logging
import logging.config
import urllib3
from fedn.common.color_handler import ColorizingStreamHandler
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
logging.getLogger("urllib3").setLevel(logging.ERROR)

handler = ColorizingStreamHandler(theme='dark')
logger = logging.getLogger()
logger.addHandler(handler)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
handler.setFormatter(formatter)


def set_log_level_from_string(level_str):
"""
Set the log level based on a string input.
"""
# Mapping of string representation to logging constants
level_mapping = {
'CRITICAL': logging.CRITICAL,
'ERROR': logging.ERROR,
'WARNING': logging.WARNING,
'INFO': logging.INFO,
'DEBUG': logging.DEBUG,
}

# Get the logging level from the mapping
level = level_mapping.get(level_str.upper())

if not level:
raise ValueError(f"Invalid log level: {level_str}")

# Set the log level
logger.setLevel(level)


def set_theme_from_string(theme_str):
"""
Set the logging color theme based on a string input.
"""
# Check if the theme string is valid
valid_themes = ['dark', 'light', 'vibrant', 'default']
if theme_str.lower() not in valid_themes:
raise ValueError(f"Invalid theme: {theme_str}. Valid themes are: {', '.join(valid_themes)}")

# Set the theme for the ColorizingStreamHandler
handler.set_theme(theme_str.lower())
10 changes: 4 additions & 6 deletions fedn/fedn/common/net/connect.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import enum

import requests as r

from fedn.common.log_config import logger

class State(enum.Enum):
Disconnected = 0
Expand Down Expand Up @@ -45,8 +45,7 @@ def __init__(self, host, port, token, name, remote_package, force_ssl=False, ver
self.connect_string = "{}{}".format(
self.prefix, self.host)

print("\n\nsetting the connection string to {}\n\n".format(
self.connect_string), flush=True)
logger.info("Established connection string to {}.".format(self.connect_string))

def state(self):
"""
Expand Down Expand Up @@ -77,7 +76,7 @@ def assign(self):
allow_redirects=True,
headers={'Authorization': 'Token {}'.format(self.token)})
except Exception as e:
print('***** {}'.format(e), flush=True)
logger.error('***** {}'.format(e))
return Status.Unassigned, {}

if retval.status_code == 401:
Expand Down Expand Up @@ -131,8 +130,7 @@ def __init__(self, host, port, myhost, fqdn, myport, token, name, secure=False,
self.connect_string = "{}{}".format(
self.prefix, self.host)

print("\n\nsetting the connection string to {}\n\n".format(
self.connect_string), flush=True)
logger.info("Established connection string to {}.".format(self.connect_string))

def state(self):
"""
Expand Down
8 changes: 6 additions & 2 deletions fedn/fedn/utils/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import platform
import logging

from fedn.utils.process import run_process
Expand Down Expand Up @@ -28,11 +29,14 @@ def run_cmd(self, cmd_type):
args = cmdsandargs[1:]

# shell (this could be a venv, TODO: parametrize)
shell = ['/bin/sh', '-c']
if platform.system() == "Windows":
shell = ['powershell.exe', '-Command']
else:
shell = ['/bin/sh', '-c']

# add the corresponding process defined in project.yaml and append arguments from invoked command
args = shell + [' '.join(cmd + args)]
# print("trying to run process {} with args {}".format(cmd, args))
logger.debug("trying to run process {} with args {}".format(cmd, args))
run_process(args=args, cwd=self.project_dir)

logger.info('DONE RUNNING {}'.format(cmd_type))
Expand Down

0 comments on commit 14c8ef3

Please sign in to comment.