Skip to content

Commit

Permalink
Improve async task flow for connections
Browse files Browse the repository at this point in the history
  • Loading branch information
merlinz01 committed Oct 25, 2024
1 parent 47335b0 commit edee67d
Show file tree
Hide file tree
Showing 10 changed files with 315 additions and 148 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ dist/
node_modules
*.tar.gz
*.whl
.ruff-cache
.ruff-cache
tests/.test-data
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ This is the changelog for RedPepper.

- Add option to disable TLS keys file mode check.

### Changed

- Improve async task flow for connections for better reliability and testability.

## [0.0.15]

### Security
Expand Down
65 changes: 38 additions & 27 deletions src/agent/redpepper/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from redpepper.common.connection import Connection
from redpepper.common.messages_pb2 import (
BYE,
CLIENTHELLO,
COMMAND,
COMMANDPROGRESS,
Expand All @@ -37,45 +38,67 @@
class Agent:
"""RedPepper Agent"""

config: AgentConfig
"""Agent configuration"""

conn: Connection
"""Connection to the manager"""

connected: trio.Event
"""Event that is set when the agent is connected"""

def __init__(self, config: AgentConfig, config_file=None):
self.config = config
self.conn: Connection
self.data_slots: dict[int, Slot] = {}
self.last_message_id = 100
self.tls_context = config.load_tls_context(ssl.Purpose.SERVER_AUTH)
self.connected = trio.Event()

async def connect(self):
async def run(self):
"""Run the agent"""
host = self.config.manager_host
port = self.config.manager_port
self.remote_address = (host, port)
logger.info("Connecting to manager at %s:%s", host, port)
stream = await trio.open_ssl_over_tcp_stream(
host, port, ssl_context=self.tls_context
)
self.conn = Connection(
stream, self.config.ping_timeout, self.config.ping_interval
)
logger.debug("Performing SSL handshake with manager")
await self.conn.stream.do_handshake()
await stream.do_handshake()
logger.info("Connected to manager at %s:%s", host, port)
self.conn = Connection(self.config, stream)
await self.handshake()
self.conn.message_handlers[COMMAND] = self.handle_command
self.conn.message_handlers[RESPONSE] = self.handle_response
self.connected.set()
await self.conn.run()

async def shutdown(self):
await self.conn.close()

async def handshake(self):
hello_slot = Slot()
self.conn.message_handlers[SERVERHELLO] = hello_slot.set # type: ignore
hello = Message()
hello.type = CLIENTHELLO
hello.client_hello.clientID = self.config.agent_id
hello.client_hello.auth = self.config.agent_secret.get_secret_value()
logger.debug("Sending client hello message to manager")
await self.conn.send_message(hello)
await self.conn.send_message_direct(hello)
try:
server_hello: Message = await hello_slot.get(self.config.hello_timeout)
with trio.fail_after(self.config.hello_timeout):
server_hello: Message = await self.conn.receive_message_direct()
except trio.TooSlowError:
logger.error("Handshake timed out")
await self.conn.close()
return
finally:
del self.conn.message_handlers[SERVERHELLO]
raise
if server_hello.type == BYE:
logger.error("Authentication failed: %s", server_hello.bye.reason)
await self.conn.close()
raise ValueError(f"Authentication failed: {server_hello.bye.reason}")
if server_hello.type != SERVERHELLO:
logger.error("Expected server hello message, got %s", server_hello.type)
await self.conn.close()
raise ValueError(
"Expected server hello message, got %s" % server_hello.type
)
logger.debug(
"Checking server hello message with version %s",
server_hello.server_hello.version,
Expand All @@ -85,9 +108,7 @@ async def handshake(self):
"Unsupported server version %s", server_hello.server_hello.version
)
await self.conn.close()
return
self.conn.message_handlers[COMMAND] = self.handle_command
self.conn.message_handlers[RESPONSE] = self.handle_response
raise ValueError("Unsupported server version")

async def handle_command(self, message: Message):
cmdtype = message.command.type
Expand Down Expand Up @@ -474,13 +495,3 @@ def evaluate_condition(self, condition):
raise ValueError(f"Invalid file condition verb {verb}")
logger.error("Invalid condition name: %s", k)
raise ValueError(f"Invalid condition name: {k!r}")

async def run(self):
"""Run the agent"""
await self.connect()
async with trio.open_nursery() as nursery:
nursery.start_soon(self.conn.run)
nursery.start_soon(self.handshake)

async def shutdown(self):
await self.conn.close()
1 change: 1 addition & 0 deletions src/common/redpepper/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def from_file(
class ConnectionConfig(pydantic.BaseModel):
ping_timeout: int = 5
ping_interval: int = 30
max_message_size: int = 1024 * 1024


class TLSConfig(pydantic.BaseModel):
Expand Down
Loading

0 comments on commit edee67d

Please sign in to comment.