Skip to content

Commit

Permalink
Update to use the NWC versioning system.
Browse files Browse the repository at this point in the history
  • Loading branch information
jklein24 committed Oct 11, 2024
1 parent 1bd63a1 commit ffeda0b
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 9 deletions.
101 changes: 96 additions & 5 deletions nwc_backend/event_handlers/__tests__/nip47_event_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,12 @@ def create_request_event(
self,
method: Nip47RequestMethod = Nip47RequestMethod.PAY_INVOICE,
params: Optional[dict[str, Any]] = None,
version: Optional[str] = "1.0",
use_nip44: bool = True,
) -> Event:
if params is None:
params = self.get_default_request_params()
return (
builder = (
EventBuilder(
kind=KindEnum.WALLET_CONNECT_REQUEST(), # pyre-ignore[6]
content=json.dumps(
Expand All @@ -70,8 +71,10 @@ def create_request_event(
)
.encrypt_content(self.nwc_keys.public_key(), use_nip44=use_nip44)
.add_tag(["p", self.nwc_keys.public_key().to_hex()])
.build()
)
if version:
builder.add_tag(["v", version])
return builder.build()

def get_default_request_params(self) -> dict[str, Any]:
return {
Expand Down Expand Up @@ -191,7 +194,10 @@ async def test_failed__invalid_input_params(
granted_permissions_groups=[PermissionsGroup.SEND_PAYMENTS],
keys=harness.client_app_keys,
)
request_event = harness.create_request_event(params={}, use_nip44=use_nip44)
version = "1.0" if use_nip44 else None
request_event = harness.create_request_event(
params={}, use_nip44=use_nip44, version=version
)
await handle_nip47_event(request_event)

mock_nostr_send.assert_called_once()
Expand Down Expand Up @@ -232,7 +238,10 @@ async def test_succeeded(
granted_permissions_groups=[PermissionsGroup.SEND_PAYMENTS],
keys=harness.client_app_keys,
)
request_event = harness.create_request_event(use_nip44=use_nip44)
version = "1.0" if use_nip44 else None
request_event = harness.create_request_event(
use_nip44=use_nip44, version=version
)
await handle_nip47_event(request_event)

mock_nostr_send.assert_called_once()
Expand Down Expand Up @@ -289,7 +298,10 @@ async def test_failed__vasp_error_response(
granted_permissions_groups=[PermissionsGroup.SEND_PAYMENTS],
keys=harness.client_app_keys,
)
request_event = harness.create_request_event(use_nip44=use_nip44)
version = "1.0" if use_nip44 else None
request_event = harness.create_request_event(
use_nip44=use_nip44, version=version
)
await handle_nip47_event(request_event)

mock_nostr_send.assert_called_once()
Expand Down Expand Up @@ -327,3 +339,82 @@ async def test_duplicate_event(
result = await db.session.execute(select(Nip47Request))
request = result.scalars().one()
assert request.id == nip47_event.id


@patch("nwc_backend.nostr.nostr_client.nostr_client.send_event", new_callable=AsyncMock)
async def test_failed__invalid_version(
mock_nostr_send: AsyncMock,
test_client: QuartClient,
) -> None:
mock_nostr_send.return_value = SendEventOutput(
id=EventId.from_hex(token_hex()),
output=Output(success=["wss://relay.getalby.com/v1"], failed={}),
)
async with test_client.app.app_context():
harness = Harness.prepare()
await create_nwc_connection(
granted_permissions_groups=[PermissionsGroup.SEND_PAYMENTS],
keys=harness.client_app_keys,
)
request_event = harness.create_request_event(params={}, version="abc")
await handle_nip47_event(request_event)

mock_nostr_send.assert_called_once()
response_event = mock_nostr_send.call_args[0][0]
content = harness.validate_response_event(response_event, request_event.id())
assert content["result_type"] == Nip47RequestMethod.PAY_INVOICE.value
assert content["error"]["code"] == ErrorCode.OTHER.name


@patch("nwc_backend.nostr.nostr_client.nostr_client.send_event", new_callable=AsyncMock)
async def test_failed__unsupported(
mock_nostr_send: AsyncMock,
test_client: QuartClient,
) -> None:
mock_nostr_send.return_value = SendEventOutput(
id=EventId.from_hex(token_hex()),
output=Output(success=["wss://relay.getalby.com/v1"], failed={}),
)
async with test_client.app.app_context():
harness = Harness.prepare()
await create_nwc_connection(
granted_permissions_groups=[PermissionsGroup.SEND_PAYMENTS],
keys=harness.client_app_keys,
)
request_event = harness.create_request_event(params={}, version="10.0")
await handle_nip47_event(request_event)

mock_nostr_send.assert_called_once()
response_event = mock_nostr_send.call_args[0][0]
content = harness.validate_response_event(response_event, request_event.id())
assert content["result_type"] == Nip47RequestMethod.PAY_INVOICE.value
assert content["error"]["code"] == ErrorCode.NOT_IMPLEMENTED.name


@patch("nwc_backend.nostr.nostr_client.nostr_client.send_event", new_callable=AsyncMock)
async def test_failed__wrong_encryption_for_version(
mock_nostr_send: AsyncMock,
test_client: QuartClient,
) -> None:
mock_nostr_send.return_value = SendEventOutput(
id=EventId.from_hex(token_hex()),
output=Output(success=["wss://relay.getalby.com/v1"], failed={}),
)
async with test_client.app.app_context():
harness = Harness.prepare()
await create_nwc_connection(
granted_permissions_groups=[PermissionsGroup.SEND_PAYMENTS],
keys=harness.client_app_keys,
)
request_event = harness.create_request_event(
params={}, version="1.0", use_nip44=False
)
await handle_nip47_event(request_event)

mock_nostr_send.assert_called_once()
response_event = mock_nostr_send.call_args[0][0]
content = harness.validate_response_event(
response_event, request_event.id(), expect_nip44=False
)
assert content["result_type"] == Nip47RequestMethod.PAY_INVOICE.value
assert content["error"]["code"] == ErrorCode.OTHER.name
47 changes: 47 additions & 0 deletions nwc_backend/event_handlers/nip47_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from nwc_backend.models.nwc_connection import NWCConnection
from nwc_backend.nostr.nostr_client import nostr_client
from nwc_backend.nostr.nostr_config import NostrConfig
from nwc_backend.nostr.versions import ParsedVersion, is_version_supported


async def handle_nip47_event(event: Event) -> None:
Expand Down Expand Up @@ -77,6 +78,22 @@ async def handle_nip47_event(event: Event) -> None:
return

method = Nip47RequestMethod(content["method"])

try:
_check_version(event)
except Nip47RequestException as ex:
error_response = create_nip47_error_response(
event=event,
method=method,
error=Nip47Error(
code=ex.error_code,
message=ex.error_message,
),
use_nip44=not is_nip04_encrypted,
)
await nostr_client.send_event(error_response)
return

if not nwc_connection.has_command_permission(method):
error_response = create_nip47_error_response(
event=event,
Expand Down Expand Up @@ -187,3 +204,33 @@ async def handle_nip47_event(event: Event) -> None:
await nip47_request.update_response_and_save(
response_event_id=output.id.to_hex(), response=response
)


def _check_version(event: Event) -> ParsedVersion:
is_nip04_encrypted = "?iv=" in event.content()
selected_version = ParsedVersion(0, 0)
version_tag = next((tag for tag in event.tags() if tag.as_vec()[0] == "v"), None)
if version_tag:
selected_version_str = version_tag.content() or "0.0"
try:
selected_version = ParsedVersion.load(selected_version_str)
except ValueError:
raise Nip47RequestException(
error_code=ErrorCode.OTHER,
error_message=f"Invalid version {selected_version_str}.",
)

if not is_version_supported(selected_version):
raise Nip47RequestException(
# TODO: Use ErrorCode.VERSION_NOT_SUPPORTED when added.
error_code=ErrorCode.NOT_IMPLEMENTED,
error_message=f"Unsupported version {selected_version}.",
)

if selected_version.major > 0 and is_nip04_encrypted:
raise Nip47RequestException(
error_code=ErrorCode.OTHER,
error_message="NIP04 encryption is not supported for version > 0. Please use NIP44.",
)

return selected_version
13 changes: 9 additions & 4 deletions nwc_backend/nostr/nostr_client_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from nwc_backend.models.nip47_request_method import Nip47RequestMethod
from nwc_backend.nostr.nostr_client import nostr_client
from nwc_backend.nostr.nostr_config import NostrConfig
from nwc_backend.nostr.versions import NWC_VERSIONS_SUPPORTED


class NotificationHandler(HandleNotification):
Expand Down Expand Up @@ -49,10 +50,14 @@ async def init_nostr_client() -> None:


async def _publish_nip47_info() -> None:
nip47_info_event = EventBuilder(
kind=KindEnum.WALLET_CONNECT_INFO(), # pyre-ignore[6]
content=" ".join([method.value for method in list(Nip47RequestMethod)]),
).build()
nip47_info_event = (
EventBuilder(
kind=KindEnum.WALLET_CONNECT_INFO(), # pyre-ignore[6]
content=" ".join([method.value for method in list(Nip47RequestMethod)]),
)
.add_tag(["v", " ".join(NWC_VERSIONS_SUPPORTED)])
.build()
)
response = await nostr_client.send_event(nip47_info_event)

logging.debug(
Expand Down
34 changes: 34 additions & 0 deletions nwc_backend/nostr/versions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from dataclasses import dataclass
from functools import total_ordering
from typing import List


NWC_VERSIONS_SUPPORTED: List[str] = ["0.0", "1.0"]


@total_ordering
@dataclass
class ParsedVersion:
major: int
minor: int

@classmethod
def load(cls, version: str) -> "ParsedVersion":
[major, minor] = version.split(".")
return ParsedVersion(major=int(major), minor=int(minor))

def __str__(self) -> str:
return f"{self.major}.{self.minor}"

def __lt__(self, other: "ParsedVersion") -> bool:
return self.major < other.major or (
self.major == other.major and self.minor < other.minor
)


def is_version_supported(version: ParsedVersion) -> bool:
for version_str in NWC_VERSIONS_SUPPORTED:
supported_version = ParsedVersion.load(version_str)
if version.major == supported_version.major:
return version.minor <= supported_version.minor
return False

0 comments on commit ffeda0b

Please sign in to comment.