Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(github/genai): add github vcs plugin for genai chatbot to make prs #5243

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements-base.in
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pdpyras
protobuf<4.24.0,>=3.6.1
psycopg2-binary
pydantic==1.*
PyGithub
pyparsing
python-dateutil
python-jose
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def run(self):
"duo_auth_mfa = dispatch.plugins.dispatch_duo.plugin:DuoMfaPlugin",
"generic_workflow = dispatch.plugins.generic_workflow.plugin:GenericWorkflowPlugin",
"github_monitor = dispatch.plugins.dispatch_github.plugin:GithubMonitorPlugin",
"github_version_control = dispatch.plugins.dispatch_github_vcs.plugin:GithubVersionControlPlugin",
"google_calendar_conference = dispatch.plugins.dispatch_google.calendar.plugin:GoogleCalendarConferencePlugin",
"google_docs_document = dispatch.plugins.dispatch_google.docs.plugin:GoogleDocsDocumentPlugin",
"google_drive_storage = dispatch.plugins.dispatch_google.drive.plugin:GoogleDriveStoragePlugin",
Expand Down
24 changes: 24 additions & 0 deletions src/dispatch/plugins/bases/version_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
.. module: dispatch.plugins.bases.version_control
:platform: Unix
:copyright: (c) 2024 by Netflix Inc., see AUTHORS for more
:license: Apache, see LICENSE for more details.
"""

from dispatch.plugins.base import Plugin


class VesionControlPlugin(Plugin):
type = "version-control"

def get_repo(self, **kwargs):
raise NotImplementedError

def clone_repo(self, **kwargs):
raise NotImplementedError

def create_pr(self, **kwargs):
raise NotImplementedError

def close_pr(self, **kwargs):
raise NotImplementedError
Empty file.
1 change: 1 addition & 0 deletions src/dispatch/plugins/dispatch_github_vcs/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.1.0"
17 changes: 17 additions & 0 deletions src/dispatch/plugins/dispatch_github_vcs/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from pydantic import Field, HttpUrl, SecretStr

from dispatch.config import BaseConfigurationModel


class GithubConfiguration(BaseConfigurationModel):
"""Github configuration description."""

pat: SecretStr = Field(
title="Personal Access Token",
description="Fine-grained personal access tokens.",
)
base_url: HttpUrl = Field(
default="https://api.github.com",
title="GitHub API Base URL",
description="The base URL for the GitHub API. Use this to specify a GitHub Enterprise instance.",
)
143 changes: 143 additions & 0 deletions src/dispatch/plugins/dispatch_github_vcs/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from typing import Any

from github import Github, GithubException
from github.ContentFile import ContentFile
from github.PullRequest import PullRequest
from github.Repository import Repository

from dispatch.decorators import apply, counter, timer
from dispatch.plugins.bases.version_control import VesionControlPlugin

from ._version import __version__
from .config import GithubConfiguration


@apply(counter, exclude=["__init__"])
@apply(timer, exclude=["__init__"])
class GithubVersionControlPlugin(VesionControlPlugin):
title = "Github Plugin - Version Control"
slug = "github-version-control"
description = "Allows for interaction with Github Enterprise Server."
version = __version__

author = "Netflix"
author_url = "https://github.com/netflix/dispatch.git"

def __init__(self) -> None:
self.configuration_schema: type[GithubConfiguration] = GithubConfiguration
self.github_client: Github | None = None

def _initialize_client(self) -> None:
if not self.github_client:
token: str = self.configuration.pat.get_secret_value()
base_url: str = str(self.configuration.base_url)
self.github_client = Github(base_url=base_url, login_or_token=token)

def get_repo(self, repo_name: str) -> Repository:
"""Get a repository object."""
self._initialize_client()
try:
return self.github_client.get_repo(repo_name)
except GithubException as e:
raise Exception(f"Failed to get repository: {str(e)}") from e

def get_file_content(self, repo_name: str, file_path: str, ref: str = "main") -> str:
"""Get the content of a file from a repository."""
self._initialize_client()
repo: Repository = self.get_repo(repo_name)

try:
content_file: ContentFile = repo.get_contents(file_path, ref=ref)
return content_file.decoded_content.decode("utf-8")
except GithubException as e:
raise Exception(f"Failed to get file content: {str(e)}") from e

def create_pr(
self,
repo_name: str,
branch_name: str,
base_branch: str,
title: str,
body: str,
file_path: str,
file_content: str,
) -> int:
"""Create a pull request with detection tuning changes."""
self._initialize_client()
repo: Repository = self.get_repo(repo_name)

try:
# Create a new branch
source_branch = repo.get_branch(base_branch)
repo.create_git_ref(ref=f"refs/heads/{branch_name}", sha=source_branch.commit.sha)

# Create or update file in the new branch
repo.create_file(
path=file_path,
message=f"Update detection rules: {title}",
content=file_content,
branch=branch_name,
)

# Create pull request
pr: PullRequest = repo.create_pull(
title=title, body=body, head=branch_name, base=base_branch
)

return pr.number

except GithubException as e:
raise Exception(f"Failed to create pull request: {str(e)}") from e

def close_pr(self, repo_name: str, pr_number: int) -> bool:
"""Close a pull request."""
self._initialize_client()
repo: Repository = self.get_repo(repo_name)

try:
pr: PullRequest = repo.get_pull(pr_number)
pr.edit(state="closed")
return True
except GithubException as e:
raise Exception(f"Failed to close pull request: {str(e)}") from e

def update_pr(
self, repo_name: str, pr_number: int, file_path: str, file_content: str, commit_message: str
) -> bool:
"""Update an existing pull request with new changes."""
self._initialize_client()
repo: Repository = self.get_repo(repo_name)

try:
pr: PullRequest = repo.get_pull(pr_number)
branch_name: str = pr.head.ref

# Update file in the PR's branch
contents: ContentFile = repo.get_contents(file_path, ref=branch_name)
repo.update_file(
path=file_path,
message=commit_message,
content=file_content,
sha=contents.sha,
branch=branch_name,
)

return True
except GithubException as e:
raise Exception(f"Failed to update pull request: {str(e)}") from e

def get_pr_status(self, repo_name: str, pr_number: int) -> dict[str, Any]:
"""Get the status of a pull request."""
self._initialize_client()
repo: Repository = self.get_repo(repo_name)

try:
pr: PullRequest = repo.get_pull(pr_number)
return {
"state": pr.state,
"merged": pr.merged,
"mergeable": pr.mergeable,
"mergeable_state": pr.mergeable_state,
}
except GithubException as e:
raise Exception(f"Failed to get pull request status: {str(e)}") from e
88 changes: 88 additions & 0 deletions src/dispatch/plugins/dispatch_slack/case/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -2505,3 +2505,91 @@ def handle_engagement_deny_submission_event(
channel=case.conversation.channel_id,
ts=signal_instance.engagement_thread_ts,
)


def tune_detection(
case: Case,
db_session: Session,
client: WebClient,
):
# Get associated signal
signal = case.signal_instances[0].signal if case.signal_instances else None

if not signal:
return "No associated signal found for this case."

# Get GitHub plugin
github_plugin = plugin_service.get_active_instance(
db_session=db_session, project_id=case.project.id, plugin_type="github-version-control"
)

if not github_plugin:
return "GitHub plugin is not configured for this project."

# Get current detection content
repo_name = "det-geiger-detections"
# repo_name = signal.github_repo or case.project.github_repo # Fallback to project-level config
file_path = f"detections/{signal.name}.yaml" # Adjust the path as needed
current_content = github_plugin.get_file_content(repo_name, file_path)

# Use AI to suggest improvements
ai_plugin = plugin_service.get_active_instance(
db_session=db_session, project_id=case.project.id, plugin_type="artificial-intelligence"
)

if not ai_plugin:
return "AI plugin is not configured for this project."

prompt = f"""
Given the following detection rule:

