Skip to content

Commit 3d47105

Browse files
committed
fixed passing by value vs ref bug and added UTs
1 parent 26b2525 commit 3d47105

5 files changed

Lines changed: 155 additions & 85 deletions

File tree

internal/cmds/cmds.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ func enable(ctx *log.Context, h types.HandlerEnvironment, report *types.RunComma
221221
var rceps *extensionpolicysettingsrc.RCv2ExtensionPolicySettings
222222

223223
if _, err := os.Stat(policyPath); err == nil {
224-
err = extensionpolicysettingsrc.InitializeExtensionPolicySettings(ExtensionPolicyManagerPtr, policyPath, rceps)
224+
ExtensionPolicyManagerPtr, rceps, err = extensionpolicysettingsrc.InitializeExtensionPolicySettings(policyPath)
225225
if err != nil {
226226
return "", "", errors.Wrap(err, "failed to initialize extension policy settings"), constants.ExitCode_LoadExtensionPolicySettingsFailed
227227
}

internal/cmds/cmds_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package commands
22

33
import (
4+
"crypto/sha256"
5+
"encoding/hex"
46
"encoding/json"
57
"errors"
68
"io/ioutil"
@@ -17,6 +19,7 @@ import (
1719
"github.com/Azure/azure-extension-platform/pkg/handlerenv"
1820
"github.com/Azure/azure-extension-platform/pkg/logging"
1921
"github.com/Azure/run-command-handler-linux/internal/constants"
22+
"github.com/Azure/run-command-handler-linux/internal/extensionpolicysettingsrc"
2023
"github.com/Azure/run-command-handler-linux/internal/files"
2124
"github.com/Azure/run-command-handler-linux/internal/handlersettings"
2225
"github.com/Azure/run-command-handler-linux/internal/settings"
@@ -1445,3 +1448,77 @@ func mustReadFile(t *testing.T, p string) string {
14451448
}
14461449
return string(b)
14471450
}
1451+
1452+
// Test_downloadScript_BlockedByAllowlist verifies that downloadScript returns an error
1453+
// when the policy allows downloaded scripts (alloweddownloaded) but the script's
1454+
// SHA256 hash is not in the DownloadedScriptsAllowlist.
1455+
func Test_downloadScript_BlockedByAllowlist(t *testing.T) {
1456+
dir, err := ioutil.TempDir("", "")
1457+
require.Nil(t, err)
1458+
defer os.RemoveAll(dir)
1459+
1460+
scriptContent := []byte("#!/bin/bash\necho hello\n")
1461+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1462+
w.WriteHeader(http.StatusOK)
1463+
w.Write(scriptContent)
1464+
}))
1465+
defer srv.Close()
1466+
1467+
policy := &extensionpolicysettingsrc.RCv2ExtensionPolicySettings{
1468+
LimitScripts: "alloweddownloaded",
1469+
// A wrong hash — the script's actual hash is not this.
1470+
DownloadedScriptsAllowlist: []string{"0000000000000000000000000000000000000000000000000000000000000000"},
1471+
}
1472+
1473+
_, err = downloadScript(log.NewContext(log.NewNopLogger()),
1474+
dir,
1475+
&handlersettings.HandlerSettings{
1476+
PublicSettings: handlersettings.PublicSettings{
1477+
Source: &handlersettings.ScriptSource{ScriptURI: srv.URL + "/script.sh"},
1478+
},
1479+
},
1480+
policy,
1481+
)
1482+
require.Error(t, err)
1483+
require.Contains(t, err.Error(), "blocked by policy")
1484+
require.Contains(t, err.Error(), "item is not in the allowlist")
1485+
}
1486+
1487+
// Test_downloadScript_AllowedByAllowlist verifies that downloadScript succeeds
1488+
// when the policy allows downloaded scripts and the script's SHA256 hash IS
1489+
// present in the DownloadedScriptsAllowlist.
1490+
func Test_downloadScript_AllowedByAllowlist(t *testing.T) {
1491+
dir, err := os.MkdirTemp("", "")
1492+
require.Nil(t, err)
1493+
defer os.RemoveAll(dir)
1494+
1495+
// Content uses Unix LF only and has no BOM, so PostProcessFile leaves bytes
1496+
// unchanged, making the pre-computed hash match the on-disk file hash.
1497+
scriptContent := []byte("#!/bin/bash\necho hello\n")
1498+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1499+
w.WriteHeader(http.StatusOK)
1500+
w.Write(scriptContent)
1501+
}))
1502+
defer srv.Close()
1503+
1504+
// Compute the SHA256 hash that ValidateFileHashInAllowlist will compare against.
1505+
h := sha256.New()
1506+
h.Write(scriptContent)
1507+
correctHash := hex.EncodeToString(h.Sum(nil))
1508+
1509+
policy := &extensionpolicysettingsrc.RCv2ExtensionPolicySettings{
1510+
LimitScripts: "alloweddownloaded",
1511+
DownloadedScriptsAllowlist: []string{correctHash},
1512+
}
1513+
1514+
_, err = downloadScript(log.NewContext(log.NewNopLogger()),
1515+
dir,
1516+
&handlersettings.HandlerSettings{
1517+
PublicSettings: handlersettings.PublicSettings{
1518+
Source: &handlersettings.ScriptSource{ScriptURI: srv.URL + "/script.sh"},
1519+
},
1520+
},
1521+
policy,
1522+
)
1523+
require.NoError(t, err)
1524+
}

