diff --git a/internal/cmds/cmds.go b/internal/cmds/cmds.go index 864c001..57c731c 100755 --- a/internal/cmds/cmds.go +++ b/internal/cmds/cmds.go @@ -12,6 +12,7 @@ import ( "io" "os" "path/filepath" + "slices" "strconv" "strings" "time" @@ -75,6 +76,9 @@ var ( RunCmd = runCmd DataDir = constants.DataDir + // Used by unit tests to mock out executing the command + ExecCmdInDir = exec.ExecCmdInDir + ErrAlreadyProcessed = errors.New("the script configuration has already been processed, will not run again") ) @@ -85,17 +89,23 @@ func update(ctx *log.Context, h types.HandlerEnvironment, report *types.RunComma return "", "", err, exitCode } - err = rehydrateMrSeqFilesForProblematicUpgrades(ctx, h, extensionEvents) - if err != nil { - // If we fail on update, then there's a risk we could re-execute the customer's script. Don't take that chance. - // By failing Update, the extension goal state will fail. WALA will try us again on the next goal state. - ctx.Log("event", "Unable to rehydrate mrseq files") - return "", "", err, constants.ExitCode_CouldNotRehydrateMrSeq + // Figure out the directories from which and to where we're upgrading. We cannot entirely rely on the environment variables from the Guest Agent + upgradeFromVersionDirectory, upgradeToVersionDirectory, upgradeFromVersion := determineUpgradeVersionDirectories(ctx, extensionEvents) + + if compareVersions(constants.FirstVersionNoRehydration, upgradeFromVersion) > 0 { + // Rehydrate any mrseq files from the corresponding status file. + err = rehydrateMrSeqFilesForProblematicUpgrades(ctx, upgradeFromVersionDirectory, upgradeToVersionDirectory, extensionEvents) + if err != nil { + // If we fail on update, then there's a risk we could re-execute the customer's script. Don't take that chance. + // By failing Update, the extension goal state will fail. WALA will try us again on the next goal state. + ctx.Log("event", "Unable to rehydrate mrseq files") + return "", "", err, constants.ExitCode_CouldNotRehydrateMrSeq + } } // Copy any .mrseq or .status files -Most Recently executed Sequence number files and status files for Run Commands from old version to new version. // This is necessary to prevent rerunning of already executed Run Commands after upgrade of extension version, and also return their statuses. - copyError := CopyStateForUpdate(ctx, extensionEvents) + copyError := CopyStateForUpdate(ctx, upgradeFromVersionDirectory, upgradeToVersionDirectory, extensionEvents) if copyError != nil { return "", "", errors.Wrap(copyError, "Migrating *.mrseq or .status files failed during update."), constants.ExitCode_CopyStateForUpdateFailed } @@ -417,15 +427,15 @@ func resetSeqNum(ctx log.Logger, mrseqPath string, extensionEvents *extensioneve } // Copy state of the extension from old version to new version during update (.mrseq files, .status files) -func CopyStateForUpdate(ctx log.Logger, extensionEvents *extensionevents.ExtensionEventManager) error { +func CopyStateForUpdate(ctx log.Logger, upgradeFromVersionDirectory string, upgradeToVersionDirectory string, extensionEvents *extensionevents.ExtensionEventManager) error { // Copy .mrseq files (Most Recently executed Sequence number) that helps determine whether a sequence number of Run Command has been previously executed or not. - mrseqFilesNameList, mrseqFileCopyErr := copyFiles(ctx, constants.MrSeqFileExtension, "", extensionEvents) + mrseqFilesNameList, mrseqFileCopyErr := copyFiles(ctx, constants.MrSeqFileExtension, "", upgradeFromVersionDirectory, upgradeToVersionDirectory, extensionEvents) if mrseqFileCopyErr != nil { return mrseqFileCopyErr } // Copy .status files of already executed sequence numbers - _, statusFileCopyErr := copyFiles(ctx, ".status", constants.StatusFileDirectory, extensionEvents) + _, statusFileCopyErr := copyFiles(ctx, ".status", constants.StatusFileDirectory, upgradeFromVersionDirectory, upgradeToVersionDirectory, extensionEvents) if statusFileCopyErr != nil { return statusFileCopyErr } @@ -440,41 +450,102 @@ func CopyStateForUpdate(ctx log.Logger, extensionEvents *extensionevents.Extensi return nil } -func rehydrateMrSeqFilesForProblematicUpgrades(ctx *log.Context, h types.HandlerEnvironment, extensionEvents *extensionevents.ExtensionEventManager) error { - // First, determine whether we're upgrading from a 'problematic' version, defined as one - // where we mistakenly deleted the mrseq files in the Disable call - newExtensionVersion := os.Getenv(constants.ExtensionVersionEnvName) - oldExtensionVersion := os.Getenv(constants.ExtensionVersionUpdatingFromEnvName) - newExtensionDirectory := os.Getenv(constants.ExtensionPathEnvName) - oldExtensionDirectory := strings.ReplaceAll(newExtensionDirectory, newExtensionVersion, oldExtensionVersion) - - // The following are problematic versions: - // Production: 1.3.17 - // Test: 1.8.0, 1.9.0 - isProblematicVersion := false - isTestExtension := strings.Contains(oldExtensionDirectory, constants.RunCommandTestExtensionName) - if isTestExtension { - isProblematicVersion = (oldExtensionVersion == constants.FirstTestVersionThatDeletesMrSeqFiles || oldExtensionVersion == constants.SecondTestVersionThatDeletesMrSeqFiles) +func determineUpgradeVersionDirectories(ctx *log.Context, extensionEvents *extensionevents.ExtensionEventManager) (upgradeFromVersionDirectory string, upgradeToVersionDirectory string, upgradeFromVersion string) { + // These two environment variables will tell us the extension versions involved, but won't actually tell us + // the from/to versions + upgradeToVersion := os.Getenv(constants.VersionEnvName) + extensionVersionValue := os.Getenv(constants.ExtensionVersionEnvName) + updatingFromVersionValue := os.Getenv(constants.ExtensionVersionUpdatingFromEnvName) + + // In some WALA versions, there is a bug where on downgrade it will send the same value for upgradeToVersion and upgradeFromVersion + // Newer versions will send the correct value for upgradeToVersion + // Therefore: + // Action | Old WALA | New WALA + // ---------------------------------------------------| ------------------------------------- + // Downgrade | upgradeToVersion == upgradeFromVersion | upgradeToVersion < upgradeFromVersion + // ------------------------------------------------------------------------------------------ + // Upgrade | upgradeToVersion > upgradeFromVersion | upgradeToVersion > upgradeFromVersion + // ------------------------------------------------------------------------------------------ + if upgradeToVersion == updatingFromVersionValue { + // This is a downgrade. We therefore need to use the extension version + upgradeFromVersion = extensionVersionValue } else { - isProblematicVersion = (oldExtensionVersion == constants.ProductionVersionThatDeletesMrSeqFiles) + // This is an upgrade on the old WALA or an upgrade or downgrade on the new WALA + upgradeFromVersion = updatingFromVersionValue } - if isProblematicVersion { - message := fmt.Sprintf("Rehydrating mrseq files deleted by version '%s' using status files", oldExtensionVersion) - ctx.Log("message", message) - extensionEvents.LogInformationalEvent("rehydratemrseq", message) - return doRehydrateMrSeqFilesForProblematicUpgrades(ctx, oldExtensionDirectory, newExtensionDirectory, extensionEvents) + // Determine the corresponding extension directories + extensionDirectory := os.Getenv(constants.ExtensionPathEnvName) + if strings.Contains(extensionDirectory, upgradeToVersion) { + upgradeToVersionDirectory = extensionDirectory + upgradeFromVersionDirectory = strings.ReplaceAll(extensionDirectory, upgradeToVersion, upgradeFromVersion) } else { - message := fmt.Sprintf("Previous extension version '%s' does not require mrseq hydration", oldExtensionVersion) - ctx.Log("message", message) - extensionEvents.LogInformationalEvent("rehydratemrseq", message) + upgradeFromVersionDirectory = extensionDirectory + upgradeToVersionDirectory = strings.ReplaceAll(extensionDirectory, upgradeFromVersion, upgradeToVersion) } - return nil + msg := fmt.Sprintf("determineUpgradeVersionDirectories: move from='%s' to='%s'", upgradeFromVersionDirectory, upgradeToVersionDirectory) + ctx.Log("message", msg) + extensionEvents.LogInformationalEvent("determineUpgradeVersions", msg) + + return upgradeFromVersionDirectory, upgradeToVersionDirectory, upgradeFromVersion +} + +// compareVersions compares two dotted version strings (e.g., "2.1", "2.1.0", "2.1.0.3"). +// Returns: +1 if a>b, -1 if a bParts[i] { + return 1 + } + if aParts[i] < bParts[i] { + return -1 + } + } + return 0 } -func doRehydrateMrSeqFilesForProblematicUpgrades(ctx *log.Context, oldExtensionDirectory string, newExtensionDirectory string, extensionEvents *extensionevents.ExtensionEventManager) error { - oldExtensionStatusDirectory := filepath.Join(oldExtensionDirectory, constants.StatusFileDirectory) +// splitVersion converts "x.y.z.t" → []int{ x, y, z, t } (non-numeric parts treated as 0). +func splitVersion(v string) []int { + parts := strings.Split(v, ".") + out := make([]int, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + n, err := strconv.Atoi(p) + if err != nil { + n = 0 + } + out = append(out, n) + } + return out +} + +func padTo(in []int, size int) []int { + + if len(in) >= size { + return in[:size] + } + n := size - len(in) + + // Ensure capacity for the extra n elements without reallocating. + out := slices.Grow(in, n) + + // Extend length to size by appending n zero-values. + out = append(out, make([]int, n)...) + return out + +} + +func rehydrateMrSeqFilesForProblematicUpgrades(ctx *log.Context, updateFromVersionDirectory string, updateToVersionDirectory string, extensionEvents *extensionevents.ExtensionEventManager) error { + oldExtensionStatusDirectory := filepath.Join(updateFromVersionDirectory, constants.StatusFileDirectory) extensionStatusDirectoryFDRef, err := os.Open(oldExtensionStatusDirectory) if err != nil { @@ -489,7 +560,7 @@ func doRehydrateMrSeqFilesForProblematicUpgrades(ctx *log.Context, oldExtensionD // If we find any status files missing their corresponding mrseq, then rehydrate it by taking the seqNo from the status file name statusFiles, err := extensionStatusDirectoryFDRef.ReadDir(0) if err != nil { - errMessage := fmt.Sprintf("could not read directory entries from status directory %s", oldExtensionDirectory) + errMessage := fmt.Sprintf("could not read directory entries from status directory %s", updateFromVersionDirectory) ctx.Log("message", errMessage) extensionEvents.LogErrorEvent("rehydratemrseq", errMessage) return errors.Wrap(err, errMessage) @@ -507,7 +578,7 @@ func doRehydrateMrSeqFilesForProblematicUpgrades(ctx *log.Context, oldExtensionD seqNo := parts[1] seqNoAsInt, _ := strconv.Atoi(seqNo) mrSeqFileName := extensionName + constants.MrSeqFileExtension - mrSeqFilePath := filepath.Join(newExtensionDirectory, mrSeqFileName) + mrSeqFilePath := filepath.Join(updateToVersionDirectory, mrSeqFileName) _, err = os.Stat(mrSeqFilePath) if err != nil { @@ -563,45 +634,39 @@ func doRehydrateMrSeqFilesForProblematicUpgrades(ctx *log.Context, oldExtensionD } // Copy files like *.mrseq (Most Recently executed Sequence number), .status files from old extension version to new extension version during update. -func copyFiles(ctx log.Logger, fileExtensionSuffix string, extensionSubdirectory string, extensionEvents *extensionevents.ExtensionEventManager) (*list.List, error) { +func copyFiles(ctx log.Logger, fileExtensionSuffix string, extensionSubdirectory string, upgradeFromVersionDirectory string, upgradeToVersionDirectory string, extensionEvents *extensionevents.ExtensionEventManager) (*list.List, error) { - newExtensionVersion := os.Getenv(constants.ExtensionVersionEnvName) - oldExtensionVersion := os.Getenv(constants.ExtensionVersionUpdatingFromEnvName) - - message := fmt.Sprintf("Migrating '%s' files from extension version '%s' to '%s'", fileExtensionSuffix, oldExtensionVersion, newExtensionVersion) + message := fmt.Sprintf("Migrating '%s' files from '%s' to '%s'", fileExtensionSuffix, upgradeFromVersionDirectory, upgradeToVersionDirectory) ctx.Log("message", message) extensionEvents.LogInformationalEvent("copyfiles", message) - newExtensionDirectory := os.Getenv(constants.ExtensionPathEnvName) - oldExtensionDirectory := strings.ReplaceAll(newExtensionDirectory, newExtensionVersion, oldExtensionVersion) - // Append subdirectory like "status" under extension folder if provided. if extensionSubdirectory != "" { - newExtensionDirectory = filepath.Join(newExtensionDirectory, extensionSubdirectory) - oldExtensionDirectory = filepath.Join(oldExtensionDirectory, extensionSubdirectory) + upgradeToVersionDirectory = filepath.Join(upgradeToVersionDirectory, extensionSubdirectory) + upgradeFromVersionDirectory = filepath.Join(upgradeFromVersionDirectory, extensionSubdirectory) // Create subdirectory like "status" directory if it does not exist - _, err := os.Open(newExtensionDirectory) + _, err := os.Open(upgradeToVersionDirectory) if err != nil { - errr := os.Mkdir(newExtensionDirectory, 0700) + errr := os.Mkdir(upgradeToVersionDirectory, 0700) if errr != nil { - errMessage := fmt.Sprintf("Failed to create directory '%s'", newExtensionDirectory) + errMessage := fmt.Sprintf("Failed to create directory '%s'", upgradeToVersionDirectory) extensionEvents.LogErrorEvent("copyfiles", errMessage) return nil, errors.Wrap(errr, errMessage) } } } - if oldExtensionDirectory == "" || newExtensionDirectory == "" { + if upgradeFromVersionDirectory == "" || upgradeToVersionDirectory == "" { errMessage := "oldExtesionDirectory or newExtensionDirectory is empty" extensionEvents.LogErrorEvent("copyfiles", errMessage) return nil, errors.New(errMessage) } // Check if the directory exists - sourceDirectoryFDRef, err := os.Open(oldExtensionDirectory) + sourceDirectoryFDRef, err := os.Open(upgradeFromVersionDirectory) if err != nil { - errMessage := fmt.Sprintf("could not open sourceDirectory %s", oldExtensionDirectory) + errMessage := fmt.Sprintf("could not open sourceDirectory %s", upgradeFromVersionDirectory) ctx.Log("message", errMessage) extensionEvents.LogErrorEvent("copyfiles", errMessage) return nil, errors.Wrap(err, errMessage) @@ -609,7 +674,7 @@ func copyFiles(ctx log.Logger, fileExtensionSuffix string, extensionSubdirectory directoryEntries, err := sourceDirectoryFDRef.ReadDir(0) if err != nil { - errMessage := fmt.Sprintf("could not read directory entries from sourceDirectory %s", oldExtensionDirectory) + errMessage := fmt.Sprintf("could not read directory entries from sourceDirectory %s", upgradeFromVersionDirectory) ctx.Log("message", errMessage) extensionEvents.LogErrorEvent("copyfiles", errMessage) return nil, errors.Wrap(err, errMessage) @@ -622,8 +687,8 @@ func copyFiles(ctx log.Logger, fileExtensionSuffix string, extensionSubdirectory fileName := dirEntry.Name() if strings.HasSuffix(fileName, fileExtensionSuffix) { - sourceFileFullPath := filepath.Join(oldExtensionDirectory, fileName) - destinationFileFullPath := filepath.Join(newExtensionDirectory, fileName) + sourceFileFullPath := filepath.Join(upgradeFromVersionDirectory, fileName) + destinationFileFullPath := filepath.Join(upgradeToVersionDirectory, fileName) sourceFile, sourceFileOpenError := os.Open(sourceFileFullPath) if sourceFileOpenError != nil { @@ -660,7 +725,7 @@ func copyFiles(ctx log.Logger, fileExtensionSuffix string, extensionSubdirectory } } - message = fmt.Sprintf("Migrated %d '%s' files from extension version '%s' to '%s'", numberOfFilesMigrated, fileExtensionSuffix, oldExtensionVersion, newExtensionVersion) + message = fmt.Sprintf("Migrated %d '%s' files from extension version '%s' to '%s'", numberOfFilesMigrated, fileExtensionSuffix, upgradeFromVersionDirectory, upgradeToVersionDirectory) ctx.Log("message", message) extensionEvents.LogInformationalEvent("copyfiles", message) @@ -867,7 +932,7 @@ func runCmd(ctx *log.Context, dir string, scriptFilePath string, cfg *handlerset defer pid.DeleteCurrentPidAndStartTime(metadata.PidFilePath) begin := time.Now() - err, exitCode = exec.ExecCmdInDir(ctx, scriptFilePath, dir, cfg) + err, exitCode = ExecCmdInDir(ctx, scriptFilePath, dir, cfg) elapsed := time.Since(begin) isSuccess := err == nil diff --git a/internal/cmds/cmds_test.go b/internal/cmds/cmds_test.go index 205c5ef..72f4e76 100755 --- a/internal/cmds/cmds_test.go +++ b/internal/cmds/cmds_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "reflect" "strconv" "strings" "testing" @@ -29,9 +30,9 @@ func Test_CopyMrseqFiles_MrseqFilesAreCopied(t *testing.T) { currentExtensionVersionDirectory := "Microsoft.CPlat.Core.RunCommandHandlerLinux-1.3.8" os.Setenv(constants.ExtensionPathEnvName, currentExtensionVersionDirectory) os.Setenv(constants.ExtensionVersionUpdatingFromEnvName, "1.3.7") - os.Setenv(constants.ExtensionVersionEnvName, "1.3.8") + os.Setenv(constants.VersionEnvName, "1.3.8") - currentVersion := os.Getenv(constants.ExtensionVersionEnvName) + currentVersion := os.Getenv(constants.VersionEnvName) previousVersion := os.Getenv(constants.ExtensionVersionUpdatingFromEnvName) previousExtensionVersionDirectory := strings.ReplaceAll(currentExtensionVersionDirectory, currentVersion, previousVersion) @@ -85,7 +86,7 @@ func Test_CopyMrseqFiles_MrseqFilesAreCopied(t *testing.T) { extensionLogger := logging.New(nil) extensionEventManager := extensionevents.New(extensionLogger, &handlerEnvironment) - err = CopyStateForUpdate(log.NewContext(log.NewNopLogger()), extensionEventManager) + err = CopyStateForUpdate(log.NewContext(log.NewNopLogger()), previousExtensionVersionDirectory, currentExtensionVersionDirectory, extensionEventManager) require.Nil(t, err) files, _ = ioutil.ReadDir(currentExtensionVersionDirectory) @@ -210,7 +211,7 @@ func Test_update_e2e_cmd(t *testing.T) { // We start on the old version os.Setenv(constants.ExtensionPathEnvName, oldVersionDirectory) - os.Setenv(constants.ExtensionVersionEnvName, "1.3.8") + os.Setenv(constants.VersionEnvName, "1.3.8") // Create two extensions enable_extension(t, fakeEnv, oldVersionDirectory, "happyChipmunk", true, 0) @@ -222,7 +223,7 @@ func Test_update_e2e_cmd(t *testing.T) { disable_extension(t, fakeEnv, oldVersionDirectory, "crazyChipmunk") // Step 2: WALA will call update - os.Setenv(constants.ExtensionVersionEnvName, "1.3.9") + os.Setenv(constants.VersionEnvName, "1.3.9") os.Setenv(constants.ExtensionPathEnvName, newVersionDirectory) os.Setenv(constants.ExtensionVersionUpdatingFromEnvName, "1.3.8") update_handler_env(&fakeEnv, newStatusPath, newVersionDirectory, newEventsPath) @@ -239,6 +240,69 @@ func Test_update_e2e_cmd(t *testing.T) { enable_extension(t, fakeEnv, newVersionDirectory, "crazyChipmunk", false, 0) } +func Test_update_e23_non_problematic_version(t *testing.T) { + tempDir, _ := os.MkdirTemp("", "deletecmd") + defer os.RemoveAll(tempDir) + + DataDir, _ = os.MkdirTemp("", "datadir") + defer os.RemoveAll(DataDir) + + oldVersionDirectory := filepath.Join(tempDir, "Microsoft.CPlat.Core.RunCommandHandlerLinux-1.3.26") + newVersionDirectory := filepath.Join(tempDir, "Microsoft.CPlat.Core.RunCommandHandlerLinux-1.3.27") + err := os.Mkdir(oldVersionDirectory, 0755) + require.Nil(t, err, "Could not create old version subdirectory") + err = os.Mkdir(newVersionDirectory, 0755) + require.Nil(t, err, "Could not create new version subdirectory") + oldStatusPath := create_folder(t, oldVersionDirectory, constants.StatusFileDirectory) + newStatusPath := create_folder(t, newVersionDirectory, constants.StatusFileDirectory) + oldEventsPath := create_folder(t, oldVersionDirectory, constants.ExtensionEventsDirectory) + newEventsPath := create_folder(t, newVersionDirectory, constants.ExtensionEventsDirectory) + + fakeEnv := types.HandlerEnvironment{} + update_handler_env(&fakeEnv, oldStatusPath, oldVersionDirectory, oldEventsPath) + + // We start on the old version + os.Setenv(constants.ExtensionPathEnvName, oldVersionDirectory) + os.Setenv(constants.VersionEnvName, "1.3.26") + + // Create three extensions + enable_extension(t, fakeEnv, oldVersionDirectory, "happyChipmunk", true, 0) + enable_extension(t, fakeEnv, oldVersionDirectory, "crazyChipmunk", true, 0) + enable_extension(t, fakeEnv, oldVersionDirectory, "stubbornChipmunk", true, 0) + + // Run one of them again to obtain multiple status files + enable_extension(t, fakeEnv, oldVersionDirectory, "happyChipmunk", true, 1) + + // Now, pretend that the extension was updated + // Step 1: WALA calls Disable on our two extensions + disable_extension(t, fakeEnv, oldVersionDirectory, "happyChipmunk") + disable_extension(t, fakeEnv, oldVersionDirectory, "crazyChipmunk") + disable_extension(t, fakeEnv, oldVersionDirectory, "stubbornChipmunk") + + // Step 2: WALA will call update + os.Setenv(constants.VersionEnvName, "1.3.27") + os.Setenv(constants.ExtensionPathEnvName, newVersionDirectory) + os.Setenv(constants.ExtensionVersionUpdatingFromEnvName, "1.3.26") + update_handler_env(&fakeEnv, newStatusPath, newVersionDirectory, newEventsPath) + update_handler(t, fakeEnv, tempDir) + + // Now, WALA will uninstall the old extension + uninstall_handler(t, fakeEnv, tempDir) + + // Then, WALA will install the new extension + install_handler(t, fakeEnv, tempDir) + + // Now call enable and verify we did NOT re-execute the script + enable_extension(t, fakeEnv, newVersionDirectory, "happyChipmunk", false, 1) + enable_extension(t, fakeEnv, newVersionDirectory, "crazyChipmunk", false, 0) + enable_extension(t, fakeEnv, newVersionDirectory, "stubbornChipmunk", false, 0) + + // Run them again with a higher seqNo to ensure they're now executed + enable_extension(t, fakeEnv, newVersionDirectory, "happyChipmunk", true, 2) + enable_extension(t, fakeEnv, newVersionDirectory, "crazyChipmunk", true, 1) + enable_extension(t, fakeEnv, newVersionDirectory, "stubbornChipmunk", true, 1) +} + func Test_udpate_e2e_problematic_version(t *testing.T) { tempDir, _ := os.MkdirTemp("", "deletecmd") defer os.RemoveAll(tempDir) @@ -262,7 +326,7 @@ func Test_udpate_e2e_problematic_version(t *testing.T) { // We start on the old version os.Setenv(constants.ExtensionPathEnvName, oldVersionDirectory) - os.Setenv(constants.ExtensionVersionEnvName, "1.3.17") + os.Setenv(constants.VersionEnvName, "1.3.17") // Create three extensions enable_extension(t, fakeEnv, oldVersionDirectory, "happyChipmunk", true, 0) @@ -289,7 +353,7 @@ func Test_udpate_e2e_problematic_version(t *testing.T) { os.WriteFile(filepath.Join(oldStatusPath, "this.is.a.bad.chipmunk.0.status"), []byte("0"), os.FileMode(0600)) // Step 2: WALA will call update - os.Setenv(constants.ExtensionVersionEnvName, "1.3.18") + os.Setenv(constants.VersionEnvName, "1.3.18") os.Setenv(constants.ExtensionPathEnvName, newVersionDirectory) os.Setenv(constants.ExtensionVersionUpdatingFromEnvName, "1.3.17") update_handler_env(&fakeEnv, newStatusPath, newVersionDirectory, newEventsPath) @@ -435,6 +499,10 @@ func Test_runCmd_success(t *testing.T) { require.Nil(t, err) defer os.RemoveAll(dir) + // Ensure that the script succeeds + ExecCmdInDir = func(ctx *log.Context, scriptFilePath, workdir string, cfg *handlersettings.HandlerSettings) (error, int) { + return nil, 0 + } metadata := types.NewRCMetadata("extName", 0, constants.DownloadFolder, DataDir) err, exitCode := runCmd(log.NewContext(log.NewNopLogger()), dir, "", &handlersettings.HandlerSettings{ PublicSettings: handlersettings.PublicSettings{Source: &handlersettings.ScriptSource{Script: script}}, @@ -442,12 +510,6 @@ func Test_runCmd_success(t *testing.T) { require.Nil(t, err, "command should run successfully") require.Equal(t, constants.ExitCode_Okay, exitCode) - // check stdout stderr files - _, err = os.Stat(filepath.Join(dir, "stdout")) - require.Nil(t, err, "stdout should exist") - _, err = os.Stat(filepath.Join(dir, "stderr")) - require.Nil(t, err, "stderr should exist") - // Check embedded script if saved to file _, err = os.Stat(filepath.Join(dir, "script.sh")) require.Nil(t, err, "script.sh should exist") @@ -461,6 +523,11 @@ func Test_runCmd_fail(t *testing.T) { require.Nil(t, err) defer os.RemoveAll(dir) + // Ensure that the script fails + ExecCmdInDir = func(ctx *log.Context, scriptFilePath, workdir string, cfg *handlersettings.HandlerSettings) (error, int) { + return errors.New("the chipmunks have risen in revolt"), 42 + } + metadata := types.NewRCMetadata("extName", 0, constants.DownloadFolder, DataDir) err, exitCode := runCmd(log.NewContext(log.NewNopLogger()), dir, "", &handlersettings.HandlerSettings{ PublicSettings: handlersettings.PublicSettings{Source: &handlersettings.ScriptSource{Script: "non-existing-cmd"}}, @@ -692,12 +759,17 @@ func Test_TreatFailureAsDeploymentFailureIsTrue_Fails(t *testing.T) { require.Nil(t, err) defer os.RemoveAll(dir) + // Ensure that the script fails + ExecCmdInDir = func(ctx *log.Context, scriptFilePath, workdir string, cfg *handlersettings.HandlerSettings) (error, int) { + return errors.New("the chipmunks do not like the script"), 127 + } + metadata := types.NewRCMetadata("extName", 0, constants.DownloadFolder, DataDir) err, exitCode := runCmd(log.NewContext(log.NewNopLogger()), dir, "", &handlersettings.HandlerSettings{ PublicSettings: handlersettings.PublicSettings{Source: &handlersettings.ScriptSource{Script: script}, TreatFailureAsDeploymentFailure: true}, }, metadata) require.NotNil(t, err) - require.Contains(t, err.Error(), "failed to execute command: command terminated with exit status=127") + require.Contains(t, err.Error(), "failed to execute command: the chipmunks do not like the script") require.NotEqual(t, constants.ExitCode_Okay, exitCode) } @@ -711,6 +783,11 @@ func Test_TreatFailureAsDeploymentFailureIsTrue_SimpleScriptSucceeds(t *testing. require.Nil(t, err) defer os.RemoveAll(dir) + // Ensure that the script succeeds + ExecCmdInDir = func(ctx *log.Context, scriptFilePath, workdir string, cfg *handlersettings.HandlerSettings) (error, int) { + return nil, 0 + } + metadata := types.NewRCMetadata("extName", 0, constants.DownloadFolder, DataDir) err, exitCode := runCmd(log.NewContext(log.NewNopLogger()), dir, "", &handlersettings.HandlerSettings{ PublicSettings: handlersettings.PublicSettings{Source: &handlersettings.ScriptSource{Script: script}, TreatFailureAsDeploymentFailure: false}, @@ -718,3 +795,653 @@ func Test_TreatFailureAsDeploymentFailureIsTrue_SimpleScriptSucceeds(t *testing. require.Nil(t, err) require.Equal(t, constants.ExitCode_Okay, exitCode) } + +func TestPadTo(t *testing.T) { + tests := []struct { + name string + in []int + size int + expected []int + }{ + { + name: "Input longer than size", + in: []int{1, 2, 3, 4}, + size: 2, + expected: []int{1, 2}, + }, + { + name: "Input equal to size", + in: []int{1, 2, 3}, + size: 3, + expected: []int{1, 2, 3}, + }, + { + name: "Input shorter than size", + in: []int{1, 2}, + size: 5, + expected: []int{1, 2, 0, 0, 0}, + }, + { + name: "Empty input", + in: []int{}, + size: 3, + expected: []int{0, 0, 0}, + }, + { + name: "Size zero", + in: []int{1, 2, 3}, + size: 0, + expected: []int{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := padTo(tt.in, tt.size) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("padTo(%v, %d) = %v; expected %v", tt.in, tt.size, result, tt.expected) + } + }) + } +} + +func TestSplitVersion(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + in string + expected []int + }{ + { + name: "simple-3-parts", + in: "1.2.3", + expected: []int{1, 2, 3}, + }, + { + name: "single-part", + in: "42", + expected: []int{42}, + }, + { + name: "leading-zeros", + in: "001.0002.00003", + expected: []int{1, 2, 3}, + }, + { + name: "spaces-around-parts", + in: " 1 . 2 . 3 ", + expected: []int{1, 2, 3}, + }, + { + name: "non-numeric-alpha", + in: "1.a.3", + expected: []int{1, 0, 3}, + }, + { + name: "non-numeric-mixed", + in: "1.2beta.3", + expected: []int{1, 0, 3}, + }, + { + name: "empty-string", + in: "", + // strings.Split("", ".") == []string{""} → p=="" → append 0 + expected: []int{0}, + }, + { + name: "consecutive-dots-empty-components", + in: "1..3....5", + // empty parts become 0 + expected: []int{1, 0, 3, 0, 0, 0, 5}, + }, + { + name: "trailing-dot", + in: "1.2.", + expected: []int{1, 2, 0}, + }, + { + name: "leading-dot", + in: ".2.3", + expected: []int{0, 2, 3}, + }, + { + name: "very-large-number", + in: "2147483647.0", + // Note: Go int is platform-dependent; still valid parsing. + expected: []int{2147483647, 0}, + }, + { + name: "zeros-only", + in: "0.0.0", + expected: []int{0, 0, 0}, + }, + { + name: "whitespace-only-component", + in: "1. .3", + // TrimSpace makes middle part "", thus 0 + expected: []int{1, 0, 3}, + }, + { + name: "unicode-digits-are-not-ASCII-digits", + in: "1.2", // Note: first char is full-width '1' (U+FF11) → non-ASCII → 0 + expected: []int{0, 2}, + }, + { + name: "dash-negative-like", + in: "1.-2.3", + // '-' makes component non-numeric → 0 + expected: []int{1, -2, 3}, + }, + { + name: "plus-sign", + in: "+1.2", + expected: []int{1, 2}, + }, + { + name: "long-many-parts", + in: "1.2.3.4.5.6.7.8.9", + expected: []int{1, 2, 3, 4, 5, 6, 7, 8, 9}, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := splitVersion(tt.in) + if !reflect.DeepEqual(got, tt.expected) { + t.Fatalf("splitVersion(%q) = %v, want %v", tt.in, got, tt.expected) + } + }) + } +} + +func TestCompareVersions(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + a string + b string + expected int + }{ + { + name: "equal-simple", + a: "1.2.3", + b: "1.2.3", + expected: 0, + }, + { + name: "equal-with-extra-zeros", + a: "1.2.3.0", + b: "1.2.3", + expected: 0, + }, + { + name: "a-greater-last-segment", + a: "1.2.3.4", + b: "1.2.3.3", + expected: 1, + }, + { + name: "b-greater-last-segment", + a: "1.2.3.3", + b: "1.2.3.4", + expected: -1, + }, + { + name: "a-greater-first-segment", + a: "2.0.0", + b: "1.9.9", + expected: 1, + }, + { + name: "b-greater-first-segment", + a: "1.9.9", + b: "2.0.0", + expected: -1, + }, + { + name: "normalize-length-a-shorter", + a: "1.2", + b: "1.2.0.1", + expected: -1, + }, + { + name: "normalize-length-b-shorter", + a: "1.2.0.1", + b: "1.2", + expected: 1, + }, + { + name: "leading-zeros-equal", + a: "01.002.0003", + b: "1.2.3", + expected: 0, + }, + { + name: "non-numeric-in-a", + a: "1.alpha.3", + b: "1.0.3", + expected: 0, // alpha → 0 + }, + { + name: "non-numeric-in-b", + a: "1.2.3", + b: "1.beta.3", + expected: 1, // beta → 0, so a > b + }, + { + name: "empty-strings", + a: "", + b: "", + expected: 0, + }, + { + name: "empty-vs-non-empty", + a: "", + b: "0.0.0.1", + expected: -1, + }, + { + name: "longer-than-4-segments-ignored-after-4", + a: "1.2.3.4.999", + b: "1.2.3.4.0", + expected: 0, // only first 4 segments matter + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := compareVersions(tt.a, tt.b) + if got != tt.expected { + t.Fatalf("compareVersions(%q, %q) = %d, want %d", tt.a, tt.b, got, tt.expected) + } + }) + } +} + +func Test_determineUpgradeVersionDirectories_Upgrade_PathContainsToVersion(t *testing.T) { + /* + Upgrade: updatingFrom != toVersion + Path contains the toVersion -> replace toVersion with fromVersion for the "from" dir + */ + to := "2.5.0" + from := "2.4.3" + curr := "2.5.0" // current extension version value (doesn't affect upgrade/downgrade decision) + dir := "/var/lib/waagent/My.Ext/" + to + "/" + + setEnvs(t, to, curr, from, dir) + + tempDir, _ := os.MkdirTemp("", "upgradetoversion") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + fromDir, toDir, gotFromVersion := determineUpgradeVersionDirectories(ctx, events) + + require.Equal(t, from, gotFromVersion, "upgradeFromVersion should be updating-from value on upgrade") + require.Equal(t, dir, toDir, "when path contains toVersion, toDir is the given path") + + expectedFromDir := strings.ReplaceAll(dir, to, from) + require.Equal(t, expectedFromDir, fromDir) +} + +func Test_determineUpgradeVersionDirectories_Upgrade_PathContainsFromVersion(t *testing.T) { + /* + Upgrade: updatingFrom != toVersion + Path contains the fromVersion -> replace fromVersion with toVersion for the "to" dir + */ + to := "3.0.1" + from := "2.9.9" + curr := "3.0.1" + dir := "/opt/exts/My.Ext/" + from + "/bin" + + setEnvs(t, to, curr, from, dir) + + tempDir, _ := os.MkdirTemp("", "upgradefromversion") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + fromDir, toDir, gotFromVersion := determineUpgradeVersionDirectories(ctx, events) + + require.Equal(t, from, gotFromVersion) + require.Equal(t, dir, fromDir, "when path does NOT contain toVersion, fromDir is the given path") + + expectedToDir := strings.ReplaceAll(dir, from, to) + require.Equal(t, expectedToDir, toDir) +} + +func Test_determineUpgradeVersionDirectories_Downgrade_PathContainsToVersion(t *testing.T) { + /* + Downgrade: updatingFrom == toVersion + upgradeFromVersion becomes extensionVersionValue (curr) + Path contains toVersion -> replace toVersion with fromVersion (curr) for the "from" dir + */ + to := "2.1.0" + curr := "2.3.0" // extensionVersionValue + fromUpdating := to + dir := "C:\\Packages\\Plugins\\My.Ext\\" + to + "\\" + + setEnvs(t, to, curr, fromUpdating, dir) + + tempDir, _ := os.MkdirTemp("", "downgradetoversion") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + fromDir, toDir, gotFromVersion := determineUpgradeVersionDirectories(ctx, events) + + require.Equal(t, curr, gotFromVersion, "on downgrade, fromVersion becomes the current extension version") + require.Equal(t, dir, toDir) + + expectedFromDir := strings.ReplaceAll(dir, to, curr) + require.Equal(t, expectedFromDir, fromDir) +} + +func Test_determineUpgradeVersionDirectories_Downgrade_PathContainsFromVersion(t *testing.T) { + /* + Downgrade: updatingFrom == toVersion + Path contains the computed fromVersion (curr), so toDir is replacement of fromVersion->toVersion + */ + to := "1.7.0" + curr := "1.7.5" + fromUpdating := to + dir := "/extensions/handler/" + curr + "/" + + setEnvs(t, to, curr, fromUpdating, dir) + + tempDir, _ := os.MkdirTemp("", "downgradefromversion") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + fromDir, toDir, gotFromVersion := determineUpgradeVersionDirectories(ctx, events) + + require.Equal(t, curr, gotFromVersion) + require.Equal(t, dir, fromDir) + + expectedToDir := strings.ReplaceAll(dir, curr, to) + require.Equal(t, expectedToDir, toDir) +} + +func Test_rehydrateMrSeqFiles_OpenStatusDirFails(t *testing.T) { + tempDir, _ := os.MkdirTemp("", "OpenStatusDirFails") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + from := t.TempDir() // does not contain the status subdir + to := t.TempDir() + + err := rehydrateMrSeqFilesForProblematicUpgrades(ctx, from, to, events) + require.NotNil(t, err, "expected error when opening missing status dir") + require.True(t, strings.Contains(err.Error(), "Failed to open status directory"), "unexpected error message: %s", err.Error()) +} + +func Test_rehydrateMrSeqFiles_ReadDirFails(t *testing.T) { + tempDir, _ := os.MkdirTemp("", "ReadDirFails") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + from := filepath.Join(tempDir, "from") + require.NoError(t, os.Mkdir(from, 0o755)) + to := filepath.Join(tempDir, "to") + require.NoError(t, os.Mkdir(to, 0o755)) + + // Create a *file* at "/status" so os.Open succeeds but ReadDir fails. + statusPath := filepath.Join(from, constants.StatusFileDirectory) + require.NoError(t, os.WriteFile(statusPath, []byte("not a directory"), 0o600)) + + err := rehydrateMrSeqFilesForProblematicUpgrades(ctx, from, to, events) + require.NotNil(t, err, "expected error when ReadDir fails") + require.True(t, strings.Contains(err.Error(), "could not read directory entries"), "unexpected error message: %s", err.Error()) +} + +func Test_rehydrateMrSeqFiles_IgnoresInvalidStatusFilename(t *testing.T) { + tempDir, _ := os.MkdirTemp("", "IgnoresInvalidStatusFilename") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + from := filepath.Join(tempDir, "from") + require.NoError(t, os.Mkdir(from, 0o755)) + to := filepath.Join(tempDir, "to") + require.NoError(t, os.Mkdir(to, 0o755)) + + // Proper status directory + require.NoError(t, os.MkdirAll(filepath.Join(from, constants.StatusFileDirectory), 0o755)) + + // Invalid filename (only two parts, missing seqNo). Should be ignored. + invalid := filepath.Join(from, constants.StatusFileDirectory, "alpha"+constants.StatusFileExtension) + require.NoError(t, os.WriteFile(invalid, []byte(""), 0o600)) + + err := rehydrateMrSeqFilesForProblematicUpgrades(ctx, from, to, events) + require.Nil(t, err, "unexpected error: %v", err) + + // No mrseq should be created for invalid filenames. + _, err = os.Stat(filepath.Join(to, "alpha"+constants.MrSeqFileExtension)) + require.True(t, os.IsNotExist(err), "alpha%s should not have been created", constants.MrSeqFileExtension) +} + +func Test_rehydrateMrSeqFiles_RehydrateMissingMrseq(t *testing.T) { + tempDir, _ := os.MkdirTemp("", "RehydrateMissingMrseq") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + from := filepath.Join(tempDir, "from") + require.NoError(t, os.Mkdir(from, 0o755)) + to := filepath.Join(tempDir, "to") + require.NoError(t, os.Mkdir(to, 0o755)) + + statusDir := filepath.Join(from, constants.StatusFileDirectory) + if err := os.MkdirAll(statusDir, 0o755); err != nil { + t.Fatalf("mkdir status: %v", err) + } + + // alpha.5.status → should create to/alpha.mrseq with "5" + alphaStatus := filepath.Join(statusDir, "alpha.5"+constants.StatusFileExtension) + require.NoError(t, os.WriteFile(alphaStatus, []byte(""), 0o600)) + + err := rehydrateMrSeqFilesForProblematicUpgrades(ctx, from, to, events) + require.Nil(t, err, "unexpected error: %v", err) + + mrseqPath := filepath.Join(to, "alpha"+constants.MrSeqFileExtension) + got := mustReadFile(t, mrseqPath) + require.Equal(t, "5", got, "mrseq content = %q, want %q", got, "5") +} + +func Test_rehydrateMrSeqFiles_UpdateExistingMrseqWhenHigherSeqFound(t *testing.T) { + tempDir, _ := os.MkdirTemp("", "UpdateExistingMrseqWhenHigherSeqFound") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + from := filepath.Join(tempDir, "from") + require.NoError(t, os.Mkdir(from, 0o755)) + to := filepath.Join(tempDir, "to") + require.NoError(t, os.Mkdir(to, 0o755)) + + statusDir := filepath.Join(from, constants.StatusFileDirectory) + require.NoError(t, os.MkdirAll(statusDir, 0o755)) + + // Existing mrseq=3 + mrseqPath := filepath.Join(to, "alpha"+constants.MrSeqFileExtension) + require.NoError(t, os.WriteFile(mrseqPath, []byte("3"), 0o600)) + + // Status reports seq=5 → should overwrite to "5" + alphaStatus := filepath.Join(statusDir, "alpha.5"+constants.StatusFileExtension) + require.NoError(t, os.WriteFile(alphaStatus, []byte(""), 0o600)) + + err := rehydrateMrSeqFilesForProblematicUpgrades(ctx, from, to, events) + require.Nil(t, err, "unexpected error: %v", err) + + got := mustReadFile(t, mrseqPath) + require.Equal(t, "5", got, "mrseq content = %q, want %q", got, "5") +} + +func Test_rehydrateMrSeqFiles_NoUpdateWhenExistingIsHigher(t *testing.T) { + tempDir, _ := os.MkdirTemp("", "NoUpdateWhenExistingIsHigher") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + from := filepath.Join(tempDir, "from") + require.NoError(t, os.Mkdir(from, 0o755)) + to := filepath.Join(tempDir, "to") + require.NoError(t, os.Mkdir(to, 0o755)) + + statusDir := filepath.Join(from, constants.StatusFileDirectory) + require.NoError(t, os.MkdirAll(statusDir, 0o755)) + + // Existing mrseq=7 + mrseqPath := filepath.Join(to, "alpha "+constants.MrSeqFileExtension) + require.NoError(t, os.WriteFile(mrseqPath, []byte("7"), 0o600)) + + // Status reports seq=5 → should NOT overwrite + alphaStatus := filepath.Join(statusDir, "alpha.5"+constants.StatusFileExtension) + require.NoError(t, os.WriteFile(alphaStatus, []byte(""), 0o600)) + + err := rehydrateMrSeqFilesForProblematicUpgrades(ctx, from, to, events) + require.Nil(t, err, "unexpected error: %v", err) + + got := mustReadFile(t, mrseqPath) + require.Equal(t, "7", got, "mrseq content = %q, want %q", got, "7") +} + +func Test_rehydrateMrSeqFiles_MultipleStatusFiles_TakesMax(t *testing.T) { + tempDir, _ := os.MkdirTemp("", "TakesMax") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + from := filepath.Join(tempDir, "from") + require.NoError(t, os.Mkdir(from, 0o755)) + to := filepath.Join(tempDir, "to") + require.NoError(t, os.Mkdir(to, 0o755)) + + statusDir := filepath.Join(from, constants.StatusFileDirectory) + require.NoError(t, os.MkdirAll(statusDir, 0o755)) + + // Both alpha.3.status and alpha.7.status — final mrseq should be "7" + for _, seq := range []int{3, 7} { + p := filepath.Join(statusDir, "alpha."+strconv.Itoa(seq)+constants.StatusFileExtension) + require.NoError(t, os.WriteFile(p, []byte(""), 0o600)) + } + + err := rehydrateMrSeqFilesForProblematicUpgrades(ctx, from, to, events) + require.Nil(t, err, "unexpected error: %v", err) + + got := mustReadFile(t, filepath.Join(to, "alpha"+constants.MrSeqFileExtension)) + require.Equal(t, "7", got, "mrseq content = %q, want %q", got, "7") +} + +func Test_rehydrateMrSeqFiles_ReadExistingMrseqFails(t *testing.T) { + tempDir, _ := os.MkdirTemp("", "ReadExistingMrseqFails") + defer os.RemoveAll(tempDir) + handlerEnvironment := handlerenv.HandlerEnvironment{ + EventsFolder: tempDir, + } + + ctx := log.NewContext(log.NewNopLogger()) + extensionLogger := logging.New(nil) + events := extensionevents.New(extensionLogger, &handlerEnvironment) + + from := filepath.Join(tempDir, "from") + require.NoError(t, os.Mkdir(from, 0o755)) + to := filepath.Join(tempDir, "to") + require.NoError(t, os.Mkdir(to, 0o755)) + + statusDir := filepath.Join(from, constants.StatusFileDirectory) + require.NoError(t, os.MkdirAll(statusDir, 0o755)) + + // Make a directory at the mrseq path so ReadFile fails + mrseqPath := filepath.Join(to, "alpha"+constants.MrSeqFileExtension) + require.NoError(t, os.MkdirAll(mrseqPath, 0o755)) + + // Now create a status that would try to read/compare + alphaStatus := filepath.Join(statusDir, "alpha.5"+constants.StatusFileExtension) + require.NoError(t, os.WriteFile(alphaStatus, []byte(""), 0o600)) + + err := rehydrateMrSeqFilesForProblematicUpgrades(ctx, from, to, events) + require.NotNil(t, err, "expected error due to unreadable mrseq") + require.True(t, strings.Contains(err.Error(), "Could not read file"), "Unexpected error: %v", err.Error()) +} + +// setEnvs is a helper to seed the env for each scenario. +func setEnvs(t *testing.T, to, curr, from, dir string) { + t.Helper() + t.Setenv(constants.VersionEnvName, to) + t.Setenv(constants.ExtensionVersionEnvName, curr) + t.Setenv(constants.ExtensionVersionUpdatingFromEnvName, from) + t.Setenv(constants.ExtensionPathEnvName, dir) +} + +func mustReadFile(t *testing.T, p string) string { + b, err := os.ReadFile(p) + if err != nil { + t.Fatalf("read %s: %v", p, err) + } + return string(b) +} diff --git a/internal/constants/constants.go b/internal/constants/constants.go index 94bdccd..87c3db2 100755 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -38,6 +38,9 @@ const ( // The output directory for logs of immediate run command ImmediateRCOutputDirectory = "/var/log/azure/run-command-handler/ImmediateRunCommandService.log" + // This is the directory where we place all the .mrseq files + WorkingDirectory = "../run-command-handler/workdir" + // Download folder to use for standard managed run command DownloadFolder = "download/" @@ -48,23 +51,24 @@ const ( RunCommandExtensionName = "Microsoft.CPlat.Core.RunCommandHandlerLinux" RunCommandTestExtensionName = "Microsoft.Azure.Extensions.Edp.RunCommandHandlerLinuxTest" - // List of problematic RCV2 versions that delete the mrseq files - ProductionVersionThatDeletesMrSeqFiles = "1.3.17" - FirstTestVersionThatDeletesMrSeqFiles = "1.8.0" - SecondTestVersionThatDeletesMrSeqFiles = "1.9.0" - // The current version of the extension. This value is provided by the agent for all commands. // See more in: https://github.com/Azure/azure-vmextension-publishing/wiki/2.0-Partner-Guide-Handler-Design-Details#236-summary - ExtensionVersionEnvName = "AZURE_GUEST_AGENT_EXTENSION_VERSION" + VersionEnvName = "VERSION" // This is the version the extension is updating from // See more in: https://github.com/Azure/azure-vmextension-publishing/wiki/2.0-Partner-Guide-Handler-Design-Details#236-summary ExtensionVersionUpdatingFromEnvName = "AZURE_GUEST_AGENT_UPDATING_FROM_VERSION" + // We'll use this variable to determine if there's a downgrade + ExtensionVersionEnvName = "AZURE_GUEST_AGENT_EXTENSION_VERSION" + // The path of the extension in the VM with full name. This value is provided by the agent for all commands. // See more in: https://github.com/Azure/azure-vmextension-publishing/wiki/2.0-Partner-Guide-Handler-Design-Details#236-summary ExtensionPathEnvName = "AZURE_GUEST_AGENT_EXTENSION_PATH" + // The first version from which no rehydration is necessary because we do not delete the .mrseq file in disable + FirstVersionNoRehydration = "1.3.26" + // The name of the immediate run command service ImmediateRunCommandHandlerName = "runCommandService" diff --git a/internal/service/serviceinstall.go b/internal/service/serviceinstall.go index 1014e2f..bd14aeb 100644 --- a/internal/service/serviceinstall.go +++ b/internal/service/serviceinstall.go @@ -39,7 +39,7 @@ func Register(ctx *log.Context, extensionEvents *extensionevents.ExtensionEventM extensionEvents.LogErrorEvent("register", "Systemd not supported. Failed to register service") return errors.New("Systemd not supported. Failed to register service") } - targetVersion := os.Getenv(constants.ExtensionVersionEnvName) + targetVersion := os.Getenv(constants.VersionEnvName) ctx.Log("message", "trying to register extension with version: "+targetVersion) ctx.Log("message", "Generating service configuration files") diff --git a/pkg/servicehandler/servicehandler.go b/pkg/servicehandler/servicehandler.go index e3567fd..078f025 100644 --- a/pkg/servicehandler/servicehandler.go +++ b/pkg/servicehandler/servicehandler.go @@ -109,7 +109,7 @@ func (handler *Handler) Register(ctx *log.Context, unitConfigContent string) err func (handler *Handler) DeRegister(ctx *log.Context) error { // We need to make sure the version that the VM Agent is trying to uninstall is the correct one. // Failing to check this can cause to uninstall the service during the update workflow. - targetVersion := os.Getenv(constants.ExtensionVersionEnvName) + targetVersion := os.Getenv(constants.VersionEnvName) ctx.Log("message", "trying to uninstall extension with version: "+targetVersion) installedVersion, err := handler.GetInstalledVersion(ctx) diff --git a/pkg/servicehandler/servicehandler_test.go b/pkg/servicehandler/servicehandler_test.go index f5ed7e1..0bc610b 100644 --- a/pkg/servicehandler/servicehandler_test.go +++ b/pkg/servicehandler/servicehandler_test.go @@ -301,7 +301,7 @@ func TestHandlerSuccessfulDeRegister(t *testing.T) { ctx := log.NewContext(log.NewSyncLogger(log.NewLogfmtLogger( os.Stdout))).With("time", log.DefaultTimestamp) - os.Setenv(constants.ExtensionVersionEnvName, installedTargetVersion) + os.Setenv(constants.VersionEnvName, installedTargetVersion) handler := NewHandler(m, config, ctx) err := handler.DeRegister(ctx) if err != nil { @@ -336,7 +336,7 @@ func TestHandlerSkipDeRegisterForNonInstalledTargetVersion(t *testing.T) { ctx := log.NewContext(log.NewSyncLogger(log.NewLogfmtLogger( os.Stdout))).With("time", log.DefaultTimestamp) - os.Setenv(constants.ExtensionVersionEnvName, nonInstalledTargetVersion) + os.Setenv(constants.VersionEnvName, nonInstalledTargetVersion) handler := NewHandler(m, config, ctx) err := handler.DeRegister(ctx) if err != nil {