diff --git a/src/Components/Server/src/Circuits/CircuitHost.cs b/src/Components/Server/src/Circuits/CircuitHost.cs index b02317a4f12b..be8193706686 100644 --- a/src/Components/Server/src/Circuits/CircuitHost.cs +++ b/src/Components/Server/src/Circuits/CircuitHost.cs @@ -33,7 +33,7 @@ internal partial class CircuitHost : IAsyncDisposable private bool _onConnectionDownFired; private bool _disposed; private long _startTime; - private PersistedCircuitState _persistedCircuitState; + private PersistedCircuitState? _persistedCircuitState; // This event is fired when there's an unrecoverable exception coming from the circuit, and // it need so be torn down. The registry listens to this even so that the circuit can @@ -944,7 +944,7 @@ internal PersistedCircuitState TakePersistedCircuitState() return result; } - internal async Task SendPersistedStateToClient(string rootComponents, string applicationState, CancellationToken cancellation) + internal async Task SendPersistedStateToClient(string rootComponents, string applicationState, DateTimeOffset expiration, CancellationToken cancellation) { try { @@ -953,6 +953,7 @@ internal async Task SendPersistedStateToClient(string rootComponents, stri CircuitId.Secret, rootComponents, applicationState, + expiration.ToUnixTimeMilliseconds(), cancellationToken: cancellation); return succeded; } diff --git a/src/Components/Server/src/Circuits/CircuitPersistenceManager.cs b/src/Components/Server/src/Circuits/CircuitPersistenceManager.cs index 0d9d0be4f252..9b53a40fe6fd 100644 --- a/src/Components/Server/src/Circuits/CircuitPersistenceManager.cs +++ b/src/Components/Server/src/Circuits/CircuitPersistenceManager.cs @@ -25,7 +25,16 @@ await circuit.Renderer.Dispatcher.InvokeAsync(async () => { var renderer = circuit.Renderer; var persistenceManager = circuit.Services.GetRequiredService(); - var collector = new CircuitPersistenceManagerCollector(circuitOptions, serverComponentSerializer, circuit.Renderer); + + // TODO (OR): Select solution variant + // Variant B: Client-side check + var distributedRetention = circuitOptions.Value.PersistedCircuitDistributedRetentionPeriod; + var localRetention = circuitOptions.Value.PersistedCircuitInMemoryRetentionPeriod; + var maxRetention = (distributedRetention > localRetention ? distributedRetention : localRetention) ?? ServerComponentSerializationSettings.DataExpiration; + var expiration = DateTimeOffset.UtcNow.Add(maxRetention); + + var collector = new CircuitPersistenceManagerCollector(serverComponentSerializer, circuit.Renderer, maxRetention); + using var subscription = persistenceManager.State.RegisterOnPersisting( collector.PersistRootComponents, RenderMode.InteractiveServer); @@ -34,7 +43,7 @@ await circuit.Renderer.Dispatcher.InvokeAsync(async () => if (saveStateToClient) { - await SaveStateToClient(circuit, collector.PersistedCircuitState, cancellation); + await SaveStateToClient(circuit, collector.PersistedCircuitState, expiration, cancellation); } else { @@ -46,10 +55,10 @@ await circuitPersistenceProvider.PersistCircuitAsync( }); } - internal async Task SaveStateToClient(CircuitHost circuit, PersistedCircuitState state, CancellationToken cancellation = default) + internal async Task SaveStateToClient(CircuitHost circuit, PersistedCircuitState state, DateTimeOffset expiration, CancellationToken cancellation = default) { var (rootComponents, applicationState) = await ToProtectedStateAsync(state); - if (!await circuit.SendPersistedStateToClient(rootComponents, applicationState, cancellation)) + if (!await circuit.SendPersistedStateToClient(rootComponents, applicationState, expiration, cancellation)) { try { @@ -101,6 +110,27 @@ public async Task ResumeCircuitAsync(CircuitId circuitId, return await circuitPersistenceProvider.RestoreCircuitAsync(circuitId, cancellation); } + internal static bool CheckRootComponentMarkers(IServerComponentDeserializer serverComponentDeserializer, byte[] rootComponents) + { + var persistedMarkers = TryDeserializeMarkers(rootComponents); + + if (persistedMarkers == null) + { + return false; + } + + foreach (var marker in persistedMarkers) + { + if (!serverComponentDeserializer.TryDeserializeWebRootComponentDescriptor(marker.Value, out var _)) + { + // OR: Expired state + return false; + } + } + + return true; + } + // We are going to construct a RootComponentOperationBatch but we are going to replace the descriptors from the client with the // descriptors that we have persisted when pausing the circuit. // The way pausing and resuming works is that when the client starts the resume process, it 'simulates' that an SSR has happened and @@ -152,6 +182,7 @@ internal static RootComponentOperationBatch ToRootComponentOperationBatch( if (!serverComponentDeserializer.TryDeserializeWebRootComponentDescriptor(operation.Marker.Value, out var descriptor)) { + // OR: Expired state return null; } @@ -159,48 +190,55 @@ internal static RootComponentOperationBatch ToRootComponentOperationBatch( } return batch; + } - static Dictionary TryDeserializeMarkers(byte[] rootComponents) + private static Dictionary TryDeserializeMarkers(byte[] rootComponents) + { + if (rootComponents == null || rootComponents.Length == 0) { - if (rootComponents == null || rootComponents.Length == 0) - { - return null; - } + return null; + } - try - { - return JsonSerializer.Deserialize>( - rootComponents, - JsonSerializerOptionsProvider.Options); - } - catch - { - return null; - } + try + { + return JsonSerializer.Deserialize>( + rootComponents, + JsonSerializerOptionsProvider.Options); + } + catch + { + return null; } } - private class CircuitPersistenceManagerCollector( - IOptions circuitOptions, - ServerComponentSerializer serverComponentSerializer, - RemoteRenderer renderer) - : IPersistentComponentStateStore + private class CircuitPersistenceManagerCollector : IPersistentComponentStateStore { + private readonly ServerComponentSerializer _serverComponentSerializer; + private readonly RemoteRenderer _renderer; + private readonly TimeSpan _maxRetention; + + public CircuitPersistenceManagerCollector( + ServerComponentSerializer serverComponentSerializer, + RemoteRenderer renderer, + TimeSpan maxRetention) + { + _serverComponentSerializer = serverComponentSerializer; + _renderer = renderer; + _maxRetention = maxRetention; + } + internal PersistedCircuitState PersistedCircuitState { get; private set; } public Task PersistRootComponents() { var persistedComponents = new Dictionary(); - var components = renderer.GetOrCreateWebRootComponentManager().GetRootComponents(); + var components = _renderer.GetOrCreateWebRootComponentManager().GetRootComponents(); var invocation = new ServerComponentInvocationSequence(); + foreach (var (id, componentKey, (componentType, parameters)) in components) { - var distributedRetention = circuitOptions.Value.PersistedCircuitInMemoryRetentionPeriod; - var localRetention = circuitOptions.Value.PersistedCircuitInMemoryRetentionPeriod; - var maxRetention = distributedRetention > localRetention ? distributedRetention : localRetention; - var marker = ComponentMarker.Create(ComponentMarker.ServerMarkerType, prerendered: false, componentKey); - serverComponentSerializer.SerializeInvocation(ref marker, invocation, componentType, parameters, maxRetention); + _serverComponentSerializer.SerializeInvocation(ref marker, invocation, componentType, parameters, _maxRetention); persistedComponents.Add(id, marker); } diff --git a/src/Components/Server/src/Circuits/ServerComponentDeserializer.cs b/src/Components/Server/src/Circuits/ServerComponentDeserializer.cs index eaebd8856968..dacc3fb5044d 100644 --- a/src/Components/Server/src/Circuits/ServerComponentDeserializer.cs +++ b/src/Components/Server/src/Circuits/ServerComponentDeserializer.cs @@ -249,6 +249,7 @@ private bool TryDeserializeServerComponent(ComponentMarker record, out ServerCom } catch (Exception e) { + // OR: Expired state Log.FailedToUnprotectDescriptor(_logger, e); result = default; return false; diff --git a/src/Components/Server/src/ComponentHub.cs b/src/Components/Server/src/ComponentHub.cs index c8a698071c7a..0cfb7a792ffd 100644 --- a/src/Components/Server/src/ComponentHub.cs +++ b/src/Components/Server/src/ComponentHub.cs @@ -184,6 +184,17 @@ public async Task UpdateRootComponents(string serializedComponentOperations, str persistedState.RootComponents, serializedComponentOperations); + if (operations == null) + { + // OR: Expired state + // There was an error, so kill the circuit. + await _circuitRegistry.TerminateAsync(circuitHost.CircuitId); + await NotifyClientError(Clients.Caller, "The persisted circuit state is invalid or expired."); + Context.Abort(); + + return; + } + store = new ProtectedPrerenderComponentApplicationStore(persistedState.ApplicationState, _dataProtectionProvider); } else @@ -334,6 +345,14 @@ public async ValueTask ResumeCircuit( Context.Abort(); return null; } + + // TODO (OR): Select solution variant + // Variant A: Server-side check in ResumeCircuit + if (!CircuitPersistenceManager.CheckRootComponentMarkers(_serverComponentSerializer, persistedCircuitState.RootComponents)) + { + Log.InvalidInputData(_logger); + return null; + } } else { diff --git a/src/Components/Server/test/Circuits/CircuitHostTest.cs b/src/Components/Server/test/Circuits/CircuitHostTest.cs index 670ba4427247..2f1b869ba751 100644 --- a/src/Components/Server/test/Circuits/CircuitHostTest.cs +++ b/src/Components/Server/test/Circuits/CircuitHostTest.cs @@ -429,10 +429,11 @@ public async Task SendPersistedStateToClient_WithSuccessfulInvocation_ReturnsTru var rootComponents = "mock-root-components"; var applicationState = "mock-application-state"; + var expiration = DateTimeOffset.UtcNow.Add(TimeSpan.FromMinutes(5)); var cancellationToken = new CancellationToken(); // Act - var result = await circuitHost.SendPersistedStateToClient(rootComponents, applicationState, cancellationToken); + var result = await circuitHost.SendPersistedStateToClient(rootComponents, applicationState, expiration, cancellationToken); // Assert Assert.True(result); @@ -463,10 +464,11 @@ public async Task SendPersistedStateToClient_WithFailedInvocation_ReturnsFalse() var rootComponents = "mock-root-components"; var applicationState = "mock-application-state"; + var expiration = DateTimeOffset.UtcNow.Add(TimeSpan.FromMinutes(5)); var cancellationToken = new CancellationToken(); // Act - var result = await circuitHost.SendPersistedStateToClient(rootComponents, applicationState, cancellationToken); + var result = await circuitHost.SendPersistedStateToClient(rootComponents, applicationState, expiration, cancellationToken); // Assert Assert.False(result); @@ -490,10 +492,11 @@ public async Task SendPersistedStateToClient_WithException_LogsAndReturnsFalse() var rootComponents = "mock-root-components"; var applicationState = "mock-application-state"; + var expiration = DateTimeOffset.UtcNow.Add(TimeSpan.FromMinutes(5)); var cancellationToken = new CancellationToken(); // Act - var result = await circuitHost.SendPersistedStateToClient(rootComponents, applicationState, cancellationToken); + var result = await circuitHost.SendPersistedStateToClient(rootComponents, applicationState, expiration, cancellationToken); // Assert Assert.False(result); @@ -514,10 +517,11 @@ public async Task SendPersistedStateToClient_WithDisconnectedClient_ReturnsFalse var rootComponents = "mock-root-components"; var applicationState = "mock-application-state"; + var expiration = DateTimeOffset.UtcNow.Add(TimeSpan.FromMinutes(5)); var cancellationToken = new CancellationToken(); // Act & Assert - Assert.False(await circuitHost.SendPersistedStateToClient(rootComponents, applicationState, cancellationToken)); + Assert.False(await circuitHost.SendPersistedStateToClient(rootComponents, applicationState, expiration, cancellationToken)); } [Fact] diff --git a/src/Components/Web.JS/src/Platform/Circuits/CircuitManager.ts b/src/Components/Web.JS/src/Platform/Circuits/CircuitManager.ts index cc2a536a4129..0faa4b227b93 100644 --- a/src/Components/Web.JS/src/Platform/Circuits/CircuitManager.ts +++ b/src/Components/Web.JS/src/Platform/Circuits/CircuitManager.ts @@ -19,6 +19,12 @@ import { showErrorNotification } from '../../BootErrors'; import { attachWebRendererInterop, detachWebRendererInterop } from '../../Rendering/WebRendererInteropMethods'; import { sendJSDataStream } from './CircuitStreamingInterop'; +interface PersistedCircuitState { + components: string; + applicationState: string; + expiration: number; +} + export class CircuitManager implements DotNet.DotNetCallDispatcher { private readonly _componentManager: RootComponentManager; @@ -53,7 +59,7 @@ export class CircuitManager implements DotNet.DotNetCallDispatcher { private _disconnectingState = new CircuitState('disconnecting'); - private _persistedCircuitState?: { components: string, applicationState: string }; + private _persistedCircuitState?: PersistedCircuitState; private _isFirstRender = true; @@ -72,6 +78,20 @@ export class CircuitManager implements DotNet.DotNetCallDispatcher { this._dispatcher = DotNet.attachDispatcher(this); } + private tryTakePersistedState(): PersistedCircuitState | undefined { + // TODO (OR): Select solution variant + // Variant B: Client-side check + if (this._persistedCircuitState && this._persistedCircuitState.expiration <= Date.now()) { + this._logger.log(LogLevel.Debug, 'Persisted circuit state has expired and will not be used.'); + this._persistedCircuitState = undefined; + return undefined; + } else { + const state = this._persistedCircuitState; + this._persistedCircuitState = undefined; + return state; + } + } + public start(): Promise { if (this.isDisposedOrDisposing()) { throw new Error('Cannot start a disposed circuit.'); @@ -139,14 +159,14 @@ export class CircuitManager implements DotNet.DotNetCallDispatcher { connection.on('JS.EndInvokeDotNet', this._dispatcher.endInvokeDotNetFromJS.bind(this._dispatcher)); connection.on('JS.ReceiveByteArray', this._dispatcher.receiveByteArray.bind(this._dispatcher)); - connection.on('JS.SavePersistedState', (circuitId: string, components: string, applicationState: string) => { + connection.on('JS.SavePersistedState', (circuitId: string, components: string, applicationState: string, expiration: number) => { if (!this._circuitId) { throw new Error('Circuit host not initialized.'); } if (circuitId !== this._circuitId) { throw new Error(`Received persisted state for circuit ID '${circuitId}', but the current circuit ID is '${this._circuitId}'.`); } - this._persistedCircuitState = { components, applicationState }; + this._persistedCircuitState = { components, applicationState, expiration }; return true; }); @@ -378,8 +398,7 @@ export class CircuitManager implements DotNet.DotNetCallDispatcher { } } - const persistedCircuitState = this._persistedCircuitState; - this._persistedCircuitState = undefined; + const persistedCircuitState = this.tryTakePersistedState(); const newCircuitId = await this._connection!.invoke( 'ResumeCircuit',