diff --git a/impl/internal/dht/scheduler.go b/impl/internal/dht/scheduler.go index 10bb9fce..5fcbf4bd 100644 --- a/impl/internal/dht/scheduler.go +++ b/impl/internal/dht/scheduler.go @@ -21,11 +21,11 @@ func NewScheduler() Scheduler { } // Schedule schedules a job to run and starts it asynchronously -func (s *Scheduler) Schedule(_ string, job func()) error { +func (s *Scheduler) Schedule(schedule string, job func()) error { if s.job != nil { return errors.New("job already scheduled") } - j, err := s.scheduler.Cron("* * * * *").Do(job) + j, err := s.scheduler.Cron(schedule).Do(job) if err != nil { return err } diff --git a/impl/internal/did/did.go b/impl/internal/did/did.go index 477989f5..29e62d49 100644 --- a/impl/internal/did/did.go +++ b/impl/internal/did/did.go @@ -4,6 +4,7 @@ import ( "crypto/ed25519" "encoding/base64" "fmt" + "strconv" "strings" "github.com/TBD54566975/ssi-sdk/crypto" @@ -16,13 +17,22 @@ import ( ) type ( - DHT string + DHT string + TypeIndex int ) const ( // Prefix did:dht prefix Prefix = "did:dht" DHTMethod did.Method = "dht" + + Organization TypeIndex = 1 + GovernmentOrganization TypeIndex = 2 + Corporation TypeIndex = 3 + LocalBusiness TypeIndex = 4 + SoftwarePackage TypeIndex = 5 + WebApplication TypeIndex = 6 + FinancialInstitution TypeIndex = 7 ) func (d DHT) IsValid() bool { @@ -179,8 +189,8 @@ func GetDIDDHTIdentifier(pubKey []byte) string { return strings.Join([]string{Prefix, zbase32.EncodeToString(pubKey)}, ":") } -// ToDNSPacket converts a DID DHT Document to a DNS packet -func (d DHT) ToDNSPacket(doc did.Document) (*dns.Msg, error) { +// ToDNSPacket converts a DID DHT Document to a DNS packet with an optional list of types to include +func (d DHT) ToDNSPacket(doc did.Document, types []TypeIndex) (*dns.Msg, error) { var records []dns.RR var rootRecord []string keyLookup := make(map[string]string) @@ -311,6 +321,24 @@ func (d DHT) ToDNSPacket(doc did.Document) (*dns.Msg, error) { } records = append(records, &rootAnswer) + // add types record + if len(types) != 0 { + var typesStr []string + for _, t := range types { + typesStr = append(typesStr, strconv.Itoa(int(t))) + } + typesAnswer := dns.TXT{ + Hdr: dns.RR_Header{ + Name: "_typ._did.", + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: 7200, + }, + Txt: []string{"id=" + strings.Join(typesStr, ",")}, + } + records = append(records, &typesAnswer) + } + // build the dns packet return &dns.Msg{ MsgHdr: dns.MsgHdr{ @@ -323,11 +351,12 @@ func (d DHT) ToDNSPacket(doc did.Document) (*dns.Msg, error) { } // FromDNSPacket converts a DNS packet to a DID DHT Document -func (d DHT) FromDNSPacket(msg *dns.Msg) (*did.Document, error) { +func (d DHT) FromDNSPacket(msg *dns.Msg) (*did.Document, []TypeIndex, error) { doc := did.Document{ ID: d.String(), } + var types []TypeIndex keyLookup := make(map[string]string) for _, rr := range msg.Answer { switch record := rr.(type) { @@ -341,15 +370,15 @@ func (d DHT) FromDNSPacket(msg *dns.Msg) (*did.Document, error) { // Convert keyBase64URL back to PublicKeyJWK pubKeyBytes, err := base64.RawURLEncoding.DecodeString(keyBase64URL) if err != nil { - return nil, err + return nil, nil, err } pubKey, err := crypto.BytesToPubKey(pubKeyBytes, keyType) if err != nil { - return nil, err + return nil, nil, err } pubKeyJWK, err := jwx.PublicKeyToPublicKeyJWK(vmID, pubKey) if err != nil { - return nil, err + return nil, nil, err } vm := did.VerificationMethod{ @@ -375,6 +404,18 @@ func (d DHT) FromDNSPacket(msg *dns.Msg) (*did.Document, error) { } doc.Services = append(doc.Services, service) + } else if record.Hdr.Name == "_typ._did." { + if record.Txt[0] == "" || len(record.Txt) != 1 { + return nil, nil, fmt.Errorf("invalid types record") + } + typesStr := strings.Split(strings.TrimPrefix(record.Txt[0], "id="), ",") + for _, t := range typesStr { + tInt, err := strconv.Atoi(t) + if err != nil { + return nil, nil, err + } + types = append(types, TypeIndex(tInt)) + } } else if record.Hdr.Name == "_did." { rootData := strings.Join(record.Txt, ";") rootItems := strings.Split(rootData, ";") @@ -421,7 +462,7 @@ func (d DHT) FromDNSPacket(msg *dns.Msg) (*did.Document, error) { } } - return &doc, nil + return &doc, types, nil } func parseTxtData(data string) map[string]string { diff --git a/impl/internal/did/did_test.go b/impl/internal/did/did_test.go index 096b8702..532ddb58 100644 --- a/impl/internal/did/did_test.go +++ b/impl/internal/did/did_test.go @@ -121,13 +121,34 @@ func TestToDNSPacket(t *testing.T) { require.NotEmpty(t, doc) didID := DHT(doc.ID) - packet, err := didID.ToDNSPacket(*doc) + packet, err := didID.ToDNSPacket(*doc, nil) require.NoError(t, err) require.NotEmpty(t, packet) - decodedDoc, err := didID.FromDNSPacket(packet) + decodedDoc, types, err := didID.FromDNSPacket(packet) require.NoError(t, err) require.NotEmpty(t, decodedDoc) + require.Empty(t, types) + + assert.EqualValues(t, *doc, *decodedDoc) + }) + + t.Run("doc with types - test to dns packet round trip", func(t *testing.T) { + privKey, doc, err := GenerateDIDDHT(CreateDIDDHTOpts{}) + require.NoError(t, err) + require.NotEmpty(t, privKey) + require.NotEmpty(t, doc) + + didID := DHT(doc.ID) + packet, err := didID.ToDNSPacket(*doc, []TypeIndex{1, 2, 3}) + require.NoError(t, err) + require.NotEmpty(t, packet) + + decodedDoc, types, err := didID.FromDNSPacket(packet) + require.NoError(t, err) + require.NotEmpty(t, decodedDoc) + require.NotEmpty(t, types) + require.Equal(t, types, []TypeIndex{1, 2, 3}) assert.EqualValues(t, *doc, *decodedDoc) }) @@ -169,13 +190,14 @@ func TestToDNSPacket(t *testing.T) { require.NotEmpty(t, doc) didID := DHT(doc.ID) - packet, err := didID.ToDNSPacket(*doc) + packet, err := didID.ToDNSPacket(*doc, nil) require.NoError(t, err) require.NotEmpty(t, packet) - decodedDoc, err := didID.FromDNSPacket(packet) + decodedDoc, types, err := didID.FromDNSPacket(packet) require.NoError(t, err) require.NotEmpty(t, decodedDoc) + require.Empty(t, types) assert.EqualValues(t, *doc, *decodedDoc) }) diff --git a/impl/pkg/dht/pkarr_test.go b/impl/pkg/dht/pkarr_test.go index 62dfc9c4..284d1089 100644 --- a/impl/pkg/dht/pkarr_test.go +++ b/impl/pkg/dht/pkarr_test.go @@ -101,7 +101,7 @@ func TestGetPutDIDDHT(t *testing.T) { require.NotEmpty(t, doc) didID := did.DHT(doc.ID) - didDocPacket, err := didID.ToDNSPacket(*doc) + didDocPacket, err := didID.ToDNSPacket(*doc, nil) require.NoError(t, err) putReq, err := CreatePKARRPublishRequest(privKey, *didDocPacket) @@ -120,7 +120,7 @@ func TestGetPutDIDDHT(t *testing.T) { require.NotEmpty(t, gotMsg.Answer) d := did.DHT("did:dht:" + gotID) - gotDoc, err := d.FromDNSPacket(gotMsg) + gotDoc, _, err := d.FromDNSPacket(gotMsg) require.NoError(t, err) require.NotEmpty(t, gotDoc) } diff --git a/impl/pkg/server/server_pkarr_test.go b/impl/pkg/server/server_pkarr_test.go index 50598ded..f654e3cd 100644 --- a/impl/pkg/server/server_pkarr_test.go +++ b/impl/pkg/server/server_pkarr_test.go @@ -81,7 +81,7 @@ func generateDIDPutRequest(t *testing.T) (string, []byte) { require.NoError(t, err) require.NotEmpty(t, doc) - packet, err := did.DHT(doc.ID).ToDNSPacket(*doc) + packet, err := did.DHT(doc.ID).ToDNSPacket(*doc, nil) assert.NoError(t, err) assert.NotEmpty(t, packet) diff --git a/impl/pkg/service/pkarr_test.go b/impl/pkg/service/pkarr_test.go index 357aa512..4d73cb6e 100644 --- a/impl/pkg/service/pkarr_test.go +++ b/impl/pkg/service/pkarr_test.go @@ -35,7 +35,7 @@ func TestPKARRService(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, doc) - packet, err := did.DHT(doc.ID).ToDNSPacket(*doc) + packet, err := did.DHT(doc.ID).ToDNSPacket(*doc, nil) assert.NoError(t, err) assert.NotEmpty(t, packet) diff --git a/impl/pkg/storage/pkarr_test.go b/impl/pkg/storage/pkarr_test.go index e41d8722..bd17703a 100644 --- a/impl/pkg/storage/pkarr_test.go +++ b/impl/pkg/storage/pkarr_test.go @@ -21,7 +21,7 @@ func TestPKARRStorage(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, doc) - packet, err := did.DHT(doc.ID).ToDNSPacket(*doc) + packet, err := did.DHT(doc.ID).ToDNSPacket(*doc, nil) assert.NoError(t, err) assert.NotEmpty(t, packet)