Skip to content

Commit

Permalink
Merge pull request #345 from shunichironomura/add-slack-reporter
Browse files Browse the repository at this point in the history
add-slack-reporter
  • Loading branch information
shunichironomura authored Sep 23, 2024
2 parents d53eaac + 8d4d64b commit b5bcc8c
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 17 deletions.
3 changes: 2 additions & 1 deletion capsula/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"PlatformContext",
"ReporterBase",
"Run",
"SlackReporter",
"TimeWatcher",
"UncaughtExceptionWatcher",
"WatcherBase",
Expand Down Expand Up @@ -47,7 +48,7 @@
from ._decorator import context, pass_pre_run_capsule, reporter, run, watcher
from ._encapsulator import Encapsulator
from ._exceptions import CapsulaConfigurationError, CapsulaError, CapsulaUninitializedError
from ._reporter import JsonDumpReporter, ReporterBase
from ._reporter import JsonDumpReporter, ReporterBase, SlackReporter
from ._root import current_run_name, record
from ._run import CapsuleParams, CommandInfo, FuncInfo, Run
from ._utils import search_for_project_root
Expand Down
22 changes: 20 additions & 2 deletions capsula/_cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from datetime import datetime, timezone
from enum import Enum
from random import choices
from string import ascii_letters, digits
from typing import Literal, NoReturn

import typer
Expand All @@ -8,7 +11,12 @@
from ._backport import Annotated
from ._config import load_config
from ._context import ContextBase
from ._run import CapsuleParams, generate_default_run_dir, get_project_root
from ._run import (
CapsuleParams,
default_run_name_factory,
get_default_vault_dir,
get_project_root,
)
from ._utils import get_default_config_path

app = typer.Typer()
Expand Down Expand Up @@ -41,10 +49,20 @@ def enc(
reporters = config[phase_key]["reporters"]

exec_info = None
run_dir = generate_default_run_dir(exec_info=exec_info)

vault_dir = get_default_vault_dir(exec_info)

run_name = default_run_name_factory(
exec_info,
"".join(choices(ascii_letters + digits, k=4)),
datetime.now(timezone.utc),
)

run_dir = vault_dir / run_name
run_dir.mkdir(exist_ok=True, parents=True)
params = CapsuleParams(
exec_info=exec_info,
run_name=run_name,
run_dir=run_dir,
phase=phase.value,
project_root=get_project_root(exec_info),
Expand Down
3 changes: 2 additions & 1 deletion capsula/_reporter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__all__ = ["JsonDumpReporter", "ReporterBase"]
__all__ = ["JsonDumpReporter", "ReporterBase", "SlackReporter"]
from ._base import ReporterBase
from ._json import JsonDumpReporter
from ._slack import SlackReporter
56 changes: 56 additions & 0 deletions capsula/_reporter/_slack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Callable, ClassVar, Literal

from slack_sdk import WebClient

from ._base import ReporterBase

if TYPE_CHECKING:
from capsula._capsule import Capsule
from capsula._run import CapsuleParams

logger = logging.getLogger(__name__)


class SlackReporter(ReporterBase):
"""Reporter to send a message to a Slack channel."""

_run_name_to_thread_ts: ClassVar[dict[str, str]] = {}

@classmethod
def builder(
cls,
*,
channel: str,
token: str,
) -> Callable[[CapsuleParams], SlackReporter]:
def build(params: CapsuleParams) -> SlackReporter:
return cls(phase=params.phase, channel=channel, token=token, run_name=params.run_name)

return build

def __init__(self, *, phase: Literal["pre", "in", "post"], channel: str, token: str, run_name: str) -> None:
self._phase = phase
self._channel = channel
self._token = token
self._run_name = run_name

def report(self, capsule: Capsule) -> None: # noqa: ARG002
client = WebClient(token=self._token)
thread_ts = SlackReporter._run_name_to_thread_ts.get(self._run_name)
if self._phase == "pre":
message = f"Capsule run `{self._run_name}` started"
response = client.chat_postMessage(channel=self._channel, text=message, thread_ts=thread_ts)
SlackReporter._run_name_to_thread_ts[self._run_name] = response["ts"]
elif self._phase == "in":
pass # Do nothing for now
elif self._phase == "post":
message = f"Capsule run `{self._run_name}` completed"
response = client.chat_postMessage(
channel=self._channel,
text=message,
thread_ts=thread_ts,
)
SlackReporter._run_name_to_thread_ts[self._run_name] = response["ts"]
16 changes: 3 additions & 13 deletions capsula/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class CommandInfo:
@dataclass
class CapsuleParams:
exec_info: FuncInfo | CommandInfo | None
run_name: str
run_dir: Path
phase: Literal["pre", "in", "post"]
project_root: Path
Expand All @@ -64,7 +65,7 @@ class CapsuleParams:
ExecInfo: TypeAlias = Union[FuncInfo, CommandInfo]


def get_vault_dir(exec_info: ExecInfo | None) -> Path:
def get_default_vault_dir(exec_info: ExecInfo | None) -> Path:
project_root = get_project_root(exec_info)
return project_root / "vault"

Expand All @@ -85,18 +86,6 @@ def default_run_name_factory(exec_info: ExecInfo | None, random_str: str, timest
return ("" if exec_name is None else f"{exec_name}_") + f"{datetime_str}_{random_str}"


def generate_default_run_dir(exec_info: ExecInfo | None = None) -> Path:
vault_dir = get_vault_dir(exec_info)

run_name = default_run_name_factory(
exec_info,
"".join(choices(ascii_letters + digits, k=4)),
datetime.now(timezone.utc),
)

return vault_dir / run_name


def get_project_root(exec_info: ExecInfo | None = None) -> Path:
if exec_info is None or isinstance(exec_info, CommandInfo):
return search_for_project_root(Path.cwd())
Expand Down Expand Up @@ -322,6 +311,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: # noqa: C901, PLR09

params = CapsuleParams(
exec_info=func_info,
run_name=run_name,
run_dir=self._run_dir,
phase="pre",
project_root=get_project_root(func_info),
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
"typing-extensions>=4.7.1; python_version<'3.11'",
"orjson>=3.9.15",
"typer>=0.9.0",
"slack-sdk>=3.33.1",
]
license = { file = "LICENSE" }
keywords = ["reproducibility", "cli"]
Expand Down

0 comments on commit b5bcc8c

Please sign in to comment.