diff --git a/pkg/extensionpolicysettings/extensionpolicysettings.go b/pkg/extensionpolicysettings/extensionpolicysettings.go index b33982f..9252617 100644 --- a/pkg/extensionpolicysettings/extensionpolicysettings.go +++ b/pkg/extensionpolicysettings/extensionpolicysettings.go @@ -1,15 +1,13 @@ package extensionpolicysettings import ( - "crypto/sha1" - "crypto/sha256" - "encoding/hex" "encoding/json" "fmt" "os" "strings" "github.com/Azure/azure-extension-platform/pkg/extensionerrors" + "github.com/Azure/azure-extension-platform/pkg/hashutils" ) type ExtensionPolicySettings interface { @@ -76,14 +74,6 @@ func (epsm *ExtensionPolicySettingsManager[T]) GetSettings() (*T, error) { } // Validation Helper Functions -type HashType int - -const ( - HashTypeNone HashType = iota - HashTypeSHA1 - HashTypeSHA256 -) - func ValidateValueInAllowlist(value string, allowlist []string) error { if len(allowlist) == 0 { return extensionerrors.ErrPolicyAllowlistEmpty @@ -104,7 +94,7 @@ func ValidateValueInAllowlist(value string, allowlist []string) error { // determines if the content is allowlisted. If hashOpt is not HashTypeNone, it will compute the hash of the file content. // If extensions don't want to validate a filepath but a value directly, they can call ValidateValueInAllowlist, // which this function calls. -func ValidateFileHashInAllowlist(filePath string, allowlist []string, hashOpt HashType) error { +func ValidateFileHashInAllowlist(filePath string, allowlist []string, hashOpt hashutils.HashType) error { if len(allowlist) == 0 { return extensionerrors.ErrPolicyAllowlistEmpty } @@ -117,35 +107,22 @@ func ValidateFileHashInAllowlist(filePath string, allowlist []string, hashOpt Ha return fmt.Errorf("file to validate does not exist: %w", err) } - content, err := os.ReadFile(filePath) - if err != nil { - return fmt.Errorf("failed to read file %s for validation: %w", filePath, err) - } - - value := string(content) - - if hashOpt != HashTypeNone { - value, err := ComputeFileHash(value, hashOpt) + if hashOpt == hashutils.HashTypeNone { + // If no hashing is needed, we can directly validate the file content against the allowlist. + content, err := os.ReadFile(filePath) if err != nil { - return fmt.Errorf("error occured when hashing contents of file %s for validation: %w", filePath, err) + return fmt.Errorf("failed to read file %s for validation: %w", filePath, err) } - return ValidateValueInAllowlist(value, allowlist) + return ValidateValueInAllowlist(string(content), allowlist) } - return ValidateValueInAllowlist(value, allowlist) -} - -// ComputeFileHash computes the hash of a file or leaves string as is. -func ComputeFileHash(contents string, hashOpt HashType) (string, error) { - var hashStr string - switch hashOpt { - case HashTypeSHA1: - hash := sha1.Sum([]byte(contents)) - hashStr = hex.EncodeToString(hash[:]) - default: - hash := sha256.Sum256([]byte(contents)) - hashStr = hex.EncodeToString(hash[:]) + hashAlg, err := hashutils.GetHashAlgorithm(hashutils.HashType(hashOpt)) + if err != nil { + return fmt.Errorf("error occured when getting hash algorithm for file %s: %w", filePath, err) } - - return hashStr, nil + value, err := hashutils.ComputeFileHash(filePath, hashAlg) + if err != nil { + return fmt.Errorf("error occured when hashing contents of file %s for validation: %w", filePath, err) + } + return ValidateValueInAllowlist(value, allowlist) } diff --git a/pkg/extensionpolicysettings/extensionpolicysettings_test.go b/pkg/extensionpolicysettings/extensionpolicysettings_test.go index 0f556da..73c09ea 100644 --- a/pkg/extensionpolicysettings/extensionpolicysettings_test.go +++ b/pkg/extensionpolicysettings/extensionpolicysettings_test.go @@ -12,6 +12,7 @@ import ( "testing" "github.com/Azure/azure-extension-platform/pkg/extensionerrors" + "github.com/Azure/azure-extension-platform/pkg/hashutils" "github.com/stretchr/testify/require" ) @@ -184,19 +185,19 @@ func TestValidateAgainstAllowlist(t *testing.T) { require.Equal(t, "true", manager.settings.RequiresSigning) require.NotEmpty(t, manager.settings.AllowedScripts) - require.NoError(t, ValidateFileHashInAllowlist("./testutils/testscripts/script1.sh", manager.settings.AllowedScripts, HashTypeSHA256)) - require.NoError(t, ValidateFileHashInAllowlist("./testutils/testscripts/script2.sh", manager.settings.AllowedScripts, HashTypeSHA256)) - require.ErrorIs(t, ValidateFileHashInAllowlist("./testutils/testscripts/script3.sh", manager.settings.AllowedScripts, HashTypeSHA256), extensionerrors.ErrItemNotInAllowlist) - require.NoError(t, ValidateFileHashInAllowlist("./testutils/testscripts/script5.sh", manager.settings.AllowedScripts, HashTypeSHA1)) + require.NoError(t, ValidateFileHashInAllowlist("./testutils/testscripts/script1.sh", manager.settings.AllowedScripts, hashutils.HashTypeSHA256)) + require.NoError(t, ValidateFileHashInAllowlist("./testutils/testscripts/script2.sh", manager.settings.AllowedScripts, hashutils.HashTypeSHA256)) + require.ErrorIs(t, ValidateFileHashInAllowlist("./testutils/testscripts/script3.sh", manager.settings.AllowedScripts, hashutils.HashTypeSHA256), extensionerrors.ErrItemNotInAllowlist) + require.NoError(t, ValidateFileHashInAllowlist("./testutils/testscripts/script5.sh", manager.settings.AllowedScripts, hashutils.HashTypeSHA1)) // Empty filepath - require.ErrorIs(t, ValidateFileHashInAllowlist("", manager.settings.AllowedScripts, HashTypeSHA256), extensionerrors.ErrEmptyFilepathToValidate) + require.ErrorIs(t, ValidateFileHashInAllowlist("", manager.settings.AllowedScripts, hashutils.HashTypeSHA256), extensionerrors.ErrEmptyFilepathToValidate) // Missing file - require.Error(t, ValidateFileHashInAllowlist("./testutils/testscripts/missing.sh", manager.settings.AllowedScripts, HashTypeSHA256)) + require.Error(t, ValidateFileHashInAllowlist("./testutils/testscripts/missing.sh", manager.settings.AllowedScripts, hashutils.HashTypeSHA256)) // Now, empty list. - require.ErrorIs(t, ValidateFileHashInAllowlist("./testutils/testscripts/script1.sh", []string{}, HashTypeSHA256), extensionerrors.ErrPolicyAllowlistEmpty) + require.ErrorIs(t, ValidateFileHashInAllowlist("./testutils/testscripts/script1.sh", []string{}, hashutils.HashTypeSHA256), extensionerrors.ErrPolicyAllowlistEmpty) // Empty file - require.NoError(t, ValidateFileHashInAllowlist("./testutils/testscripts/script4.sh", manager.settings.AllowedScripts, HashTypeSHA256)) + require.NoError(t, ValidateFileHashInAllowlist("./testutils/testscripts/script4.sh", manager.settings.AllowedScripts, hashutils.HashTypeSHA256)) } @@ -207,28 +208,8 @@ func TestValidateFileHashInAllowlist_HashTypeNone_UsesRawContent(t *testing.T) { require.NoError(t, writeToFile(filePath, content)) defer cleanupFile(filePath) - require.NoError(t, ValidateFileHashInAllowlist(filePath, []string{content}, HashTypeNone)) - require.ErrorIs(t, ValidateFileHashInAllowlist(filePath, []string{"different-content"}, HashTypeNone), extensionerrors.ErrItemNotInAllowlist) -} - -func TestComputeFileHash(t *testing.T) { - input := "abc" - - sha1Expected := sha1.Sum([]byte(input)) - sha256Expected := sha256.Sum256([]byte(input)) - - gotSHA1, err := ComputeFileHash(input, HashTypeSHA1) - require.NoError(t, err) - require.Equal(t, hex.EncodeToString(sha1Expected[:]), gotSHA1) - - gotSHA256, err := ComputeFileHash(input, HashTypeSHA256) - require.NoError(t, err) - require.Equal(t, hex.EncodeToString(sha256Expected[:]), gotSHA256) - - // Current behavior: unknown hash type falls back to SHA256. - gotUnknown, err := ComputeFileHash(input, HashType(999)) - require.NoError(t, err) - require.Equal(t, hex.EncodeToString(sha256Expected[:]), gotUnknown) + require.NoError(t, ValidateFileHashInAllowlist(filePath, []string{content}, hashutils.HashTypeNone)) + require.ErrorIs(t, ValidateFileHashInAllowlist(filePath, []string{"different-content"}, hashutils.HashTypeNone), extensionerrors.ErrItemNotInAllowlist) } // Helper functions for tests diff --git a/pkg/hashutils/hashutils.go b/pkg/hashutils/hashutils.go new file mode 100644 index 0000000..007dabc --- /dev/null +++ b/pkg/hashutils/hashutils.go @@ -0,0 +1,65 @@ +package hashutils + +import ( + "crypto/sha1" + "crypto/sha256" + "encoding/hex" + "fmt" + "hash" + "io" + "os" +) + +type HashType int + +const ( + HashTypeNone HashType = 0 + HashTypeSHA1 HashType = 1 + HashTypeSHA256 HashType = 2 +) + +func GetHashAlgorithm(hashOpt HashType) (hash.Hash, error) { + switch hashOpt { + case HashTypeSHA1: + return sha1.New(), nil + case HashTypeSHA256: + return sha256.New(), nil + default: + return nil, fmt.Errorf("unsupported hash type option: %v", hashOpt) + } +} + +// This is a separate function from ComputeHash because streaming the file contents into the hasher is +// more efficient than reading the entire file into memory at once, especially for larger files. +func ComputeFileHash(filePath string, hashAlg hash.Hash) (string, error) { + // make sure filepath is not empty and file exists + if filePath == "" { + return "", fmt.Errorf("file path cannot be empty") + } + if _, err := os.Stat(filePath); os.IsNotExist(err) { + return "", fmt.Errorf("file does not exist at path: %s", filePath) + } + + f, err := os.Open(filePath) + if err != nil { + return "", fmt.Errorf("failed to open file for hashing: %w", err) + } + defer f.Close() + // We can stream the file contents to the hasher which is more efficient for large files. + if _, err := io.Copy(hashAlg, f); err != nil { + return "", fmt.Errorf("failed to read file for hashing: %w", err) + } + + hash := hashAlg.Sum(nil) + hashStr := hex.EncodeToString(hash[:]) + + return hashStr, nil +} + +func ComputeHash(contents string, hashAlg hash.Hash) string { + var hashStr string + hashAlg.Write([]byte(contents)) + hash := hashAlg.Sum(nil) + hashStr = hex.EncodeToString(hash[:]) + return hashStr +} diff --git a/pkg/hashutils/hashutils_test.go b/pkg/hashutils/hashutils_test.go new file mode 100644 index 0000000..818e9b4 --- /dev/null +++ b/pkg/hashutils/hashutils_test.go @@ -0,0 +1,137 @@ +package hashutils + +import ( + "crypto/sha1" + "crypto/sha256" + "encoding/hex" + "os" + "runtime" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestComputeFileHash_Success(t *testing.T) { + content := []byte("hello world") + tmpFile, err := os.CreateTemp(t.TempDir(), "hash_test_*.txt") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + if _, err := tmpFile.Write(content); err != nil { + t.Fatalf("failed to write to temp file: %v", err) + } + tmpFile.Close() + defer os.Remove(tmpFile.Name()) + + hashAlg := sha256.New() + got, err := ComputeFileHash(tmpFile.Name(), hashAlg) + + require.Nil(t, err, "expected no error from ComputeFileHash, got: %v", err) + require.NotEmpty(t, got, "expected non-empty hash result") + + // Verify hash matches expected + expected := sha256.Sum256(content) + require.Equal(t, hex.EncodeToString(expected[:]), got, "hash mismatch") +} + +func TestComputeFileHash_EmptyFilePath(t *testing.T) { + hashAlg := sha256.New() + _, err := ComputeFileHash("", hashAlg) + require.NotNil(t, err, "expected error for empty file path") + require.Equal(t, "file path cannot be empty", err.Error(), "unexpected error message") +} + +func TestComputeFileHash_FileDoesNotExist(t *testing.T) { + hashAlg := sha256.New() + nonExistentPath := "./nonexistent_file.txt" + _, err := ComputeFileHash(nonExistentPath, hashAlg) + require.NotNil(t, err, "expected error for non-existent file path") + require.Contains(t, err.Error(), "file does not exist at path", "unexpected error message") +} + +func TestComputeFileHash_FileNotReadable(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("permission-based test not reliable on Windows") + } + + tmpFile, err := os.CreateTemp(t.TempDir(), "no_read_*.txt") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + tmpFile.WriteString("some content") + tmpFile.Close() + defer os.Remove(tmpFile.Name()) + + // Remove read permission + if err := os.Chmod(tmpFile.Name(), 0000); err != nil { + t.Fatalf("failed to change file permissions: %v", err) + } + t.Cleanup(func() { + os.Chmod(tmpFile.Name(), 0644) // restore for cleanup + }) + + hashAlg := sha256.New() + _, err = ComputeFileHash(tmpFile.Name(), hashAlg) + require.NotNil(t, err, "expected error for unreadable file") + require.Contains(t, err.Error(), "failed to open file for hashing", "unexpected error message") +} + +func TestComputeFileHash_EmptyFile(t *testing.T) { + tmpFile, err := os.CreateTemp(t.TempDir(), "empty_*.txt") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + tmpFile.Close() + defer os.Remove(tmpFile.Name()) + + hashAlg := sha256.New() + got, err := ComputeFileHash(tmpFile.Name(), hashAlg) + if err != nil { + t.Fatalf("unexpected error for empty file: %v", err) + } + + expected := sha256.Sum256([]byte{}) + require.Equal(t, hex.EncodeToString(expected[:]), got, "hash mismatch for empty file") +} + +func TestComputeHash_Success(t *testing.T) { + input := "hello world" + hashAlg := sha256.New() + got := ComputeHash(input, hashAlg) + require.NotEmpty(t, got, "expected non-empty hash result") + + expected := sha256.Sum256([]byte(input)) + require.Equal(t, hex.EncodeToString(expected[:]), got, "hash mismatch") +} + +func TestComputeHash_EmptyString(t *testing.T) { + hashAlg := sha256.New() + got := ComputeHash("", hashAlg) + require.NotEmpty(t, got, "expected non-empty hash result for empty string") + + expected := sha256.Sum256([]byte{}) + require.Equal(t, hex.EncodeToString(expected[:]), got, "hash mismatch for empty string") +} + +func TestComputeHash_DifferentInputsDifferentHashes(t *testing.T) { + hash1 := ComputeHash("input1", sha256.New()) + hash2 := ComputeHash("input2", sha256.New()) + require.NotEqual(t, hash1, hash2, "expected different hashes for different inputs") +} + +func TestComputeHash_SameInputSameHash(t *testing.T) { + input := "consistent input" + hash1 := ComputeHash(input, sha256.New()) + hash2 := ComputeHash(input, sha256.New()) + require.Equal(t, hash1, hash2, "expected same hash for same input") +} + +func TestComputeHash_DifferentAlgorithm(t *testing.T) { + input := "test" + sha256Hash := ComputeHash(input, sha256.New()) + sha1Hash := ComputeHash(input, sha1.New()) + + require.NotEqual(t, sha256Hash, sha1Hash, "expected different hashes for different algorithms") + require.Equal(t, 64, len(sha256Hash), "expected SHA-256 hex string length 64") + require.Equal(t, 40, len(sha1Hash), "expected SHA-1 hex string length 40") +}