Skip to content

Commit

Permalink
Merge pull request hashicorp#40533 from hashicorp/f-CaseInsensitiveSt…
Browse files Browse the repository at this point in the history
…ringType

Add a Terraform Plugin Framework case-insensitive custom string type
  • Loading branch information
ewbankkit authored Dec 11, 2024
2 parents 75720e8 + 44a526d commit 5eb68fc
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 21 deletions.
140 changes: 140 additions & 0 deletions internal/framework/types/case_insensitive_string.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package types

import (
"context"
"fmt"
"strings"

"github.com/hashicorp/terraform-plugin-framework/attr"
"github.com/hashicorp/terraform-plugin-framework/diag"
"github.com/hashicorp/terraform-plugin-framework/types"
"github.com/hashicorp/terraform-plugin-framework/types/basetypes"
"github.com/hashicorp/terraform-plugin-go/tftypes"
)

var (
_ basetypes.StringTypable = (*caseInsensitiveStringType)(nil)
)

type caseInsensitiveStringType struct {
basetypes.StringType
}

var (
CaseInsensitiveStringType = caseInsensitiveStringType{}
)

func (t caseInsensitiveStringType) Equal(o attr.Type) bool {
other, ok := o.(caseInsensitiveStringType)
if !ok {
return false
}

return t.StringType.Equal(other.StringType)
}

func (caseInsensitiveStringType) String() string {
return "CaseInsensitiveStringType"
}

func (t caseInsensitiveStringType) ValueFromString(_ context.Context, in types.String) (basetypes.StringValuable, diag.Diagnostics) {
var diags diag.Diagnostics

if in.IsNull() {
return CaseInsensitiveStringNull(), diags
}
if in.IsUnknown() {
return CaseInsensitiveStringUnknown(), diags
}

return CaseInsensitiveStringValue(in.ValueString()), diags
}

func (t caseInsensitiveStringType) ValueFromTerraform(ctx context.Context, in tftypes.Value) (attr.Value, error) {
attrValue, err := t.StringType.ValueFromTerraform(ctx, in)
if err != nil {
return nil, err
}

stringValue, ok := attrValue.(basetypes.StringValue)
if !ok {
return nil, fmt.Errorf("unexpected value type of %T", attrValue)
}

stringValuable, diags := t.ValueFromString(ctx, stringValue)
if diags.HasError() {
return nil, fmt.Errorf("unexpected error converting StringValue to StringValuable: %v", diags)
}

return stringValuable, nil
}

func (caseInsensitiveStringType) ValueType(context.Context) attr.Value {
return CaseInsensitiveString{}
}

var (
_ basetypes.StringValuable = (*CaseInsensitiveString)(nil)
_ basetypes.StringValuableWithSemanticEquals = (*CaseInsensitiveString)(nil)
)

type CaseInsensitiveString struct {
basetypes.StringValue
}

func CaseInsensitiveStringNull() CaseInsensitiveString {
return CaseInsensitiveString{StringValue: basetypes.NewStringNull()}
}

func CaseInsensitiveStringUnknown() CaseInsensitiveString {
return CaseInsensitiveString{StringValue: basetypes.NewStringUnknown()}
}

func CaseInsensitiveStringValue(value string) CaseInsensitiveString {
return CaseInsensitiveString{StringValue: basetypes.NewStringValue(value)}
}

func (v CaseInsensitiveString) Equal(o attr.Value) bool {
other, ok := o.(CaseInsensitiveString)
if !ok {
return false
}

return v.StringValue.Equal(other.StringValue)
}

func (CaseInsensitiveString) Type(context.Context) attr.Type {
return CaseInsensitiveStringType
}

func (v CaseInsensitiveString) StringSemanticEquals(ctx context.Context, newValuable basetypes.StringValuable) (bool, diag.Diagnostics) {
return caseInsensitiveStringSemanticEquals(ctx, v, newValuable)
}

// caseInsensitiveStringSemanticEquals returns whether oldValuable and newValuable are equal under simple Unicode case-folding.
func caseInsensitiveStringSemanticEquals[T basetypes.StringValuable](ctx context.Context, oldValuable T, newValuable basetypes.StringValuable) (bool, diag.Diagnostics) {
var diags diag.Diagnostics

newValue, ok := newValuable.(T)
if !ok {
return false, diags
}

old, d := oldValuable.ToStringValue(ctx)
diags.Append(d...)
if diags.HasError() {
return false, diags
}

new, d := newValue.ToStringValue(ctx)
diags.Append(d...)
if diags.HasError() {
return false, diags
}

// Case insensitive comparison.
return strings.EqualFold(old.ValueString(), new.ValueString()), diags
}
61 changes: 61 additions & 0 deletions internal/framework/types/case_insensitive_string_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package types_test

import (
"context"
"testing"

fwtypes "github.com/hashicorp/terraform-provider-aws/internal/framework/types"
)

func TestCaseInsensitiveStringSemanticEquals(t *testing.T) {
t.Parallel()

type testCase struct {
val1, val2 fwtypes.CaseInsensitiveString
equals bool
}
tests := map[string]testCase{
"both lowercase, equal": {
val1: fwtypes.CaseInsensitiveStringValue("thursday"),
val2: fwtypes.CaseInsensitiveStringValue("thursday"),
equals: true,
},
"both uppercase, equal": {
val1: fwtypes.CaseInsensitiveStringValue("THURSDAY"),
val2: fwtypes.CaseInsensitiveStringValue("THURSDAY"),
equals: true,
},
"first uppercase, second lowercase, equal": {
val1: fwtypes.CaseInsensitiveStringValue("THURSDAY"),
val2: fwtypes.CaseInsensitiveStringValue("thursday"),
equals: true,
},
"first lowercase, second uppercase, equal": {
val1: fwtypes.CaseInsensitiveStringValue("thursday"),
val2: fwtypes.CaseInsensitiveStringValue("THURSDAY"),
equals: true,
},
"not equal": {
val1: fwtypes.CaseInsensitiveStringValue("Thursday"),
val2: fwtypes.CaseInsensitiveStringValue("Friday"),
equals: false,
},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
t.Parallel()

ctx := context.Background()

equals, _ := test.val1.StringSemanticEquals(ctx, test.val2)

if got, want := equals, test.equals; got != want {
t.Errorf("StringSemanticEquals(%q, %q) = %v, want %v", test.val1, test.val2, got, want)
}
})
}
}
22 changes: 1 addition & 21 deletions internal/framework/types/once_a_week_window.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,27 +124,7 @@ func (OnceAWeekWindow) Type(context.Context) attr.Type {
}

func (v OnceAWeekWindow) StringSemanticEquals(ctx context.Context, newValuable basetypes.StringValuable) (bool, diag.Diagnostics) {
var diags diag.Diagnostics

newValue, ok := newValuable.(OnceAWeekWindow)
if !ok {
return false, diags
}

old, d := v.ToStringValue(ctx)
diags.Append(d...)
if diags.HasError() {
return false, diags
}

new, d := newValue.ToStringValue(ctx)
diags.Append(d...)
if diags.HasError() {
return false, diags
}

// Case insensitive comparison.
return strings.EqualFold(old.ValueString(), new.ValueString()), diags
return caseInsensitiveStringSemanticEquals(ctx, v, newValuable)
}

func (v OnceAWeekWindow) ValidateAttribute(ctx context.Context, req xattr.ValidateAttributeRequest, resp *xattr.ValidateAttributeResponse) {
Expand Down

0 comments on commit 5eb68fc

Please sign in to comment.