internal/extensionpolicysettingsrc/extensionpolicysettingsrc.go

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,26 @@ import (
99
"github.com/pkg/errors"
1010
)
1111

12-
func InitializeExtensionPolicySettings(ExtensionPolicyManagerPtr *extensionpolicysettings.ExtensionPolicySettingsManager[RCv2ExtensionPolicySettings],
13-
policyPath string,
14-
rceps *RCv2ExtensionPolicySettings) error {
12+
func InitializeExtensionPolicySettings(policyPath string) (*extensionpolicysettings.ExtensionPolicySettingsManager[RCv2ExtensionPolicySettings], *RCv2ExtensionPolicySettings, error) {
13+
var ExtensionPolicyManagerPtr *extensionpolicysettings.ExtensionPolicySettingsManager[RCv2ExtensionPolicySettings]
14+
var rceps *RCv2ExtensionPolicySettings
15+
1516
ExtensionPolicyManagerPtr, err := extensionpolicysettings.NewExtensionPolicySettingsManager[RCv2ExtensionPolicySettings](policyPath)
1617
if err != nil {
17-
return errors.Wrap(err, "failed to create extension policy settings manager")
18+
return nil, nil, errors.Wrap(err, "failed to create extension policy settings manager")
1819
}
1920

2021
err = ExtensionPolicyManagerPtr.LoadExtensionPolicySettings()
2122
if err != nil {
22-
return errors.Wrap(err, "failed to load extension policy settings")
23+
return nil, nil, errors.Wrap(err, "failed to load extension policy settings")
2324
} else {
2425
rceps, err = ExtensionPolicyManagerPtr.GetSettings()
2526

2627
if err != nil {
27-
return errors.Wrap(err, "failed to get extension policy settings")
28+
return nil, nil, errors.Wrap(err, "failed to get extension policy settings")
2829
}
2930
}
30-
return nil
31+
return ExtensionPolicyManagerPtr, rceps, nil
3132
}
3233

3334
func InitialValidateHandlerSettingsAgainstPolicy(settings *handlersettings.HandlerSettings, policy *RCv2ExtensionPolicySettings) error {
@@ -47,9 +48,9 @@ func InitialValidateHandlerSettingsAgainstPolicy(settings *handlersettings.Handl
4748
return err
4849
}
4950
}
50-
if policy.DisableOutputBlobs {
51-
ValidateOutputBlob(settings, policy)
52-
}
51+
52+
// TO-DO: Validate Disable Outputblob and RequireSigning once those features are implemented for RCv2.
53+
5354
return nil
5455
}
5556

@@ -71,7 +72,11 @@ func ValidateCommandId(settings *handlersettings.HandlerSettings, policy *RCv2Ex
7172
// if list is empty, all commandIds are allowed
7273
return nil
7374
}
74-
return extensionpolicysettings.ValidateValueInAllowlist(settingsCommandId, allowedCommandIds)
75+
err := extensionpolicysettings.ValidateValueInAllowlist(settingsCommandId, allowedCommandIds)
76+
if err != nil {
77+
return errors.Wrapf(err, "command ID %s is not allowed by policy", settingsCommandId)
78+
}
79+
return nil
7580
}
7681

