Skip to content

Commit

Permalink
Fixes object to hashtable conversion for default params #3033 (#3175)
Browse files Browse the repository at this point in the history
  • Loading branch information
BernieWhite authored Nov 11, 2024
1 parent 547e2b4 commit 267fb26
Show file tree
Hide file tree
Showing 7 changed files with 723 additions and 626 deletions.
6 changes: 6 additions & 0 deletions docs/CHANGELOG-v1.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ See [upgrade notes][1] for helpful information when upgrading from previous vers

## Unreleased

What's changed since pre-release v1.40.0-B0103:

- Bug fixes:
- Fixed object to hashtable conversion for default parameter values by @BernieWhite.
[#3033](https://github.com/Azure/PSRule.Rules.Azure/issues/3033)

## v1.40.0-B0103 (pre-release)

What's changed since pre-release v1.40.0-B0063:
Expand Down
1,175 changes: 587 additions & 588 deletions src/PSRule.Rules.Azure/Common/ResourceHelper.cs

Large diffs are not rendered by default.

69 changes: 52 additions & 17 deletions src/PSRule.Rules.Azure/Data/Template/TemplateVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ internal TemplateContext()
_Symbols = new Dictionary<string, IDeploymentSymbol>(StringComparer.OrdinalIgnoreCase);
}

