diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 3dfd03b..3432e00 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,5 +1,5 @@ ARG VARIANT=1.17-bullseye -FROM mcr.microsoft.com/vscode/devcontainers/go:0-${VARIANT} +FROM mcr.microsoft.com/devcontainers/go:0-${VARIANT} RUN apt-get update \ && apt-get install -y build-essential diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index cb5ff99..cdf1e72 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -12,7 +12,9 @@ "GO_OPENSSL_VERSION_OVERRIDE": "1.1.0", }, "onCreateCommand": "sh ${containerWorkspaceFolder}/scripts/openssl.sh ${GO_OPENSSL_VERSION_OVERRIDE}", - "extensions": [ - "golang.go" - ] + "customizations": { + "vscode": { + "extensions": ["golang.go"] + } + }, } \ No newline at end of file diff --git a/go.mod b/go.mod index d781a37..6b7e85c 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/microsoft/go-crypto-openssl -go 1.16 +go 1.17 diff --git a/openssl/sha.go b/openssl/sha.go index bb96a9b..12ef875 100644 --- a/openssl/sha.go +++ b/openssl/sha.go @@ -119,6 +119,26 @@ func (h *evpHash) Write(p []byte) (int, error) { return len(p), nil } +func (h *evpHash) WriteString(s string) (int, error) { + // TODO: use unsafe.StringData once we drop support + // for go1.19 and earlier. + hdr := (*struct { + Data *byte + Len int + })(unsafe.Pointer(&s)) + if len(s) > 0 && C.go_openssl_EVP_DigestUpdate(h.ctx, unsafe.Pointer(hdr.Data), C.size_t(len(s))) == 0 { + panic("openssl: EVP_DigestUpdate failed") + } + return len(s), nil +} + +func (h *evpHash) WriteByte(c byte) error { + if C.go_openssl_EVP_DigestUpdate(h.ctx, unsafe.Pointer(&c), 1) == 0 { + panic("openssl: EVP_DigestUpdate failed") + } + return nil +} + func (h *evpHash) Size() int { return h.size } diff --git a/openssl/sha_test.go b/openssl/sha_test.go index 7859b06..3169fcb 100644 --- a/openssl/sha_test.go +++ b/openssl/sha_test.go @@ -10,6 +10,7 @@ import ( "bytes" "encoding" "hash" + "io" "testing" ) @@ -62,6 +63,23 @@ func TestSha(t *testing.T) { if !bytes.Equal(sum, initSum) { t.Errorf("got:%x want:%x", sum, initSum) } + + bw := h.(io.ByteWriter) + for i := 0; i < len(msg); i++ { + bw.WriteByte(msg[i]) + } + h.Reset() + sum = h.Sum(nil) + if !bytes.Equal(sum, initSum) { + t.Errorf("got:%x want:%x", sum, initSum) + } + + h.(io.StringWriter).WriteString(string(msg)) + h.Reset() + sum = h.Sum(nil) + if !bytes.Equal(sum, initSum) { + t.Errorf("got:%x want:%x", sum, initSum) + } }) } }