7782
func ValidateRunAsUser(settings *handlersettings.HandlerSettings, policy *RCv2ExtensionPolicySettings) error {
@@ -83,14 +88,3 @@ func ValidateRunAsUser(settings *handlersettings.HandlerSettings, policy *RCv2Ex
8388
}
8489
return nil
8590
}
86-
87-
func ValidateOutputBlob(settings *handlersettings.HandlerSettings, policy *RCv2ExtensionPolicySettings) {
88-
if policy.DisableOutputBlobs {
89-
// Log a warning that output blobs are disabled by policy. The command will still execute, but no output blobs will be created.
90-
if settings.OutputBlobURI != "" {
91-
fmt.Println("Warning: Output blobs are disabled by policy. The provided output blob URI will be ignored and no output blobs will be created for this command.")
92-
} else {
93-
fmt.Println("Warning: Output blobs are disabled by policy. No output blobs will be created for this command.")
94-
}
95-
}
96-
}

internal/extensionpolicysettingsrc/extensionpolicysettingsrc_test.go

Lines changed: 56 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"path/filepath"
77
"testing"
88

9-
"github.com/Azure/azure-extension-platform/pkg/extensionpolicysettings"
109
"github.com/Azure/run-command-handler-linux/internal/handlersettings"
1110
"github.com/stretchr/testify/require"
1211
)
@@ -44,10 +43,7 @@ func captureStdout(t *testing.T, fn func()) string {
4443
}
4544

4645
func TestInitializeExtensionPolicySettings_InvalidPath_ReturnsError(t *testing.T) {
47-
var mgr *extensionpolicysettings.ExtensionPolicySettingsManager[RCv2ExtensionPolicySettings]
48-
out := &RCv2ExtensionPolicySettings{}
49-
50-
err := InitializeExtensionPolicySettings(mgr, "/definitely/not/found/policy.json", out)
46+
_, _, err := InitializeExtensionPolicySettings("/definitely/not/found/policy.json")
5147
require.Error(t, err)
5248
require.Contains(t, err.Error(), "failed to")
5349
}
@@ -60,10 +56,7 @@ func TestInitializeExtensionPolicySettings_ValidFile_ReturnsNil(t *testing.T) {
6056
err := os.WriteFile(policyPath, []byte("{}"), 0600)
6157
require.NoError(t, err)
6258

63-
var mgr *extensionpolicysettings.ExtensionPolicySettingsManager[RCv2ExtensionPolicySettings]
64-
out := &RCv2ExtensionPolicySettings{}
65-
66-
err = InitializeExtensionPolicySettings(mgr, policyPath, out)
59+
_, _, err = InitializeExtensionPolicySettings(policyPath)
6760
require.NoError(t, err)
6861
}
6962

@@ -75,17 +68,16 @@ func TestInitializeExtensionPolicySettings_CurrentBehavior_DoesNotPopulateOutput
7568
err := os.WriteFile(policyPath, []byte(payload), 0600)
7669
require.NoError(t, err)
7770

78-
var mgr *extensionpolicysettings.ExtensionPolicySettingsManager[RCv2ExtensionPolicySettings]
7971
out := &RCv2ExtensionPolicySettings{}
8072

81-
err = InitializeExtensionPolicySettings(mgr, policyPath, out)
73+
_, out, err = InitializeExtensionPolicySettings(policyPath)
8274
require.NoError(t, err)
8375

84-
// Documents current implementation behavior (pointer reassignment inside function).
85-
require.Equal(t, "", out.LimitScripts)
86-
require.Equal(t, "", out.RunAsUser)
76+
require.Equal(t, "inline", out.LimitScripts)
77+
require.Equal(t, "alice", out.RunAsUser)
8778
}
8879

