Skip to content

Commit

Permalink
Adding INF2 (transformers-neuronx) compilation latencies to SageMaker…
Browse files Browse the repository at this point in the history
… Health Metrics (#1185)

Co-authored-by: Tyler Osterberg <[email protected]>
  • Loading branch information
Lokiiiiii and tosterberg authored Oct 17, 2023
1 parent 182ae0e commit 1da14b1
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 2 deletions.
23 changes: 22 additions & 1 deletion engines/python/build.gradle
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
group "ai.djl.python"

def stripPackageVersion() {
def initFile = file("setup/djl_python/__init__.py")
def initLines = initFile.readLines()
def versionRegex = /^__version__.*$/
def overwriteInitLines = initLines.findAll { !(it =~ versionRegex) }

initFile.withWriter(out -> {
overwriteInitLines.each {
out.println it
}
})
}

dependencies {
api platform("ai.djl:bom:${project.version}")
api "ai.djl:api"
Expand All @@ -21,6 +34,13 @@ sourceSets {
}

processResources {
doFirst {
stripPackageVersion()
def versionLine = new StringBuilder()
versionLine.append("__version__ = '${djl_version}'")
file("setup/djl_python/__init__.py").append(versionLine)
}

exclude "build", "*.egg-info", "__pycache__", "PyPiDescription.rst", "setup.py"
outputs.file file("${project.buildDir}/classes/java/main/native/lib/python.properties")
doLast {
Expand Down Expand Up @@ -61,4 +81,5 @@ clean.doFirst {
delete "setup/djl_python/scheduler/__pycache__/"
delete "src/test/resources/accumulate/__pycache__/"
delete System.getProperty("user.home") + "/.djl.ai/python"
}
stripPackageVersion()
}
46 changes: 46 additions & 0 deletions engines/python/setup/djl_python/sm_log_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/usr/bin/env python
#
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.


import copy
from collections import defaultdict
from djl_python import __version__
import logging


# https://docs.aws.amazon.com/deep-learning-containers/latest/devguide/logging-and-monitoring.html
class SMLogFilter(logging.Filter):
sm_log_markers = ['ModelServerError', 'UserScriptError', 'SysHealth']
counter = defaultdict(int)

def filter(self, record):
try:
if isinstance(record.msg, str):
for i in self.sm_log_markers:
if record.msg.startswith(i+':'):
altered_record = copy.deepcopy(record)
tag, metric_name, metric = [i.strip() for i in altered_record.msg.split(':')]
value, units = metric.split(' ')
altered_metric_name = ''.join([word[0].upper()+word[1:] for word in metric_name.split(' ')])
altered_record.msg = f"{tag}.Count:{self.count(altered_metric_name)}|#DJLServing:{__version__},{altered_metric_name}:{value} {units}"
return altered_record
return False
else:
return False
except Exception as exc:
logging.warning(f"Forwarding {str(record)} failed due to {str(exc)}")
return False

def count(self, key):
self.counter[key] += 1
return self.counter[key]
43 changes: 43 additions & 0 deletions engines/python/setup/djl_python/tests/test_sm_log_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import unittest
from unittest.mock import MagicMock
from djl_python.sm_log_filter import SMLogFilter
import logging


class TestSMLogFilter(unittest.TestCase):

def test_filter_hit(self):
filter = SMLogFilter()

record = MagicMock()
record.msg = f"SysHealth: LLM sharding and compilation latency: 845.62 secs"
actual = filter.filter(record).msg
expected = "SysHealth.Count:1|#DJLServing:0.24.0,LLMShardingAndCompilationLatency:845.62 secs"
self.assertEqual(actual.split('|')[0], expected.split('|')[0])
self.assertEqual(actual.split(':')[1], expected.split(':')[1])
self.assertEqual(actual.split(',')[1], expected.split(',')[1])

record = MagicMock()
record.msg = f"SysHealth: LLM sharding and compilation latency: 845.62 secs"
actual = filter.filter(record).msg
expected = "SysHealth.Count:2|#DJLServing:0.24.0,LLMShardingAndCompilationLatency:845.62 secs"
self.assertEqual(actual.split('|')[0], expected.split('|')[0])
self.assertEqual(actual.split(':')[1], expected.split(':')[1])
self.assertEqual(actual.split(',')[1], expected.split(',')[1])

def test_filter_warning(self):
filter = SMLogFilter()
record = MagicMock()
record.msg = f"SysHealth: LLM sharding and compilation latency: 845.62 : secs"
actual = filter.filter(record)

with self.assertLogs(level=logging.WARNING):
filter.filter(record)

def test_filter_miss(self):
filter = SMLogFilter()
record = MagicMock()
record.msg = f"LLM sharding and compilation latency: 845.62 : secs"
actual = filter.filter(record)
self.assertFalse(actual)

2 changes: 1 addition & 1 deletion engines/python/setup/djl_python/transformers_neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def load_model(self, model_type):
self.model.to_neuron()
os.chdir(path)
elapsed = time.time() - start
logging.info(f"LLM sharding and compiling completed with {elapsed}s")
logging.info(f"SysHealth: LLM sharding and compilation latency: {elapsed} secs")

def initialize(self, properties):
# Neuron recommendation for transformersneuronx speedup
Expand Down
11 changes: 11 additions & 0 deletions engines/python/setup/djl_python_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from djl_python.inputs import Input
from djl_python.outputs import Output
from djl_python.service_loader import load_model_service
from djl_python.sm_log_filter import SMLogFilter

SOCKET_ACCEPT_TIMEOUT = 30.0

Expand Down Expand Up @@ -145,6 +146,7 @@ def main():
logging.basicConfig(stream=sys.stdout,
format="%(message)s",
level=logging.INFO)
configure_sm_logging()
logging.info(
f"{pid} - djl_python_engine started with args: {sys.argv[1:]}")
args = ArgParser.python_engine_args().parse_args()
Expand Down Expand Up @@ -175,6 +177,15 @@ def main():
os.remove(pid_file)


def configure_sm_logging():
if 'SM_TELEMETRY_LOG_REV_2022_12' in os.environ:
# https://docs.aws.amazon.com/deep-learning-containers/latest/devguide/logging-and-monitoring.html
root_logger = logging.getLogger()
sm_log_handler = logging.FileHandler(filename=os.getenv('SM_TELEMETRY_LOG_REV_2022_12'))
sm_log_handler.addFilter(SMLogFilter())
root_logger.addHandler(sm_log_handler)


if __name__ == "__main__":
main()
exit(1)
11 changes: 11 additions & 0 deletions engines/python/setup/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

import os
import setuptools.command.build_py
from setuptools import setup, find_packages

Expand All @@ -28,6 +29,15 @@ def detect_version():
return None


def add_version(version_str):
djl_version_string = f"__version__ = '{version_str}'"
with open(os.path.join("djl_python", "__init__.py"), 'r') as f:
existing = [i.strip() for i in f.readlines()]
with open(os.path.join("djl_python", "__init__.py"), 'a') as f:
if djl_version_string not in existing:
f.writelines(['\n', djl_version_string])


def pypi_description():
with open('PyPiDescription.rst') as df:
return df.read()
Expand All @@ -41,6 +51,7 @@ def run(self):

if __name__ == '__main__':
version = detect_version()
add_version(version)

requirements = ['psutil', 'packaging', 'wheel']

Expand Down

0 comments on commit 1da14b1

Please sign in to comment.