From ca8c22fe0d08dc5026f2004591eb0a717a46f8fc Mon Sep 17 00:00:00 2001 From: boreq Date: Wed, 1 Nov 2023 15:53:36 +0900 Subject: [PATCH] Normalize relay addresses --- service/app/downloader.go | 44 +++++++++++++++++++--------- service/domain/relay_address.go | 16 +++++++++- service/domain/relay_address_test.go | 31 ++++++++++++++++++++ 3 files changed, 76 insertions(+), 15 deletions(-) create mode 100644 service/domain/relay_address_test.go diff --git a/service/app/downloader.go b/service/app/downloader.go index d290f6f..e133cec 100644 --- a/service/app/downloader.go +++ b/service/app/downloader.go @@ -235,22 +235,12 @@ func (d *PublicKeyDownloader) storeMetrics() { d.metrics.ReportNumberOfPublicKeyDownloaderRelays(d.publicKey, len(d.downloaders)) } -func (d *PublicKeyDownloader) refreshRelays(longCtx context.Context) error { - fnCtx, fnCtxCancel := context.WithCancel(longCtx) - defer fnCtxCancel() - - relayAddresses, err := d.relaySource.GetRelays(fnCtx, d.publicKey) +func (d *PublicKeyDownloader) refreshRelays(ctx context.Context) error { + relayAddressesSet, err := d.getRelayAddresses(ctx) if err != nil { - return errors.Wrap(err, "error getting relayAddresses") + return errors.Wrap(err, "error getting relay addresses") } - d.logger.Debug(). - WithField("numberOfAddresses", len(relayAddresses)). - WithField("publicKey", d.publicKey.Hex()). - Message("got relay addresses") - - relayAddressesSet := internal.NewSet(relayAddresses) - d.downloadersLock.Lock() defer d.downloadersLock.Unlock() @@ -270,7 +260,7 @@ func (d *PublicKeyDownloader) refreshRelays(longCtx context.Context) error { WithField("relayAddress", relayAddress.String()). Message("creating a relay downloader") - ctx, cancel := context.WithCancel(longCtx) + ctx, cancel := context.WithCancel(ctx) go d.downloadMessages(ctx, relayAddress) d.downloaders[relayAddress] = cancel } @@ -279,6 +269,32 @@ func (d *PublicKeyDownloader) refreshRelays(longCtx context.Context) error { return nil } +func (d *PublicKeyDownloader) getRelayAddresses(ctx context.Context) (*internal.Set[domain.RelayAddress], error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + relayAddresses, err := d.relaySource.GetRelays(ctx, d.publicKey) + if err != nil { + return nil, errors.Wrap(err, "error getting relayAddresses") + } + + d.logger.Debug(). + WithField("numberOfAddresses", len(relayAddresses)). + WithField("publicKey", d.publicKey.Hex()). + Message("got relay addresses") + + normalizedRelayAddresses := internal.NewEmptySet[domain.RelayAddress]() + for _, relayAddress := range relayAddresses { + normalizedRelayAddress, err := domain.NormalizeRelayAddress(relayAddress) + if err != nil { + return nil, errors.Wrapf(err, "error normalizing a relay address '%s'", relayAddress.String()) + } + normalizedRelayAddresses.Put(normalizedRelayAddress) + } + + return normalizedRelayAddresses, nil +} + func (d *PublicKeyDownloader) downloadMessages(ctx context.Context, relayAddress domain.RelayAddress) { t := howFarIntoThePastToLook for eventOrEOSE := range d.relayEventDownloader.GetEvents(ctx, d.publicKey, relayAddress, domain.EventKindsToDownload(), &t) { diff --git a/service/domain/relay_address.go b/service/domain/relay_address.go index 5ef7bb9..8c6ecff 100644 --- a/service/domain/relay_address.go +++ b/service/domain/relay_address.go @@ -6,15 +6,24 @@ import ( "github.com/boreq/errors" ) +const ( + protocolWs = "ws://" + protocolWss = "wss://" +) + type RelayAddress struct { s string } func NewRelayAddress(s string) (RelayAddress, error) { - if !strings.HasPrefix(s, "ws://") && !strings.HasPrefix(s, "wss://") { + if !strings.HasPrefix(s, protocolWs) && !strings.HasPrefix(s, protocolWss) { return RelayAddress{}, errors.New("invalid protocol") } + if s == protocolWs || s == protocolWss { + return RelayAddress{}, errors.New("just protocol") + } + return RelayAddress{s: s}, nil } @@ -29,3 +38,8 @@ func MustNewRelayAddress(s string) RelayAddress { func (r RelayAddress) String() string { return r.s } + +func NormalizeRelayAddress(relayAddress RelayAddress) (RelayAddress, error) { + addr := strings.TrimSuffix(relayAddress.String(), "/") + return NewRelayAddress(addr) +} diff --git a/service/domain/relay_address_test.go b/service/domain/relay_address_test.go new file mode 100644 index 0000000..ed56474 --- /dev/null +++ b/service/domain/relay_address_test.go @@ -0,0 +1,31 @@ +package domain + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNormalizeRelayAddress(t *testing.T) { + testCases := []struct { + In string + Out string + }{ + { + In: "wss://nos.social", + Out: "wss://nos.social", + }, + { + In: "wss://nos.social/", + Out: "wss://nos.social", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.In, func(t *testing.T) { + address, err := NormalizeRelayAddress(MustNewRelayAddress(testCase.In)) + require.NoError(t, err) + require.Equal(t, testCase.Out, address.String()) + }) + } +}