diff --git a/openssl/sha.go b/openssl/sha.go index bb96a9b..12ef875 100644 --- a/openssl/sha.go +++ b/openssl/sha.go @@ -119,6 +119,26 @@ func (h *evpHash) Write(p []byte) (int, error) { return len(p), nil } +func (h *evpHash) WriteString(s string) (int, error) { + // TODO: use unsafe.StringData once we drop support + // for go1.19 and earlier. + hdr := (*struct { + Data *byte + Len int + })(unsafe.Pointer(&s)) + if len(s) > 0 && C.go_openssl_EVP_DigestUpdate(h.ctx, unsafe.Pointer(hdr.Data), C.size_t(len(s))) == 0 { + panic("openssl: EVP_DigestUpdate failed") + } + return len(s), nil +} + +func (h *evpHash) WriteByte(c byte) error { + if C.go_openssl_EVP_DigestUpdate(h.ctx, unsafe.Pointer(&c), 1) == 0 { + panic("openssl: EVP_DigestUpdate failed") + } + return nil +} + func (h *evpHash) Size() int { return h.size } diff --git a/openssl/sha_test.go b/openssl/sha_test.go index 7859b06..3169fcb 100644 --- a/openssl/sha_test.go +++ b/openssl/sha_test.go @@ -10,6 +10,7 @@ import ( "bytes" "encoding" "hash" + "io" "testing" ) @@ -62,6 +63,23 @@ func TestSha(t *testing.T) { if !bytes.Equal(sum, initSum) { t.Errorf("got:%x want:%x", sum, initSum) } + + bw := h.(io.ByteWriter) + for i := 0; i < len(msg); i++ { + bw.WriteByte(msg[i]) + } + h.Reset() + sum = h.Sum(nil) + if !bytes.Equal(sum, initSum) { + t.Errorf("got:%x want:%x", sum, initSum) + } + + h.(io.StringWriter).WriteString(string(msg)) + h.Reset() + sum = h.Sum(nil) + if !bytes.Equal(sum, initSum) { + t.Errorf("got:%x want:%x", sum, initSum) + } }) } }