Skip to content
This repository has been archived by the owner on Jul 15, 2024. It is now read-only.

Commit

Permalink
Merge pull request caikit#634 from HonakerM/move_service_names_to_int…
Browse files Browse the repository at this point in the history
…erfaces

Move Service/Server Constant Generation to Interfaces
  • Loading branch information
gabe-l-hart authored Jan 8, 2024
2 parents 0486d1a + bbb3c53 commit dff142c
Show file tree
Hide file tree
Showing 10 changed files with 399 additions and 167 deletions.
22 changes: 22 additions & 0 deletions caikit/core/toolkit/name_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Common string functions that are generally helpful for generating runtime RPC names
and other Protobuf names
"""


def snake_to_upper_camel(string: str) -> str:
"""Simple snake -> upper camel conversion for descriptors"""
return "".join([part[0].upper() + part[1:] for part in string.split("_")])
57 changes: 13 additions & 44 deletions caikit/runtime/http_server/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@
# Standard
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from enum import Enum
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Type, Union, get_args
import asyncio
import inspect
import io
import json
import os
import re
import signal
import ssl
import tempfile
Expand Down Expand Up @@ -71,6 +69,16 @@
CaikitCoreStatusCode,
)
from caikit.core.toolkit.sync_to_async import async_wrap_iter
from caikit.runtime.names import (
HEALTH_ENDPOINT,
MODEL_ID,
MODELS_INFO_ENDPOINT,
OPTIONAL_INPUTS_KEY,
REQUIRED_INPUTS_KEY,
RUNTIME_INFO_ENDPOINT,
StreamEventTypes,
get_http_route_name,
)
from caikit.runtime.server_base import RuntimeServerBase
from caikit.runtime.service_factory import ServicePackage
from caikit.runtime.service_generation.rpcs import (
Expand Down Expand Up @@ -120,26 +128,6 @@
}


# These keys are used to define the logical sections of the request and response
# data structures.
REQUIRED_INPUTS_KEY = "inputs"
OPTIONAL_INPUTS_KEY = "parameters"
MODEL_ID = "model_id"

# Endpoint to use for health checks
HEALTH_ENDPOINT = "/health"

# Endpoint to use for server info
RUNTIME_INFO_ENDPOINT = "/info/version"
MODELS_INFO_ENDPOINT = "/info/models"


# Stream event types enum
class StreamEventTypes(Enum):
MESSAGE = "message"
ERROR = "error"


# Small dataclass for consolidating TLS files
@dataclass
class _TlsFiles:
Expand Down Expand Up @@ -429,7 +417,7 @@ def _train_add_unary_input_unary_output_handler(self, rpc: CaikitRPCBase):
pydantic_response = dataobject_to_pydantic(response_data_object)

@self.app.post(
self._get_route(rpc),
get_http_route_name(rpc.name),
responses=self._get_response_openapi(
response_data_object, pydantic_response
),
Expand Down Expand Up @@ -490,7 +478,7 @@ def _add_unary_input_unary_output_handler(self, rpc: TaskPredictRPC):
pydantic_response = dataobject_to_pydantic(response_data_object)

@self.app.post(
self._get_route(rpc),
get_http_route_name(rpc.name),
responses=self._get_response_openapi(
response_data_object, pydantic_response
),
Expand Down Expand Up @@ -573,7 +561,7 @@ def _add_unary_input_stream_output_handler(self, rpc: CaikitRPCBase):

# pylint: disable=unused-argument
@self.app.post(
self._get_route(rpc),
get_http_route_name(rpc.name),
response_model=pydantic_response,
openapi_extra=self._get_request_openapi(pydantic_request),
)
Expand Down Expand Up @@ -656,25 +644,6 @@ async def _generator() -> pydantic_response:

return EventSourceResponse(_generator())

def _get_route(self, rpc: CaikitRPCBase) -> str:
"""Get the REST route for this rpc"""
if rpc.name.endswith("Predict"):
task_name = re.sub(
r"(?<!^)(?=[A-Z])",
"-",
re.sub("Task$", "", re.sub("Predict$", "", rpc.name)),
).lower()
route = "/".join([self.config.runtime.http.route_prefix, "task", task_name])
if route[0] != "/":
route = "/" + route
return route
if rpc.name.endswith("Train"):
route = "/".join([self.config.runtime.http.route_prefix, rpc.name])
if route[0] != "/":
route = "/" + route
return route
raise NotImplementedError("No support for train rpcs yet!")

def _get_request_dataobject(self, rpc: CaikitRPCBase) -> Type[DataBase]:
"""Get the dataobject request for the given rpc"""
is_inference_rpc = hasattr(rpc, "task")
Expand Down
Loading

0 comments on commit dff142c

Please sign in to comment.