Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ public MicrosoftGraphTokenProvider(
useDeviceCode);

var script = BuildPowerShellScript(tenantId, validatedScopes, useDeviceCode, clientAppId);
var result = await ExecuteWithFallbackAsync(script, useDeviceCode, ct);
var result = await ExecuteWithFallbackAsync(script, ct);
var token = ProcessResult(result);

if (string.IsNullOrWhiteSpace(token))
Expand Down Expand Up @@ -216,17 +216,16 @@ private static string BuildScopesArray(string[] scopes)

private async Task<CommandResult> ExecuteWithFallbackAsync(
string script,
bool useDeviceCode,
CancellationToken ct)
{
// Try PowerShell Core first (cross-platform)
var result = await ExecutePowerShellAsync("pwsh", script, useDeviceCode, ct);
var result = await ExecutePowerShellAsync("pwsh", script, ct);

// Fallback to Windows PowerShell if pwsh is not available
if (!result.Success && IsPowerShellNotFoundError(result))
{
_logger.LogDebug("PowerShell Core not found, falling back to Windows PowerShell");
result = await ExecutePowerShellAsync("powershell", script, useDeviceCode, ct);
result = await ExecutePowerShellAsync("powershell", script, ct);
}

return result;
Expand All @@ -235,33 +234,17 @@ private async Task<CommandResult> ExecuteWithFallbackAsync(
private async Task<CommandResult> ExecutePowerShellAsync(
string shell,
string script,
bool useDeviceCode,
CancellationToken ct)
{
var arguments = BuildPowerShellArguments(shell, script);

if (useDeviceCode)
{
// Use streaming for device code flow so user sees the instructions in real-time
return await _executor.ExecuteWithStreamingAsync(
command: shell,
arguments: arguments,
workingDirectory: null,
outputPrefix: "",
interactive: true, // Allow user to see device code instructions
cancellationToken: ct);
}
else
{
// Use standard execution for browser flow (captures output silently)
return await _executor.ExecuteAsync(
command: shell,
arguments: arguments,
workingDirectory: null,
captureOutput: true,
suppressErrorLogging: true, // We handle logging ourselves
cancellationToken: ct);
}
return await _executor.ExecuteWithStreamingAsync(
command: shell,
arguments: arguments,
workingDirectory: null,
outputPrefix: "",
interactive: true,
cancellationToken: ct);
}

private static string BuildPowerShellArguments(string shell, string script)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ public async Task GetMgGraphAccessTokenAsync_WithValidClientAppId_IncludesClient
var clientAppId = "87654321-4321-4321-4321-cba987654321";
var expectedToken = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.signature";

_executor.ExecuteAsync(
_executor.ExecuteWithStreamingAsync(
Arg.Any<string>(),
Arg.Is<string>(args => args.Contains($"-ClientId '{clientAppId}'")),
Arg.Any<string?>(),
Arg.Any<bool>(),
Arg.Any<string>(),
Arg.Any<bool>(),
Arg.Any<CancellationToken>())
.Returns(new CommandResult { ExitCode = 0, StandardOutput = expectedToken, StandardError = string.Empty });
Expand All @@ -45,11 +45,11 @@ public async Task GetMgGraphAccessTokenAsync_WithValidClientAppId_IncludesClient

// Assert
token.Should().Be(expectedToken);
await _executor.Received(1).ExecuteAsync(
await _executor.Received(1).ExecuteWithStreamingAsync(
Arg.Is<string>(cmd => cmd == "pwsh" || cmd == "powershell"),
Arg.Is<string>(args => args.Contains($"-ClientId '{clientAppId}'")),
Arg.Any<string?>(),
Arg.Any<bool>(),
Arg.Any<string>(),
Arg.Any<bool>(),
Arg.Any<CancellationToken>());
}
Expand All @@ -62,11 +62,11 @@ public async Task GetMgGraphAccessTokenAsync_WithoutClientAppId_OmitsClientIdPar
var scopes = new[] { "User.Read" };
var expectedToken = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.signature";

_executor.ExecuteAsync(
_executor.ExecuteWithStreamingAsync(
Arg.Any<string>(),
Arg.Is<string>(args => !args.Contains("-ClientId")),
Arg.Any<string?>(),
Arg.Any<bool>(),
Arg.Any<string>(),
Arg.Any<bool>(),
Arg.Any<CancellationToken>())
.Returns(new CommandResult { ExitCode = 0, StandardOutput = expectedToken, StandardError = string.Empty });
Expand All @@ -78,11 +78,11 @@ public async Task GetMgGraphAccessTokenAsync_WithoutClientAppId_OmitsClientIdPar

// Assert
token.Should().Be(expectedToken);
await _executor.Received(1).ExecuteAsync(
await _executor.Received(1).ExecuteWithStreamingAsync(
Arg.Any<string>(),
Arg.Is<string>(args => !args.Contains("-ClientId")),
Arg.Any<string?>(),
Arg.Any<bool>(),
Arg.Any<string>(),
Arg.Any<bool>(),
Arg.Any<CancellationToken>());
}
Expand Down Expand Up @@ -152,11 +152,11 @@ public async Task GetMgGraphAccessTokenAsync_WhenExecutionFails_ReturnsNull()
var tenantId = "12345678-1234-1234-1234-123456789abc";
var scopes = new[] { "User.Read" };

_executor.ExecuteAsync(
_executor.ExecuteWithStreamingAsync(
Arg.Any<string>(),
Arg.Any<string>(),
Arg.Any<string?>(),
Arg.Any<bool>(),
Arg.Any<string>(),
Arg.Any<bool>(),
Arg.Any<CancellationToken>())
.Returns(new CommandResult { ExitCode = 1, StandardOutput = string.Empty, StandardError = "PowerShell error" });
Expand All @@ -178,11 +178,11 @@ public async Task GetMgGraphAccessTokenAsync_WithValidToken_ReturnsToken()
var scopes = new[] { "User.Read" };
var expectedToken = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.signature";

_executor.ExecuteAsync(
_executor.ExecuteWithStreamingAsync(
Arg.Any<string>(),
Arg.Any<string>(),
Arg.Any<string?>(),
Arg.Any<bool>(),
Arg.Any<string>(),
Arg.Any<bool>(),
Arg.Any<CancellationToken>())
.Returns(new CommandResult { ExitCode = 0, StandardOutput = expectedToken, StandardError = string.Empty });
Expand Down Expand Up @@ -223,11 +223,11 @@ public async Task GetMgGraphAccessTokenAsync_EscapesSingleQuotesInClientAppId()
var clientAppId = "87654321-4321-4321-4321-cba987654321";
var expectedToken = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.signature";

_executor.ExecuteAsync(
_executor.ExecuteWithStreamingAsync(
Arg.Any<string>(),
Arg.Is<string>(args => !args.Contains("''")), // Should not have escaped quotes for valid GUID
Arg.Any<string?>(),
Arg.Any<bool>(),
Arg.Any<string>(),
Arg.Any<bool>(),
Arg.Any<CancellationToken>())
.Returns(new CommandResult { ExitCode = 0, StandardOutput = expectedToken, StandardError = string.Empty });
Expand Down