{current_content}

And considering the case details:

{case.description}

Suggest improvements to the detection rule to reduce false positives and increase accuracy.
"""

ai_response = ai_plugin.chat_completion([{"role": "user", "content": prompt}])

# Create a pull request with the suggested changes
branch_name = f"tune-detection-{case.id}"
pr_title = f"Tune detection for case {case.id}"
pr_body = f"Suggested improvements for detection related to case {case.id}\n\n{ai_response}"

pr_number = github_plugin.create_pr(
repo_name, branch_name, "main", pr_title, pr_body, file_path, ai_response
)

return f"Created a pull request (#{pr_number}) with suggested improvements to the detection rule. Please review and merge if appropriate."


@message_dispatcher.add(subject=CaseSubjects.case)
def handle_case_message(
ack: Ack,
body: dict,
user: DispatchUser,
context: BoltContext,
db_session: Session,
client: WebClient,
):
ack()

# Check if the message is in a thread
if "thread_ts" not in body:
return

message_text = body["text"].lower()

# Check for keywords related to tuning the detection
tuning_keywords = ["tune detection", "improve detection", "update signal", "refine signal"]
if any(keyword in message_text for keyword in tuning_keywords):
case = case_service.get(db_session=db_session, case_id=context["subject"].id)

# Call the tune_detection function
result = tune_detection(case, db_session, client)

# Post the result in the thread
client.chat_postMessage(channel=body["channel"], thread_ts=body["thread_ts"], text=result)
Loading