Skip to content

Commit

Permalink
update the PR, allow specify labels and timestamp
Browse files Browse the repository at this point in the history
  • Loading branch information
chesterxgchen committed Aug 17, 2024
1 parent f1816b3 commit 43d8656
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 151 deletions.
36 changes: 0 additions & 36 deletions nvflare/fuel_opt/prometheus/app_metrics_config.json

This file was deleted.

68 changes: 0 additions & 68 deletions nvflare/fuel_opt/prometheus/load_metrics.py

This file was deleted.

70 changes: 43 additions & 27 deletions nvflare/fuel_opt/prometheus/metrics_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,32 @@
import argparse
import json
import logging
import os.path
import time
from http.server import HTTPServer

from load_metrics import load_metrics_config
from prometheus_client import Counter, Gauge
from prometheus_client import REGISTRY
from prometheus_client.exposition import MetricsHandler

# Load the metrics configuration
from prometheus_client.metrics_core import GaugeMetricFamily
from prometheus_client.registry import Collector

from nvflare.metrics.metrics_keys import MetricKeys

metrics_store = {}
logger = logging.getLogger("CustomMetricsHandler")


# Use a custom collector to yield the stored metrics
class CustomCollector(Collector):
def collect(self):
for metric in metrics_store.values():
yield metric


REGISTRY.register(CustomCollector())


class CustomMetricsHandler(MetricsHandler):
def __init__(self, *args, **kwargs):
self.metrics_store = kwargs.pop("metrics_store", {})
Expand All @@ -38,20 +51,29 @@ def do_POST(self):
if self.path == "/update_metrics":
content_length = int(self.headers["Content-Length"])
post_data = self.rfile.read(content_length)
metrics_data = json.loads(post_data)

# Update the metrics store
for metric_name, value in metrics_data.items():
if metric_name in metrics_store:
p1 = metrics_store[metric_name]
if isinstance(p1, Gauge):
p1.set(value)
elif isinstance(p, Counter):
p1.inc(value)
else:
p1 = Gauge(metric_name, metric_name)
metrics_store[metric_name] = p1
p1.set(value)
content = json.loads(post_data)

if content:
for metric_data in content:
metric_name = metric_data.get(MetricKeys.metric_name)
value = metric_data.get(MetricKeys.value)
labels = metric_data.get(MetricKeys.labels, {})
timestamp = metric_data.get(MetricKeys.timestamp, int(time.time()))

# Create a unique key based on metric name and labels
metric_key = (metric_name, tuple(sorted(labels.items())))

if metric_key not in metrics_store:
# Register/update GaugeMetricFamily with timestamp
gauge = GaugeMetricFamily(
metric_name, f"Description of {metric_name}", labels=list(labels.keys())
)
metrics_store[metric_key] = gauge
else:
# Update the existing gauge
gauge = metrics_store[metric_key]

gauge.add_metric(list(labels.values()), value, timestamp=timestamp)

self.send_response(200)
self.end_headers()
Expand All @@ -69,36 +91,30 @@ def run_http_server(port):

thread = Thread(target=server.serve_forever)
thread.daemon = True

thread.start()
logger.info(f"started prometheus metrics server on port {port}")
print(f"started prometheus metrics server on port {port}")


def parse_arguments():
parser = argparse.ArgumentParser(description="Start/Stop Prometheus metrics collection server.")
parser.add_argument("--config", type=str, help="Path to the JSON configuration file")
parser.add_argument("--start", action="store_true", help="Start the Prometheus HTTP server")
parser.add_argument("--port", type=int, default=9090, help="Port number for the Prometheus HTTP server")
parser.add_argument("--port", type=int, default=8000, help="Port number for the Prometheus HTTP server")

return parser


if __name__ == "__main__":

p = parse_arguments()
args = p.parse_args()
if args.start:
current_dir = os.path.dirname(__file__)
app_metrics_config_path = os.path.join(current_dir, "app_metrics_config.json")
metrics_store = load_metrics_config(app_metrics_config_path)
if args.config:
metrics_store = load_metrics_config(args.config)

