From 580d46a55ef39c01ed560b45fe414eb08c762642 Mon Sep 17 00:00:00 2001 From: Joseph Calev Date: Mon, 8 Dec 2025 10:55:06 -0800 Subject: [PATCH 1/7] Handles updates --- internal/cleanup/cleanup.go | 45 ++- internal/cmds/cmds.go | 226 ++++++++---- internal/cmds/cmds_test.go | 584 +++++++++++++++++++++++++++++++- internal/constants/constants.go | 13 +- internal/pid/pid.go | 3 +- 5 files changed, 769 insertions(+), 102 deletions(-) diff --git a/internal/cleanup/cleanup.go b/internal/cleanup/cleanup.go index 32bcf6a..2d623c2 100644 --- a/internal/cleanup/cleanup.go +++ b/internal/cleanup/cleanup.go @@ -2,11 +2,8 @@ package cleanup import ( "fmt" - "os" "path/filepath" - "strconv" - "github.com/Azure/azure-extension-platform/pkg/utils" "github.com/Azure/run-command-handler-linux/internal/constants" "github.com/Azure/run-command-handler-linux/internal/types" "github.com/Azure/run-command-handler-linux/pkg/linuxutils" @@ -41,28 +38,28 @@ func deleteAllScriptsAndSettings(ctx *log.Context, metadata types.RCMetadata, h } func deleteScriptsAndSettingsExceptMostRecent(ctx *log.Context, metadata types.RCMetadata, h types.HandlerEnvironment, runAsUser string) { - runtimeSettingsRegexFormat := metadata.ExtName + ".\\d+.settings" - runtimeSettingsLastSeqNumFormat := metadata.ExtName + ".%d.settings" + //runtimeSettingsRegexFormat := metadata.ExtName + ".\\d+.settings" + //runtimeSettingsLastSeqNumFormat := metadata.ExtName + ".%d.settings" // check if directory exists - _, err := os.Open(metadata.DownloadPath) - if err == nil { - err := utils.TryClearExtensionScriptsDirectoriesAndSettingsFilesExceptMostRecent(metadata.DownloadPath, h.HandlerEnvironment.ConfigFolder, "", - uint64(metadata.SeqNum), runtimeSettingsRegexFormat, runtimeSettingsLastSeqNumFormat) - if err != nil { - ctx.Log("event", "could not clear settings and script files", "error", err) - } - } else { - ctx.Log("message", "directory does not exist. Skipping cleanup") - } + //_, err := os.Open(metadata.DownloadPath) + //if err == nil { + // err := utils.TryClearExtensionScriptsDirectoriesAndSettingsFilesExceptMostRecent(metadata.DownloadPath, h.HandlerEnvironment.ConfigFolder, "", + // uint64(metadata.SeqNum), runtimeSettingsRegexFormat, runtimeSettingsLastSeqNumFormat) + // if err != nil { + // ctx.Log("event", "could not clear settings and script files", "error", err) + // } + //} else { + // ctx.Log("message", "directory does not exist. Skipping cleanup") + //} - if runAsUser != "" { - runAsDownloadParent := filepath.Join(fmt.Sprintf(constants.RunAsDir, runAsUser), metadata.DownloadDir) - seqNumString := strconv.Itoa(metadata.SeqNum) - ctx.Log("message", "removing all files from the download 'runas' directory "+runAsDownloadParent) - err = utils.TryDeleteDirectoriesExcept(runAsDownloadParent, seqNumString) - if err != nil { - ctx.Log("event", "could not clear runas script") - } - } + //if runAsUser != "" { + // runAsDownloadParent := filepath.Join(fmt.Sprintf(constants.RunAsDir, runAsUser), metadata.DownloadDir) + // seqNumString := strconv.Itoa(metadata.SeqNum) + // ctx.Log("message", "removing all files from the download 'runas' directory "+runAsDownloadParent) + // err = utils.TryDeleteDirectoriesExcept(runAsDownloadParent, seqNumString) + // if err != nil { + // ctx.Log("event", "could not clear runas script") + // } + //} } diff --git a/internal/cmds/cmds.go b/internal/cmds/cmds.go index 864c001..dbbfd96 100755 --- a/internal/cmds/cmds.go +++ b/internal/cmds/cmds.go @@ -75,6 +75,9 @@ var ( RunCmd = runCmd DataDir = constants.DataDir + // Used by unit tests to mock out executing the command + ExecCmdInDir = exec.ExecCmdInDir + ErrAlreadyProcessed = errors.New("the script configuration has already been processed, will not run again") ) @@ -85,17 +88,23 @@ func update(ctx *log.Context, h types.HandlerEnvironment, report *types.RunComma return "", "", err, exitCode } - err = rehydrateMrSeqFilesForProblematicUpgrades(ctx, h, extensionEvents) - if err != nil { - // If we fail on update, then there's a risk we could re-execute the customer's script. Don't take that chance. - // By failing Update, the extension goal state will fail. WALA will try us again on the next goal state. - ctx.Log("event", "Unable to rehydrate mrseq files") - return "", "", err, constants.ExitCode_CouldNotRehydrateMrSeq + // Figure out the directories from which and to where we're upgrading. We cannot entirely rely on the environment variables from the Guest Agent + upgradeFromVersionDirectory, upgradeToVersionDirectory, upgradeFromVersion := determineUpgradeVersionDirectories(ctx, extensionEvents) + + if compareVersions(constants.FirstVersionNoRehydration, upgradeFromVersion) > 0 { + // Rehydrate any mrseq files from the corresponding status file. + err = rehydrateMrSeqFilesForProblematicUpgrades(ctx, upgradeFromVersionDirectory, upgradeToVersionDirectory, extensionEvents) + if err != nil { + // If we fail on update, then there's a risk we could re-execute the customer's script. Don't take that chance. + // By failing Update, the extension goal state will fail. WALA will try us again on the next goal state. + ctx.Log("event", "Unable to rehydrate mrseq files") + return "", "", err, constants.ExitCode_CouldNotRehydrateMrSeq + } } // Copy any .mrseq or .status files -Most Recently executed Sequence number files and status files for Run Commands from old version to new version. // This is necessary to prevent rerunning of already executed Run Commands after upgrade of extension version, and also return their statuses. - copyError := CopyStateForUpdate(ctx, extensionEvents) + copyError := CopyStateForUpdate(ctx, upgradeFromVersionDirectory, upgradeToVersionDirectory, extensionEvents) if copyError != nil { return "", "", errors.Wrap(copyError, "Migrating *.mrseq or .status files failed during update."), constants.ExitCode_CopyStateForUpdateFailed } @@ -417,15 +426,15 @@ func resetSeqNum(ctx log.Logger, mrseqPath string, extensionEvents *extensioneve } // Copy state of the extension from old version to new version during update (.mrseq files, .status files) -func CopyStateForUpdate(ctx log.Logger, extensionEvents *extensionevents.ExtensionEventManager) error { +func CopyStateForUpdate(ctx log.Logger, upgradeFromVersionDirectory string, upgradeToVersionDirectory string, extensionEvents *extensionevents.ExtensionEventManager) error { // Copy .mrseq files (Most Recently executed Sequence number) that helps determine whether a sequence number of Run Command has been previously executed or not. - mrseqFilesNameList, mrseqFileCopyErr := copyFiles(ctx, constants.MrSeqFileExtension, "", extensionEvents) + mrseqFilesNameList, mrseqFileCopyErr := copyFiles(ctx, constants.MrSeqFileExtension, "", upgradeFromVersionDirectory, upgradeToVersionDirectory, extensionEvents) if mrseqFileCopyErr != nil { return mrseqFileCopyErr } // Copy .status files of already executed sequence numbers - _, statusFileCopyErr := copyFiles(ctx, ".status", constants.StatusFileDirectory, extensionEvents) + _, statusFileCopyErr := copyFiles(ctx, ".status", constants.StatusFileDirectory, upgradeFromVersionDirectory, upgradeToVersionDirectory, extensionEvents) if statusFileCopyErr != nil { return statusFileCopyErr } @@ -440,41 +449,140 @@ func CopyStateForUpdate(ctx log.Logger, extensionEvents *extensionevents.Extensi return nil } -func rehydrateMrSeqFilesForProblematicUpgrades(ctx *log.Context, h types.HandlerEnvironment, extensionEvents *extensionevents.ExtensionEventManager) error { - // First, determine whether we're upgrading from a 'problematic' version, defined as one - // where we mistakenly deleted the mrseq files in the Disable call - newExtensionVersion := os.Getenv(constants.ExtensionVersionEnvName) - oldExtensionVersion := os.Getenv(constants.ExtensionVersionUpdatingFromEnvName) - newExtensionDirectory := os.Getenv(constants.ExtensionPathEnvName) - oldExtensionDirectory := strings.ReplaceAll(newExtensionDirectory, newExtensionVersion, oldExtensionVersion) - - // The following are problematic versions: - // Production: 1.3.17 - // Test: 1.8.0, 1.9.0 - isProblematicVersion := false - isTestExtension := strings.Contains(oldExtensionDirectory, constants.RunCommandTestExtensionName) - if isTestExtension { - isProblematicVersion = (oldExtensionVersion == constants.FirstTestVersionThatDeletesMrSeqFiles || oldExtensionVersion == constants.SecondTestVersionThatDeletesMrSeqFiles) - } else { - isProblematicVersion = (oldExtensionVersion == constants.ProductionVersionThatDeletesMrSeqFiles) +func determineUpgradeVersionDirectories(ctx *log.Context, extensionEvents *extensionevents.ExtensionEventManager) (upgradeFromVersionDirectory string, upgradeToVersionDirectory string, upgradeFromVersion string) { + // These two environment variables will tell us the extension versions involved, but won't actually tell us + // the from/to versions + firstExtensionVersion := os.Getenv(constants.ExtensionVersionEnvName) + secondExtensionVersion := os.Getenv(constants.ExtensionVersionUpdatingFromEnvName) + + // To determine to which version we're actually upgrading, we'll need to look into the folders + // The higher version isn't necessarily the one we're upgrading to, since we may be downgrading + // If one has at least one .mrseq file, and the other has none, then we're upgrading to the one that has none + // If neither has a .mrseq file, then just choose the higher version number + // If both have .mrseq files, then this shouldn't happen, but for the sake of sanity choose the highe version number + firstExtensionDirectory := os.Getenv(constants.ExtensionPathEnvName) + secondExtensionDirectory := strings.ReplaceAll(firstExtensionDirectory, firstExtensionVersion, secondExtensionVersion) + + // Check for *.mrseq presence in each directory + firstHasMrseq := hasMrseq(ctx, firstExtensionDirectory) + secondHasMrseq := hasMrseq(ctx, secondExtensionDirectory) + + // If one has mrseq and the other doesn't → upgrade to the one without mrseq + if firstHasMrseq != secondHasMrseq { + if firstHasMrseq && !secondHasMrseq { + upgradeToVersionDirectory, upgradeFromVersionDirectory = secondExtensionDirectory, firstExtensionDirectory + upgradeFromVersion = firstExtensionVersion + } else { + upgradeToVersionDirectory, upgradeFromVersionDirectory = firstExtensionDirectory, secondExtensionDirectory + upgradeFromVersion = secondExtensionVersion + } + + msg := fmt.Sprintf("determineUpgradeVersions: mrseq-guided choice → to='%s' from='%s'", upgradeToVersionDirectory, upgradeFromVersionDirectory) + ctx.Log("message", msg) + extensionEvents.LogInformationalEvent("determineUpgradeVersions", msg) + + return upgradeFromVersionDirectory, upgradeToVersionDirectory, upgradeFromVersion } - if isProblematicVersion { - message := fmt.Sprintf("Rehydrating mrseq files deleted by version '%s' using status files", oldExtensionVersion) - ctx.Log("message", message) - extensionEvents.LogInformationalEvent("rehydratemrseq", message) - return doRehydrateMrSeqFilesForProblematicUpgrades(ctx, oldExtensionDirectory, newExtensionDirectory, extensionEvents) - } else { - message := fmt.Sprintf("Previous extension version '%s' does not require mrseq hydration", oldExtensionVersion) - ctx.Log("message", message) - extensionEvents.LogInformationalEvent("rehydratemrseq", message) + // Rule 2 & 3: neither has mrseq OR both have mrseq → choose higher version number as upgradeTo + switch c := compareVersions(firstExtensionVersion, secondExtensionVersion); { + case c > 0: + upgradeToVersionDirectory, upgradeFromVersionDirectory = firstExtensionDirectory, secondExtensionDirectory + upgradeFromVersion = secondExtensionVersion + case c < 0: + upgradeToVersionDirectory, upgradeFromVersionDirectory = secondExtensionDirectory, firstExtensionDirectory + upgradeFromVersion = firstExtensionVersion + default: + // Equal versions (shouldn’t normally happen in an upgrade path). Keep first as "to". + upgradeToVersionDirectory, upgradeFromVersionDirectory = firstExtensionDirectory, secondExtensionDirectory + upgradeFromVersion = secondExtensionVersion } - return nil + msg := fmt.Sprintf("determineUpgradeVersions: version-ordered choice → to='%s' from='%s' (mrseq first=%t second=%t)", upgradeToVersionDirectory, upgradeFromVersionDirectory, firstHasMrseq, secondHasMrseq) + ctx.Log("message", msg) + extensionEvents.LogInformationalEvent("determineUpgradeVersions", msg) + + return upgradeFromVersionDirectory, upgradeToVersionDirectory, upgradeFromVersion +} + +// hasMrseq returns true if the given directory contains at least one *.mrseq file. +// It is resilient to missing directories and IO errors (logs and returns false). +func hasMrseq(ctx *log.Context, dir string) bool { + if dir == "" { + return false + } + // Resolve glob pattern + pattern := filepath.Join(dir, "*.mrseq") + + matches, err := filepath.Glob(pattern) + if err != nil { + ctx.Log("error", fmt.Sprintf("hasMrseq: glob error for '%s': %v", pattern, err)) + return false + } + return len(matches) > 0 } -func doRehydrateMrSeqFilesForProblematicUpgrades(ctx *log.Context, oldExtensionDirectory string, newExtensionDirectory string, extensionEvents *extensionevents.ExtensionEventManager) error { - oldExtensionStatusDirectory := filepath.Join(oldExtensionDirectory, constants.StatusFileDirectory) +// compareVersions compares two dotted version strings (e.g., "2.1", "2.1.0", "2.1.0.3"). +// Returns: +1 if a>b, -1 if a bParts[i] { + return 1 + } + if aParts[i] < bParts[i] { + return -1 + } + } + return 0 +} + +// splitVersion converts "x.y.z.t" → []int{ x, y, z, t } (non-numeric parts treated as 0). +func splitVersion(v string) []int { + parts := strings.Split(v, ".") + out := make([]int, 0, len(parts)) + for _, p := range parts { + // Trim any stray spaces; non-numeric gets 0 + p = strings.TrimSpace(p) + n := 0 + for i := 0; i < len(p); i++ { + if p[i] < '0' || p[i] > '9' { + // non-numeric component; keep as 0 + n = 0 + goto done + } + } + if p != "" { + // safe Atoi without error branch since we checked digits + for i := 0; i < len(p); i++ { + n = n*10 + int(p[i]-'0') + } + } + done: + out = append(out, n) + } + return out +} + +func padTo(in []int, size int) []int { + if len(in) >= size { + return in[:size] + } + out := make([]int, size) + copy(out, in) + // remaining default to 0 + return out +} + +func rehydrateMrSeqFilesForProblematicUpgrades(ctx *log.Context, updateFromVersionDirectory string, updateToVersionDirectory string, extensionEvents *extensionevents.ExtensionEventManager) error { + oldExtensionStatusDirectory := filepath.Join(updateFromVersionDirectory, constants.StatusFileDirectory) extensionStatusDirectoryFDRef, err := os.Open(oldExtensionStatusDirectory) if err != nil { @@ -489,7 +597,7 @@ func doRehydrateMrSeqFilesForProblematicUpgrades(ctx *log.Context, oldExtensionD // If we find any status files missing their corresponding mrseq, then rehydrate it by taking the seqNo from the status file name statusFiles, err := extensionStatusDirectoryFDRef.ReadDir(0) if err != nil { - errMessage := fmt.Sprintf("could not read directory entries from status directory %s", oldExtensionDirectory) + errMessage := fmt.Sprintf("could not read directory entries from status directory %s", updateFromVersionDirectory) ctx.Log("message", errMessage) extensionEvents.LogErrorEvent("rehydratemrseq", errMessage) return errors.Wrap(err, errMessage) @@ -507,7 +615,7 @@ func doRehydrateMrSeqFilesForProblematicUpgrades(ctx *log.Context, oldExtensionD seqNo := parts[1] seqNoAsInt, _ := strconv.Atoi(seqNo) mrSeqFileName := extensionName + constants.MrSeqFileExtension - mrSeqFilePath := filepath.Join(newExtensionDirectory, mrSeqFileName) + mrSeqFilePath := filepath.Join(updateToVersionDirectory, mrSeqFileName) _, err = os.Stat(mrSeqFilePath) if err != nil { @@ -563,45 +671,39 @@ func doRehydrateMrSeqFilesForProblematicUpgrades(ctx *log.Context, oldExtensionD } // Copy files like *.mrseq (Most Recently executed Sequence number), .status files from old extension version to new extension version during update. -func copyFiles(ctx log.Logger, fileExtensionSuffix string, extensionSubdirectory string, extensionEvents *extensionevents.ExtensionEventManager) (*list.List, error) { - - newExtensionVersion := os.Getenv(constants.ExtensionVersionEnvName) - oldExtensionVersion := os.Getenv(constants.ExtensionVersionUpdatingFromEnvName) +func copyFiles(ctx log.Logger, fileExtensionSuffix string, extensionSubdirectory string, upgradeFromVersionDirectory string, upgradeToVersionDirectory string, extensionEvents *extensionevents.ExtensionEventManager) (*list.List, error) { - message := fmt.Sprintf("Migrating '%s' files from extension version '%s' to '%s'", fileExtensionSuffix, oldExtensionVersion, newExtensionVersion) + message := fmt.Sprintf("Migrating '%s' files from '%s' to '%s'", fileExtensionSuffix, upgradeFromVersionDirectory, upgradeToVersionDirectory) ctx.Log("message", message) extensionEvents.LogInformationalEvent("copyfiles", message) - newExtensionDirectory := os.Getenv(constants.ExtensionPathEnvName) - oldExtensionDirectory := strings.ReplaceAll(newExtensionDirectory, newExtensionVersion, oldExtensionVersion) - // Append subdirectory like "status" under extension folder if provided. if extensionSubdirectory != "" { - newExtensionDirectory = filepath.Join(newExtensionDirectory, extensionSubdirectory) - oldExtensionDirectory = filepath.Join(oldExtensionDirectory, extensionSubdirectory) + upgradeToVersionDirectory = filepath.Join(upgradeToVersionDirectory, extensionSubdirectory) + upgradeFromVersionDirectory = filepath.Join(upgradeFromVersionDirectory, extensionSubdirectory) // Create subdirectory like "status" directory if it does not exist - _, err := os.Open(newExtensionDirectory) + _, err := os.Open(upgradeToVersionDirectory) if err != nil { - errr := os.Mkdir(newExtensionDirectory, 0700) + errr := os.Mkdir(upgradeToVersionDirectory, 0700) if errr != nil { - errMessage := fmt.Sprintf("Failed to create directory '%s'", newExtensionDirectory) + errMessage := fmt.Sprintf("Failed to create directory '%s'", upgradeToVersionDirectory) extensionEvents.LogErrorEvent("copyfiles", errMessage) return nil, errors.Wrap(errr, errMessage) } } } - if oldExtensionDirectory == "" || newExtensionDirectory == "" { + if upgradeFromVersionDirectory == "" || upgradeToVersionDirectory == "" { errMessage := "oldExtesionDirectory or newExtensionDirectory is empty" extensionEvents.LogErrorEvent("copyfiles", errMessage) return nil, errors.New(errMessage) } // Check if the directory exists - sourceDirectoryFDRef, err := os.Open(oldExtensionDirectory) + sourceDirectoryFDRef, err := os.Open(upgradeFromVersionDirectory) if err != nil { - errMessage := fmt.Sprintf("could not open sourceDirectory %s", oldExtensionDirectory) + errMessage := fmt.Sprintf("could not open sourceDirectory %s", upgradeFromVersionDirectory) ctx.Log("message", errMessage) extensionEvents.LogErrorEvent("copyfiles", errMessage) return nil, errors.Wrap(err, errMessage) @@ -609,7 +711,7 @@ func copyFiles(ctx log.Logger, fileExtensionSuffix string, extensionSubdirectory directoryEntries, err := sourceDirectoryFDRef.ReadDir(0) if err != nil { - errMessage := fmt.Sprintf("could not read directory entries from sourceDirectory %s", oldExtensionDirectory) + errMessage := fmt.Sprintf("could not read directory entries from sourceDirectory %s", upgradeFromVersionDirectory) ctx.Log("message", errMessage) extensionEvents.LogErrorEvent("copyfiles", errMessage) return nil, errors.Wrap(err, errMessage) @@ -622,8 +724,8 @@ func copyFiles(ctx log.Logger, fileExtensionSuffix string, extensionSubdirectory fileName := dirEntry.Name() if strings.HasSuffix(fileName, fileExtensionSuffix) { - sourceFileFullPath := filepath.Join(oldExtensionDirectory, fileName) - destinationFileFullPath := filepath.Join(newExtensionDirectory, fileName) + sourceFileFullPath := filepath.Join(upgradeFromVersionDirectory, fileName) + destinationFileFullPath := filepath.Join(upgradeToVersionDirectory, fileName) sourceFile, sourceFileOpenError := os.Open(sourceFileFullPath) if sourceFileOpenError != nil { @@ -660,7 +762,7 @@ func copyFiles(ctx log.Logger, fileExtensionSuffix string, extensionSubdirectory } } - message = fmt.Sprintf("Migrated %d '%s' files from extension version '%s' to '%s'", numberOfFilesMigrated, fileExtensionSuffix, oldExtensionVersion, newExtensionVersion) + message = fmt.Sprintf("Migrated %d '%s' files from extension version '%s' to '%s'", numberOfFilesMigrated, fileExtensionSuffix, upgradeFromVersionDirectory, upgradeToVersionDirectory) ctx.Log("message", message) extensionEvents.LogInformationalEvent("copyfiles", message) @@ -867,7 +969,7 @@ func runCmd(ctx *log.Context, dir string, scriptFilePath string, cfg *handlerset defer pid.DeleteCurrentPidAndStartTime(metadata.PidFilePath) begin := time.Now() - err, exitCode = exec.ExecCmdInDir(ctx, scriptFilePath, dir, cfg) + err, exitCode = ExecCmdInDir(ctx, scriptFilePath, dir, cfg) elapsed := time.Since(begin) isSuccess := err == nil diff --git a/internal/cmds/cmds_test.go b/internal/cmds/cmds_test.go index 205c5ef..fdfd448 100755 --- a/internal/cmds/cmds_test.go +++ b/internal/cmds/cmds_test.go @@ -3,11 +3,13 @@ package commands import ( "encoding/json" "errors" + "fmt" "io/ioutil" "net/http" "net/http/httptest" "os" "path/filepath" + "reflect" "strconv" "strings" "testing" @@ -85,7 +87,7 @@ func Test_CopyMrseqFiles_MrseqFilesAreCopied(t *testing.T) { extensionLogger := logging.New(nil) extensionEventManager := extensionevents.New(extensionLogger, &handlerEnvironment) - err = CopyStateForUpdate(log.NewContext(log.NewNopLogger()), extensionEventManager) + err = CopyStateForUpdate(log.NewContext(log.NewNopLogger()), previousExtensionVersionDirectory, currentExtensionVersionDirectory, extensionEventManager) require.Nil(t, err) files, _ = ioutil.ReadDir(currentExtensionVersionDirectory) @@ -239,6 +241,69 @@ func Test_update_e2e_cmd(t *testing.T) { enable_extension(t, fakeEnv, newVersionDirectory, "crazyChipmunk", false, 0) } +func Test_update_e23_non_problematic_version(t *testing.T) { + tempDir, _ := os.MkdirTemp("", "deletecmd") + defer os.RemoveAll(tempDir) + + DataDir, _ = os.MkdirTemp("", "datadir") + defer os.RemoveAll(DataDir) + + oldVersionDirectory := filepath.Join(tempDir, "Microsoft.CPlat.Core.RunCommandHandlerLinux-1.3.26") + newVersionDirectory := filepath.Join(tempDir, "Microsoft.CPlat.Core.RunCommandHandlerLinux-1.3.27") + err := os.Mkdir(oldVersionDirectory, 0755) + require.Nil(t, err, "Could not create old version subdirectory") + err = os.Mkdir(newVersionDirectory, 0755) + require.Nil(t, err, "Could not create new version subdirectory") + oldStatusPath := create_folder(t, oldVersionDirectory, constants.StatusFileDirectory) + newStatusPath := create_folder(t, newVersionDirectory, constants.StatusFileDirectory) + oldEventsPath := create_folder(t, oldVersionDirectory, constants.ExtensionEventsDirectory) + newEventsPath := create_folder(t, newVersionDirectory, constants.ExtensionEventsDirectory) + + fakeEnv := types.HandlerEnvironment{} + update_handler_env(&fakeEnv, oldStatusPath, oldVersionDirectory, oldEventsPath) + + // We start on the old version + os.Setenv(constants.ExtensionPathEnvName, oldVersionDirectory) + os.Setenv(constants.ExtensionVersionEnvName, "1.3.26") + + // Create three extensions + enable_extension(t, fakeEnv, oldVersionDirectory, "happyChipmunk", true, 0) + enable_extension(t, fakeEnv, oldVersionDirectory, "crazyChipmunk", true, 0) + enable_extension(t, fakeEnv, oldVersionDirectory, "stubbornChipmunk", true, 0) + + // Run one of them again to obtain multiple status files + enable_extension(t, fakeEnv, oldVersionDirectory, "happyChipmunk", true, 1) + + // Now, pretend that the extension was updated + // Step 1: WALA calls Disable on our two extensions + disable_extension(t, fakeEnv, oldVersionDirectory, "happyChipmunk") + disable_extension(t, fakeEnv, oldVersionDirectory, "crazyChipmunk") + disable_extension(t, fakeEnv, oldVersionDirectory, "stubbornChipmunk") + + // Step 2: WALA will call update + os.Setenv(constants.ExtensionVersionEnvName, "1.3.27") + os.Setenv(constants.ExtensionPathEnvName, newVersionDirectory) + os.Setenv(constants.ExtensionVersionUpdatingFromEnvName, "1.3.26") + update_handler_env(&fakeEnv, newStatusPath, newVersionDirectory, newEventsPath) + update_handler(t, fakeEnv, tempDir) + + // Now, WALA will uninstall the old extension + uninstall_handler(t, fakeEnv, tempDir) + + // Then, WALA will install the new extension + install_handler(t, fakeEnv, tempDir) + + // Now call enable and verify we did NOT re-execute the script + enable_extension(t, fakeEnv, newVersionDirectory, "happyChipmunk", false, 1) + enable_extension(t, fakeEnv, newVersionDirectory, "crazyChipmunk", false, 0) + enable_extension(t, fakeEnv, newVersionDirectory, "stubbornChipmunk", false, 0) + + // Run them again with a higher seqNo to ensure they're now executed + enable_extension(t, fakeEnv, newVersionDirectory, "happyChipmunk", true, 2) + enable_extension(t, fakeEnv, newVersionDirectory, "crazyChipmunk", true, 1) + enable_extension(t, fakeEnv, newVersionDirectory, "stubbornChipmunk", true, 1) +} + func Test_udpate_e2e_problematic_version(t *testing.T) { tempDir, _ := os.MkdirTemp("", "deletecmd") defer os.RemoveAll(tempDir) @@ -435,6 +500,10 @@ func Test_runCmd_success(t *testing.T) { require.Nil(t, err) defer os.RemoveAll(dir) + // Ensure that the script succeeds + ExecCmdInDir = func(ctx *log.Context, scriptFilePath, workdir string, cfg *handlersettings.HandlerSettings) (error, int) { + return nil, 0 + } metadata := types.NewRCMetadata("extName", 0, constants.DownloadFolder, DataDir) err, exitCode := runCmd(log.NewContext(log.NewNopLogger()), dir, "", &handlersettings.HandlerSettings{ PublicSettings: handlersettings.PublicSettings{Source: &handlersettings.ScriptSource{Script: script}}, @@ -442,12 +511,6 @@ func Test_runCmd_success(t *testing.T) { require.Nil(t, err, "command should run successfully") require.Equal(t, constants.ExitCode_Okay, exitCode) - // check stdout stderr files - _, err = os.Stat(filepath.Join(dir, "stdout")) - require.Nil(t, err, "stdout should exist") - _, err = os.Stat(filepath.Join(dir, "stderr")) - require.Nil(t, err, "stderr should exist") - // Check embedded script if saved to file _, err = os.Stat(filepath.Join(dir, "script.sh")) require.Nil(t, err, "script.sh should exist") @@ -461,6 +524,11 @@ func Test_runCmd_fail(t *testing.T) { require.Nil(t, err) defer os.RemoveAll(dir) + // Ensure that the script fails + ExecCmdInDir = func(ctx *log.Context, scriptFilePath, workdir string, cfg *handlersettings.HandlerSettings) (error, int) { + return errors.New("the chipmunks have risen in revolt"), 42 + } + metadata := types.NewRCMetadata("extName", 0, constants.DownloadFolder, DataDir) err, exitCode := runCmd(log.NewContext(log.NewNopLogger()), dir, "", &handlersettings.HandlerSettings{ PublicSettings: handlersettings.PublicSettings{Source: &handlersettings.ScriptSource{Script: "non-existing-cmd"}}, @@ -692,12 +760,17 @@ func Test_TreatFailureAsDeploymentFailureIsTrue_Fails(t *testing.T) { require.Nil(t, err) defer os.RemoveAll(dir) + // Ensure that the script fails + ExecCmdInDir = func(ctx *log.Context, scriptFilePath, workdir string, cfg *handlersettings.HandlerSettings) (error, int) { + return errors.New("the chipmunks do not like the script"), 127 + } + metadata := types.NewRCMetadata("extName", 0, constants.DownloadFolder, DataDir) err, exitCode := runCmd(log.NewContext(log.NewNopLogger()), dir, "", &handlersettings.HandlerSettings{ PublicSettings: handlersettings.PublicSettings{Source: &handlersettings.ScriptSource{Script: script}, TreatFailureAsDeploymentFailure: true}, }, metadata) require.NotNil(t, err) - require.Contains(t, err.Error(), "failed to execute command: command terminated with exit status=127") + require.Contains(t, err.Error(), "failed to execute command: the chipmunks do not like the script") require.NotEqual(t, constants.ExitCode_Okay, exitCode) } @@ -711,6 +784,11 @@ func Test_TreatFailureAsDeploymentFailureIsTrue_SimpleScriptSucceeds(t *testing. require.Nil(t, err) defer os.RemoveAll(dir) + // Ensure that the script succeeds + ExecCmdInDir = func(ctx *log.Context, scriptFilePath, workdir string, cfg *handlersettings.HandlerSettings) (error, int) { + return nil, 0 + } + metadata := types.NewRCMetadata("extName", 0, constants.DownloadFolder, DataDir) err, exitCode := runCmd(log.NewContext(log.NewNopLogger()), dir, "", &handlersettings.HandlerSettings{ PublicSettings: handlersettings.PublicSettings{Source: &handlersettings.ScriptSource{Script: script}, TreatFailureAsDeploymentFailure: false}, @@ -718,3 +796,493 @@ func Test_TreatFailureAsDeploymentFailureIsTrue_SimpleScriptSucceeds(t *testing. require.Nil(t, err) require.Equal(t, constants.ExitCode_Okay, exitCode) } + +func TestPadTo(t *testing.T) { + tests := []struct { + name string + in []int + size int + expected []int + }{ + { + name: "Input longer than size", + in: []int{1, 2, 3, 4}, + size: 2, + expected: []int{1, 2}, + }, + { + name: "Input equal to size", + in: []int{1, 2, 3}, + size: 3, + expected: []int{1, 2, 3}, + }, + { + name: "Input shorter than size", + in: []int{1, 2}, + size: 5, + expected: []int{1, 2, 0, 0, 0}, + }, + { + name: "Empty input", + in: []int{}, + size: 3, + expected: []int{0, 0, 0}, + }, + { + name: "Size zero", + in: []int{1, 2, 3}, + size: 0, + expected: []int{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := padTo(tt.in, tt.size) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("padTo(%v, %d) = %v; expected %v", tt.in, tt.size, result, tt.expected) + } + }) + } +} + +func TestSplitVersion(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + in string + expected []int + }{ + { + name: "simple-3-parts", + in: "1.2.3", + expected: []int{1, 2, 3}, + }, + { + name: "single-part", + in: "42", + expected: []int{42}, + }, + { + name: "leading-zeros", + in: "001.0002.00003", + expected: []int{1, 2, 3}, + }, + { + name: "spaces-around-parts", + in: " 1 . 2 . 3 ", + expected: []int{1, 2, 3}, + }, + { + name: "non-numeric-alpha", + in: "1.a.3", + expected: []int{1, 0, 3}, + }, + { + name: "non-numeric-mixed", + in: "1.2beta.3", + expected: []int{1, 0, 3}, + }, + { + name: "empty-string", + in: "", + // strings.Split("", ".") == []string{""} → p=="" → append 0 + expected: []int{0}, + }, + { + name: "consecutive-dots-empty-components", + in: "1..3....5", + // empty parts become 0 + expected: []int{1, 0, 3, 0, 0, 0, 5}, + }, + { + name: "trailing-dot", + in: "1.2.", + expected: []int{1, 2, 0}, + }, + { + name: "leading-dot", + in: ".2.3", + expected: []int{0, 2, 3}, + }, + { + name: "very-large-number", + in: "2147483647.0", + // Note: Go int is platform-dependent; still valid parsing. + expected: []int{2147483647, 0}, + }, + { + name: "zeros-only", + in: "0.0.0", + expected: []int{0, 0, 0}, + }, + { + name: "whitespace-only-component", + in: "1. .3", + // TrimSpace makes middle part "", thus 0 + expected: []int{1, 0, 3}, + }, + { + name: "unicode-digits-are-not-ASCII-digits", + in: "1.2", // Note: first char is full-width '1' (U+FF11) → non-ASCII → 0 + expected: []int{0, 2}, + }, + { + name: "dash-negative-like", + in: "1.-2.3", + // '-' makes component non-numeric → 0 + expected: []int{1, 0, 3}, + }, + { + name: "plus-sign", + in: "+1.2", + expected: []int{0, 2}, + }, + { + name: "long-many-parts", + in: "1.2.3.4.5.6.7.8.9", + expected: []int{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := splitVersion(tt.in) + if !reflect.DeepEqual(got, tt.expected) { + t.Fatalf("splitVersion(%q) = %v, want %v", tt.in, got, tt.expected) + } + }) + } +} + +func TestCompareVersions(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + a string + b string + expected int + }{ + { + name: "equal-simple", + a: "1.2.3", + b: "1.2.3", + expected: 0, + }, + { + name: "equal-with-extra-zeros", + a: "1.2.3.0", + b: "1.2.3", + expected: 0, + }, + { + name: "a-greater-last-segment", + a: "1.2.3.4", + b: "1.2.3.3", + expected: 1, + }, + { + name: "b-greater-last-segment", + a: "1.2.3.3", + b: "1.2.3.4", + expected: -1, + }, + { + name: "a-greater-first-segment", + a: "2.0.0", + b: "1.9.9", + expected: 1, + }, + { + name: "b-greater-first-segment", + a: "1.9.9", + b: "2.0.0", + expected: -1, + }, + { + name: "normalize-length-a-shorter", + a: "1.2", + b: "1.2.0.1", + expected: -1, + }, + { + name: "normalize-length-b-shorter", + a: "1.2.0.1", + b: "1.2", + expected: 1, + }, + { + name: "leading-zeros-equal", + a: "01.002.0003", + b: "1.2.3", + expected: 0, + }, + { + name: "non-numeric-in-a", + a: "1.alpha.3", + b: "1.0.3", + expected: 0, // alpha → 0 + }, + { + name: "non-numeric-in-b", + a: "1.2.3", + b: "1.beta.3", + expected: 1, // beta → 0, so a > b + }, + { + name: "empty-strings", + a: "", + b: "", + expected: 0, + }, + { + name: "empty-vs-non-empty", + a: "", + b: "0.0.0.1", + expected: -1, + }, + { + name: "longer-than-4-segments-ignored-after-4", + a: "1.2.3.4.999", + b: "1.2.3.4.0", + expected: 0, // only first 4 segments matter + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := compareVersions(tt.a, tt.b) + if got != tt.expected { + t.Fatalf("compareVersions(%q, %q) = %d, want %d", tt.a, tt.b, got, tt.expected) + } + }) + } +} + +func TestHasMrseq(t *testing.T) { + ctx := log.NewContext(log.NewNopLogger()) + + t.Run("empty dir string returns false", func(t *testing.T) { + if got := hasMrseq(ctx, ""); got { + t.Fatalf("hasMrseq(ctx, \"\") = true; want false") + } + }) + + t.Run("non-existent directory returns false", func(t *testing.T) { + nonExistent := filepath.Join(t.TempDir(), "this-dir-does-not-exist") + // Ensure it truly doesn't exist + if _, err := os.Stat(nonExistent); !os.IsNotExist(err) { + t.Fatalf("test setup: expected directory to not exist: %s", nonExistent) + } + if got := hasMrseq(ctx, nonExistent); got { + t.Fatalf("hasMrseq(ctx, %q) = true; want false", nonExistent) + } + }) + + t.Run("empty directory returns false", func(t *testing.T) { + dir := t.TempDir() + if got := hasMrseq(ctx, dir); got { + t.Fatalf("hasMrseq(ctx, %q) = true; want false", dir) + } + }) + + t.Run("directory with one .mrseq file returns true", func(t *testing.T) { + dir := t.TempDir() + f := filepath.Join(dir, "run1.mrseq") + if err := os.WriteFile(f, []byte("dummy"), 0o644); err != nil { + t.Fatalf("test setup: write %s: %v", f, err) + } + if got := hasMrseq(ctx, dir); !got { + t.Fatalf("hasMrseq(ctx, %q) = false; want true", dir) + } + }) + + t.Run("directory with multiple .mrseq files returns true", func(t *testing.T) { + dir := t.TempDir() + for i := 1; i <= 3; i++ { + name := filepath.Join(dir, fmt.Sprintf("batch_%d.mrseq", i)) + if err := os.WriteFile(name, []byte("dummy"), 0o644); err != nil { + t.Fatalf("test setup: write %s: %v", name, err) + } + } + if got := hasMrseq(ctx, dir); !got { + t.Fatalf("hasMrseq(ctx, %q) = false; want true", dir) + } + }) + + t.Run("directory with non-mrseq files only returns false", func(t *testing.T) { + dir := t.TempDir() + others := []string{"a.txt", "b.mrseq.bak", "c.mrseqq", "d.MRSEQ"} // case-sensitive on most platforms + for _, name := range others { + path := filepath.Join(dir, name) + if err := os.WriteFile(path, []byte("dummy"), 0o644); err != nil { + t.Fatalf("test setup: write %s: %v", path, err) + } + } + if got := hasMrseq(ctx, dir); got { + t.Fatalf("hasMrseq(ctx, %q) = true; want false", dir) + } + }) + + t.Run("non-recursive: file only in subdirectory does not count", func(t *testing.T) { + dir := t.TempDir() + sub := filepath.Join(dir, "sub") + if err := os.MkdirAll(sub, 0o755); err != nil { + t.Fatalf("test setup: mkdir %s: %v", sub, err) + } + f := filepath.Join(sub, "nested.mrseq") + if err := os.WriteFile(f, []byte("dummy"), 0o644); err != nil { + t.Fatalf("test setup: write %s: %v", f, err) + } + // Glob(dir, "*.mrseq") should not find files in subdir + if got := hasMrseq(ctx, dir); got { + t.Fatalf("hasMrseq(ctx, %q) = true; want false (non-recursive glob)", dir) + } + }) + + // Optional: demonstrate that unrelated extensions don't affect the outcome when at least one *.mrseq exists. + t.Run("mixed files: presence of .mrseq wins", func(t *testing.T) { + dir := t.TempDir() + _ = os.WriteFile(filepath.Join(dir, "a.txt"), []byte("dummy"), 0o644) + _ = os.WriteFile(filepath.Join(dir, "b.log"), []byte("dummy"), 0o644) + _ = os.WriteFile(filepath.Join(dir, "c.mrseq"), []byte("dummy"), 0o644) + if got := hasMrseq(ctx, dir); !got { + t.Fatalf("hasMrseq(ctx, %q) = false; want true", dir) + } + }) +} + +func makeDirWithMrseq(t *testing.T, dir string, addMrseq bool, version string) string { + sub := filepath.Join(dir, version) + if err := os.MkdirAll(sub, 0o755); err != nil { + t.Fatalf("test setup: mkdir %s: %v", sub, err) + } + + if addMrseq { + f := filepath.Join(sub, "floopster.mrseq") + if err := os.WriteFile(f, []byte("0"), 0o644); err != nil { + t.Fatalf("setup: write %s: %v", f, err) + } + } + return sub +} + +func TestDetermineUpgradeVersionDirectories(t *testing.T) { + tests := []struct { + name string + firstVersion string + secondVersion string + firstHasMrseq bool + secondHasMrseq bool + expectedToSuffix string + expectedFrom string + }{ + { + name: "first has mrseq, second does not → upgrade to second", + firstVersion: "1.0.0", + secondVersion: "2.0.0", + firstHasMrseq: true, + secondHasMrseq: false, + expectedToSuffix: "second", + expectedFrom: "1.0.0", + }, + { + name: "second has mrseq, first does not → upgrade to first", + firstVersion: "1.0.0", + secondVersion: "2.0.0", + firstHasMrseq: false, + secondHasMrseq: true, + expectedToSuffix: "first", + expectedFrom: "2.0.0", + }, + { + name: "neither has mrseq → choose higher version (second)", + firstVersion: "1.0.0", + secondVersion: "2.0.0", + firstHasMrseq: false, + secondHasMrseq: false, + expectedToSuffix: "second", + expectedFrom: "1.0.0", + }, + { + name: "both have mrseq → choose higher version (second)", + firstVersion: "1.0.0", + secondVersion: "2.0.0", + firstHasMrseq: true, + secondHasMrseq: true, + expectedToSuffix: "second", + expectedFrom: "1.0.0", + }, + { + name: "equal versions → choose first as upgradeTo", + firstVersion: "1.0.0", + secondVersion: "1.0.0", + firstHasMrseq: false, + secondHasMrseq: false, + expectedToSuffix: "first", + expectedFrom: "1.0.0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Prepare directories + dir := t.TempDir() + firstDir := makeDirWithMrseq(t, dir, tt.firstHasMrseq, tt.firstVersion) + secondDir := makeDirWithMrseq(t, dir, tt.secondHasMrseq, tt.secondVersion) + + // Simulate environment variables + os.Setenv(constants.ExtensionVersionEnvName, tt.firstVersion) + os.Setenv(constants.ExtensionVersionUpdatingFromEnvName, tt.secondVersion) + os.Setenv(constants.ExtensionPathEnvName, firstDir) + + // Replace secondDir logic: mimic original code's substitution + // (strings.ReplaceAll(firstDir, firstVersion, secondVersion)) + // For test simplicity, override secondDir directly + // but ensure substitution works if versions appear in path + if strings.Contains(firstDir, tt.firstVersion) { + secondDir = strings.ReplaceAll(firstDir, tt.firstVersion, tt.secondVersion) + } + + tempDir, _ := os.MkdirTemp("", "determineupgrade") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + ctx := log.NewContext(log.NewNopLogger()) + + gotFromDir, gotToDir, gotFromVersion := determineUpgradeVersionDirectories(ctx, events) + + // Validate upgradeFromVersion + if gotFromVersion != tt.expectedFrom { + t.Errorf("upgradeFromVersion = %q; want %q", gotFromVersion, tt.expectedFrom) + } + + // Validate which directory chosen as upgradeTo + if tt.expectedToSuffix == "first" { + if gotToDir != firstDir { + t.Errorf("upgradeToDir = %q; want firstDir %q", gotToDir, firstDir) + } + if gotFromDir != secondDir { + t.Errorf("upgradeFromDir = %q; want secondDir %q", gotFromDir, secondDir) + } + } else { + if gotToDir != secondDir { + t.Errorf("upgradeToDir = %q; want secondDir %q", gotToDir, secondDir) + } + if gotFromDir != firstDir { + t.Errorf("upgradeFromDir = %q; want firstDir %q", gotFromDir, firstDir) + } + } + }) + } +} diff --git a/internal/constants/constants.go b/internal/constants/constants.go index 94bdccd..d7457b6 100755 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -38,6 +38,9 @@ const ( // The output directory for logs of immediate run command ImmediateRCOutputDirectory = "/var/log/azure/run-command-handler/ImmediateRunCommandService.log" + // This is the directory where we place all the .mrseq files + WorkingDirectory = "../run-command-handler/workdir" + // Download folder to use for standard managed run command DownloadFolder = "download/" @@ -48,14 +51,9 @@ const ( RunCommandExtensionName = "Microsoft.CPlat.Core.RunCommandHandlerLinux" RunCommandTestExtensionName = "Microsoft.Azure.Extensions.Edp.RunCommandHandlerLinuxTest" - // List of problematic RCV2 versions that delete the mrseq files - ProductionVersionThatDeletesMrSeqFiles = "1.3.17" - FirstTestVersionThatDeletesMrSeqFiles = "1.8.0" - SecondTestVersionThatDeletesMrSeqFiles = "1.9.0" - // The current version of the extension. This value is provided by the agent for all commands. // See more in: https://github.com/Azure/azure-vmextension-publishing/wiki/2.0-Partner-Guide-Handler-Design-Details#236-summary - ExtensionVersionEnvName = "AZURE_GUEST_AGENT_EXTENSION_VERSION" + ExtensionVersionEnvName = "VERSION" // This is the version the extension is updating from // See more in: https://github.com/Azure/azure-vmextension-publishing/wiki/2.0-Partner-Guide-Handler-Design-Details#236-summary @@ -65,6 +63,9 @@ const ( // See more in: https://github.com/Azure/azure-vmextension-publishing/wiki/2.0-Partner-Guide-Handler-Design-Details#236-summary ExtensionPathEnvName = "AZURE_GUEST_AGENT_EXTENSION_PATH" + // The first version from which no rehydration is necessary because we do not delete the .mrseq file in disable + FirstVersionNoRehydration = "1.3.26" + // The name of the immediate run command service ImmediateRunCommandHandlerName = "runCommandService" diff --git a/internal/pid/pid.go b/internal/pid/pid.go index bf46afe..6145e6d 100644 --- a/internal/pid/pid.go +++ b/internal/pid/pid.go @@ -7,7 +7,6 @@ import ( "os/exec" "strconv" "strings" - "syscall" "github.com/go-kit/kit/log" "github.com/pkg/errors" @@ -94,7 +93,7 @@ func KillPreviousExtension(ctx *log.Context, pidFilePath string) { if ctx != nil { ctx.Log("event", "check process", "Active previous execution found. Killing pid ", previousPid) } - syscall.Kill(-previousPid, syscall.SIGKILL) // Negative pid means kill the whole process group + //syscall.Kill(-previousPid, syscall.SIGKILL) // Negative pid means kill the whole process group DeleteCurrentPidAndStartTime(pidFilePath) } } From d04754e2ec0e63e89991d4358e33b691f44b2622 Mon Sep 17 00:00:00 2001 From: Joseph Calev Date: Tue, 16 Dec 2025 13:59:51 -0800 Subject: [PATCH 2/7] PR feedback --- internal/cleanup/cleanup.go | 45 ++++++++++++++++++++----------------- internal/cmds/cmds.go | 2 +- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/internal/cleanup/cleanup.go b/internal/cleanup/cleanup.go index 2d623c2..be8634b 100644 --- a/internal/cleanup/cleanup.go +++ b/internal/cleanup/cleanup.go @@ -2,8 +2,11 @@ package cleanup import ( "fmt" + "os" "path/filepath" + "strconv" + "github.com/!azure/azure-extension-platform/pkg/utils" "github.com/Azure/run-command-handler-linux/internal/constants" "github.com/Azure/run-command-handler-linux/internal/types" "github.com/Azure/run-command-handler-linux/pkg/linuxutils" @@ -38,28 +41,28 @@ func deleteAllScriptsAndSettings(ctx *log.Context, metadata types.RCMetadata, h } func deleteScriptsAndSettingsExceptMostRecent(ctx *log.Context, metadata types.RCMetadata, h types.HandlerEnvironment, runAsUser string) { - //runtimeSettingsRegexFormat := metadata.ExtName + ".\\d+.settings" - //runtimeSettingsLastSeqNumFormat := metadata.ExtName + ".%d.settings" + runtimeSettingsRegexFormat := metadata.ExtName + ".\\d+.settings" + runtimeSettingsLastSeqNumFormat := metadata.ExtName + ".%d.settings" // check if directory exists - //_, err := os.Open(metadata.DownloadPath) - //if err == nil { - // err := utils.TryClearExtensionScriptsDirectoriesAndSettingsFilesExceptMostRecent(metadata.DownloadPath, h.HandlerEnvironment.ConfigFolder, "", - // uint64(metadata.SeqNum), runtimeSettingsRegexFormat, runtimeSettingsLastSeqNumFormat) - // if err != nil { - // ctx.Log("event", "could not clear settings and script files", "error", err) - // } - //} else { - // ctx.Log("message", "directory does not exist. Skipping cleanup") - //} + _, err := os.Open(metadata.DownloadPath) + if err == nil { + err := utils.TryClearExtensionScriptsDirectoriesAndSettingsFilesExceptMostRecent(metadata.DownloadPath, h.HandlerEnvironment.ConfigFolder, "", + uint64(metadata.SeqNum), runtimeSettingsRegexFormat, runtimeSettingsLastSeqNumFormat) + if err != nil { + ctx.Log("event", "could not clear settings and script files", "error", err) + } + } else { + ctx.Log("message", "directory does not exist. Skipping cleanup") + } - //if runAsUser != "" { - // runAsDownloadParent := filepath.Join(fmt.Sprintf(constants.RunAsDir, runAsUser), metadata.DownloadDir) - // seqNumString := strconv.Itoa(metadata.SeqNum) - // ctx.Log("message", "removing all files from the download 'runas' directory "+runAsDownloadParent) - // err = utils.TryDeleteDirectoriesExcept(runAsDownloadParent, seqNumString) - // if err != nil { - // ctx.Log("event", "could not clear runas script") - // } - //} + if runAsUser != "" { + runAsDownloadParent := filepath.Join(fmt.Sprintf(constants.RunAsDir, runAsUser), metadata.DownloadDir) + seqNumString := strconv.Itoa(metadata.SeqNum) + ctx.Log("message", "removing all files from the download 'runas' directory "+runAsDownloadParent) + err = utils.TryDeleteDirectoriesExcept(runAsDownloadParent, seqNumString) + if err != nil { + ctx.Log("event", "could not clear runas script") + } + } } diff --git a/internal/cmds/cmds.go b/internal/cmds/cmds.go index dbbfd96..d804c47 100755 --- a/internal/cmds/cmds.go +++ b/internal/cmds/cmds.go @@ -459,7 +459,7 @@ func determineUpgradeVersionDirectories(ctx *log.Context, extensionEvents *exten // The higher version isn't necessarily the one we're upgrading to, since we may be downgrading // If one has at least one .mrseq file, and the other has none, then we're upgrading to the one that has none // If neither has a .mrseq file, then just choose the higher version number - // If both have .mrseq files, then this shouldn't happen, but for the sake of sanity choose the highe version number + // If both have .mrseq files, then this shouldn't happen, but for the sake of sanity choose the higher version number firstExtensionDirectory := os.Getenv(constants.ExtensionPathEnvName) secondExtensionDirectory := strings.ReplaceAll(firstExtensionDirectory, firstExtensionVersion, secondExtensionVersion) From 136a3cdb0ba2e0f158f70379832f2c8de4b498cd Mon Sep 17 00:00:00 2001 From: Joseph Calev Date: Mon, 5 Jan 2026 10:51:30 -0800 Subject: [PATCH 3/7] Fixes build break --- internal/cleanup/cleanup.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/cleanup/cleanup.go b/internal/cleanup/cleanup.go index be8634b..32bcf6a 100644 --- a/internal/cleanup/cleanup.go +++ b/internal/cleanup/cleanup.go @@ -6,7 +6,7 @@ import ( "path/filepath" "strconv" - "github.com/!azure/azure-extension-platform/pkg/utils" + "github.com/Azure/azure-extension-platform/pkg/utils" "github.com/Azure/run-command-handler-linux/internal/constants" "github.com/Azure/run-command-handler-linux/internal/types" "github.com/Azure/run-command-handler-linux/pkg/linuxutils" From b9dda0169c46ac6b9022895d6fdcc373227f65c0 Mon Sep 17 00:00:00 2001 From: Joseph Calev Date: Tue, 13 Jan 2026 16:44:06 -0800 Subject: [PATCH 4/7] Changes approach on how we determine the upgrade from and to versions --- internal/cmds/cmds.go | 114 +++----- internal/cmds/cmds_test.go | 335 ++++++++-------------- internal/constants/constants.go | 5 +- internal/pid/pid.go | 3 +- internal/service/serviceinstall.go | 2 +- pkg/servicehandler/servicehandler.go | 2 +- pkg/servicehandler/servicehandler_test.go | 4 +- 7 files changed, 169 insertions(+), 296 deletions(-) diff --git a/internal/cmds/cmds.go b/internal/cmds/cmds.go index d804c47..973b00d 100755 --- a/internal/cmds/cmds.go +++ b/internal/cmds/cmds.go @@ -12,6 +12,7 @@ import ( "io" "os" "path/filepath" + "slices" "strconv" "strings" "time" @@ -452,76 +453,39 @@ func CopyStateForUpdate(ctx log.Logger, upgradeFromVersionDirectory string, upgr func determineUpgradeVersionDirectories(ctx *log.Context, extensionEvents *extensionevents.ExtensionEventManager) (upgradeFromVersionDirectory string, upgradeToVersionDirectory string, upgradeFromVersion string) { // These two environment variables will tell us the extension versions involved, but won't actually tell us // the from/to versions - firstExtensionVersion := os.Getenv(constants.ExtensionVersionEnvName) - secondExtensionVersion := os.Getenv(constants.ExtensionVersionUpdatingFromEnvName) - - // To determine to which version we're actually upgrading, we'll need to look into the folders - // The higher version isn't necessarily the one we're upgrading to, since we may be downgrading - // If one has at least one .mrseq file, and the other has none, then we're upgrading to the one that has none - // If neither has a .mrseq file, then just choose the higher version number - // If both have .mrseq files, then this shouldn't happen, but for the sake of sanity choose the higher version number - firstExtensionDirectory := os.Getenv(constants.ExtensionPathEnvName) - secondExtensionDirectory := strings.ReplaceAll(firstExtensionDirectory, firstExtensionVersion, secondExtensionVersion) - - // Check for *.mrseq presence in each directory - firstHasMrseq := hasMrseq(ctx, firstExtensionDirectory) - secondHasMrseq := hasMrseq(ctx, secondExtensionDirectory) - - // If one has mrseq and the other doesn't → upgrade to the one without mrseq - if firstHasMrseq != secondHasMrseq { - if firstHasMrseq && !secondHasMrseq { - upgradeToVersionDirectory, upgradeFromVersionDirectory = secondExtensionDirectory, firstExtensionDirectory - upgradeFromVersion = firstExtensionVersion - } else { - upgradeToVersionDirectory, upgradeFromVersionDirectory = firstExtensionDirectory, secondExtensionDirectory - upgradeFromVersion = secondExtensionVersion - } - - msg := fmt.Sprintf("determineUpgradeVersions: mrseq-guided choice → to='%s' from='%s'", upgradeToVersionDirectory, upgradeFromVersionDirectory) - ctx.Log("message", msg) - extensionEvents.LogInformationalEvent("determineUpgradeVersions", msg) - - return upgradeFromVersionDirectory, upgradeToVersionDirectory, upgradeFromVersion + upgradeToVersion := os.Getenv(constants.VersionEnvName) + extensionVersionValue := os.Getenv(constants.ExtensionVersionEnvName) + updatingFromVersionValue := os.Getenv(constants.ExtensionVersionUpdatingFromEnvName) + upgradeType := "upgrade" + + // First, we need to determine if this is an upgrade or a downgrade + // This is a downgrade if the updating from version is equal to the version + if upgradeToVersion == updatingFromVersionValue { + // This is a downgrade + upgradeFromVersion = extensionVersionValue + upgradeType = "downgrade" + } else { + // This is an upgrade + upgradeFromVersion = updatingFromVersionValue } - // Rule 2 & 3: neither has mrseq OR both have mrseq → choose higher version number as upgradeTo - switch c := compareVersions(firstExtensionVersion, secondExtensionVersion); { - case c > 0: - upgradeToVersionDirectory, upgradeFromVersionDirectory = firstExtensionDirectory, secondExtensionDirectory - upgradeFromVersion = secondExtensionVersion - case c < 0: - upgradeToVersionDirectory, upgradeFromVersionDirectory = secondExtensionDirectory, firstExtensionDirectory - upgradeFromVersion = firstExtensionVersion - default: - // Equal versions (shouldn’t normally happen in an upgrade path). Keep first as "to". - upgradeToVersionDirectory, upgradeFromVersionDirectory = firstExtensionDirectory, secondExtensionDirectory - upgradeFromVersion = secondExtensionVersion + // Determine the corresponding extension directories + extensionDirectory := os.Getenv(constants.ExtensionPathEnvName) + if strings.Contains(extensionDirectory, upgradeToVersion) { + upgradeToVersionDirectory = extensionDirectory + upgradeFromVersionDirectory = strings.ReplaceAll(extensionDirectory, upgradeToVersion, upgradeFromVersion) + } else { + upgradeFromVersionDirectory = extensionDirectory + upgradeToVersionDirectory = strings.ReplaceAll(extensionDirectory, upgradeFromVersion, upgradeToVersion) } - msg := fmt.Sprintf("determineUpgradeVersions: version-ordered choice → to='%s' from='%s' (mrseq first=%t second=%t)", upgradeToVersionDirectory, upgradeFromVersionDirectory, firstHasMrseq, secondHasMrseq) + msg := fmt.Sprintf("determineUpgradeVersionDirectories: %s from='%s' to='%s'", upgradeType, upgradeToVersionDirectory, upgradeFromVersionDirectory) ctx.Log("message", msg) extensionEvents.LogInformationalEvent("determineUpgradeVersions", msg) return upgradeFromVersionDirectory, upgradeToVersionDirectory, upgradeFromVersion } -// hasMrseq returns true if the given directory contains at least one *.mrseq file. -// It is resilient to missing directories and IO errors (logs and returns false). -func hasMrseq(ctx *log.Context, dir string) bool { - if dir == "" { - return false - } - // Resolve glob pattern - pattern := filepath.Join(dir, "*.mrseq") - - matches, err := filepath.Glob(pattern) - if err != nil { - ctx.Log("error", fmt.Sprintf("hasMrseq: glob error for '%s': %v", pattern, err)) - return false - } - return len(matches) > 0 -} - // compareVersions compares two dotted version strings (e.g., "2.1", "2.1.0", "2.1.0.3"). // Returns: +1 if a>b, -1 if a '9' { - // non-numeric component; keep as 0 - n = 0 - goto done - } - } - if p != "" { - // safe Atoi without error branch since we checked digits - for i := 0; i < len(p); i++ { - n = n*10 + int(p[i]-'0') - } + n, err := strconv.Atoi(p) + if err != nil { + n = 0 } - done: out = append(out, n) } return out } func padTo(in []int, size int) []int { + if len(in) >= size { return in[:size] } - out := make([]int, size) - copy(out, in) - // remaining default to 0 + n := size - len(in) + + // Ensure capacity for the extra n elements without reallocating. + out := slices.Grow(in, n) + + // Extend length to size by appending n zero-values. + out = append(out, make([]int, n)...) return out + } func rehydrateMrSeqFilesForProblematicUpgrades(ctx *log.Context, updateFromVersionDirectory string, updateToVersionDirectory string, extensionEvents *extensionevents.ExtensionEventManager) error { diff --git a/internal/cmds/cmds_test.go b/internal/cmds/cmds_test.go index fdfd448..c09b5d9 100755 --- a/internal/cmds/cmds_test.go +++ b/internal/cmds/cmds_test.go @@ -3,7 +3,6 @@ package commands import ( "encoding/json" "errors" - "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -31,9 +30,9 @@ func Test_CopyMrseqFiles_MrseqFilesAreCopied(t *testing.T) { currentExtensionVersionDirectory := "Microsoft.CPlat.Core.RunCommandHandlerLinux-1.3.8" os.Setenv(constants.ExtensionPathEnvName, currentExtensionVersionDirectory) os.Setenv(constants.ExtensionVersionUpdatingFromEnvName, "1.3.7") - os.Setenv(constants.ExtensionVersionEnvName, "1.3.8") + os.Setenv(constants.VersionEnvName, "1.3.8") - currentVersion := os.Getenv(constants.ExtensionVersionEnvName) + currentVersion := os.Getenv(constants.VersionEnvName) previousVersion := os.Getenv(constants.ExtensionVersionUpdatingFromEnvName) previousExtensionVersionDirectory := strings.ReplaceAll(currentExtensionVersionDirectory, currentVersion, previousVersion) @@ -212,7 +211,7 @@ func Test_update_e2e_cmd(t *testing.T) { // We start on the old version os.Setenv(constants.ExtensionPathEnvName, oldVersionDirectory) - os.Setenv(constants.ExtensionVersionEnvName, "1.3.8") + os.Setenv(constants.VersionEnvName, "1.3.8") // Create two extensions enable_extension(t, fakeEnv, oldVersionDirectory, "happyChipmunk", true, 0) @@ -224,7 +223,7 @@ func Test_update_e2e_cmd(t *testing.T) { disable_extension(t, fakeEnv, oldVersionDirectory, "crazyChipmunk") // Step 2: WALA will call update - os.Setenv(constants.ExtensionVersionEnvName, "1.3.9") + os.Setenv(constants.VersionEnvName, "1.3.9") os.Setenv(constants.ExtensionPathEnvName, newVersionDirectory) os.Setenv(constants.ExtensionVersionUpdatingFromEnvName, "1.3.8") update_handler_env(&fakeEnv, newStatusPath, newVersionDirectory, newEventsPath) @@ -264,7 +263,7 @@ func Test_update_e23_non_problematic_version(t *testing.T) { // We start on the old version os.Setenv(constants.ExtensionPathEnvName, oldVersionDirectory) - os.Setenv(constants.ExtensionVersionEnvName, "1.3.26") + os.Setenv(constants.VersionEnvName, "1.3.26") // Create three extensions enable_extension(t, fakeEnv, oldVersionDirectory, "happyChipmunk", true, 0) @@ -281,7 +280,7 @@ func Test_update_e23_non_problematic_version(t *testing.T) { disable_extension(t, fakeEnv, oldVersionDirectory, "stubbornChipmunk") // Step 2: WALA will call update - os.Setenv(constants.ExtensionVersionEnvName, "1.3.27") + os.Setenv(constants.VersionEnvName, "1.3.27") os.Setenv(constants.ExtensionPathEnvName, newVersionDirectory) os.Setenv(constants.ExtensionVersionUpdatingFromEnvName, "1.3.26") update_handler_env(&fakeEnv, newStatusPath, newVersionDirectory, newEventsPath) @@ -327,7 +326,7 @@ func Test_udpate_e2e_problematic_version(t *testing.T) { // We start on the old version os.Setenv(constants.ExtensionPathEnvName, oldVersionDirectory) - os.Setenv(constants.ExtensionVersionEnvName, "1.3.17") + os.Setenv(constants.VersionEnvName, "1.3.17") // Create three extensions enable_extension(t, fakeEnv, oldVersionDirectory, "happyChipmunk", true, 0) @@ -354,7 +353,7 @@ func Test_udpate_e2e_problematic_version(t *testing.T) { os.WriteFile(filepath.Join(oldStatusPath, "this.is.a.bad.chipmunk.0.status"), []byte("0"), os.FileMode(0600)) // Step 2: WALA will call update - os.Setenv(constants.ExtensionVersionEnvName, "1.3.18") + os.Setenv(constants.VersionEnvName, "1.3.18") os.Setenv(constants.ExtensionPathEnvName, newVersionDirectory) os.Setenv(constants.ExtensionVersionUpdatingFromEnvName, "1.3.17") update_handler_env(&fakeEnv, newStatusPath, newVersionDirectory, newEventsPath) @@ -932,12 +931,12 @@ func TestSplitVersion(t *testing.T) { name: "dash-negative-like", in: "1.-2.3", // '-' makes component non-numeric → 0 - expected: []int{1, 0, 3}, + expected: []int{1, -2, 3}, }, { name: "plus-sign", in: "+1.2", - expected: []int{0, 2}, + expected: []int{1, 2}, }, { name: "long-many-parts", @@ -1065,224 +1064,136 @@ func TestCompareVersions(t *testing.T) { } } -func TestHasMrseq(t *testing.T) { - ctx := log.NewContext(log.NewNopLogger()) +func Test_determineUpgradeVersionDirectories_Upgrade_PathContainsToVersion(t *testing.T) { + /* + Upgrade: updatingFrom != toVersion + Path contains the toVersion -> replace toVersion with fromVersion for the "from" dir + */ + to := "2.5.0" + from := "2.4.3" + curr := "2.5.0" // current extension version value (doesn't affect upgrade/downgrade decision) + dir := "/var/lib/waagent/My.Ext/" + to + "/" - t.Run("empty dir string returns false", func(t *testing.T) { - if got := hasMrseq(ctx, ""); got { - t.Fatalf("hasMrseq(ctx, \"\") = true; want false") - } - }) + setEnvs(t, to, curr, from, dir) - t.Run("non-existent directory returns false", func(t *testing.T) { - nonExistent := filepath.Join(t.TempDir(), "this-dir-does-not-exist") - // Ensure it truly doesn't exist - if _, err := os.Stat(nonExistent); !os.IsNotExist(err) { - t.Fatalf("test setup: expected directory to not exist: %s", nonExistent) - } - if got := hasMrseq(ctx, nonExistent); got { - t.Fatalf("hasMrseq(ctx, %q) = true; want false", nonExistent) - } - }) + tempDir, _ := os.MkdirTemp("", "upgradetoversion") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } - t.Run("empty directory returns false", func(t *testing.T) { - dir := t.TempDir() - if got := hasMrseq(ctx, dir); got { - t.Fatalf("hasMrseq(ctx, %q) = true; want false", dir) - } - }) + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) - t.Run("directory with one .mrseq file returns true", func(t *testing.T) { - dir := t.TempDir() - f := filepath.Join(dir, "run1.mrseq") - if err := os.WriteFile(f, []byte("dummy"), 0o644); err != nil { - t.Fatalf("test setup: write %s: %v", f, err) - } - if got := hasMrseq(ctx, dir); !got { - t.Fatalf("hasMrseq(ctx, %q) = false; want true", dir) - } - }) - - t.Run("directory with multiple .mrseq files returns true", func(t *testing.T) { - dir := t.TempDir() - for i := 1; i <= 3; i++ { - name := filepath.Join(dir, fmt.Sprintf("batch_%d.mrseq", i)) - if err := os.WriteFile(name, []byte("dummy"), 0o644); err != nil { - t.Fatalf("test setup: write %s: %v", name, err) - } - } - if got := hasMrseq(ctx, dir); !got { - t.Fatalf("hasMrseq(ctx, %q) = false; want true", dir) - } - }) - - t.Run("directory with non-mrseq files only returns false", func(t *testing.T) { - dir := t.TempDir() - others := []string{"a.txt", "b.mrseq.bak", "c.mrseqq", "d.MRSEQ"} // case-sensitive on most platforms - for _, name := range others { - path := filepath.Join(dir, name) - if err := os.WriteFile(path, []byte("dummy"), 0o644); err != nil { - t.Fatalf("test setup: write %s: %v", path, err) - } - } - if got := hasMrseq(ctx, dir); got { - t.Fatalf("hasMrseq(ctx, %q) = true; want false", dir) - } - }) + fromDir, toDir, gotFromVersion := determineUpgradeVersionDirectories(ctx, events) - t.Run("non-recursive: file only in subdirectory does not count", func(t *testing.T) { - dir := t.TempDir() - sub := filepath.Join(dir, "sub") - if err := os.MkdirAll(sub, 0o755); err != nil { - t.Fatalf("test setup: mkdir %s: %v", sub, err) - } - f := filepath.Join(sub, "nested.mrseq") - if err := os.WriteFile(f, []byte("dummy"), 0o644); err != nil { - t.Fatalf("test setup: write %s: %v", f, err) - } - // Glob(dir, "*.mrseq") should not find files in subdir - if got := hasMrseq(ctx, dir); got { - t.Fatalf("hasMrseq(ctx, %q) = true; want false (non-recursive glob)", dir) - } - }) - - // Optional: demonstrate that unrelated extensions don't affect the outcome when at least one *.mrseq exists. - t.Run("mixed files: presence of .mrseq wins", func(t *testing.T) { - dir := t.TempDir() - _ = os.WriteFile(filepath.Join(dir, "a.txt"), []byte("dummy"), 0o644) - _ = os.WriteFile(filepath.Join(dir, "b.log"), []byte("dummy"), 0o644) - _ = os.WriteFile(filepath.Join(dir, "c.mrseq"), []byte("dummy"), 0o644) - if got := hasMrseq(ctx, dir); !got { - t.Fatalf("hasMrseq(ctx, %q) = false; want true", dir) - } - }) + require.Equal(t, from, gotFromVersion, "upgradeFromVersion should be updating-from value on upgrade") + require.Equal(t, dir, toDir, "when path contains toVersion, toDir is the given path") + + expectedFromDir := strings.ReplaceAll(dir, to, from) + require.Equal(t, expectedFromDir, fromDir) } -func makeDirWithMrseq(t *testing.T, dir string, addMrseq bool, version string) string { - sub := filepath.Join(dir, version) - if err := os.MkdirAll(sub, 0o755); err != nil { - t.Fatalf("test setup: mkdir %s: %v", sub, err) - } +func Test_determineUpgradeVersionDirectories_Upgrade_PathContainsFromVersion(t *testing.T) { + /* + Upgrade: updatingFrom != toVersion + Path contains the fromVersion -> replace fromVersion with toVersion for the "to" dir + */ + to := "3.0.1" + from := "2.9.9" + curr := "3.0.1" + dir := "/opt/exts/My.Ext/" + from + "/bin" - if addMrseq { - f := filepath.Join(sub, "floopster.mrseq") - if err := os.WriteFile(f, []byte("0"), 0o644); err != nil { - t.Fatalf("setup: write %s: %v", f, err) - } + setEnvs(t, to, curr, from, dir) + + tempDir, _ := os.MkdirTemp("", "upgradefromversion") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, } - return sub + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + fromDir, toDir, gotFromVersion := determineUpgradeVersionDirectories(ctx, events) + + require.Equal(t, from, gotFromVersion) + require.Equal(t, dir, fromDir, "when path does NOT contain toVersion, fromDir is the given path") + + expectedToDir := strings.ReplaceAll(dir, from, to) + require.Equal(t, expectedToDir, toDir) } -func TestDetermineUpgradeVersionDirectories(t *testing.T) { - tests := []struct { - name string - firstVersion string - secondVersion string - firstHasMrseq bool - secondHasMrseq bool - expectedToSuffix string - expectedFrom string - }{ - { - name: "first has mrseq, second does not → upgrade to second", - firstVersion: "1.0.0", - secondVersion: "2.0.0", - firstHasMrseq: true, - secondHasMrseq: false, - expectedToSuffix: "second", - expectedFrom: "1.0.0", - }, - { - name: "second has mrseq, first does not → upgrade to first", - firstVersion: "1.0.0", - secondVersion: "2.0.0", - firstHasMrseq: false, - secondHasMrseq: true, - expectedToSuffix: "first", - expectedFrom: "2.0.0", - }, - { - name: "neither has mrseq → choose higher version (second)", - firstVersion: "1.0.0", - secondVersion: "2.0.0", - firstHasMrseq: false, - secondHasMrseq: false, - expectedToSuffix: "second", - expectedFrom: "1.0.0", - }, - { - name: "both have mrseq → choose higher version (second)", - firstVersion: "1.0.0", - secondVersion: "2.0.0", - firstHasMrseq: true, - secondHasMrseq: true, - expectedToSuffix: "second", - expectedFrom: "1.0.0", - }, - { - name: "equal versions → choose first as upgradeTo", - firstVersion: "1.0.0", - secondVersion: "1.0.0", - firstHasMrseq: false, - secondHasMrseq: false, - expectedToSuffix: "first", - expectedFrom: "1.0.0", - }, +func Test_determineUpgradeVersionDirectories_Downgrade_PathContainsToVersion(t *testing.T) { + /* + Downgrade: updatingFrom == toVersion + upgradeFromVersion becomes extensionVersionValue (curr) + Path contains toVersion -> replace toVersion with fromVersion (curr) for the "from" dir + */ + to := "2.1.0" + curr := "2.3.0" // extensionVersionValue + fromUpdating := to + dir := "C:\\Packages\\Plugins\\My.Ext\\" + to + "\\" + + setEnvs(t, to, curr, fromUpdating, dir) + + tempDir, _ := os.MkdirTemp("", "downgradetoversion") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Prepare directories - dir := t.TempDir() - firstDir := makeDirWithMrseq(t, dir, tt.firstHasMrseq, tt.firstVersion) - secondDir := makeDirWithMrseq(t, dir, tt.secondHasMrseq, tt.secondVersion) - - // Simulate environment variables - os.Setenv(constants.ExtensionVersionEnvName, tt.firstVersion) - os.Setenv(constants.ExtensionVersionUpdatingFromEnvName, tt.secondVersion) - os.Setenv(constants.ExtensionPathEnvName, firstDir) - - // Replace secondDir logic: mimic original code's substitution - // (strings.ReplaceAll(firstDir, firstVersion, secondVersion)) - // For test simplicity, override secondDir directly - // but ensure substitution works if versions appear in path - if strings.Contains(firstDir, tt.firstVersion) { - secondDir = strings.ReplaceAll(firstDir, tt.firstVersion, tt.secondVersion) - } + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) - tempDir, _ := os.MkdirTemp("", "determineupgrade") - defer os.RemoveAll(tempDir) - handlerEnvironment := handlerenv.HandlerEnvironment{ - EventsFolder: tempDir, - } + fromDir, toDir, gotFromVersion := determineUpgradeVersionDirectories(ctx, events) - extensionLogger := logging.New(nil) - events := extensionevents.New(extensionLogger, &handlerEnvironment) - ctx := log.NewContext(log.NewNopLogger()) + require.Equal(t, curr, gotFromVersion, "on downgrade, fromVersion becomes the current extension version") + require.Equal(t, dir, toDir) - gotFromDir, gotToDir, gotFromVersion := determineUpgradeVersionDirectories(ctx, events) + expectedFromDir := strings.ReplaceAll(dir, to, curr) + require.Equal(t, expectedFromDir, fromDir) +} - // Validate upgradeFromVersion - if gotFromVersion != tt.expectedFrom { - t.Errorf("upgradeFromVersion = %q; want %q", gotFromVersion, tt.expectedFrom) - } +func Test_determineUpgradeVersionDirectories_Downgrade_PathContainsFromVersion(t *testing.T) { + /* + Downgrade: updatingFrom == toVersion + Path contains the computed fromVersion (curr), so toDir is replacement of fromVersion->toVersion + */ + to := "1.7.0" + curr := "1.7.5" + fromUpdating := to + dir := "/extensions/handler/" + curr + "/" - // Validate which directory chosen as upgradeTo - if tt.expectedToSuffix == "first" { - if gotToDir != firstDir { - t.Errorf("upgradeToDir = %q; want firstDir %q", gotToDir, firstDir) - } - if gotFromDir != secondDir { - t.Errorf("upgradeFromDir = %q; want secondDir %q", gotFromDir, secondDir) - } - } else { - if gotToDir != secondDir { - t.Errorf("upgradeToDir = %q; want secondDir %q", gotToDir, secondDir) - } - if gotFromDir != firstDir { - t.Errorf("upgradeFromDir = %q; want firstDir %q", gotFromDir, firstDir) - } - } - }) + setEnvs(t, to, curr, fromUpdating, dir) + + tempDir, _ := os.MkdirTemp("", "downgradefromversion") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + fromDir, toDir, gotFromVersion := determineUpgradeVersionDirectories(ctx, events) + + require.Equal(t, curr, gotFromVersion) + require.Equal(t, dir, fromDir) + + expectedToDir := strings.ReplaceAll(dir, curr, to) + require.Equal(t, expectedToDir, toDir) +} + +// setEnvs is a helper to seed the env for each scenario. +func setEnvs(t *testing.T, to, curr, from, dir string) { + t.Helper() + t.Setenv(constants.VersionEnvName, to) + t.Setenv(constants.ExtensionVersionEnvName, curr) + t.Setenv(constants.ExtensionVersionUpdatingFromEnvName, from) + t.Setenv(constants.ExtensionPathEnvName, dir) } diff --git a/internal/constants/constants.go b/internal/constants/constants.go index d7457b6..87c3db2 100755 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -53,12 +53,15 @@ const ( // The current version of the extension. This value is provided by the agent for all commands. // See more in: https://github.com/Azure/azure-vmextension-publishing/wiki/2.0-Partner-Guide-Handler-Design-Details#236-summary - ExtensionVersionEnvName = "VERSION" + VersionEnvName = "VERSION" // This is the version the extension is updating from // See more in: https://github.com/Azure/azure-vmextension-publishing/wiki/2.0-Partner-Guide-Handler-Design-Details#236-summary ExtensionVersionUpdatingFromEnvName = "AZURE_GUEST_AGENT_UPDATING_FROM_VERSION" + // We'll use this variable to determine if there's a downgrade + ExtensionVersionEnvName = "AZURE_GUEST_AGENT_EXTENSION_VERSION" + // The path of the extension in the VM with full name. This value is provided by the agent for all commands. // See more in: https://github.com/Azure/azure-vmextension-publishing/wiki/2.0-Partner-Guide-Handler-Design-Details#236-summary ExtensionPathEnvName = "AZURE_GUEST_AGENT_EXTENSION_PATH" diff --git a/internal/pid/pid.go b/internal/pid/pid.go index 6145e6d..bf46afe 100644 --- a/internal/pid/pid.go +++ b/internal/pid/pid.go @@ -7,6 +7,7 @@ import ( "os/exec" "strconv" "strings" + "syscall" "github.com/go-kit/kit/log" "github.com/pkg/errors" @@ -93,7 +94,7 @@ func KillPreviousExtension(ctx *log.Context, pidFilePath string) { if ctx != nil { ctx.Log("event", "check process", "Active previous execution found. Killing pid ", previousPid) } - //syscall.Kill(-previousPid, syscall.SIGKILL) // Negative pid means kill the whole process group + syscall.Kill(-previousPid, syscall.SIGKILL) // Negative pid means kill the whole process group DeleteCurrentPidAndStartTime(pidFilePath) } } diff --git a/internal/service/serviceinstall.go b/internal/service/serviceinstall.go index 1014e2f..bd14aeb 100644 --- a/internal/service/serviceinstall.go +++ b/internal/service/serviceinstall.go @@ -39,7 +39,7 @@ func Register(ctx *log.Context, extensionEvents *extensionevents.ExtensionEventM extensionEvents.LogErrorEvent("register", "Systemd not supported. Failed to register service") return errors.New("Systemd not supported. Failed to register service") } - targetVersion := os.Getenv(constants.ExtensionVersionEnvName) + targetVersion := os.Getenv(constants.VersionEnvName) ctx.Log("message", "trying to register extension with version: "+targetVersion) ctx.Log("message", "Generating service configuration files") diff --git a/pkg/servicehandler/servicehandler.go b/pkg/servicehandler/servicehandler.go index e3567fd..078f025 100644 --- a/pkg/servicehandler/servicehandler.go +++ b/pkg/servicehandler/servicehandler.go @@ -109,7 +109,7 @@ func (handler *Handler) Register(ctx *log.Context, unitConfigContent string) err func (handler *Handler) DeRegister(ctx *log.Context) error { // We need to make sure the version that the VM Agent is trying to uninstall is the correct one. // Failing to check this can cause to uninstall the service during the update workflow. - targetVersion := os.Getenv(constants.ExtensionVersionEnvName) + targetVersion := os.Getenv(constants.VersionEnvName) ctx.Log("message", "trying to uninstall extension with version: "+targetVersion) installedVersion, err := handler.GetInstalledVersion(ctx) diff --git a/pkg/servicehandler/servicehandler_test.go b/pkg/servicehandler/servicehandler_test.go index f5ed7e1..0bc610b 100644 --- a/pkg/servicehandler/servicehandler_test.go +++ b/pkg/servicehandler/servicehandler_test.go @@ -301,7 +301,7 @@ func TestHandlerSuccessfulDeRegister(t *testing.T) { ctx := log.NewContext(log.NewSyncLogger(log.NewLogfmtLogger( os.Stdout))).With("time", log.DefaultTimestamp) - os.Setenv(constants.ExtensionVersionEnvName, installedTargetVersion) + os.Setenv(constants.VersionEnvName, installedTargetVersion) handler := NewHandler(m, config, ctx) err := handler.DeRegister(ctx) if err != nil { @@ -336,7 +336,7 @@ func TestHandlerSkipDeRegisterForNonInstalledTargetVersion(t *testing.T) { ctx := log.NewContext(log.NewSyncLogger(log.NewLogfmtLogger( os.Stdout))).With("time", log.DefaultTimestamp) - os.Setenv(constants.ExtensionVersionEnvName, nonInstalledTargetVersion) + os.Setenv(constants.VersionEnvName, nonInstalledTargetVersion) handler := NewHandler(m, config, ctx) err := handler.DeRegister(ctx) if err != nil { From b002a7ba418bac9a65185cb25b990dece1fd59fa Mon Sep 17 00:00:00 2001 From: Joseph Calev Date: Wed, 14 Jan 2026 09:38:44 -0800 Subject: [PATCH 5/7] Improves comment --- internal/cmds/cmds.go | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/internal/cmds/cmds.go b/internal/cmds/cmds.go index 973b00d..c5f9ad5 100755 --- a/internal/cmds/cmds.go +++ b/internal/cmds/cmds.go @@ -456,14 +456,19 @@ func determineUpgradeVersionDirectories(ctx *log.Context, extensionEvents *exten upgradeToVersion := os.Getenv(constants.VersionEnvName) extensionVersionValue := os.Getenv(constants.ExtensionVersionEnvName) updatingFromVersionValue := os.Getenv(constants.ExtensionVersionUpdatingFromEnvName) - upgradeType := "upgrade" - // First, we need to determine if this is an upgrade or a downgrade - // This is a downgrade if the updating from version is equal to the version + // In some WALA versions, there is a bug where on downgrade it will send the same value for upgradeToVersion and upgradeFromVersion + // Newer versions will send the correct value for upgradeToVersion + // Therefore: + // Action | Old WALA | New WALA + // ---------------------------------------------------| ------------------------------------- + // Downgrade | upgradeToVersion == upgradeFromVersion | upgradeToVersion < upgradeFromVersion + // ------------------------------------------------------------------------------------------ + // Upgrade | upgradeToVersion > upgradeFromVersion | upgradeToVersion > upgradeFromVersion + // ------------------------------------------------------------------------------------------ if upgradeToVersion == updatingFromVersionValue { - // This is a downgrade + // This is a downgrade. We therefore need to use the extension version upgradeFromVersion = extensionVersionValue - upgradeType = "downgrade" } else { // This is an upgrade upgradeFromVersion = updatingFromVersionValue @@ -479,7 +484,7 @@ func determineUpgradeVersionDirectories(ctx *log.Context, extensionEvents *exten upgradeToVersionDirectory = strings.ReplaceAll(extensionDirectory, upgradeFromVersion, upgradeToVersion) } - msg := fmt.Sprintf("determineUpgradeVersionDirectories: %s from='%s' to='%s'", upgradeType, upgradeToVersionDirectory, upgradeFromVersionDirectory) + msg := fmt.Sprintf("determineUpgradeVersionDirectories: move from='%s' to='%s'", upgradeToVersionDirectory, upgradeFromVersionDirectory) ctx.Log("message", msg) extensionEvents.LogInformationalEvent("determineUpgradeVersions", msg) From 1cfae7c023e09d711dd2741d9ee5c54e378cb024 Mon Sep 17 00:00:00 2001 From: Joseph Calev Date: Fri, 16 Jan 2026 13:45:28 -0800 Subject: [PATCH 6/7] Adds more unit tests --- internal/cmds/cmds.go | 2 +- internal/cmds/cmds_test.go | 248 +++++++++++++++++++++++++++++++++++++ 2 files changed, 249 insertions(+), 1 deletion(-) diff --git a/internal/cmds/cmds.go b/internal/cmds/cmds.go index c5f9ad5..dac4b2d 100755 --- a/internal/cmds/cmds.go +++ b/internal/cmds/cmds.go @@ -470,7 +470,7 @@ func determineUpgradeVersionDirectories(ctx *log.Context, extensionEvents *exten // This is a downgrade. We therefore need to use the extension version upgradeFromVersion = extensionVersionValue } else { - // This is an upgrade + // This is an upgrade on the old WALA or an upgrade or downgrade on the new WALA upgradeFromVersion = updatingFromVersionValue } diff --git a/internal/cmds/cmds_test.go b/internal/cmds/cmds_test.go index c09b5d9..72f4e76 100755 --- a/internal/cmds/cmds_test.go +++ b/internal/cmds/cmds_test.go @@ -1189,6 +1189,246 @@ func Test_determineUpgradeVersionDirectories_Downgrade_PathContainsFromVersion(t require.Equal(t, expectedToDir, toDir) } +func Test_rehydrateMrSeqFiles_OpenStatusDirFails(t *testing.T) { + tempDir, _ := os.MkdirTemp("", "OpenStatusDirFails") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + from := t.TempDir() // does not contain the status subdir + to := t.TempDir() + + err := rehydrateMrSeqFilesForProblematicUpgrades(ctx, from, to, events) + require.NotNil(t, err, "expected error when opening missing status dir") + require.True(t, strings.Contains(err.Error(), "Failed to open status directory"), "unexpected error message: %s", err.Error()) +} + +func Test_rehydrateMrSeqFiles_ReadDirFails(t *testing.T) { + tempDir, _ := os.MkdirTemp("", "ReadDirFails") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + from := filepath.Join(tempDir, "from") + require.NoError(t, os.Mkdir(from, 0o755)) + to := filepath.Join(tempDir, "to") + require.NoError(t, os.Mkdir(to, 0o755)) + + // Create a *file* at "/status" so os.Open succeeds but ReadDir fails. + statusPath := filepath.Join(from, constants.StatusFileDirectory) + require.NoError(t, os.WriteFile(statusPath, []byte("not a directory"), 0o600)) + + err := rehydrateMrSeqFilesForProblematicUpgrades(ctx, from, to, events) + require.NotNil(t, err, "expected error when ReadDir fails") + require.True(t, strings.Contains(err.Error(), "could not read directory entries"), "unexpected error message: %s", err.Error()) +} + +func Test_rehydrateMrSeqFiles_IgnoresInvalidStatusFilename(t *testing.T) { + tempDir, _ := os.MkdirTemp("", "IgnoresInvalidStatusFilename") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + from := filepath.Join(tempDir, "from") + require.NoError(t, os.Mkdir(from, 0o755)) + to := filepath.Join(tempDir, "to") + require.NoError(t, os.Mkdir(to, 0o755)) + + // Proper status directory + require.NoError(t, os.MkdirAll(filepath.Join(from, constants.StatusFileDirectory), 0o755)) + + // Invalid filename (only two parts, missing seqNo). Should be ignored. + invalid := filepath.Join(from, constants.StatusFileDirectory, "alpha"+constants.StatusFileExtension) + require.NoError(t, os.WriteFile(invalid, []byte(""), 0o600)) + + err := rehydrateMrSeqFilesForProblematicUpgrades(ctx, from, to, events) + require.Nil(t, err, "unexpected error: %v", err) + + // No mrseq should be created for invalid filenames. + _, err = os.Stat(filepath.Join(to, "alpha"+constants.MrSeqFileExtension)) + require.True(t, os.IsNotExist(err), "alpha%s should not have been created", constants.MrSeqFileExtension) +} + +func Test_rehydrateMrSeqFiles_RehydrateMissingMrseq(t *testing.T) { + tempDir, _ := os.MkdirTemp("", "RehydrateMissingMrseq") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + from := filepath.Join(tempDir, "from") + require.NoError(t, os.Mkdir(from, 0o755)) + to := filepath.Join(tempDir, "to") + require.NoError(t, os.Mkdir(to, 0o755)) + + statusDir := filepath.Join(from, constants.StatusFileDirectory) + if err := os.MkdirAll(statusDir, 0o755); err != nil { + t.Fatalf("mkdir status: %v", err) + } + + // alpha.5.status → should create to/alpha.mrseq with "5" + alphaStatus := filepath.Join(statusDir, "alpha.5"+constants.StatusFileExtension) + require.NoError(t, os.WriteFile(alphaStatus, []byte(""), 0o600)) + + err := rehydrateMrSeqFilesForProblematicUpgrades(ctx, from, to, events) + require.Nil(t, err, "unexpected error: %v", err) + + mrseqPath := filepath.Join(to, "alpha"+constants.MrSeqFileExtension) + got := mustReadFile(t, mrseqPath) + require.Equal(t, "5", got, "mrseq content = %q, want %q", got, "5") +} + +func Test_rehydrateMrSeqFiles_UpdateExistingMrseqWhenHigherSeqFound(t *testing.T) { + tempDir, _ := os.MkdirTemp("", "UpdateExistingMrseqWhenHigherSeqFound") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + from := filepath.Join(tempDir, "from") + require.NoError(t, os.Mkdir(from, 0o755)) + to := filepath.Join(tempDir, "to") + require.NoError(t, os.Mkdir(to, 0o755)) + + statusDir := filepath.Join(from, constants.StatusFileDirectory) + require.NoError(t, os.MkdirAll(statusDir, 0o755)) + + // Existing mrseq=3 + mrseqPath := filepath.Join(to, "alpha"+constants.MrSeqFileExtension) + require.NoError(t, os.WriteFile(mrseqPath, []byte("3"), 0o600)) + + // Status reports seq=5 → should overwrite to "5" + alphaStatus := filepath.Join(statusDir, "alpha.5"+constants.StatusFileExtension) + require.NoError(t, os.WriteFile(alphaStatus, []byte(""), 0o600)) + + err := rehydrateMrSeqFilesForProblematicUpgrades(ctx, from, to, events) + require.Nil(t, err, "unexpected error: %v", err) + + got := mustReadFile(t, mrseqPath) + require.Equal(t, "5", got, "mrseq content = %q, want %q", got, "5") +} + +func Test_rehydrateMrSeqFiles_NoUpdateWhenExistingIsHigher(t *testing.T) { + tempDir, _ := os.MkdirTemp("", "NoUpdateWhenExistingIsHigher") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + from := filepath.Join(tempDir, "from") + require.NoError(t, os.Mkdir(from, 0o755)) + to := filepath.Join(tempDir, "to") + require.NoError(t, os.Mkdir(to, 0o755)) + + statusDir := filepath.Join(from, constants.StatusFileDirectory) + require.NoError(t, os.MkdirAll(statusDir, 0o755)) + + // Existing mrseq=7 + mrseqPath := filepath.Join(to, "alpha "+constants.MrSeqFileExtension) + require.NoError(t, os.WriteFile(mrseqPath, []byte("7"), 0o600)) + + // Status reports seq=5 → should NOT overwrite + alphaStatus := filepath.Join(statusDir, "alpha.5"+constants.StatusFileExtension) + require.NoError(t, os.WriteFile(alphaStatus, []byte(""), 0o600)) + + err := rehydrateMrSeqFilesForProblematicUpgrades(ctx, from, to, events) + require.Nil(t, err, "unexpected error: %v", err) + + got := mustReadFile(t, mrseqPath) + require.Equal(t, "7", got, "mrseq content = %q, want %q", got, "7") +} + +func Test_rehydrateMrSeqFiles_MultipleStatusFiles_TakesMax(t *testing.T) { + tempDir, _ := os.MkdirTemp("", "TakesMax") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + from := filepath.Join(tempDir, "from") + require.NoError(t, os.Mkdir(from, 0o755)) + to := filepath.Join(tempDir, "to") + require.NoError(t, os.Mkdir(to, 0o755)) + + statusDir := filepath.Join(from, constants.StatusFileDirectory) + require.NoError(t, os.MkdirAll(statusDir, 0o755)) + + // Both alpha.3.status and alpha.7.status — final mrseq should be "7" + for _, seq := range []int{3, 7} { + p := filepath.Join(statusDir, "alpha."+strconv.Itoa(seq)+constants.StatusFileExtension) + require.NoError(t, os.WriteFile(p, []byte(""), 0o600)) + } + + err := rehydrateMrSeqFilesForProblematicUpgrades(ctx, from, to, events) + require.Nil(t, err, "unexpected error: %v", err) + + got := mustReadFile(t, filepath.Join(to, "alpha"+constants.MrSeqFileExtension)) + require.Equal(t, "7", got, "mrseq content = %q, want %q", got, "7") +} + +func Test_rehydrateMrSeqFiles_ReadExistingMrseqFails(t *testing.T) { + tempDir, _ := os.MkdirTemp("", "ReadExistingMrseqFails") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + from := filepath.Join(tempDir, "from") + require.NoError(t, os.Mkdir(from, 0o755)) + to := filepath.Join(tempDir, "to") + require.NoError(t, os.Mkdir(to, 0o755)) + + statusDir := filepath.Join(from, constants.StatusFileDirectory) + require.NoError(t, os.MkdirAll(statusDir, 0o755)) + + // Make a directory at the mrseq path so ReadFile fails + mrseqPath := filepath.Join(to, "alpha"+constants.MrSeqFileExtension) + require.NoError(t, os.MkdirAll(mrseqPath, 0o755)) + + // Now create a status that would try to read/compare + alphaStatus := filepath.Join(statusDir, "alpha.5"+constants.StatusFileExtension) + require.NoError(t, os.WriteFile(alphaStatus, []byte(""), 0o600)) + + err := rehydrateMrSeqFilesForProblematicUpgrades(ctx, from, to, events) + require.NotNil(t, err, "expected error due to unreadable mrseq") + require.True(t, strings.Contains(err.Error(), "Could not read file"), "Unexpected error: %v", err.Error()) +} + // setEnvs is a helper to seed the env for each scenario. func setEnvs(t *testing.T, to, curr, from, dir string) { t.Helper() @@ -1197,3 +1437,11 @@ func setEnvs(t *testing.T, to, curr, from, dir string) { t.Setenv(constants.ExtensionVersionUpdatingFromEnvName, from) t.Setenv(constants.ExtensionPathEnvName, dir) } + +func mustReadFile(t *testing.T, p string) string { + b, err := os.ReadFile(p) + if err != nil { + t.Fatalf("read %s: %v", p, err) + } + return string(b) +} From 72b7c50dbe0a9776374f4bdeb8032a8b9dd220b2 Mon Sep 17 00:00:00 2001 From: Joseph Calev Date: Tue, 20 Jan 2026 13:09:51 -0800 Subject: [PATCH 7/7] PR fix --- internal/cmds/cmds.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/cmds/cmds.go b/internal/cmds/cmds.go index dac4b2d..57c731c 100755 --- a/internal/cmds/cmds.go +++ b/internal/cmds/cmds.go @@ -484,7 +484,7 @@ func determineUpgradeVersionDirectories(ctx *log.Context, extensionEvents *exten upgradeToVersionDirectory = strings.ReplaceAll(extensionDirectory, upgradeFromVersion, upgradeToVersion) } - msg := fmt.Sprintf("determineUpgradeVersionDirectories: move from='%s' to='%s'", upgradeToVersionDirectory, upgradeFromVersionDirectory) + msg := fmt.Sprintf("determineUpgradeVersionDirectories: move from='%s' to='%s'", upgradeFromVersionDirectory, upgradeToVersionDirectory) ctx.Log("message", msg) extensionEvents.LogInformationalEvent("determineUpgradeVersions", msg)