diff --git a/CHANGELOG.md b/CHANGELOG.md index 2eee4cba..0b06f42e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,7 @@ Agents provisioned before this release need `Agent365.Observability.OtelWrite` g - `--yes` / `-y` option on `develop-mcp publish` — skips the interactive "Proceed with publish? (y/N)" confirmation. ### Fixed +- Commands requiring authentication no longer return misleading Graph 403 errors when Windows has multiple cached work accounts. The CLI detects a wrong-tenant token (`tid` claim mismatch), clears the MSAL token cache, and retries automatically (#430). - `setup all` now exits silently on Ctrl+C instead of printing `ERROR: Setup failed: A task was canceled.` followed by a misleading partial summary. - `setup all --m365` no longer fails with `AADSTS650053` because the Messaging Bot scope was hard-coded to scopes the resource SP does not publish (issue #429). - `setup all` no longer fails with `AADSTS650053` for any drift between requested scopes and what a resource SP actually publishes (issue #429). Unpublished scopes are filtered out before building the consent URL; per-resource warnings surface what was dropped. diff --git a/src/Microsoft.Agents.A365.DevTools.Cli/Constants/AuthenticationConstants.cs b/src/Microsoft.Agents.A365.DevTools.Cli/Constants/AuthenticationConstants.cs index 64820e7d..da4c339f 100644 --- a/src/Microsoft.Agents.A365.DevTools.Cli/Constants/AuthenticationConstants.cs +++ b/src/Microsoft.Agents.A365.DevTools.Cli/Constants/AuthenticationConstants.cs @@ -88,6 +88,13 @@ public static string[] GetRequiredRedirectUris(string clientAppId) /// public const string TokenCacheFileName = "auth-token.json"; + /// + /// MSAL persistent token cache file name. + /// Used by MsalBrowserCredential (WAM/browser auth) and referenced by + /// AuthenticationService when clearing stale cross-tenant cached tokens. + /// + public const string MsalCacheFileName = "msal-token-cache"; + /// /// Token expiration buffer in minutes /// Tokens are considered expired this many minutes before actual expiration diff --git a/src/Microsoft.Agents.A365.DevTools.Cli/Services/AuthenticationService.cs b/src/Microsoft.Agents.A365.DevTools.Cli/Services/AuthenticationService.cs index a7c60aaf..13a6d9f5 100644 --- a/src/Microsoft.Agents.A365.DevTools.Cli/Services/AuthenticationService.cs +++ b/src/Microsoft.Agents.A365.DevTools.Cli/Services/AuthenticationService.cs @@ -63,14 +63,17 @@ public class AuthenticationService : IAuthenticationService { private readonly ILogger _logger; private readonly string _tokenCachePath; + // Stored so ClearStaleTokenCachesAsync can compute the MSAL cache path without a + // cross-class dependency on MsalBrowserCredential's private static field. + private readonly string _cacheDir; public AuthenticationService(ILogger logger) { _logger = logger; var appDataPath = Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData); - var cacheDir = Path.Combine(appDataPath, AuthenticationConstants.ApplicationName); - Directory.CreateDirectory(cacheDir); - _tokenCachePath = Path.Combine(cacheDir, AuthenticationConstants.TokenCacheFileName); + _cacheDir = Path.Combine(appDataPath, AuthenticationConstants.ApplicationName); + Directory.CreateDirectory(_cacheDir); + _tokenCachePath = Path.Combine(_cacheDir, AuthenticationConstants.TokenCacheFileName); } /// @@ -198,6 +201,38 @@ public async Task GetAccessTokenAsync( _logger.LogDebug("Authentication required for Agent 365 Tools"); var token = await AuthenticateInteractivelyAsync(resourceUrl, tenantId, clientId, scopes, useInteractiveBrowser, loginHint: userId, ct: ct); + // Self-heal: validate the tid claim in the returned JWT against the requested tenant. + // WAM may silently select a cached work account from a different tenant when multiple + // Windows accounts are present (issue #430). On mismatch, clear both our JSON cache + // and the MSAL persistent cache to reset WAM's account selection, then retry once. + // Only compare tid when the requested tenantId is a GUID — JWT tid claims are always + // GUIDs, so comparison against a domain-form tenantId (e.g. contoso.onmicrosoft.com) + // would always appear as a mismatch, causing unnecessary cache clears and retry loops. + if (!string.IsNullOrWhiteSpace(tenantId) && Guid.TryParse(tenantId, out _)) + { + var returnedTid = JwtHelper.TryDecodeClaim(token.AccessToken, "tid"); + if (!string.IsNullOrWhiteSpace(returnedTid) && + !string.Equals(returnedTid, tenantId, StringComparison.OrdinalIgnoreCase)) + { + _logger.LogWarning( + "Authentication returned token for tenant {ReturnedTenant} but {RequestedTenant} is required. " + + "Clearing cached credentials and retrying...", + returnedTid, tenantId); + await ClearStaleTokenCachesAsync(); + // Retry once with the same parameters — MSAL disk cache is now empty so WAM + // gets a clean slate and will either pick the correct account or prompt. + token = await AuthenticateInteractivelyAsync(resourceUrl, tenantId, clientId, scopes, useInteractiveBrowser, loginHint: userId, ct: ct); + var retryTid = JwtHelper.TryDecodeClaim(token.AccessToken, "tid"); + if (!string.IsNullOrWhiteSpace(retryTid) && + !string.Equals(retryTid, tenantId, StringComparison.OrdinalIgnoreCase)) + { + throw new AzureAuthenticationException( + $"The account selected does not match the configured tenant ({tenantId}). " + + $"Ensure 'az login' targets the correct tenant, or select the correct account when prompted."); + } + } + } + // Validate the token identity before caching: if a userId was requested, // ensure the returned token is actually for that user. WAM may return a // guest/cross-app token for an account it considers "equivalent" (same Microsoft @@ -348,7 +383,13 @@ private async Task AuthenticateInteractivelyAsync( { AccessToken = tokenResult.Token, ExpiresOn = tokenResult.ExpiresOn.UtcDateTime, - TenantId = effectiveTenantId + // Store the decoded JWT tid only when the requested tenantId is also a GUID. + // If callers pass a domain name (e.g. contoso.onmicrosoft.com), storing the + // GUID tid would cause the next cache-read comparison to always fail, forcing + // re-authentication on every run. + TenantId = Guid.TryParse(effectiveTenantId, out _) + ? JwtHelper.TryDecodeClaim(tokenResult.Token, "tid") ?? effectiveTenantId + : effectiveTenantId }; } catch (MsalAuthenticationFailedException ex) when (ex.Message.Contains("code_expired") || ex.InnerException?.Message.Contains("code_expired") == true) @@ -686,27 +727,49 @@ protected virtual TokenCredential CreateDeviceCodeCredential(string clientId, st private static string? TryExtractUpnFromJwt(string? jwt) { - if (string.IsNullOrWhiteSpace(jwt)) return null; + // Try the UPN claim variants in order of specificity. + // Delegates to JwtHelper.TryDecodeClaim for the shared Base64Url decode. + return JwtHelper.TryDecodeClaim(jwt, "upn") + ?? JwtHelper.TryDecodeClaim(jwt, "preferred_username") + ?? JwtHelper.TryDecodeClaim(jwt, "unique_name"); + } + + /// + /// Deletes both the JSON token cache and the MSAL persistent cache. + /// Each deletion is independently non-fatal; errors are logged at Debug level. + /// + private Task ClearStaleTokenCachesAsync() + { + // 1. Our JSON token cache try { - var parts = jwt.Split('.'); - if (parts.Length < 2) return null; - var payload = parts[1]; - // JWT uses Base64Url encoding: replace URL-safe chars before standard Base64 decode. - payload = payload.Replace('-', '+').Replace('_', '/'); - // Restore Base64 padding stripped by JWT encoding. - payload = payload.PadRight(payload.Length + (4 - payload.Length % 4) % 4, '='); - var bytes = Convert.FromBase64String(payload); - using var doc = JsonDocument.Parse(bytes); - if (doc.RootElement.TryGetProperty("upn", out var upn) && !string.IsNullOrWhiteSpace(upn.GetString())) - return upn.GetString(); - if (doc.RootElement.TryGetProperty("preferred_username", out var pref) && !string.IsNullOrWhiteSpace(pref.GetString())) - return pref.GetString(); - if (doc.RootElement.TryGetProperty("unique_name", out var uniqueName) && !string.IsNullOrWhiteSpace(uniqueName.GetString())) - return uniqueName.GetString(); + if (File.Exists(_tokenCachePath)) + { + File.Delete(_tokenCachePath); + _logger.LogDebug("Cleared JSON token cache at {Path}", _tokenCachePath); + } } - catch { } // Static helper — no logger access. Caller logs via ResolveLoginHintFromCacheAsync. - return null; + catch (Exception ex) + { + _logger.LogDebug(ex, "Failed to clear JSON token cache: {Message}", ex.Message); + } + + // 2. MSAL persistent cache (WAM/browser) + var msalCachePath = Path.Combine(_cacheDir, AuthenticationConstants.MsalCacheFileName); + try + { + if (File.Exists(msalCachePath)) + { + File.Delete(msalCachePath); + _logger.LogDebug("Cleared MSAL token cache at {Path}", msalCachePath); + } + } + catch (Exception ex) + { + _logger.LogDebug(ex, "Failed to clear MSAL token cache: {Message}", ex.Message); + } + + return Task.CompletedTask; } /// diff --git a/src/Microsoft.Agents.A365.DevTools.Cli/Services/GraphApiService.cs b/src/Microsoft.Agents.A365.DevTools.Cli/Services/GraphApiService.cs index b5db10bf..8ed1388f 100644 --- a/src/Microsoft.Agents.A365.DevTools.Cli/Services/GraphApiService.cs +++ b/src/Microsoft.Agents.A365.DevTools.Cli/Services/GraphApiService.cs @@ -1911,22 +1911,7 @@ public virtual async Task DeleteAgentInstanceAsync( /// Returns null if the token cannot be decoded or the claim is absent. /// private static string? TryDecodeTokenClaim(string token, string claimName) - { - try - { - var parts = token.Split('.'); - if (parts.Length < 2) return null; - var payload = parts[1]; - payload = payload.PadRight(payload.Length + (4 - payload.Length % 4) % 4, '='); - var json = System.Text.Encoding.UTF8.GetString(Convert.FromBase64String(payload)); - using var doc = JsonDocument.Parse(json); - return doc.RootElement.TryGetProperty(claimName, out var claim) ? claim.GetString() : null; - } - catch - { - return null; - } - } + => JwtHelper.TryDecodeClaim(token, claimName); /// /// Attempts to extract a human-readable error message from a Graph API JSON error response body. diff --git a/src/Microsoft.Agents.A365.DevTools.Cli/Services/Helpers/JwtHelper.cs b/src/Microsoft.Agents.A365.DevTools.Cli/Services/Helpers/JwtHelper.cs new file mode 100644 index 00000000..bceca592 --- /dev/null +++ b/src/Microsoft.Agents.A365.DevTools.Cli/Services/Helpers/JwtHelper.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json; + +namespace Microsoft.Agents.A365.DevTools.Cli.Services.Helpers; + +/// +/// Shared helper for decoding JWT token claims. +/// Consolidates the duplicated Base64Url-decode logic that previously existed in +/// AuthenticationService.TryExtractUpnFromJwt, GraphApiService.TryDecodeTokenClaim, +/// and MsalBrowserCredential inline decode blocks. +/// +internal static class JwtHelper +{ + /// + /// Decodes a single claim from the payload of a JWT string. + /// Returns null if the token is malformed, the claim is absent, or decoding fails. + /// + internal static string? TryDecodeClaim(string? jwt, string claimName) + { + if (string.IsNullOrWhiteSpace(jwt)) return null; + try + { + var parts = jwt.Split('.'); + if (parts.Length < 2) return null; + var payload = parts[1]; + // JWT uses Base64Url encoding: restore standard Base64 chars and padding. + payload = payload.Replace('-', '+').Replace('_', '/'); + payload = payload.PadRight(payload.Length + (4 - payload.Length % 4) % 4, '='); + var bytes = Convert.FromBase64String(payload); + using var doc = JsonDocument.Parse(bytes); + return doc.RootElement.TryGetProperty(claimName, out var claim) + ? claim.GetString() + : null; + } + catch + { + return null; + } + } +} diff --git a/src/Microsoft.Agents.A365.DevTools.Cli/Services/Internal/MicrosoftGraphTokenProvider.cs b/src/Microsoft.Agents.A365.DevTools.Cli/Services/Internal/MicrosoftGraphTokenProvider.cs index 965af087..b6581c53 100644 --- a/src/Microsoft.Agents.A365.DevTools.Cli/Services/Internal/MicrosoftGraphTokenProvider.cs +++ b/src/Microsoft.Agents.A365.DevTools.Cli/Services/Internal/MicrosoftGraphTokenProvider.cs @@ -13,6 +13,7 @@ using Azure.Core; using Microsoft.Agents.A365.DevTools.Cli.Constants; using Microsoft.Agents.A365.DevTools.Cli.Helpers; +using Microsoft.Agents.A365.DevTools.Cli.Services.Helpers; using Microsoft.Extensions.Logging; namespace Microsoft.Agents.A365.DevTools.Cli.Services; @@ -178,6 +179,65 @@ public MicrosoftGraphTokenProvider( return null; } + // Self-heal: validate the tid claim in the returned token against the requested tenant. + // WAM may silently select a cached Windows work account from a different tenant when + // multiple accounts are present (issue #430). On mismatch, clear the in-memory cache + // entry and the MSAL persistent disk cache, then retry once with forceRefresh so WAM + // gets a clean slate and either picks the correct account or prompts the user. + // Only compare tid when tenantId is a GUID — JWT tid claims are always GUIDs, + // so a domain-form tenantId (e.g. contoso.onmicrosoft.com) would always appear + // as a mismatch and clear caches unnecessarily. + var returnedTid = JwtHelper.TryDecodeClaim(token, "tid"); + if (!string.IsNullOrWhiteSpace(returnedTid) && + Guid.TryParse(tenantId, out _) && + !string.Equals(returnedTid, tenantId, StringComparison.OrdinalIgnoreCase)) + { + _logger.LogWarning( + "Graph token returned for tenant {ReturnedTenant} but {RequestedTenant} is required. " + + "Clearing cached credentials and retrying...", + returnedTid, tenantId); + + // Evict in-memory entry so the retry actually calls through to MSAL/PS. + _tokenCache.TryRemove(cacheKey, out _); + + // Delete the MSAL persistent cache file so WAM starts with a clean account list. + var msalCachePath = Path.Combine( + Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData), + AuthenticationConstants.ApplicationName, + AuthenticationConstants.MsalCacheFileName); + try + { + if (File.Exists(msalCachePath)) + { + File.Delete(msalCachePath); + _logger.LogDebug("Cleared MSAL token cache at {Path}", msalCachePath); + } + } + catch (Exception ex) + { + _logger.LogDebug(ex, "Failed to clear MSAL token cache: {Message}", ex.Message); + } + + // Retry once — do not recurse; use the underlying acquirer directly. + var retryToken = MsalTokenAcquirerOverride != null + ? await MsalTokenAcquirerOverride(tenantId, validatedScopes, clientAppId, ct) + : await AcquireGraphTokenViaMsalAsync(tenantId, validatedScopes, clientAppId, ct, loginHint, forceRefresh: true); + + if (!string.IsNullOrWhiteSpace(retryToken)) + token = retryToken; + + // Fail fast if the retry also returned the wrong tenant — caching and returning + // a known-bad token would produce the same misleading 403s the fix is meant to prevent. + var retryTid = JwtHelper.TryDecodeClaim(token, "tid"); + if (!string.IsNullOrWhiteSpace(retryTid) && + !string.Equals(retryTid, tenantId, StringComparison.OrdinalIgnoreCase)) + { + throw new InvalidOperationException( + $"Graph token retry returned token for tenant {retryTid} but {tenantId} is required. " + + $"Ensure 'az login' targets the correct tenant, or select the correct account when prompted."); + } + } + // Cache expiry from JWT exp; if parsing fails, cache short (10 min) to still reduce spam if (!TryGetJwtExpiryUtc(token, out var expUtc)) { diff --git a/src/Microsoft.Agents.A365.DevTools.Cli/Services/MsalBrowserCredential.cs b/src/Microsoft.Agents.A365.DevTools.Cli/Services/MsalBrowserCredential.cs index 4d275bd7..e47a5468 100644 --- a/src/Microsoft.Agents.A365.DevTools.Cli/Services/MsalBrowserCredential.cs +++ b/src/Microsoft.Agents.A365.DevTools.Cli/Services/MsalBrowserCredential.cs @@ -50,7 +50,7 @@ public sealed class MsalBrowserCredential : TokenCredential private static MsalCacheHelper? _cacheHelper; private static readonly object _cacheHelperLock = new(); - private static readonly string CacheFileName = "msal-token-cache"; + private static readonly string CacheFileName = AuthenticationConstants.MsalCacheFileName; private static readonly string CacheDirectory = Path.Combine( Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData), AuthenticationConstants.ApplicationName); diff --git a/src/Tests/Microsoft.Agents.A365.DevTools.Cli.Tests/Services/AuthenticationServiceTests.cs b/src/Tests/Microsoft.Agents.A365.DevTools.Cli.Tests/Services/AuthenticationServiceTests.cs index cfddfec4..845168e5 100644 --- a/src/Tests/Microsoft.Agents.A365.DevTools.Cli.Tests/Services/AuthenticationServiceTests.cs +++ b/src/Tests/Microsoft.Agents.A365.DevTools.Cli.Tests/Services/AuthenticationServiceTests.cs @@ -7,6 +7,7 @@ using Microsoft.Agents.A365.DevTools.Cli.Exceptions; using Microsoft.Agents.A365.DevTools.Cli.Models; using Microsoft.Agents.A365.DevTools.Cli.Services; +using Microsoft.Agents.A365.DevTools.Cli.Services.Helpers; using NSubstitute; using System.Text.Json; using Microsoft.Agents.A365.DevTools.Cli.Constants; @@ -1058,4 +1059,162 @@ public async Task ResolveLoginHintFromCacheAsync_WhenJwtIsMalformed_ReturnsNull( } #endregion + + #region WAM wrong-tenant self-heal (issue #430) + + // Builds a minimal JWT whose payload contains only the supplied tid claim. + // Re-uses the existing BuildJwt helper (same Base64Url encoding). + private static StubTokenCredential CredentialWithTid(string tid) + => new(BuildJwt(new { tid }), DateTimeOffset.UtcNow.AddHours(1)); + + /// + /// TestableAuthenticationService variant that uses a call-indexed credential list. + /// Each call to CreateBrowserCredential pops the next credential from the queue. + /// + private sealed class SequencedAuthenticationService : AuthenticationService + { + private readonly Queue _credentials; + + public SequencedAuthenticationService( + ILogger logger, + IEnumerable credentials) + : base(logger) + => _credentials = new Queue(credentials); + + protected override TokenCredential CreateBrowserCredential(string clientId, string tenantId, string? loginHint = null) + => _credentials.Count > 0 ? _credentials.Dequeue() : base.CreateBrowserCredential(clientId, tenantId, loginHint); + } + + // Computes the MSAL cache path so tests can back it up and restore it. + private static string MsalCachePath => Path.Combine( + Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData), + AuthenticationConstants.ApplicationName, + AuthenticationConstants.MsalCacheFileName); + + [Fact] + public async Task GetAccessTokenAsync_WhenFirstTokenHasWrongTid_ClearsCachesAndRetries() + { + // Arrange — first call returns wrong-tenant token; second returns correct tenant. + var correctTenant = "aaaaaaaa-0000-0000-0000-aaaaaaaaaaaa"; + var wrongTenant = "bbbbbbbb-0000-0000-0000-bbbbbbbbbbbb"; + + var logger = Substitute.For>(); + var sut = new SequencedAuthenticationService(logger, new[] + { + CredentialWithTid(wrongTenant), // attempt 1 — WAM picked the wrong account + CredentialWithTid(correctTenant) // attempt 2 — after cache clear, correct account + }); + + // Back up any pre-existing real MSAL cache so the test does not destroy it. + string? originalCache = File.Exists(MsalCachePath) + ? await File.ReadAllTextAsync(MsalCachePath) + : null; + + try + { + // Act + var result = await sut.GetAccessTokenAsync( + "https://graph.microsoft.com", + tenantId: correctTenant, + forceRefresh: true, + useInteractiveBrowser: true); + + // Assert — the token returned contains the correct tid + var tid = JwtHelper.TryDecodeClaim(result, "tid"); + tid.Should().Be(correctTenant, + because: "after the self-heal retry the correct-tenant token must be returned"); + + // A warning must have been logged to indicate that the wrong account was detected + logger.Received().Log( + LogLevel.Warning, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("Clearing cached credentials")), + Arg.Any(), + Arg.Any>()); + } + finally + { + sut.ClearCache(); + if (originalCache is null) { if (File.Exists(MsalCachePath)) File.Delete(MsalCachePath); } + else { await File.WriteAllTextAsync(MsalCachePath, originalCache); } + } + } + + [Fact] + public async Task GetAccessTokenAsync_WhenBothAttemptsReturnWrongTid_ThrowsAzureAuthenticationException() + { + // Arrange — both calls return a token for the wrong tenant (worst-case: retry also wrong). + var correctTenant = "aaaaaaaa-0000-0000-0000-aaaaaaaaaaaa"; + var wrongTenant = "bbbbbbbb-0000-0000-0000-bbbbbbbbbbbb"; + + var logger = Substitute.For>(); + var sut = new SequencedAuthenticationService(logger, new[] + { + CredentialWithTid(wrongTenant), // attempt 1 + CredentialWithTid(wrongTenant) // attempt 2 — still wrong after cache clear + }); + + // Back up any pre-existing real MSAL cache so the test does not destroy it. + string? originalCache = File.Exists(MsalCachePath) + ? await File.ReadAllTextAsync(MsalCachePath) + : null; + + try + { + // Act + Func act = async () => await sut.GetAccessTokenAsync( + "https://graph.microsoft.com", + tenantId: correctTenant, + forceRefresh: true, + useInteractiveBrowser: true); + + // Assert + await act.Should().ThrowAsync( + because: "when both attempts return the wrong tenant the CLI must fail with a clear error " + + "rather than silently proceeding with a token that will cause 403 on every Graph call"); + } + finally + { + sut.ClearCache(); + if (originalCache is null) { if (File.Exists(MsalCachePath)) File.Delete(MsalCachePath); } + else { await File.WriteAllTextAsync(MsalCachePath, originalCache); } + } + } + + [Fact] + public async Task GetAccessTokenAsync_WhenTokenTidMatchesConfiguredTenant_NoRetryOccurs() + { + // Arrange — token tid matches configured tenant; self-heal path must NOT fire. + var correctTenant = "aaaaaaaa-0000-0000-0000-aaaaaaaaaaaa"; + + var logger = Substitute.For>(); + var sut = new SequencedAuthenticationService(logger, new[] + { + CredentialWithTid(correctTenant) // only one credential queued — a retry would throw + }); + + try + { + // Act — should succeed without touching the second (non-existent) credential + var result = await sut.GetAccessTokenAsync( + "https://graph.microsoft.com", + tenantId: correctTenant, + forceRefresh: true, + useInteractiveBrowser: true); + + // Assert — no warning about clearing caches + logger.DidNotReceive().Log( + LogLevel.Warning, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("Clearing cached credentials")), + Arg.Any(), + Arg.Any>()); + } + finally + { + sut.ClearCache(); + } + } + + #endregion } diff --git a/src/Tests/Microsoft.Agents.A365.DevTools.Cli.Tests/Services/MicrosoftGraphTokenProviderTests.cs b/src/Tests/Microsoft.Agents.A365.DevTools.Cli.Tests/Services/MicrosoftGraphTokenProviderTests.cs index 07a016ba..33a834d9 100644 --- a/src/Tests/Microsoft.Agents.A365.DevTools.Cli.Tests/Services/MicrosoftGraphTokenProviderTests.cs +++ b/src/Tests/Microsoft.Agents.A365.DevTools.Cli.Tests/Services/MicrosoftGraphTokenProviderTests.cs @@ -2,13 +2,16 @@ // Licensed under the MIT License. using FluentAssertions; +using Microsoft.Agents.A365.DevTools.Cli.Constants; using Microsoft.Agents.A365.DevTools.Cli.Services; +using Microsoft.Agents.A365.DevTools.Cli.Services.Helpers; using Microsoft.Extensions.Logging; using NSubstitute; using Xunit; namespace Microsoft.Agents.A365.DevTools.Cli.Tests.Services; +[Collection("AuthTests")] public class MicrosoftGraphTokenProviderTests { private readonly ILogger _logger; @@ -501,4 +504,122 @@ await _executor.Received(1).ExecuteWithStreamingAsync( Arg.Any(), Arg.Any(), Arg.Any?>(), Arg.Any(), Arg.Any()); } + + // ── WAM wrong-tenant self-heal (issue #430) ─────────────────────────────── + + private static string BuildJwt(object payload) + { + var json = System.Text.Json.JsonSerializer.Serialize(payload); + var payloadB64 = Convert.ToBase64String(System.Text.Encoding.UTF8.GetBytes(json)) + .Replace('+', '-').Replace('/', '_').TrimEnd('='); + return $"header.{payloadB64}.signature"; + } + + [Fact] + public async Task GetMgGraphAccessTokenAsync_WhenMsalTokenHasWrongTid_ClearsFileAndRetries() + { + // Arrange + var correctTenant = "aaaaaaaa-0000-0000-0000-aaaaaaaaaaaa"; + var wrongTenant = "bbbbbbbb-0000-0000-0000-bbbbbbbbbbbb"; + var clientAppId = "87654321-4321-4321-4321-cba987654321"; + var scopes = new[] { "User.Read" }; + + // Write a dummy MSAL cache file to verify it gets deleted on mismatch. + var msalCachePath = Path.Combine( + Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData), + AuthenticationConstants.ApplicationName, + AuthenticationConstants.MsalCacheFileName); + Directory.CreateDirectory(Path.GetDirectoryName(msalCachePath)!); + + // Back up any pre-existing real MSAL cache so the test does not destroy it. + string? originalCache = File.Exists(msalCachePath) + ? await File.ReadAllTextAsync(msalCachePath) + : null; + + try + { + await File.WriteAllTextAsync(msalCachePath, "stale-msal-cache"); + + var callCount = 0; + var provider = new MicrosoftGraphTokenProvider(_executor, _logger) + { + MsalTokenAcquirerOverride = (tid, _, _, _) => + { + callCount++; + // First call: wrong tenant (simulates WAM picking stale cached account) + // Second call: correct tenant (after cache clear WAM uses the right account) + var returnedTid = callCount == 1 ? wrongTenant : correctTenant; + return Task.FromResult(BuildJwt(new { tid = returnedTid })); + } + }; + + // Act + var token = await provider.GetMgGraphAccessTokenAsync(correctTenant, scopes, false, clientAppId); + + // Assert — retry was triggered and correct-tenant token returned + callCount.Should().Be(2, + because: "a tid mismatch on the first MSAL call must trigger exactly one retry"); + var returnedTid2 = JwtHelper.TryDecodeClaim(token, "tid"); + returnedTid2.Should().Be(correctTenant, + because: "the token returned after self-heal must be for the configured tenant"); + + // MSAL cache file must have been deleted to give WAM a clean slate + File.Exists(msalCachePath).Should().BeFalse( + because: "the stale MSAL cache must be removed so WAM re-evaluates account selection on retry"); + + // Warning must have been logged + _logger.Received().Log( + LogLevel.Warning, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("Clearing cached credentials")), + Arg.Any(), + Arg.Any>()); + } + finally + { + // Restore the original MSAL cache (or remove the test file if none existed before). + if (originalCache is null) + { + if (File.Exists(msalCachePath)) File.Delete(msalCachePath); + } + else + { + await File.WriteAllTextAsync(msalCachePath, originalCache); + } + } + } + + [Fact] + public async Task GetMgGraphAccessTokenAsync_WhenTokenTidMatchesConfiguredTenant_NoRetryOccurs() + { + // Arrange — token already has the correct tid; self-heal must not fire. + var correctTenant = "aaaaaaaa-0000-0000-0000-aaaaaaaaaaaa"; + var clientAppId = "87654321-4321-4321-4321-cba987654321"; + var scopes = new[] { "User.Read" }; + + var callCount = 0; + var provider = new MicrosoftGraphTokenProvider(_executor, _logger) + { + MsalTokenAcquirerOverride = (_, _, _, _) => + { + callCount++; + return Task.FromResult(BuildJwt(new { tid = correctTenant })); + } + }; + + // Act + var token = await provider.GetMgGraphAccessTokenAsync(correctTenant, scopes, false, clientAppId); + + // Assert — exactly one MSAL call; no retry + callCount.Should().Be(1, + because: "when the returned tid matches the configured tenant no retry should occur"); + token.Should().NotBeNullOrEmpty(); + + _logger.DidNotReceive().Log( + LogLevel.Warning, + Arg.Any(), + Arg.Is(o => o.ToString()!.Contains("Clearing cached credentials")), + Arg.Any(), + Arg.Any>()); + } }