Skip to content

Commit d9b1e1f

Browse files
authored
Reduce repeated Graph sign-in prompts (#139)
1 parent 2cff104 commit d9b1e1f

5 files changed

Lines changed: 202 additions & 36 deletions

File tree

src/Microsoft.Agents.A365.DevTools.Cli/Commands/SetupSubcommands/SetupHelpers.cs

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,15 +195,32 @@ public static async Task EnsureResourcePermissionsAsync(
195195
if (string.IsNullOrWhiteSpace(config.AgentBlueprintId))
196196
throw new SetupValidationException("AgentBlueprintId (appId) is required.");
197197

198-
var blueprintSpObjectId = await graph.LookupServicePrincipalByAppIdAsync(config.TenantId, config.AgentBlueprintId, ct);
198+
// Use delegated token provider for *all* permission operations to avoid bouncing between Azure CLI auth and Microsoft Graph PowerShell auth.
199+
var permissionGrantScopes = AuthenticationConstants.RequiredPermissionGrantScopes;
200+
201+
// Pre-warm the delegated token once
202+
var user = await graph.GraphGetAsync(
203+
config.TenantId,
204+
"/v1.0/me?$select=id",
205+
ct,
206+
scopes: permissionGrantScopes);
207+
208+
if (user == null)
209+
{
210+
throw new SetupValidationException(
211+
"Failed to authenticate to Microsoft Graph with delegated permissions. " +
212+
"Please sign in when prompted and ensure your account has the required roles and permission scopes.");
213+
}
214+
215+
var blueprintSpObjectId = await graph.LookupServicePrincipalByAppIdAsync(config.TenantId, config.AgentBlueprintId, ct, permissionGrantScopes);
199216
if (string.IsNullOrWhiteSpace(blueprintSpObjectId))
200217
{
201218
throw new SetupValidationException($"Blueprint Service Principal not found for appId {config.AgentBlueprintId}. " +
202219
"The service principal may not have propagated yet. Wait a few minutes and retry.");
203220
}
204221

205222
// Ensure resource service principal exists
206-
var resourceSpObjectId = await graph.EnsureServicePrincipalForAppIdAsync(config.TenantId, resourceAppId, ct);
223+
var resourceSpObjectId = await graph.EnsureServicePrincipalForAppIdAsync(config.TenantId, resourceAppId, ct, permissionGrantScopes);
207224
if (string.IsNullOrWhiteSpace(resourceSpObjectId))
208225
{
209226
throw new SetupValidationException($"{resourceName} Service Principal not found for appId {resourceAppId}. " +
@@ -233,7 +250,7 @@ public static async Task EnsureResourcePermissionsAsync(
233250
blueprintSpObjectId, resourceSpObjectId, string.Join(' ', scopes));
234251

235252
var response = await graph.CreateOrUpdateOauth2PermissionGrantAsync(
236-
config.TenantId, blueprintSpObjectId, resourceSpObjectId, scopes, ct);
253+
config.TenantId, blueprintSpObjectId, resourceSpObjectId, scopes, ct, permissionGrantScopes);
237254

238255
if (!response)
239256
{

src/Microsoft.Agents.A365.DevTools.Cli/Constants/AuthenticationConstants.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,17 @@ public static class AuthenticationConstants
7979
"Directory.Read.All"
8080
};
8181

82+
/// <summary>
83+
/// Required scopes for oauth2 permission grants to service principals.
84+
/// These scopes enable the service principals to operate correctly with the necessary permissions.
85+
/// All scopes require admin consent.
86+
/// </summary>
87+
public static readonly string[] RequiredPermissionGrantScopes = new[]
88+
{
89+
"Application.ReadWrite.All",
90+
"DelegatedPermissionGrant.ReadWrite.All"
91+
};
92+
8293
/// <summary>
8394
/// Environment variable name for bearer token used in local development.
8495
/// This token is stored in .env files (Python/Node.js) or launchSettings.json (.NET)

src/Microsoft.Agents.A365.DevTools.Cli/Services/GraphApiService.cs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -325,9 +325,10 @@ public async Task<bool> GraphDeleteAsync(
325325
/// Looks up a service principal by its application (client) ID.
326326
/// Virtual to allow mocking in unit tests using Moq.
327327
/// </summary>
328-
public virtual async Task<string?> LookupServicePrincipalByAppIdAsync(string tenantId, string appId, CancellationToken ct = default)
328+
public virtual async Task<string?> LookupServicePrincipalByAppIdAsync(
329+
string tenantId, string appId, CancellationToken ct = default, IEnumerable<string>? scopes = null)
329330
{
330-
var doc = await GraphGetAsync(tenantId, $"/v1.0/servicePrincipals?$filter=appId eq '{appId}'&$select=id", ct);
331+
var doc = await GraphGetAsync(tenantId, $"/v1.0/servicePrincipals?$filter=appId eq '{appId}'&$select=id", ct, scopes);
331332
if (doc == null) return null;
332333
if (!doc.RootElement.TryGetProperty("value", out var value) || value.GetArrayLength() == 0) return null;
333334
return value[0].GetProperty("id").GetString();
@@ -339,14 +340,14 @@ public async Task<bool> GraphDeleteAsync(
339340
/// Virtual to allow mocking in unit tests using Moq.
340341
/// </summary>
341342
public virtual async Task<string> EnsureServicePrincipalForAppIdAsync(
342-
string tenantId, string appId, CancellationToken ct = default)
343+
string tenantId, string appId, CancellationToken ct = default, IEnumerable<string>? scopes = null)
343344
{
344345
// Try existing
345-
var spId = await LookupServicePrincipalByAppIdAsync(tenantId, appId, ct);
346+
var spId = await LookupServicePrincipalByAppIdAsync(tenantId, appId, ct, scopes);
346347
if (!string.IsNullOrWhiteSpace(spId)) return spId!;
347348

348349
// Create SP for this application
349-
var created = await GraphPostAsync(tenantId, "/v1.0/servicePrincipals", new { appId }, ct);
350+
var created = await GraphPostAsync(tenantId, "/v1.0/servicePrincipals", new { appId }, ct, scopes);
350351
if (created == null || !created.RootElement.TryGetProperty("id", out var idProp))
351352
throw new InvalidOperationException($"Failed to create servicePrincipal for appId {appId}");
352353

@@ -358,15 +359,17 @@ public async Task<bool> CreateOrUpdateOauth2PermissionGrantAsync(
358359
string clientSpObjectId,
359360
string resourceSpObjectId,
360361
IEnumerable<string> scopes,
361-
CancellationToken ct = default)
362+
CancellationToken ct = default,
363+
IEnumerable<string>? permissionGrantScopes = null)
362364
{
363365
var desiredScopeString = string.Join(' ', scopes);
364366

365367
// Read existing
366368
var listDoc = await GraphGetAsync(
367369
tenantId,
368370
$"/v1.0/oauth2PermissionGrants?$filter=clientId eq '{clientSpObjectId}' and resourceId eq '{resourceSpObjectId}'",
369-
ct);
371+
ct,
372+
permissionGrantScopes);
370373

371374
var existing = listDoc?.RootElement.TryGetProperty("value", out var arr) == true && arr.GetArrayLength() > 0
372375
? arr[0]
@@ -382,7 +385,7 @@ public async Task<bool> CreateOrUpdateOauth2PermissionGrantAsync(
382385
resourceId = resourceSpObjectId,
383386
scope = desiredScopeString
384387
};
385-
var created = await GraphPostAsync(tenantId, "/v1.0/oauth2PermissionGrants", payload, ct);
388+
var created = await GraphPostAsync(tenantId, "/v1.0/oauth2PermissionGrants", payload, ct, permissionGrantScopes);
386389
return created != null; // success if response parsed
387390
}
388391

@@ -399,7 +402,7 @@ public async Task<bool> CreateOrUpdateOauth2PermissionGrantAsync(
399402
var id = existing.Value.GetProperty("id").GetString();
400403
if (string.IsNullOrWhiteSpace(id)) return false;
401404

402-
return await GraphPatchAsync(tenantId, $"/v1.0/oauth2PermissionGrants/{id}", new { scope = merged }, ct);
405+
return await GraphPatchAsync(tenantId, $"/v1.0/oauth2PermissionGrants/{id}", new { scope = merged }, ct, permissionGrantScopes);
403406
}
404407

405408
/// <summary>

src/Microsoft.Agents.A365.DevTools.Cli/Services/Internal/MicrosoftGraphTokenProvider.cs

Lines changed: 144 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
using System;
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT License.
3+
4+
using System;
5+
using System.Collections.Concurrent;
26
using System.Collections.Generic;
37
using System.Linq;
8+
using System.Text;
9+
using System.Text.Json;
410
using System.Threading;
511
using System.Threading.Tasks;
12+
using Microsoft.Agents.A365.DevTools.Cli.Constants;
613
using Microsoft.Agents.A365.DevTools.Cli.Helpers;
714
using Microsoft.Extensions.Logging;
815

@@ -11,11 +18,40 @@ namespace Microsoft.Agents.A365.DevTools.Cli.Services;
1118
/// <summary>
1219
/// Implements Microsoft Graph token acquisition via PowerShell Microsoft.Graph module.
1320
/// </summary>
14-
public sealed class MicrosoftGraphTokenProvider : IMicrosoftGraphTokenProvider
21+
public sealed class MicrosoftGraphTokenProvider : IMicrosoftGraphTokenProvider, IDisposable
1522
{
1623
private readonly CommandExecutor _executor;
1724
private readonly ILogger<MicrosoftGraphTokenProvider> _logger;
1825

26+
// Cache tokens per (tenant + clientId + scopes) for the lifetime of this CLI process.
27+
// This reduces repeated Connect-MgGraph prompts in setup flows.
28+
private readonly ConcurrentDictionary<string, CachedToken> _tokenCache = new();
29+
private readonly ConcurrentDictionary<string, SemaphoreSlim> _locks = new();
30+
31+
private sealed record CachedToken(string AccessToken, DateTimeOffset ExpiresOnUtc);
32+
33+
private bool _disposed;
34+
public void Dispose()
35+
{
36+
if (_disposed) return;
37+
_disposed = true;
38+
39+
foreach (var kvp in _locks)
40+
{
41+
try
42+
{
43+
kvp.Value.Dispose();
44+
}
45+
catch (Exception ex)
46+
{
47+
_logger.LogDebug(ex, "Failed to dispose semaphore for key '{Key}' in MicrosoftGraphTokenProvider.", kvp.Key);
48+
}
49+
}
50+
51+
_locks.Clear();
52+
_tokenCache.Clear();
53+
}
54+
1955
public MicrosoftGraphTokenProvider(
2056
CommandExecutor executor,
2157
ILogger<MicrosoftGraphTokenProvider> logger)
@@ -31,10 +67,6 @@ public MicrosoftGraphTokenProvider(
3167
string? clientAppId = null,
3268
CancellationToken ct = default)
3369
{
34-
_logger.LogInformation(
35-
"Acquiring Microsoft Graph delegated access token via PowerShell (Device Code: {UseDeviceCode})",
36-
useDeviceCode);
37-
3870
var validatedScopes = ValidateAndPrepareScopes(scopes);
3971
ValidateTenantId(tenantId);
4072

@@ -43,10 +75,61 @@ public MicrosoftGraphTokenProvider(
4375
ValidateClientAppId(clientAppId);
4476
}
4577

46-
var script = BuildPowerShellScript(tenantId, validatedScopes, useDeviceCode, clientAppId);
47-
var result = await ExecuteWithFallbackAsync(script, useDeviceCode, ct);
78+
var cacheKey = MakeCacheKey(tenantId, validatedScopes, clientAppId);
79+
var tokenExpirationMinutes = AuthenticationConstants.TokenExpirationBufferMinutes;
4880

49-
return ProcessResult(result);
81+
// Fast path: cached + not expiring soon
82+
if (_tokenCache.TryGetValue(cacheKey, out var cached) &&
83+
cached.ExpiresOnUtc > DateTimeOffset.UtcNow.AddMinutes(tokenExpirationMinutes) &&
84+
!string.IsNullOrWhiteSpace(cached.AccessToken))
85+
{
86+
_logger.LogDebug("Reusing cached Graph token for key {Key} expiring at {Exp}",
87+
cacheKey, cached.ExpiresOnUtc);
88+
return cached.AccessToken;
89+
}
90+
91+
// Single-flight: only one PowerShell auth per key at a time
92+
var gate = _locks.GetOrAdd(cacheKey, _ => new SemaphoreSlim(1, 1));
93+
await gate.WaitAsync(ct);
94+
try
95+
{
96+
// Re-check inside lock
97+
if (_tokenCache.TryGetValue(cacheKey, out cached) &&
98+
cached.ExpiresOnUtc > DateTimeOffset.UtcNow.AddMinutes(tokenExpirationMinutes) &&
99+
!string.IsNullOrWhiteSpace(cached.AccessToken))
100+
{
101+
_logger.LogDebug("Reusing cached Graph token (post-lock) for key {Key} expiring at {Exp}",
102+
cacheKey, cached.ExpiresOnUtc);
103+
return cached.AccessToken;
104+
}
105+
106+
_logger.LogInformation(
107+
"Acquiring Microsoft Graph delegated access token via PowerShell (Device Code: {UseDeviceCode})",
108+
useDeviceCode);
109+
110+
var script = BuildPowerShellScript(tenantId, validatedScopes, useDeviceCode, clientAppId);
111+
var result = await ExecuteWithFallbackAsync(script, useDeviceCode, ct);
112+
var token = ProcessResult(result);
113+
114+
if (string.IsNullOrWhiteSpace(token))
115+
{
116+
return null;
117+
}
118+
119+
// Cache expiry from JWT exp; if parsing fails, cache short (10 min) to still reduce spam
120+
if (!TryGetJwtExpiryUtc(token, out var expUtc))
121+
{
122+
expUtc = DateTimeOffset.UtcNow.AddMinutes(10);
123+
_logger.LogDebug("Could not parse JWT exp; caching token for a short duration until {Exp}", expUtc);
124+
}
125+
126+
_tokenCache[cacheKey] = new CachedToken(token, expUtc);
127+
return token;
128+
}
129+
finally
130+
{
131+
gate.Release();
132+
}
50133
}
51134

52135
private string[] ValidateAndPrepareScopes(IEnumerable<string> scopes)
@@ -249,4 +332,56 @@ private static bool IsValidJwtFormat(string token)
249332
return token.StartsWith("eyJ", StringComparison.Ordinal) &&
250333
token.Count(c => c == '.') == 2;
251334
}
335+
336+
private static string MakeCacheKey(string tenantId, IEnumerable<string> scopes, string? clientAppId)
337+
{
338+
var scopeKey = string.Join(" ", scopes
339+
.Where(s => !string.IsNullOrWhiteSpace(s))
340+
.Select(s => s.Trim())
341+
.Distinct(StringComparer.OrdinalIgnoreCase)
342+
.OrderBy(s => s, StringComparer.OrdinalIgnoreCase));
343+
344+
return $"{tenantId}::{clientAppId ?? ""}::{scopeKey}";
345+
}
346+
347+
private bool TryGetJwtExpiryUtc(string jwt, out DateTimeOffset expiresOnUtc)
348+
{
349+
expiresOnUtc = default;
350+
351+
if (string.IsNullOrWhiteSpace(jwt)) return false;
352+
353+
try
354+
{
355+
var parts = jwt.Split('.');
356+
if (parts.Length < 2) return false;
357+
358+
var payloadJson = Encoding.UTF8.GetString(Base64UrlDecode(parts[1]));
359+
using var doc = JsonDocument.Parse(payloadJson);
360+
361+
if (!doc.RootElement.TryGetProperty("exp", out var expEl)) return false;
362+
if (expEl.ValueKind != JsonValueKind.Number) return false;
363+
364+
// exp is seconds since Unix epoch
365+
var expSeconds = expEl.GetInt64();
366+
expiresOnUtc = DateTimeOffset.FromUnixTimeSeconds(expSeconds);
367+
return true;
368+
}
369+
catch (Exception ex)
370+
{
371+
_logger.LogDebug(ex, "Failed to parse JWT expiry (exp) from access token.");
372+
return false;
373+
}
374+
}
375+
376+
private static byte[] Base64UrlDecode(string input)
377+
{
378+
// Base64Url decode with padding fix
379+
var s = input.Replace('-', '+').Replace('_', '/');
380+
switch (s.Length % 4)
381+
{
382+
case 2: s += "=="; break;
383+
case 3: s += "="; break;
384+
}
385+
return Convert.FromBase64String(s);
386+
}
252387
}

0 commit comments

Comments
 (0)