run_http_server(args.port)
# Keep the main thread alive to prevent the server from shutting down
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
logger.info("Shutting down the server...")
print("Shutting down the server...")
else:
p.print_help()
13 changes: 13 additions & 0 deletions nvflare/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
21 changes: 21 additions & 0 deletions nvflare/metrics/metrics_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


class MetricKeys:

metric_name = "metric_name"
value = "value"
labels = "labels"
timestamp = "timestamp"
35 changes: 35 additions & 0 deletions nvflare/metrics/metrics_publisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time

from nvflare.apis.fl_constant import ReservedTopic
from nvflare.fuel.data_event.data_bus import DataBus
from nvflare.metrics.metrics_keys import MetricKeys


def publish_app_metrics(metrics: dict, metric_name: str, labels: dict, data_bus: DataBus):
metrics_data = []

for key in metrics:
metrics_value = metrics.get(key)
metrics_data.append(
{
MetricKeys.metric_name: f"{metric_name}_{key}" if metric_name else key,
MetricKeys.value: metrics_value,
MetricKeys.labels: {} if labels is None else labels,
MetricKeys.timestamp: int(time.time()),
}
)
data_bus.publish([ReservedTopic.APP_METRICS], metrics_data)
25 changes: 9 additions & 16 deletions nvflare/private/fed/server/server_command_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@

import logging

from nvflare.apis.fl_constant import FLContextKey, ReservedTopic, ServerCommandKey
from nvflare.apis.fl_constant import FLContextKey, ServerCommandKey
from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx
from nvflare.fuel.data_event.data_bus import DataBus
from nvflare.fuel.f3.cellnet.cell import Cell
from nvflare.fuel.f3.cellnet.core_cell import MessageHeaderKey, ReturnCode, make_reply
from nvflare.fuel.f3.message import Message as CellMessage
from nvflare.metrics.metrics_publisher import publish_app_metrics
from nvflare.private.defs import CellChannel, CellMessageHeaderKeys, new_cell_message
from nvflare.utils.collect_time_ctx import CollectTimeContext

Expand All @@ -28,10 +29,11 @@

class ServerCommandAgent(object):
def __init__(self, engine, cell: Cell) -> None:
"""To init the CommandAgent.
"""
To init the CommandAgent.
Args:
listen_port: port to listen the command
engine:
cell:
"""
self.logger = logging.getLogger(self.__class__.__name__)
self.asked_to_stop = False
Expand All @@ -54,11 +56,11 @@ def start(self):
self.logger.info(f"ServerCommandAgent cell register_request_cb: {self.cell.get_fqcn()}")

def execute_command(self, request: CellMessage) -> CellMessage:
metrics_group = ""
metrics_name = ""
try:
with CollectTimeContext() as context:
command_name = self.get_command_name(request)
metrics_group = command_name
metrics_name = command_name

# data = fobs.loads(request.payload)
data = request.payload
Expand Down Expand Up @@ -98,7 +100,7 @@ def execute_command(self, request: CellMessage) -> CellMessage:
else:
return make_reply(ReturnCode.INVALID_REQUEST, "No server command found", None)
finally:
self.publish_app_metrics(context.metrics, metrics_group)
publish_app_metrics(context.metrics, metrics_name, {}, self.data_bus)

def get_command_name(self, request):
if not isinstance(request, CellMessage):
Expand Down Expand Up @@ -140,14 +142,5 @@ def aux_communicate(self, request: CellMessage) -> CellMessage:
return_message = new_cell_message({}, None)
return return_message

def publish_app_metrics(self, metrics: dict, metric_group: str):
metrics_data = {}
for metric_name in metrics:
label = f"{metric_group}_{metric_name}"
metrics_value = metrics.get(metric_name)
metrics_data.update({label: metrics_value})

self.data_bus.publish([ReservedTopic.APP_METRICS], metrics_data)

def shutdown(self):
self.asked_to_stop = True
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@
"site-2": 1.0
}
}
},
{
"id": "prometheus",
"path": "nvflare.fuel_opt.prometheus.metrics_collector.MetricsCollector"
}
],
"workflows": [
Expand Down

0 comments on commit 43d8656

Please sign in to comment.