internal TemplateContext(PipelineContext context, SubscriptionOption subscription, ResourceGroupOption resourceGroup, TenantOption tenant, ManagementGroupOption managementGroup, ParameterDefaultsOption parameterDefaults)
internal TemplateContext(PipelineContext context, SubscriptionOption subscription, ResourceGroupOption resourceGroup, TenantOption tenant, ManagementGroupOption managementGroup, IDictionary<string, object> parameterDefaults)
: this()
{
Pipeline = context;
Expand All @@ -136,27 +136,27 @@ internal TemplateContext(PipelineContext context, SubscriptionOption subscriptio
ManagementGroup = managementGroup;

if (parameterDefaults != null)
ParameterDefaults = parameterDefaults;
ParameterDefaults = new Dictionary<string, object>(parameterDefaults, StringComparer.OrdinalIgnoreCase);
}

internal TemplateContext(PipelineContext context)
: this()
{
Pipeline = context;
if (context?.Option?.Configuration?.Subscription != null)
Subscription = context?.Option?.Configuration?.Subscription;
Subscription = context.Option.Configuration.Subscription;

if (context?.Option?.Configuration?.ResourceGroup != null)
ResourceGroup = context?.Option?.Configuration?.ResourceGroup;
ResourceGroup = context.Option.Configuration.ResourceGroup;

if (context?.Option?.Configuration?.Tenant != null)
Tenant = context?.Option?.Configuration?.Tenant;
Tenant = context.Option.Configuration.Tenant;

if (context?.Option?.Configuration?.ManagementGroup != null)
ManagementGroup = context?.Option?.Configuration?.ManagementGroup;
ManagementGroup = context.Option.Configuration.ManagementGroup;

if (context?.Option?.Configuration?.ParameterDefaults != null)
ParameterDefaults = context?.Option?.Configuration?.ParameterDefaults;
ParameterDefaults = new Dictionary<string, object>(context.Option.Configuration.ParameterDefaults, StringComparer.OrdinalIgnoreCase);
}

private Dictionary<string, IParameterValue> Parameters { get; }
Expand All @@ -173,7 +173,7 @@ internal TemplateContext(PipelineContext context)

public ManagementGroupOption ManagementGroup { get; internal set; }

public ParameterDefaultsOption ParameterDefaults { get; private set; }
public IDictionary<string, object> ParameterDefaults { get; private set; }

/// <inheritdoc/>
public DeploymentValue Deployment => _Deployment.Count > 0 ? _Deployment.Peek() : null;
Expand Down Expand Up @@ -628,15 +628,51 @@ internal bool TryParameterAssignment(string parameterName, out JToken value)
internal bool TryParameterDefault(string parameterName, ParameterType type, out JToken value)
{
value = default;
return type.Type switch
switch (type.Type)
{
TypePrimitive.String or TypePrimitive.SecureString => ParameterDefaults.TryGetString(parameterName, out value),
TypePrimitive.Bool => ParameterDefaults.TryGetBool(parameterName, out value),
TypePrimitive.Int => ParameterDefaults.TryGetLong(parameterName, out value),
TypePrimitive.Array => ParameterDefaults.TryGetArray(parameterName, out value),
TypePrimitive.Object or TypePrimitive.SecureObject => ParameterDefaults.TryGetObject(parameterName, out value),
_ => false,
case TypePrimitive.String:
case TypePrimitive.SecureString:
if (ParameterDefaults.TryGetString(parameterName, out var s))
{
value = new JValue(s);
return true;
}
break;

case TypePrimitive.Bool:
if (ParameterDefaults.TryGetBool(parameterName, out var b))
{
value = new JValue(b);
return true;
}
break;

case TypePrimitive.Int:
if (ParameterDefaults.TryGetLong(parameterName, out var i))
{
value = new JValue(i);
return true;
}
break;

case TypePrimitive.Array:
if (ParameterDefaults.TryGetArray(parameterName, out var a))
{
value = a;
return true;
}
break;

case TypePrimitive.Object:
case TypePrimitive.SecureObject:
if (ParameterDefaults.TryGetObject(parameterName, out var o))
{
value = o;
return true;
}
break;
};
return false;
}

internal bool TryParameter(string parameterName)
Expand Down Expand Up @@ -1492,7 +1528,6 @@ private TemplateContext GetDeploymentContext(TemplateContext context, string dep
var resourceGroup = new ResourceGroupOption(context.ResourceGroup);
var tenant = new TenantOption(context.Tenant);
var managementGroup = new ManagementGroupOption(context.ManagementGroup);
var parameterDefaults = new ParameterDefaultsOption(context.ParameterDefaults);
if (TryStringProperty(resource, PROPERTY_SUBSCRIPTIONID, out var subscriptionId))
{
var targetSubscriptionId = ExpandString(context, subscriptionId);
Expand All @@ -1510,7 +1545,7 @@ private TemplateContext GetDeploymentContext(TemplateContext context, string dep
resourceGroup.SubscriptionId = subscription.SubscriptionId;
TryObjectProperty(template, PROPERTY_PARAMETERS, out var templateParameters);

var deploymentContext = new TemplateContext(context.Pipeline, subscription, resourceGroup, tenant, managementGroup, parameterDefaults);
var deploymentContext = new TemplateContext(context.Pipeline, subscription, resourceGroup, tenant, managementGroup, context.ParameterDefaults);

// Handle custom type definitions early to allow type mapping of parameters if required.
if (TryObjectProperty(template, PROPERTY_DEFINITIONS, out var definitions))
Expand Down
33 changes: 28 additions & 5 deletions src/PSRule.Rules.Azure/DictionaryExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using Newtonsoft.Json.Linq;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
Expand Down Expand Up @@ -82,12 +83,14 @@ public static bool TryPopBool(this IDictionary<string, object> dictionary, strin
public static bool TryGetString(this IDictionary<string, object> dictionary, string key, out string value)
{
value = null;
if (!dictionary.TryGetValue(key, out var o))
return false;

if (o is string result)
if (dictionary.TryGetValue(key, out var o) && o is string s)
{
value = result;
value = s;
return true;
}
if (dictionary.TryGetValue(key, out s))
{
value = s;
return true;
}
return false;
Expand Down Expand Up @@ -121,6 +124,26 @@ public static bool TryGetLong(this IDictionary<string, object> dictionary, strin
return false;
}

public static bool TryGetArray(this IDictionary<string, object> dictionary, string parameterName, out JToken value)
{
value = default;
if (!dictionary.TryGetValue<List<object>>(parameterName, out var result))
return false;

value = JArray.FromObject(result);
return true;
}

public static bool TryGetObject(this IDictionary<string, object> dictionary, string parameterName, out JToken value)
{
value = default;
if (!dictionary.TryGetValue<Dictionary<object, object>>(parameterName, out var result))
return false;

value = JObject.FromObject(result);
return true;
}

/// <summary>
/// Add an item to the dictionary if it doesn't already exist in the dictionary.
/// </summary>
Expand Down
14 changes: 12 additions & 2 deletions src/PSRule.Rules.Azure/PSObjectExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,19 @@ internal static bool GetPath(this PSObject sourceObject, out string path)
internal static Hashtable ToHashtable(this PSObject o)
{
var result = new Hashtable();
foreach (var p in o.Properties)
if (o.BaseObject is IDictionary d)
{
result[p.Name] = p.Value;
foreach (var k in d.Keys)
{
result[k.ToString()] = d[k];
}
}
else
{
foreach (var p in o.Properties)
{
result[p.Name] = p.Value;
}
}
return result;
}
Expand Down
26 changes: 13 additions & 13 deletions tests/PSRule.Rules.Azure.Tests/ResourceHelperTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,29 +53,29 @@ public void TryResourceIdComponentsFromResourceId()
"Microsoft.Network/virtualNetworks/vnet-A/subnets/GatewaySubnet"
};

Assert.True(ResourceHelper.TryResourceIdComponents(id[0], out var subscriptionId, out var resourceGroupName, out string[] resourceTypeComponents, out string[] nameComponents));
Assert.True(ResourceHelper.TryResourceIdComponents(id[0], out var subscriptionId, out var resourceGroupName, out string[]? resourceTypeComponents, out string[]? nameComponents));
Assert.Equal("00000000-0000-0000-0000-000000000000", subscriptionId);
Assert.Equal("rg-test", resourceGroupName);
Assert.Equal("microsoft.operationalinsights/workspaces", resourceTypeComponents[0]);
Assert.Equal("workspace001", nameComponents[0]);
Assert.Equal("microsoft.operationalinsights/workspaces", resourceTypeComponents?[0]);
Assert.Equal("workspace001", nameComponents?[0]);

Assert.True(ResourceHelper.TryResourceIdComponents(id[1], out subscriptionId, out resourceGroupName, out resourceTypeComponents, out nameComponents));
Assert.Equal("ffffffff-ffff-ffff-ffff-ffffffffffff", subscriptionId);
Assert.Equal("test-rg", resourceGroupName);
Assert.Equal("Microsoft.Network/virtualNetworks", resourceTypeComponents[0]);
Assert.Equal("subnets", resourceTypeComponents[1]);
Assert.Equal("vnet-A", nameComponents[0]);
Assert.Equal("GatewaySubnet", nameComponents[1]);
Assert.Equal("Microsoft.Network/virtualNetworks", resourceTypeComponents?[0]);
Assert.Equal("subnets", resourceTypeComponents?[1]);
Assert.Equal("vnet-A", nameComponents?[0]);
Assert.Equal("GatewaySubnet", nameComponents?[1]);

Assert.True(ResourceHelper.TryResourceIdComponents(id[2], out _, out _, out resourceTypeComponents, out nameComponents));
Assert.Equal("Microsoft.Network/virtualNetworks", resourceTypeComponents[0]);
Assert.Equal("vnet-A", nameComponents[0]);
Assert.Equal("Microsoft.Network/virtualNetworks", resourceTypeComponents?[0]);
Assert.Equal("vnet-A", nameComponents?[0]);

Assert.True(ResourceHelper.TryResourceIdComponents(id[3], out _, out _, out resourceTypeComponents, out nameComponents));
Assert.Equal("Microsoft.Network/virtualNetworks", resourceTypeComponents[0]);
Assert.Equal("subnets", resourceTypeComponents[1]);
Assert.Equal("vnet-A", nameComponents[0]);
Assert.Equal("GatewaySubnet", nameComponents[1]);
Assert.Equal("Microsoft.Network/virtualNetworks", resourceTypeComponents?[0]);
Assert.Equal("subnets", resourceTypeComponents?[1]);
Assert.Equal("vnet-A", nameComponents?[0]);
Assert.Equal("GatewaySubnet", nameComponents?[1]);
}

[Fact]
Expand Down
26 changes: 25 additions & 1 deletion tests/PSRule.Rules.Azure.Tests/RuntimeServiceTests.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System.Collections;
using System.Management.Automation;
using PSRule.Rules.Azure.Runtime;

Expand Down Expand Up @@ -89,7 +90,7 @@ public void WithAzureDeployment_WhenValid_ShouldSetOptions()
}

[Fact]
public void WithParameterDefaults_WhenValid_ShouldSetOptions()
public void WithParameterDefaults_WhenValidPSObject_ShouldSetOptions()
{
var runtime = GetRuntimeService();
var pso = new PSObject();
Expand All @@ -107,6 +108,29 @@ public void WithParameterDefaults_WhenValid_ShouldSetOptions()
Assert.Equal("2", value2);
}

[Fact]
public void WithParameterDefaults_WhenValidHashtable_ShouldSetOptions()
{
var runtime = GetRuntimeService();
var hashtable = new Hashtable
{
["value1"] = "1",
["value2"] = "2"
};

var pso = new PSObject(hashtable);

// Act
runtime.WithParameterDefaults(pso);

// Assert
var actual = runtime.ToPSRuleOption();
Assert.True(actual.Configuration.ParameterDefaults.TryGetString("value1", out string value1));
Assert.Equal("1", value1);
Assert.True(actual.Configuration.ParameterDefaults.TryGetString("value2", out string value2));
Assert.Equal("2", value2);
}

#region Helper methods

private static RuntimeService GetRuntimeService()
Expand Down

0 comments on commit 267fb26

Please sign in to comment.