80+
// Test that validation passes and fails as expected.
8981
func TestInitialValidateHandlerSettingsAgainstPolicy(t *testing.T) {
9082
t.Run("nil policy", func(t *testing.T) {
9183
settings := makeSettings(handlersettings.InlineScript, "", "", "")
@@ -94,6 +86,8 @@ func TestInitialValidateHandlerSettingsAgainstPolicy(t *testing.T) {
9486
require.Contains(t, err.Error(), "no policy provided")
9587
})
9688

89+
// This test mimicks running an inline script, but policy only allows gallery scripts.
90+
// Validation fails.
9791
t.Run("script type blocked by policy", func(t *testing.T) {
9892
settings := makeSettings(handlersettings.InlineScript, "", "", "")
9993
policy := &RCv2ExtensionPolicySettings{
@@ -105,7 +99,9 @@ func TestInitialValidateHandlerSettingsAgainstPolicy(t *testing.T) {
10599
require.Contains(t, err.Error(), "script type inline is not allowed by policy")
106100
})
107101

108-
t.Run("command id not in allowlist", func(t *testing.T) {
102+
// This test mimicks running a commandId that is not in the allowlist.
103+
// Additionally, only commandId types are allowed.
104+
t.Run("command ID not in allowlist", func(t *testing.T) {
109105
settings := makeSettings(handlersettings.CommandIdScript, "restartVM", "", "")
110106
policy := &RCv2ExtensionPolicySettings{
111107
LimitScripts: "allowedcommandid",
@@ -128,7 +124,33 @@ func TestInitialValidateHandlerSettingsAgainstPolicy(t *testing.T) {
128124
require.Contains(t, err.Error(), "does not match")
129125
})
130126

131-
t.Run("all checks pass", func(t *testing.T) {
127+
t.Run("enforce limitScripts must be set. If not set, all commands fail", func(t *testing.T) {
128+
settings := makeSettings(handlersettings.CommandIdScript, "safeCommand", " Alice ", "https://example/blob")
129+
policy := &RCv2ExtensionPolicySettings{
130+
LimitScripts: "",
131+
CommandIdAllowlist: []string{"safeCommand"},
132+
RunAsUser: "Alice",
133+
DisableOutputBlobs: true,
134+
}
135+
136+
err := InitialValidateHandlerSettingsAgainstPolicy(settings, policy)
137+
require.Contains(t, err.Error(), "script type commandId is not allowed by policy")
138+
})
139+
140+
t.Run("all checks pass commandId", func(t *testing.T) {
141+
settings := makeSettings(handlersettings.CommandIdScript, "safeCommand", " Alice ", "https://example/blob")
142+
policy := &RCv2ExtensionPolicySettings{
143+
LimitScripts: "allowall",
144+
CommandIdAllowlist: []string{"safeCommand"},
145+
RunAsUser: "alice",
146+
DisableOutputBlobs: true,
147+
}
148+
149+
err := InitialValidateHandlerSettingsAgainstPolicy(settings, policy)
150+
require.NoError(t, err)
151+
})
152+
153+
t.Run("all checks pass commandId", func(t *testing.T) {
132154
settings := makeSettings(handlersettings.CommandIdScript, "safeCommand", " Alice ", "https://example/blob")
133155
policy := &RCv2ExtensionPolicySettings{
134156
LimitScripts: "allowall",
@@ -140,6 +162,19 @@ func TestInitialValidateHandlerSettingsAgainstPolicy(t *testing.T) {
140162
err := InitialValidateHandlerSettingsAgainstPolicy(settings, policy)
141163
require.NoError(t, err)
142164
})
165+
166+
t.Run("all checks pass downloadedScript", func(t *testing.T) {
167+
settings := makeSettings(handlersettings.DownloadedScript, "safeCommand", " Alice ", "https://example/blob")
168+
policy := &RCv2ExtensionPolicySettings{
169+
LimitScripts: "alloweddownloaded",
170+
CommandIdAllowlist: []string{"safeCommand"},
171+
RunAsUser: "alice",
172+
DisableOutputBlobs: true,
173+
}
174+
175+
err := InitialValidateHandlerSettingsAgainstPolicy(settings, policy)
176+
require.NoError(t, err)
177+
})
143178
}
144179

145180
func TestValidateScriptTypeAgainstPolicy(t *testing.T) {
@@ -154,7 +189,8 @@ func TestValidateScriptTypeAgainstPolicy(t *testing.T) {
154189
require.Contains(t, err.Error(), "script type gallery is not allowed by policy")
155190
})
156191

157-
t.Run("invalid policy token currently treated as blocked", func(t *testing.T) {
192+
// This tests edge case where policy has an invalid script type token.
193+
t.Run("invalid policy token is treated as blocked", func(t *testing.T) {
158194
err := ValidateScriptTypeAgainstPolicy(handlersettings.InlineScript, "notARealScriptType")
159195
require.Error(t, err)
160196
require.Contains(t, err.Error(), "script type inline is not allowed by policy")
@@ -186,7 +222,8 @@ func TestValidateCommandId(t *testing.T) {
186222
CommandIdAllowlist: []string{"safeCommand", "other"},
187223
}
188224
err := ValidateCommandId(settings, policy)
189-
require.Error(t, err)
225+
require.Contains(t, err.Error(), "command ID restartVM is not allowed by policy")
226+
require.Contains(t, err.Error(), "item is not in the allowlist")
190227
})
191228
}
192229

@@ -207,46 +244,6 @@ func TestValidateRunAsUser(t *testing.T) {
207244
}
208245
err := ValidateRunAsUser(settings, policy)
209246
require.Error(t, err)
210-
require.Contains(t, err.Error(), "does not match")
211-
})
212-
}
213-
214-
func TestValidateOutputBlob(t *testing.T) {
215-
t.Run("policy does not disable output blobs prints nothing", func(t *testing.T) {
216-
settings := makeSettings(handlersettings.InlineScript, "", "", "https://example/blob")
217-
policy := &RCv2ExtensionPolicySettings{
218-
DisableOutputBlobs: false,
219-
}
220-
221-
out := captureStdout(t, func() {
222-
ValidateOutputBlob(settings, policy)
223-
})
224-
require.Equal(t, "", out)
225-
})
226-
227-
t.Run("disabled with output blob uri prints ignore warning", func(t *testing.T) {
228-
settings := makeSettings(handlersettings.InlineScript, "", "", "https://example/blob")
229-
policy := &RCv2ExtensionPolicySettings{
230-
DisableOutputBlobs: true,
231-
}
232-
233-
out := captureStdout(t, func() {
234-
ValidateOutputBlob(settings, policy)
235-
})
236-
require.Contains(t, out, "Output blobs are disabled by policy")
237-
require.Contains(t, out, "provided output blob URI will be ignored")
238-
})
239-
240-
t.Run("disabled without output blob uri prints no blob warning", func(t *testing.T) {
241-
settings := makeSettings(handlersettings.InlineScript, "", "", "")
242-
policy := &RCv2ExtensionPolicySettings{
243-
DisableOutputBlobs: true,
244-
}
245-
246-
out := captureStdout(t, func() {
247-
ValidateOutputBlob(settings, policy)
248-
})
249-
require.Contains(t, out, "Output blobs are disabled by policy")
250-
require.Contains(t, out, "No output blobs will be created")
247+
require.Contains(t, err.Error(), "RunAsUser 'bob' in settings does not match RunAsUser 'alice' in policy")
251248
})
252249
}

internal/extensionpolicysettingsrc/types.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ func (rceps RCv2ExtensionPolicySettings) ValidateFormat() error {
8686
flag, err := StringToAllowedScriptTypeFlag(string(rceps.LimitScripts))
8787
// Requirements:
8888
// 1. If RequireSigning is not "none", FileRootCert must be present and non-empty.
89+
// TO-DO: implement RequireSigning and FileRootCert validation once signature verification is implemented for RCv2.
8990
// 2. LimitScripts must be a valid AllowedScriptType value. so map/check the value to the AllowedScriptTypeFlag bitmask.
9091
if rceps.LimitScripts != "" {
9192
if err != nil {
@@ -107,9 +108,10 @@ func (rceps RCv2ExtensionPolicySettings) ValidateFormat() error {
107108
return nil
108109
}
109110

110-
// This function compares a script type (of type ScriptType, defined in this file) to the allowed script types
111-
// (of type AllowedScriptTypeFlag, also defined in this file) listed in the policy. These values and mappings
112-
// are specific to Run Command, hence why they are defined here and not in the shared library.
111+
// This function compares a script type (of string type ScriptType, defined in this file) to the allowed script types
112+
// (of type AllowedScriptTypeFlag, also defined in this file) listed in the policy.
113+
// Depending on the string case (the value of scriptType), it checks if the corresponding bit is enabled in the allowed script types bitmask.
114+
// These values and mappings are specific to Run Command, hence why they are defined here and not in the shared library.
113115
func CompareScriptTypeToAllowedScriptType(scriptType handlersettings.ScriptType, allowedScriptTypes AllowedScriptTypeFlag) error {
114116
switch scriptType {
115117
case handlersettings.InlineScript:

0 commit comments

Comments
 (0)