Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ private static async Task<bool> CheckMosPrerequisitesAsync(
CancellationToken ct)
{
// Check 1: Verify all required service principals exist
// Cache SP IDs so Check 3 can reuse them without redundant Graph API lookups.
var firstPartyClientSpId = await graph.LookupServicePrincipalByAppIdAsync(config.TenantId,
MosConstants.TpsAppServicesClientAppId, ct);
if (string.IsNullOrWhiteSpace(firstPartyClientSpId))
Expand All @@ -37,6 +38,7 @@ private static async Task<bool> CheckMosPrerequisitesAsync(
logger.LogDebug("Verified service principal for {ConstantName} ({AppId})",
nameof(MosConstants.TpsAppServicesClientAppId), MosConstants.TpsAppServicesClientAppId);

var resourceSpIds = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
foreach (var resourceAppId in MosConstants.AllResourceAppIds)
{
var spId = await graph.LookupServicePrincipalByAppIdAsync(config.TenantId, resourceAppId, ct);
Expand All @@ -45,6 +47,7 @@ private static async Task<bool> CheckMosPrerequisitesAsync(
logger.LogDebug("Service principal for {ResourceAppId} not found - configuration needed", resourceAppId);
return false;
}
resourceSpIds[resourceAppId] = spId;
logger.LogDebug("Verified service principal for resource app ({ResourceAppId})", resourceAppId);
}

Expand Down Expand Up @@ -96,15 +99,15 @@ private static async Task<bool> CheckMosPrerequisitesAsync(
}

// Check 3: Verify admin consent is granted for all MOS resources
// Reuse SP IDs cached from Check 1 to avoid redundant Graph API lookups.
var mosResourceScopes = MosConstants.ResourcePermissions.GetAll()
.ToDictionary(p => p.ResourceAppId, p => p.ScopeName);

foreach (var (resourceAppId, scopeName) in mosResourceScopes)
{
var resourceSpId = await graph.LookupServicePrincipalByAppIdAsync(config.TenantId, resourceAppId, ct);
if (string.IsNullOrWhiteSpace(resourceSpId))
if (!resourceSpIds.TryGetValue(resourceAppId, out var resourceSpId))
{
logger.LogDebug("Service principal for {ResourceAppId} not found - configuration needed", resourceAppId);
logger.LogWarning("Service principal for {ResourceAppId} not found in cache (unexpected - should have been cached by Check 1)", resourceAppId);
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@ public class GraphApiService
private readonly CommandExecutor _executor;
private readonly HttpClient _httpClient;
private readonly IMicrosoftGraphTokenProvider? _tokenProvider;

// Azure CLI token cache to avoid spawning az subprocess for every Graph API call.
// Tokens acquired via 'az account get-access-token' are typically valid for 60+ minutes;
// we cache them for a shorter window so the CLI still picks up token refreshes promptly.
private string? _cachedAzCliToken;
private string? _cachedAzCliTenantId;
private DateTimeOffset _cachedAzCliTokenExpiry = DateTimeOffset.MinValue;
internal static readonly TimeSpan AzCliTokenCacheDuration = TimeSpan.FromMinutes(5);

/// <summary>
/// Expiry time for the cached Azure CLI token. Internal for testing purposes.
/// </summary>
internal DateTimeOffset CachedAzCliTokenExpiry
{
get => _cachedAzCliTokenExpiry;
set => _cachedAzCliTokenExpiry = value;
}

/// <summary>
/// Optional custom client app ID to use for authentication with Microsoft Graph PowerShell.
Expand Down Expand Up @@ -212,16 +229,37 @@ private async Task<bool> EnsureGraphHeadersAsync(string tenantId, CancellationTo
else
{
// Use Azure CLI token (default fallback for operations that don't need special scopes)
_logger.LogDebug("Acquiring Graph token via Azure CLI (no specific scopes required)");
token = await GetGraphAccessTokenAsync(tenantId, ct);

if (string.IsNullOrWhiteSpace(token))
// Check if we have a cached token for this tenant that hasn't expired
if (_cachedAzCliToken != null
&& string.Equals(_cachedAzCliTenantId, tenantId, StringComparison.OrdinalIgnoreCase)
&& DateTimeOffset.UtcNow < _cachedAzCliTokenExpiry)
{
_logger.LogError("Failed to acquire Graph token via Azure CLI. Ensure 'az login' is completed.");
return false;
_logger.LogDebug("Using cached Azure CLI Graph token (expires in {Minutes:F1} minutes)",
(_cachedAzCliTokenExpiry - DateTimeOffset.UtcNow).TotalMinutes);
token = _cachedAzCliToken;
}
else
{
_logger.LogDebug("Acquiring Graph token via Azure CLI (no specific scopes required)");
token = await GetGraphAccessTokenAsync(tenantId, ct);

if (string.IsNullOrWhiteSpace(token))
{
// Clear cache on failure to ensure clean state
_cachedAzCliToken = null;
_cachedAzCliTenantId = null;
_cachedAzCliTokenExpiry = DateTimeOffset.MinValue;

_logger.LogDebug("Successfully acquired Graph token via Azure CLI");
_logger.LogError("Failed to acquire Graph token via Azure CLI. Ensure 'az login' is completed.");
return false;
}

// Cache the token for subsequent calls within the same command execution
_cachedAzCliToken = token;
_cachedAzCliTenantId = tenantId;
_cachedAzCliTokenExpiry = DateTimeOffset.UtcNow.Add(AzCliTokenCacheDuration);
_logger.LogDebug("Cached Azure CLI Graph token for {Duration} minutes", AzCliTokenCacheDuration.TotalMinutes);
}
}

// Remove all newline characters and trim whitespace to prevent header validation errors
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System.Net;
using System.Net.Http;
using FluentAssertions;
using Microsoft.Agents.A365.DevTools.Cli.Services;
using Microsoft.Extensions.Logging;
using NSubstitute;
using Xunit;

namespace Microsoft.Agents.A365.DevTools.Cli.Tests.Services;

/// <summary>
/// Tests to validate that Azure CLI Graph tokens are cached across consecutive
/// Graph API calls, avoiding redundant 'az' subprocess spawns.
/// </summary>
public class GraphApiServiceTokenCacheTests
{
/// <summary>
/// Helper: create a GraphApiService with a mock executor that counts calls
/// and returns a predictable token.
/// </summary>
private static (GraphApiService service, TestHttpMessageHandler handler, CommandExecutor executor) CreateService(string token = "cached-token")
{
var handler = new TestHttpMessageHandler();
var logger = Substitute.For<ILogger<GraphApiService>>();
var executor = Substitute.For<CommandExecutor>(Substitute.For<ILogger<CommandExecutor>>());

executor.ExecuteAsync(
Arg.Any<string>(), Arg.Any<string>(), Arg.Any<string?>(),
Arg.Any<bool>(), Arg.Any<bool>(), Arg.Any<CancellationToken>())
.Returns(callInfo =>
{
var cmd = callInfo.ArgAt<string>(0);
var args = callInfo.ArgAt<string>(1);
if (cmd == "az" && args != null && args.StartsWith("account show", StringComparison.OrdinalIgnoreCase))
return Task.FromResult(new CommandResult { ExitCode = 0, StandardOutput = "{}", StandardError = string.Empty });
if (cmd == "az" && args != null && args.Contains("get-access-token", StringComparison.OrdinalIgnoreCase))
return Task.FromResult(new CommandResult { ExitCode = 0, StandardOutput = token, StandardError = string.Empty });
return Task.FromResult(new CommandResult { ExitCode = 0, StandardOutput = string.Empty, StandardError = string.Empty });
});

var service = new GraphApiService(logger, executor, handler);
return (service, handler, executor);
}

[Fact]
public async Task MultipleGraphGetAsync_SameTenant_AcquiresTokenOnlyOnce()
{
// Arrange
var (service, handler, executor) = CreateService();

try
{
// Queue 3 successful GET responses
for (int i = 0; i < 3; i++)
{
handler.QueueResponse(new HttpResponseMessage(HttpStatusCode.OK)
{
Content = new StringContent("{\"value\":[]}")
});
}

// Act - make 3 consecutive Graph GET calls to the same tenant
var r1 = await service.GraphGetAsync("tenant-1", "/v1.0/path1");
var r2 = await service.GraphGetAsync("tenant-1", "/v1.0/path2");
var r3 = await service.GraphGetAsync("tenant-1", "/v1.0/path3");

// Assert - all calls should succeed
r1.Should().NotBeNull();
r2.Should().NotBeNull();
r3.Should().NotBeNull();

// The token should be acquired only ONCE (1 account show + 1 get-access-token = 2 az calls)
await executor.Received(1).ExecuteAsync(
"az",
Arg.Is<string>(s => s.Contains("get-access-token")),
Arg.Any<string?>(), Arg.Any<bool>(), Arg.Any<bool>(), Arg.Any<CancellationToken>());

await executor.Received(1).ExecuteAsync(
"az",
Arg.Is<string>(s => s.Contains("account show")),
Arg.Any<string?>(), Arg.Any<bool>(), Arg.Any<bool>(), Arg.Any<CancellationToken>());
}
finally
{
handler.Dispose();
}
}

[Fact]
public async Task GraphGetAsync_DifferentTenants_AcquiresTokenForEach()
{
// Arrange
var (service, handler, executor) = CreateService();

try
{
// Queue 2 responses
for (int i = 0; i < 2; i++)
{
handler.QueueResponse(new HttpResponseMessage(HttpStatusCode.OK)
{
Content = new StringContent("{\"value\":[]}")
});
}

// Act - make calls to different tenants
var r1 = await service.GraphGetAsync("tenant-1", "/v1.0/path1");
var r2 = await service.GraphGetAsync("tenant-2", "/v1.0/path2");

// Assert
r1.Should().NotBeNull();
r2.Should().NotBeNull();

// Token should be acquired twice (once per tenant)
await executor.Received(2).ExecuteAsync(
"az",
Arg.Is<string>(s => s.Contains("get-access-token")),
Arg.Any<string?>(), Arg.Any<bool>(), Arg.Any<bool>(), Arg.Any<CancellationToken>());
}
finally
{
handler.Dispose();
}
}

[Fact]
public async Task MixedGraphOperations_SameTenant_AcquiresTokenOnlyOnce()
{
// Arrange
var (service, handler, executor) = CreateService();

try
{
// Queue responses for GET, POST, GET sequence
handler.QueueResponse(new HttpResponseMessage(HttpStatusCode.OK)
{
Content = new StringContent("{\"value\":[]}")
});
handler.QueueResponse(new HttpResponseMessage(HttpStatusCode.OK)
{
Content = new StringContent("{\"id\":\"123\"}")
});
handler.QueueResponse(new HttpResponseMessage(HttpStatusCode.OK)
{
Content = new StringContent("{\"value\":[]}")
});

// Act - interleave GET and POST calls
var r1 = await service.GraphGetAsync("tenant-1", "/v1.0/path1");
var r2 = await service.GraphPostAsync("tenant-1", "/v1.0/path2", new { name = "test" });
var r3 = await service.GraphGetAsync("tenant-1", "/v1.0/path3");

// Assert
r1.Should().NotBeNull();
r2.Should().NotBeNull();
r3.Should().NotBeNull();

// Only one token acquisition across all operations
await executor.Received(1).ExecuteAsync(
"az",
Arg.Is<string>(s => s.Contains("get-access-token")),
Arg.Any<string?>(), Arg.Any<bool>(), Arg.Any<bool>(), Arg.Any<CancellationToken>());
}
finally
{
handler.Dispose();
}
}

[Fact]
public void AzCliTokenCacheDuration_IsFiveMinutes()
{
// The cache duration should be a reasonable window to avoid stale tokens
// while eliminating redundant subprocess spawns within a single command.
GraphApiService.AzCliTokenCacheDuration.Should().Be(TimeSpan.FromMinutes(5));
}

[Fact]
public async Task GraphGetAsync_ExpiredCache_AcquiresNewToken()
{
// Arrange
var (service, handler, executor) = CreateService();

try
{
// Queue 2 successful GET responses
handler.QueueResponse(new HttpResponseMessage(HttpStatusCode.OK)
{
Content = new StringContent("{\"value\":[]}")
});
handler.QueueResponse(new HttpResponseMessage(HttpStatusCode.OK)
{
Content = new StringContent("{\"value\":[]}")
});

// Act - First call should acquire token and cache it
await service.GraphGetAsync("tenant-1", "/v1.0/path1");

// Simulate cache expiry by setting expiry to past
service.CachedAzCliTokenExpiry = DateTimeOffset.UtcNow.AddMinutes(-1);

// Second call should acquire new token because cache expired
await service.GraphGetAsync("tenant-1", "/v1.0/path2");

// Assert - Token should be acquired twice (once for each call since cache expired)
await executor.Received(2).ExecuteAsync(
"az",
Arg.Is<string>(s => s.Contains("get-access-token")),
Arg.Any<string?>(), Arg.Any<bool>(), Arg.Any<bool>(), Arg.Any<CancellationToken>());
}
finally
{
handler.Dispose();
}
}
}
Loading