Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 15 additions & 38 deletions pkg/extensionpolicysettings/extensionpolicysettings.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
package extensionpolicysettings

import (
"crypto/sha1"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"strings"

"github.com/Azure/azure-extension-platform/pkg/extensionerrors"
"github.com/Azure/azure-extension-platform/pkg/hashutils"
)

type ExtensionPolicySettings interface {
Expand Down Expand Up @@ -76,14 +74,6 @@ func (epsm *ExtensionPolicySettingsManager[T]) GetSettings() (*T, error) {
}

// Validation Helper Functions
type HashType int

const (
HashTypeNone HashType = iota
HashTypeSHA1
HashTypeSHA256
)

func ValidateValueInAllowlist(value string, allowlist []string) error {
if len(allowlist) == 0 {
return extensionerrors.ErrPolicyAllowlistEmpty
Expand All @@ -104,7 +94,7 @@ func ValidateValueInAllowlist(value string, allowlist []string) error {
// determines if the content is allowlisted. If hashOpt is not HashTypeNone, it will compute the hash of the file content.
// If extensions don't want to validate a filepath but a value directly, they can call ValidateValueInAllowlist,
// which this function calls.
func ValidateFileHashInAllowlist(filePath string, allowlist []string, hashOpt HashType) error {
func ValidateFileHashInAllowlist(filePath string, allowlist []string, hashOpt hashutils.HashType) error {
if len(allowlist) == 0 {
return extensionerrors.ErrPolicyAllowlistEmpty
}
Expand All @@ -117,35 +107,22 @@ func ValidateFileHashInAllowlist(filePath string, allowlist []string, hashOpt Ha
return fmt.Errorf("file to validate does not exist: %w", err)
}

content, err := os.ReadFile(filePath)
if err != nil {
return fmt.Errorf("failed to read file %s for validation: %w", filePath, err)
}

value := string(content)

if hashOpt != HashTypeNone {
value, err := ComputeFileHash(value, hashOpt)
if hashOpt == hashutils.HashTypeNone {
// If no hashing is needed, we can directly validate the file content against the allowlist.
content, err := os.ReadFile(filePath)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Streaming via io.Copy instead of os.ReadFile into memory is probably good for large files.

if err != nil {
return fmt.Errorf("error occured when hashing contents of file %s for validation: %w", filePath, err)
return fmt.Errorf("failed to read file %s for validation: %w", filePath, err)
}
return ValidateValueInAllowlist(value, allowlist)
return ValidateValueInAllowlist(string(content), allowlist)
}

return ValidateValueInAllowlist(value, allowlist)
}

// ComputeFileHash computes the hash of a file or leaves string as is.
func ComputeFileHash(contents string, hashOpt HashType) (string, error) {
var hashStr string
switch hashOpt {
case HashTypeSHA1:
hash := sha1.Sum([]byte(contents))
hashStr = hex.EncodeToString(hash[:])
default:
hash := sha256.Sum256([]byte(contents))
hashStr = hex.EncodeToString(hash[:])
hashAlg, err := hashutils.GetHashAlgorithm(hashutils.HashType(hashOpt))
if err != nil {
return fmt.Errorf("error occured when getting hash algorithm for file %s: %w", filePath, err)
}

return hashStr, nil
value, err := hashutils.ComputeFileHash(filePath, hashAlg)
if err != nil {
return fmt.Errorf("error occured when hashing contents of file %s for validation: %w", filePath, err)
}
return ValidateValueInAllowlist(value, allowlist)
}
41 changes: 11 additions & 30 deletions pkg/extensionpolicysettings/extensionpolicysettings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"testing"

"github.com/Azure/azure-extension-platform/pkg/extensionerrors"
"github.com/Azure/azure-extension-platform/pkg/hashutils"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -184,19 +185,19 @@ func TestValidateAgainstAllowlist(t *testing.T) {
require.Equal(t, "true", manager.settings.RequiresSigning)
require.NotEmpty(t, manager.settings.AllowedScripts)

require.NoError(t, ValidateFileHashInAllowlist("./testutils/testscripts/script1.sh", manager.settings.AllowedScripts, HashTypeSHA256))
require.NoError(t, ValidateFileHashInAllowlist("./testutils/testscripts/script2.sh", manager.settings.AllowedScripts, HashTypeSHA256))
require.ErrorIs(t, ValidateFileHashInAllowlist("./testutils/testscripts/script3.sh", manager.settings.AllowedScripts, HashTypeSHA256), extensionerrors.ErrItemNotInAllowlist)
require.NoError(t, ValidateFileHashInAllowlist("./testutils/testscripts/script5.sh", manager.settings.AllowedScripts, HashTypeSHA1))
require.NoError(t, ValidateFileHashInAllowlist("./testutils/testscripts/script1.sh", manager.settings.AllowedScripts, hashutils.HashTypeSHA256))
require.NoError(t, ValidateFileHashInAllowlist("./testutils/testscripts/script2.sh", manager.settings.AllowedScripts, hashutils.HashTypeSHA256))
require.ErrorIs(t, ValidateFileHashInAllowlist("./testutils/testscripts/script3.sh", manager.settings.AllowedScripts, hashutils.HashTypeSHA256), extensionerrors.ErrItemNotInAllowlist)
require.NoError(t, ValidateFileHashInAllowlist("./testutils/testscripts/script5.sh", manager.settings.AllowedScripts, hashutils.HashTypeSHA1))

// Empty filepath
require.ErrorIs(t, ValidateFileHashInAllowlist("", manager.settings.AllowedScripts, HashTypeSHA256), extensionerrors.ErrEmptyFilepathToValidate)
require.ErrorIs(t, ValidateFileHashInAllowlist("", manager.settings.AllowedScripts, hashutils.HashTypeSHA256), extensionerrors.ErrEmptyFilepathToValidate)
// Missing file
require.Error(t, ValidateFileHashInAllowlist("./testutils/testscripts/missing.sh", manager.settings.AllowedScripts, HashTypeSHA256))
require.Error(t, ValidateFileHashInAllowlist("./testutils/testscripts/missing.sh", manager.settings.AllowedScripts, hashutils.HashTypeSHA256))
// Now, empty list.
require.ErrorIs(t, ValidateFileHashInAllowlist("./testutils/testscripts/script1.sh", []string{}, HashTypeSHA256), extensionerrors.ErrPolicyAllowlistEmpty)
require.ErrorIs(t, ValidateFileHashInAllowlist("./testutils/testscripts/script1.sh", []string{}, hashutils.HashTypeSHA256), extensionerrors.ErrPolicyAllowlistEmpty)
// Empty file
require.NoError(t, ValidateFileHashInAllowlist("./testutils/testscripts/script4.sh", manager.settings.AllowedScripts, HashTypeSHA256))
require.NoError(t, ValidateFileHashInAllowlist("./testutils/testscripts/script4.sh", manager.settings.AllowedScripts, hashutils.HashTypeSHA256))

}

Expand All @@ -207,28 +208,8 @@ func TestValidateFileHashInAllowlist_HashTypeNone_UsesRawContent(t *testing.T) {
require.NoError(t, writeToFile(filePath, content))
defer cleanupFile(filePath)

require.NoError(t, ValidateFileHashInAllowlist(filePath, []string{content}, HashTypeNone))
require.ErrorIs(t, ValidateFileHashInAllowlist(filePath, []string{"different-content"}, HashTypeNone), extensionerrors.ErrItemNotInAllowlist)
}

func TestComputeFileHash(t *testing.T) {
input := "abc"

sha1Expected := sha1.Sum([]byte(input))
sha256Expected := sha256.Sum256([]byte(input))

gotSHA1, err := ComputeFileHash(input, HashTypeSHA1)
require.NoError(t, err)
require.Equal(t, hex.EncodeToString(sha1Expected[:]), gotSHA1)

gotSHA256, err := ComputeFileHash(input, HashTypeSHA256)
require.NoError(t, err)
require.Equal(t, hex.EncodeToString(sha256Expected[:]), gotSHA256)

// Current behavior: unknown hash type falls back to SHA256.
gotUnknown, err := ComputeFileHash(input, HashType(999))
require.NoError(t, err)
require.Equal(t, hex.EncodeToString(sha256Expected[:]), gotUnknown)
require.NoError(t, ValidateFileHashInAllowlist(filePath, []string{content}, hashutils.HashTypeNone))
require.ErrorIs(t, ValidateFileHashInAllowlist(filePath, []string{"different-content"}, hashutils.HashTypeNone), extensionerrors.ErrItemNotInAllowlist)
}

// Helper functions for tests
Expand Down
65 changes: 65 additions & 0 deletions pkg/hashutils/hashutils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package hashutils

import (
"crypto/sha1"
"crypto/sha256"
"encoding/hex"
"fmt"
"hash"
"io"
"os"
)

type HashType int

const (
HashTypeNone HashType = 0
HashTypeSHA1 HashType = 1
HashTypeSHA256 HashType = 2
)

func GetHashAlgorithm(hashOpt HashType) (hash.Hash, error) {
switch hashOpt {
case HashTypeSHA1:
return sha1.New(), nil
case HashTypeSHA256:
return sha256.New(), nil
default:
return nil, fmt.Errorf("unsupported hash type option: %v", hashOpt)
}
}

// This is a separate function from ComputeHash because streaming the file contents into the hasher is
// more efficient than reading the entire file into memory at once, especially for larger files.
func ComputeFileHash(filePath string, hashAlg hash.Hash) (string, error) {
// make sure filepath is not empty and file exists
if filePath == "" {
return "", fmt.Errorf("file path cannot be empty")
}
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return "", fmt.Errorf("file does not exist at path: %s", filePath)
}

f, err := os.Open(filePath)
if err != nil {
return "", fmt.Errorf("failed to open file for hashing: %w", err)
}
defer f.Close()
// We can stream the file contents to the hasher which is more efficient for large files.
if _, err := io.Copy(hashAlg, f); err != nil {
return "", fmt.Errorf("failed to read file for hashing: %w", err)
}

hash := hashAlg.Sum(nil)
hashStr := hex.EncodeToString(hash[:])

return hashStr, nil
}

func ComputeHash(contents string, hashAlg hash.Hash) string {
var hashStr string
hashAlg.Write([]byte(contents))
hash := hashAlg.Sum(nil)
hashStr = hex.EncodeToString(hash[:])
return hashStr
}
137 changes: 137 additions & 0 deletions pkg/hashutils/hashutils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package hashutils

import (
"crypto/sha1"
"crypto/sha256"
"encoding/hex"
"os"
"runtime"
"testing"

"github.com/stretchr/testify/require"
)

func TestComputeFileHash_Success(t *testing.T) {
content := []byte("hello world")
tmpFile, err := os.CreateTemp(t.TempDir(), "hash_test_*.txt")
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
if _, err := tmpFile.Write(content); err != nil {
t.Fatalf("failed to write to temp file: %v", err)
}
tmpFile.Close()
defer os.Remove(tmpFile.Name())

hashAlg := sha256.New()
got, err := ComputeFileHash(tmpFile.Name(), hashAlg)

require.Nil(t, err, "expected no error from ComputeFileHash, got: %v", err)
require.NotEmpty(t, got, "expected non-empty hash result")

// Verify hash matches expected
expected := sha256.Sum256(content)
require.Equal(t, hex.EncodeToString(expected[:]), got, "hash mismatch")
}

func TestComputeFileHash_EmptyFilePath(t *testing.T) {
hashAlg := sha256.New()
_, err := ComputeFileHash("", hashAlg)
require.NotNil(t, err, "expected error for empty file path")
require.Equal(t, "file path cannot be empty", err.Error(), "unexpected error message")
}

func TestComputeFileHash_FileDoesNotExist(t *testing.T) {
hashAlg := sha256.New()
nonExistentPath := "./nonexistent_file.txt"
_, err := ComputeFileHash(nonExistentPath, hashAlg)
require.NotNil(t, err, "expected error for non-existent file path")
require.Contains(t, err.Error(), "file does not exist at path", "unexpected error message")
}

func TestComputeFileHash_FileNotReadable(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("permission-based test not reliable on Windows")
}

tmpFile, err := os.CreateTemp(t.TempDir(), "no_read_*.txt")
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
tmpFile.WriteString("some content")
tmpFile.Close()
defer os.Remove(tmpFile.Name())

// Remove read permission
if err := os.Chmod(tmpFile.Name(), 0000); err != nil {
t.Fatalf("failed to change file permissions: %v", err)
}
t.Cleanup(func() {
os.Chmod(tmpFile.Name(), 0644) // restore for cleanup
})

hashAlg := sha256.New()
_, err = ComputeFileHash(tmpFile.Name(), hashAlg)
require.NotNil(t, err, "expected error for unreadable file")
require.Contains(t, err.Error(), "failed to open file for hashing", "unexpected error message")
}

func TestComputeFileHash_EmptyFile(t *testing.T) {
tmpFile, err := os.CreateTemp(t.TempDir(), "empty_*.txt")
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
tmpFile.Close()
defer os.Remove(tmpFile.Name())

hashAlg := sha256.New()
got, err := ComputeFileHash(tmpFile.Name(), hashAlg)
if err != nil {
t.Fatalf("unexpected error for empty file: %v", err)
}

expected := sha256.Sum256([]byte{})
require.Equal(t, hex.EncodeToString(expected[:]), got, "hash mismatch for empty file")
}

func TestComputeHash_Success(t *testing.T) {
input := "hello world"
hashAlg := sha256.New()
got := ComputeHash(input, hashAlg)
require.NotEmpty(t, got, "expected non-empty hash result")

expected := sha256.Sum256([]byte(input))
require.Equal(t, hex.EncodeToString(expected[:]), got, "hash mismatch")
}

func TestComputeHash_EmptyString(t *testing.T) {
hashAlg := sha256.New()
got := ComputeHash("", hashAlg)
require.NotEmpty(t, got, "expected non-empty hash result for empty string")

expected := sha256.Sum256([]byte{})
require.Equal(t, hex.EncodeToString(expected[:]), got, "hash mismatch for empty string")
}

func TestComputeHash_DifferentInputsDifferentHashes(t *testing.T) {
hash1 := ComputeHash("input1", sha256.New())
hash2 := ComputeHash("input2", sha256.New())
require.NotEqual(t, hash1, hash2, "expected different hashes for different inputs")
}

func TestComputeHash_SameInputSameHash(t *testing.T) {
input := "consistent input"
hash1 := ComputeHash(input, sha256.New())
hash2 := ComputeHash(input, sha256.New())
require.Equal(t, hash1, hash2, "expected same hash for same input")
}

func TestComputeHash_DifferentAlgorithm(t *testing.T) {
input := "test"
sha256Hash := ComputeHash(input, sha256.New())
sha1Hash := ComputeHash(input, sha1.New())

require.NotEqual(t, sha256Hash, sha1Hash, "expected different hashes for different algorithms")
require.Equal(t, 64, len(sha256Hash), "expected SHA-256 hex string length 64")
require.Equal(t, 40, len(sha1Hash), "expected SHA-1 hex string length 40")
}
Loading