1- using System ;
1+ // Copyright (c) Microsoft Corporation.
2+ // Licensed under the MIT License.
3+
4+ using System ;
5+ using System . Collections . Concurrent ;
26using System . Collections . Generic ;
37using System . Linq ;
8+ using System . Text ;
9+ using System . Text . Json ;
410using System . Threading ;
511using System . Threading . Tasks ;
12+ using Microsoft . Agents . A365 . DevTools . Cli . Constants ;
613using Microsoft . Agents . A365 . DevTools . Cli . Helpers ;
714using Microsoft . Extensions . Logging ;
815
@@ -11,11 +18,40 @@ namespace Microsoft.Agents.A365.DevTools.Cli.Services;
1118/// <summary>
1219/// Implements Microsoft Graph token acquisition via PowerShell Microsoft.Graph module.
1320/// </summary>
14- public sealed class MicrosoftGraphTokenProvider : IMicrosoftGraphTokenProvider
21+ public sealed class MicrosoftGraphTokenProvider : IMicrosoftGraphTokenProvider , IDisposable
1522{
1623 private readonly CommandExecutor _executor ;
1724 private readonly ILogger < MicrosoftGraphTokenProvider > _logger ;
1825
26+ // Cache tokens per (tenant + clientId + scopes) for the lifetime of this CLI process.
27+ // This reduces repeated Connect-MgGraph prompts in setup flows.
28+ private readonly ConcurrentDictionary < string , CachedToken > _tokenCache = new ( ) ;
29+ private readonly ConcurrentDictionary < string , SemaphoreSlim > _locks = new ( ) ;
30+
31+ private sealed record CachedToken ( string AccessToken , DateTimeOffset ExpiresOnUtc ) ;
32+
33+ private bool _disposed ;
34+ public void Dispose ( )
35+ {
36+ if ( _disposed ) return ;
37+ _disposed = true ;
38+
39+ foreach ( var kvp in _locks )
40+ {
41+ try
42+ {
43+ kvp . Value . Dispose ( ) ;
44+ }
45+ catch ( Exception ex )
46+ {
47+ _logger . LogDebug ( ex , "Failed to dispose semaphore for key '{Key}' in MicrosoftGraphTokenProvider." , kvp . Key ) ;
48+ }
49+ }
50+
51+ _locks . Clear ( ) ;
52+ _tokenCache . Clear ( ) ;
53+ }
54+
1955 public MicrosoftGraphTokenProvider (
2056 CommandExecutor executor ,
2157 ILogger < MicrosoftGraphTokenProvider > logger )
@@ -31,10 +67,6 @@ public MicrosoftGraphTokenProvider(
3167 string ? clientAppId = null ,
3268 CancellationToken ct = default )
3369 {
34- _logger . LogInformation (
35- "Acquiring Microsoft Graph delegated access token via PowerShell (Device Code: {UseDeviceCode})" ,
36- useDeviceCode ) ;
37-
3870 var validatedScopes = ValidateAndPrepareScopes ( scopes ) ;
3971 ValidateTenantId ( tenantId ) ;
4072
@@ -43,10 +75,61 @@ public MicrosoftGraphTokenProvider(
4375 ValidateClientAppId ( clientAppId ) ;
4476 }
4577
46- var script = BuildPowerShellScript ( tenantId , validatedScopes , useDeviceCode , clientAppId ) ;
47- var result = await ExecuteWithFallbackAsync ( script , useDeviceCode , ct ) ;
78+ var cacheKey = MakeCacheKey ( tenantId , validatedScopes , clientAppId ) ;
79+ var tokenExpirationMinutes = AuthenticationConstants . TokenExpirationBufferMinutes ;
4880
49- return ProcessResult ( result ) ;
81+ // Fast path: cached + not expiring soon
82+ if ( _tokenCache . TryGetValue ( cacheKey , out var cached ) &&
83+ cached . ExpiresOnUtc > DateTimeOffset . UtcNow . AddMinutes ( tokenExpirationMinutes ) &&
84+ ! string . IsNullOrWhiteSpace ( cached . AccessToken ) )
85+ {
86+ _logger . LogDebug ( "Reusing cached Graph token for key {Key} expiring at {Exp}" ,
87+ cacheKey , cached . ExpiresOnUtc ) ;
88+ return cached . AccessToken ;
89+ }
90+
91+ // Single-flight: only one PowerShell auth per key at a time
92+ var gate = _locks . GetOrAdd ( cacheKey , _ => new SemaphoreSlim ( 1 , 1 ) ) ;
93+ await gate . WaitAsync ( ct ) ;
94+ try
95+ {
96+ // Re-check inside lock
97+ if ( _tokenCache . TryGetValue ( cacheKey , out cached ) &&
98+ cached . ExpiresOnUtc > DateTimeOffset . UtcNow . AddMinutes ( tokenExpirationMinutes ) &&
99+ ! string . IsNullOrWhiteSpace ( cached . AccessToken ) )
100+ {
101+ _logger . LogDebug ( "Reusing cached Graph token (post-lock) for key {Key} expiring at {Exp}" ,
102+ cacheKey , cached . ExpiresOnUtc ) ;
103+ return cached . AccessToken ;
104+ }
105+
106+ _logger . LogInformation (
107+ "Acquiring Microsoft Graph delegated access token via PowerShell (Device Code: {UseDeviceCode})" ,
108+ useDeviceCode ) ;
109+
110+ var script = BuildPowerShellScript ( tenantId , validatedScopes , useDeviceCode , clientAppId ) ;
111+ var result = await ExecuteWithFallbackAsync ( script , useDeviceCode , ct ) ;
112+ var token = ProcessResult ( result ) ;
113+
114+ if ( string . IsNullOrWhiteSpace ( token ) )
115+ {
116+ return null ;
117+ }
118+
119+ // Cache expiry from JWT exp; if parsing fails, cache short (10 min) to still reduce spam
120+ if ( ! TryGetJwtExpiryUtc ( token , out var expUtc ) )
121+ {
122+ expUtc = DateTimeOffset . UtcNow . AddMinutes ( 10 ) ;
123+ _logger . LogDebug ( "Could not parse JWT exp; caching token for a short duration until {Exp}" , expUtc ) ;
124+ }
125+
126+ _tokenCache [ cacheKey ] = new CachedToken ( token , expUtc ) ;
127+ return token ;
128+ }
129+ finally
130+ {
131+ gate . Release ( ) ;
132+ }
50133 }
51134
52135 private string [ ] ValidateAndPrepareScopes ( IEnumerable < string > scopes )
@@ -249,4 +332,56 @@ private static bool IsValidJwtFormat(string token)
249332 return token . StartsWith ( "eyJ" , StringComparison . Ordinal ) &&
250333 token . Count ( c => c == '.' ) == 2 ;
251334 }
335+
336+ private static string MakeCacheKey ( string tenantId , IEnumerable < string > scopes , string ? clientAppId )
337+ {
338+ var scopeKey = string . Join ( " " , scopes
339+ . Where ( s => ! string . IsNullOrWhiteSpace ( s ) )
340+ . Select ( s => s . Trim ( ) )
341+ . Distinct ( StringComparer . OrdinalIgnoreCase )
342+ . OrderBy ( s => s , StringComparer . OrdinalIgnoreCase ) ) ;
343+
344+ return $ "{ tenantId } ::{ clientAppId ?? "" } ::{ scopeKey } ";
345+ }
346+
347+ private bool TryGetJwtExpiryUtc ( string jwt , out DateTimeOffset expiresOnUtc )
348+ {
349+ expiresOnUtc = default ;
350+
351+ if ( string . IsNullOrWhiteSpace ( jwt ) ) return false ;
352+
353+ try
354+ {
355+ var parts = jwt . Split ( '.' ) ;
356+ if ( parts . Length < 2 ) return false ;
357+
358+ var payloadJson = Encoding . UTF8 . GetString ( Base64UrlDecode ( parts [ 1 ] ) ) ;
359+ using var doc = JsonDocument . Parse ( payloadJson ) ;
360+
361+ if ( ! doc . RootElement . TryGetProperty ( "exp" , out var expEl ) ) return false ;
362+ if ( expEl . ValueKind != JsonValueKind . Number ) return false ;
363+
364+ // exp is seconds since Unix epoch
365+ var expSeconds = expEl . GetInt64 ( ) ;
366+ expiresOnUtc = DateTimeOffset . FromUnixTimeSeconds ( expSeconds ) ;
367+ return true ;
368+ }
369+ catch ( Exception ex )
370+ {
371+ _logger . LogDebug ( ex , "Failed to parse JWT expiry (exp) from access token." ) ;
372+ return false ;
373+ }
374+ }
375+
376+ private static byte [ ] Base64UrlDecode ( string input )
377+ {
378+ // Base64Url decode with padding fix
379+ var s = input . Replace ( '-' , '+' ) . Replace ( '_' , '/' ) ;
380+ switch ( s . Length % 4 )
381+ {
382+ case 2 : s += "==" ; break ;
383+ case 3 : s += "=" ; break ;
384+ }
385+ return Convert . FromBase64String ( s ) ;
386+ }
252387}
0 commit comments