Skip to content

Commit

Permalink
wallet - add WalletKeyset struct
Browse files Browse the repository at this point in the history
  • Loading branch information
elnosh committed Jun 11, 2024
1 parent bde63cf commit 79f33f0
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 56 deletions.
51 changes: 30 additions & 21 deletions crypto/keyset.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,43 @@ import (

const maxOrder = 64

// KeysetsMap maps a mint url to map of string keyset id to keyset
type KeysetsMap map[string]map[string]Keyset

type Keyset struct {
Id string
MintURL string
Unit string
Active bool
Keys map[uint64]KeyPair
Id string
//MintURL string
Unit string
Active bool
Keys map[uint64]KeyPair
}

type KeyPair struct {
PrivateKey *secp256k1.PrivateKey
PublicKey *secp256k1.PublicKey
}

// KeysetsMap maps a mint url to map of string keyset id to keyset
type KeysetsMap map[string]map[string]WalletKeyset

type WalletKeyset struct {
Id string
MintURL string
Unit string
Active bool
PublicKeys map[uint64]*secp256k1.PublicKey
Counter uint64
}

func GenerateKeyset(seed, derivationPath string) *Keyset {
keys := make(map[uint64]KeyPair, maxOrder)

pks := make(map[uint64]*secp256k1.PublicKey)
for i := 0; i < maxOrder; i++ {
amount := uint64(math.Pow(2, float64(i)))
hash := sha256.Sum256([]byte(seed + derivationPath + strconv.FormatUint(amount, 10)))
privKey, pubKey := btcec.PrivKeyFromBytes(hash[:])
keys[amount] = KeyPair{PrivateKey: privKey, PublicKey: pubKey}
pks[amount] = pubKey
}
keysetId := DeriveKeysetId(keys)
keysetId := DeriveKeysetId(pks)
return &Keyset{Id: keysetId, Unit: "sat", Active: true, Keys: keys}
}

Expand All @@ -50,15 +61,15 @@ func GenerateKeyset(seed, derivationPath string) *Keyset {
// - HASH_SHA256 the concatenated public keys
// - take the first 14 characters of the hex-encoded hash
// - prefix it with a keyset ID version byte
func DeriveKeysetId(keyset map[uint64]KeyPair) string {
func DeriveKeysetId(keyset map[uint64]*secp256k1.PublicKey) string {
type pubkey struct {
amount uint64
pk *secp256k1.PublicKey
}
pubkeys := make([]pubkey, len(keyset))
i := 0
for amount, key := range keyset {
pubkeys[i] = pubkey{amount, key.PublicKey}
pubkeys[i] = pubkey{amount, key}
i++
}
sort.Slice(pubkeys, func(i, j int) bool {
Expand Down Expand Up @@ -87,19 +98,17 @@ func (ks *Keyset) DerivePublic() map[uint64]string {
}

type KeysetTemp struct {
Id string
MintURL string
Unit string
Active bool
Keys map[uint64]json.RawMessage
Id string
Unit string
Active bool
Keys map[uint64]json.RawMessage
}

func (ks *Keyset) MarshalJSON() ([]byte, error) {
temp := &KeysetTemp{
Id: ks.Id,
MintURL: ks.MintURL,
Unit: ks.Unit,
Active: ks.Active,
Id: ks.Id,
Unit: ks.Unit,
Active: ks.Active,
Keys: func() map[uint64]json.RawMessage {
m := make(map[uint64]json.RawMessage)
for k, v := range ks.Keys {
Expand All @@ -121,7 +130,7 @@ func (ks *Keyset) UnmarshalJSON(data []byte) error {
}

ks.Id = temp.Id
ks.MintURL = temp.MintURL
//ks.MintURL = temp.MintURL
ks.Unit = temp.Unit
ks.Active = temp.Active

Expand Down
5 changes: 2 additions & 3 deletions crypto/keyset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func TestDeriveKeysetId(t *testing.T) {
}

for _, test := range tests {
keys := make(map[uint64]KeyPair)
keys := make(map[uint64]*secp256k1.PublicKey)

for amount, pubkey := range test.pubkeys {
pubkeyBytes, _ := hex.DecodeString(pubkey)
Expand All @@ -102,8 +102,7 @@ func TestDeriveKeysetId(t *testing.T) {
t.Errorf("error parsing pub key: %v", err)
}

keyPair := KeyPair{PublicKey: publicKey}
keys[amount] = keyPair
keys[amount] = publicKey
}

id := DeriveKeysetId(keys)
Expand Down
34 changes: 31 additions & 3 deletions wallet/storage/bolt.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func (db *BoltDB) DeleteProof(secret string) error {
})
}

func (db *BoltDB) SaveKeyset(keyset *crypto.Keyset) error {
func (db *BoltDB) SaveKeyset(keyset crypto.WalletKeyset) error {
jsonKeyset, err := json.Marshal(keyset)
if err != nil {
return fmt.Errorf("invalid keyset format: %v", err)
Expand All @@ -157,13 +157,13 @@ func (db *BoltDB) GetKeysets() crypto.KeysetsMap {
keysetsb := tx.Bucket([]byte(keysetsBucket))

return keysetsb.ForEach(func(mintURL, v []byte) error {
mintKeysets := make(map[string]crypto.Keyset)
mintKeysets := make(map[string]crypto.WalletKeyset)

mintBucket := keysetsb.Bucket(mintURL)
c := mintBucket.Cursor()

for k, v := c.First(); k != nil; k, v = c.Next() {
var keyset crypto.Keyset
var keyset crypto.WalletKeyset
if err := json.Unmarshal(v, &keyset); err != nil {
return err
}
Expand All @@ -179,6 +179,34 @@ func (db *BoltDB) GetKeysets() crypto.KeysetsMap {
}

return keysets

}

func (db *BoltDB) IncrementKeysetCounter(keysetId string) error {
if err := db.bolt.Update(func(tx *bolt.Tx) error {
keysetsb := tx.Bucket([]byte(keysetsBucket))
keysetBytes := keysetsb.Get([]byte(keysetId))
if keysetBytes == nil {
return errors.New("keyset does not exist")
}

var keyset crypto.WalletKeyset
err := json.Unmarshal(keysetBytes, &keyset)
if err != nil {
return fmt.Errorf("error reading keyset from db: %v", err)
}
keyset.Counter += 1

jsonBytes, err := json.Marshal(keyset)
if err != nil {
return err
}
return keysetsb.Put([]byte(keysetId), jsonBytes)
}); err != nil {
return err
}

return nil
}

func (db *BoltDB) SaveInvoice(invoice Invoice) error {
Expand Down
3 changes: 2 additions & 1 deletion wallet/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ type DB interface {
GetProofsByKeysetId(string) cashu.Proofs
GetProofs() cashu.Proofs
DeleteProof(string) error
SaveKeyset(*crypto.Keyset) error
SaveKeyset(crypto.WalletKeyset) error
GetKeysets() crypto.KeysetsMap
IncrementKeysetCounter(string) error
SaveInvoice(Invoice) error
GetInvoice(string) *Invoice
GetInvoices() []Invoice
Expand Down
55 changes: 27 additions & 28 deletions wallet/wallet.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ type Wallet struct {
type walletMint struct {
mintURL string
// active keysets from mint
activeKeysets map[string]crypto.Keyset
activeKeysets map[string]crypto.WalletKeyset
// list of inactive keysets (if any) from mint
inactiveKeysets map[string]crypto.Keyset
inactiveKeysets map[string]crypto.WalletKeyset
}

func InitStorage(path string) (storage.DB, error) {
Expand Down Expand Up @@ -71,7 +71,7 @@ func LoadWallet(config Config) (*Wallet, error) {
wallet.currentMint = mint
} else { // if mint is already known, check if active keyset has changed
// get last stored active sat keyset
var lastActiveSatKeyset crypto.Keyset
var lastActiveSatKeyset crypto.WalletKeyset
for _, keyset := range walletMint.activeKeysets {
if keyset.Unit == "sat" {
lastActiveSatKeyset = keyset
Expand All @@ -92,7 +92,7 @@ func LoadWallet(config Config) (*Wallet, error) {

// there is new keyset, change last active to inactive
lastActiveSatKeyset.Active = false
db.SaveKeyset(&lastActiveSatKeyset)
db.SaveKeyset(lastActiveSatKeyset)
break
}
}
Expand All @@ -104,7 +104,7 @@ func LoadWallet(config Config) (*Wallet, error) {
return nil, fmt.Errorf("error getting keysets from mint: %v", err)
}
for _, keyset := range activeKeysets {
db.SaveKeyset(&keyset)
db.SaveKeyset(keyset)
}
walletMint.activeKeysets = activeKeysets
}
Expand Down Expand Up @@ -143,27 +143,27 @@ func (w *Wallet) addMint(mint string) (*walletMint, error) {
}

for _, keyset := range mintInfo.activeKeysets {
w.db.SaveKeyset(&keyset)
w.db.SaveKeyset(keyset)
}
for _, keyset := range mintInfo.inactiveKeysets {
w.db.SaveKeyset(&keyset)
w.db.SaveKeyset(keyset)
}
w.mints[mintURL] = *mintInfo

return mintInfo, nil
}

func GetMintActiveKeysets(mintURL string) (map[string]crypto.Keyset, error) {
func GetMintActiveKeysets(mintURL string) (map[string]crypto.WalletKeyset, error) {
keysetsResponse, err := GetActiveKeysets(mintURL)
if err != nil {
return nil, fmt.Errorf("error getting active keyset from mint: %v", err)
}

activeKeysets := make(map[string]crypto.Keyset)
activeKeysets := make(map[string]crypto.WalletKeyset)
for i, keyset := range keysetsResponse.Keysets {
if keyset.Unit == "sat" {
activeKeyset := crypto.Keyset{MintURL: mintURL, Unit: keyset.Unit, Active: true}
keys := make(map[uint64]crypto.KeyPair)
activeKeyset := crypto.WalletKeyset{MintURL: mintURL, Unit: keyset.Unit, Active: true}
keys := make(map[uint64]*secp256k1.PublicKey)
for amount, key := range keysetsResponse.Keysets[i].Keys {
pkbytes, err := hex.DecodeString(key)
if err != nil {
Expand All @@ -173,10 +173,10 @@ func GetMintActiveKeysets(mintURL string) (map[string]crypto.Keyset, error) {
if err != nil {
return nil, err
}
keys[amount] = crypto.KeyPair{PublicKey: pubkey}
keys[amount] = pubkey
}
activeKeyset.Keys = keys
id := crypto.DeriveKeysetId(activeKeyset.Keys)
activeKeyset.PublicKeys = keys
id := crypto.DeriveKeysetId(activeKeyset.PublicKeys)
activeKeyset.Id = id
activeKeysets[id] = activeKeyset
}
Expand All @@ -185,17 +185,16 @@ func GetMintActiveKeysets(mintURL string) (map[string]crypto.Keyset, error) {
return activeKeysets, nil
}

func GetMintInactiveKeysets(mintURL string) (map[string]crypto.Keyset, error) {
func GetMintInactiveKeysets(mintURL string) (map[string]crypto.WalletKeyset, error) {
keysetsResponse, err := GetAllKeysets(mintURL)
if err != nil {
return nil, fmt.Errorf("error getting keysets from mint: %v", err)
}

inactiveKeysets := make(map[string]crypto.Keyset)

inactiveKeysets := make(map[string]crypto.WalletKeyset)
for _, keysetRes := range keysetsResponse.Keysets {
if !keysetRes.Active && keysetRes.Unit == "sat" {
keyset := crypto.Keyset{
keyset := crypto.WalletKeyset{
Id: keysetRes.Id,
MintURL: mintURL,
Unit: keysetRes.Unit,
Expand Down Expand Up @@ -375,7 +374,7 @@ func (w *Wallet) Receive(token cashu.Token, swap bool) (uint64, error) {
walletMint = *mint
}

var activeSatKeyset crypto.Keyset
var activeSatKeyset crypto.WalletKeyset
for _, k := range walletMint.activeKeysets {
activeSatKeyset = k
break
Expand Down Expand Up @@ -583,7 +582,7 @@ func (w *Wallet) getProofsForAmount(amount uint64, mintURL string) (cashu.Proofs
return selectedProofs, nil
}

var activeSatKeyset crypto.Keyset
var activeSatKeyset crypto.WalletKeyset
for _, k := range selectedMint.activeKeysets {
activeSatKeyset = k
break
Expand Down Expand Up @@ -659,7 +658,7 @@ func newBlindedMessage(id string, amount uint64, B_ *secp256k1.PublicKey) cashu.
}

// returns Blinded messages, secrets - [][]byte, and list of r
func createBlindedMessages(amount uint64, keyset crypto.Keyset) (cashu.BlindedMessages, []string, []*secp256k1.PrivateKey, error) {
func createBlindedMessages(amount uint64, keyset crypto.WalletKeyset) (cashu.BlindedMessages, []string, []*secp256k1.PrivateKey, error) {
splitAmounts := cashu.AmountSplit(amount)
splitLen := len(splitAmounts)

Expand Down Expand Up @@ -701,7 +700,7 @@ func createBlindedMessages(amount uint64, keyset crypto.Keyset) (cashu.BlindedMe

// constructProofs unblinds the blindedSignatures and returns the proofs
func constructProofs(blindedSignatures cashu.BlindedSignatures,
secrets []string, rs []*secp256k1.PrivateKey, keyset *crypto.Keyset) (cashu.Proofs, error) {
secrets []string, rs []*secp256k1.PrivateKey, keyset *crypto.WalletKeyset) (cashu.Proofs, error) {

if len(blindedSignatures) != len(secrets) || len(blindedSignatures) != len(rs) {
return nil, errors.New("lengths do not match")
Expand All @@ -718,12 +717,12 @@ func constructProofs(blindedSignatures cashu.BlindedSignatures,
return nil, err
}

keyp, ok := keyset.Keys[blindedSignature.Amount]
pubkey, ok := keyset.PublicKeys[blindedSignature.Amount]
if !ok {
return nil, errors.New("key not found")
}

C := crypto.UnblindSignature(C_, rs[i], keyp.PublicKey)
C := crypto.UnblindSignature(C_, rs[i], pubkey)
Cstr := hex.EncodeToString(C.SerializeCompressed())

proof := cashu.Proof{Amount: blindedSignature.Amount,
Expand All @@ -735,8 +734,8 @@ func constructProofs(blindedSignatures cashu.BlindedSignatures,
return proofs, nil
}

func (w *Wallet) GetActiveSatKeyset() crypto.Keyset {
var activeKeyset crypto.Keyset
func (w *Wallet) GetActiveSatKeyset() crypto.WalletKeyset {
var activeKeyset crypto.WalletKeyset
for _, keyset := range w.currentMint.activeKeysets {
if keyset.Unit == "sat" {
activeKeyset = keyset
Expand All @@ -751,8 +750,8 @@ func (w *Wallet) getWalletMints() map[string]walletMint {

keysets := w.db.GetKeysets()
for k, mintKeysets := range keysets {
activeKeysets := make(map[string]crypto.Keyset)
inactiveKeysets := make(map[string]crypto.Keyset)
activeKeysets := make(map[string]crypto.WalletKeyset)
inactiveKeysets := make(map[string]crypto.WalletKeyset)
for _, keyset := range mintKeysets {
if keyset.Active {
activeKeysets[keyset.Id] = keyset
Expand Down

0 comments on commit 79f33f0

Please sign in to comment.