From 6d2985e75bb1a661b5522e972a27a14155af76d7 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Mon, 2 Jun 2025 14:56:31 +1000 Subject: [PATCH 1/5] feat: add vpn start progress --- .../.idea/projectSettingsUpdater.xml | 1 + App/Models/RpcModel.cs | 19 ++ .../PublishProfiles/win-arm64.pubxml | 12 - App/Properties/PublishProfiles/win-x64.pubxml | 12 - App/Properties/PublishProfiles/win-x86.pubxml | 12 - App/Services/RpcController.cs | 31 ++- App/ViewModels/TrayWindowViewModel.cs | 38 ++- .../Pages/TrayWindowLoginRequiredPage.xaml | 2 +- App/Views/Pages/TrayWindowMainPage.xaml | 11 +- Tests.Vpn.Service/DownloaderTest.cs | 50 +++- Vpn.Proto/vpn.proto | 16 +- Vpn.Service/Downloader.cs | 220 +++++++++++++++--- Vpn.Service/Manager.cs | 67 +++++- Vpn.Service/ManagerRpc.cs | 2 +- Vpn.Service/Program.cs | 11 +- Vpn.Service/TunnelSupervisor.cs | 6 +- Vpn/Speaker.cs | 2 +- 17 files changed, 421 insertions(+), 91 deletions(-) delete mode 100644 App/Properties/PublishProfiles/win-arm64.pubxml delete mode 100644 App/Properties/PublishProfiles/win-x64.pubxml delete mode 100644 App/Properties/PublishProfiles/win-x86.pubxml diff --git a/.idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml b/.idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml index 64af657..ef20cb0 100644 --- a/.idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml +++ b/.idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml @@ -2,6 +2,7 @@ \ No newline at end of file diff --git a/App/Models/RpcModel.cs b/App/Models/RpcModel.cs index 034f405..33b4647 100644 --- a/App/Models/RpcModel.cs +++ b/App/Models/RpcModel.cs @@ -19,12 +19,30 @@ public enum VpnLifecycle Stopping, } +public class VpnStartupProgress +{ + public double Progress { get; set; } = 0.0; // 0.0 to 1.0 + public string Message { get; set; } = string.Empty; + + public VpnStartupProgress Clone() + { + return new VpnStartupProgress + { + Progress = Progress, + Message = Message, + }; + } +} + public class RpcModel { public RpcLifecycle RpcLifecycle { get; set; } = RpcLifecycle.Disconnected; public VpnLifecycle VpnLifecycle { get; set; } = VpnLifecycle.Unknown; + // Nullable because it is only set when the VpnLifecycle is Starting + public VpnStartupProgress? VpnStartupProgress { get; set; } + public IReadOnlyList Workspaces { get; set; } = []; public IReadOnlyList Agents { get; set; } = []; @@ -35,6 +53,7 @@ public RpcModel Clone() { RpcLifecycle = RpcLifecycle, VpnLifecycle = VpnLifecycle, + VpnStartupProgress = VpnStartupProgress?.Clone(), Workspaces = Workspaces, Agents = Agents, }; diff --git a/App/Properties/PublishProfiles/win-arm64.pubxml b/App/Properties/PublishProfiles/win-arm64.pubxml deleted file mode 100644 index ac9753e..0000000 --- a/App/Properties/PublishProfiles/win-arm64.pubxml +++ /dev/null @@ -1,12 +0,0 @@ - - - - - FileSystem - ARM64 - win-arm64 - bin\$(Configuration)\$(TargetFramework)\$(RuntimeIdentifier)\publish\ - - diff --git a/App/Properties/PublishProfiles/win-x64.pubxml b/App/Properties/PublishProfiles/win-x64.pubxml deleted file mode 100644 index 942523b..0000000 --- a/App/Properties/PublishProfiles/win-x64.pubxml +++ /dev/null @@ -1,12 +0,0 @@ - - - - - FileSystem - x64 - win-x64 - bin\$(Configuration)\$(TargetFramework)\$(RuntimeIdentifier)\publish\ - - diff --git a/App/Properties/PublishProfiles/win-x86.pubxml b/App/Properties/PublishProfiles/win-x86.pubxml deleted file mode 100644 index e763481..0000000 --- a/App/Properties/PublishProfiles/win-x86.pubxml +++ /dev/null @@ -1,12 +0,0 @@ - - - - - FileSystem - x86 - win-x86 - bin\$(Configuration)\$(TargetFramework)\$(RuntimeIdentifier)\publish\ - - diff --git a/App/Services/RpcController.cs b/App/Services/RpcController.cs index 7beff66..b42c058 100644 --- a/App/Services/RpcController.cs +++ b/App/Services/RpcController.cs @@ -161,7 +161,12 @@ public async Task StartVpn(CancellationToken ct = default) throw new RpcOperationException( $"Cannot start VPN without valid credentials, current state: {credentials.State}"); - MutateState(state => { state.VpnLifecycle = VpnLifecycle.Starting; }); + MutateState(state => + { + state.VpnLifecycle = VpnLifecycle.Starting; + // Explicitly clear the startup progress. + state.VpnStartupProgress = null; + }); ServiceMessage reply; try @@ -251,6 +256,9 @@ private void MutateState(Action mutator) using (_stateLock.Lock()) { mutator(_state); + // Unset the startup progress if the VpnLifecycle is not Starting + if (_state.VpnLifecycle != VpnLifecycle.Starting) + _state.VpnStartupProgress = null; newState = _state.Clone(); } @@ -283,15 +291,32 @@ private void ApplyStatusUpdate(Status status) }); } + private void ApplyStartProgressUpdate(StartProgress message) + { + MutateState(state => + { + // MutateState will undo these changes if it doesn't believe we're + // in the "Starting" state. + state.VpnStartupProgress = new VpnStartupProgress + { + Progress = message.Progress, + Message = message.Message, + }; + }); + } + private void SpeakerOnReceive(ReplyableRpcMessage message) { switch (message.Message.MsgCase) { + case ServiceMessage.MsgOneofCase.Start: + case ServiceMessage.MsgOneofCase.Stop: case ServiceMessage.MsgOneofCase.Status: ApplyStatusUpdate(message.Message.Status); break; - case ServiceMessage.MsgOneofCase.Start: - case ServiceMessage.MsgOneofCase.Stop: + case ServiceMessage.MsgOneofCase.StartProgress: + ApplyStartProgressUpdate(message.Message.StartProgress); + break; case ServiceMessage.MsgOneofCase.None: default: // TODO: log unexpected message diff --git a/App/ViewModels/TrayWindowViewModel.cs b/App/ViewModels/TrayWindowViewModel.cs index d8b3182..cd3a641 100644 --- a/App/ViewModels/TrayWindowViewModel.cs +++ b/App/ViewModels/TrayWindowViewModel.cs @@ -29,7 +29,7 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost { private const int MaxAgents = 5; private const string DefaultDashboardUrl = "https://coder.com"; - private const string DefaultHostnameSuffix = ".coder"; + private const string DefaultStartProgressMessage = "Starting Coder Connect..."; private readonly IServiceProvider _services; private readonly IRpcController _rpcController; @@ -53,6 +53,7 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost [ObservableProperty] [NotifyPropertyChangedFor(nameof(ShowEnableSection))] + [NotifyPropertyChangedFor(nameof(ShowVpnStartProgressSection))] [NotifyPropertyChangedFor(nameof(ShowWorkspacesHeader))] [NotifyPropertyChangedFor(nameof(ShowNoAgentsSection))] [NotifyPropertyChangedFor(nameof(ShowAgentsSection))] @@ -63,6 +64,7 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost [ObservableProperty] [NotifyPropertyChangedFor(nameof(ShowEnableSection))] + [NotifyPropertyChangedFor(nameof(ShowVpnStartProgressSection))] [NotifyPropertyChangedFor(nameof(ShowWorkspacesHeader))] [NotifyPropertyChangedFor(nameof(ShowNoAgentsSection))] [NotifyPropertyChangedFor(nameof(ShowAgentsSection))] @@ -70,7 +72,25 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost [NotifyPropertyChangedFor(nameof(ShowFailedSection))] public partial string? VpnFailedMessage { get; set; } = null; - public bool ShowEnableSection => VpnFailedMessage is null && VpnLifecycle is not VpnLifecycle.Started; + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(VpnStartProgressIsIndeterminate))] + [NotifyPropertyChangedFor(nameof(VpnStartProgressValueOrDefault))] + public partial int? VpnStartProgressValue { get; set; } = null; + + public int VpnStartProgressValueOrDefault => VpnStartProgressValue ?? 0; + + [ObservableProperty] + [NotifyPropertyChangedFor(nameof(VpnStartProgressMessageOrDefault))] + public partial string? VpnStartProgressMessage { get; set; } = null; + + public string VpnStartProgressMessageOrDefault => + string.IsNullOrEmpty(VpnStartProgressMessage) ? DefaultStartProgressMessage : VpnStartProgressMessage; + + public bool VpnStartProgressIsIndeterminate => VpnStartProgressValueOrDefault == 0; + + public bool ShowEnableSection => VpnFailedMessage is null && VpnLifecycle is not VpnLifecycle.Starting and not VpnLifecycle.Started; + + public bool ShowVpnStartProgressSection => VpnFailedMessage is null && VpnLifecycle is VpnLifecycle.Starting; public bool ShowWorkspacesHeader => VpnFailedMessage is null && VpnLifecycle is VpnLifecycle.Started; @@ -170,6 +190,20 @@ private void UpdateFromRpcModel(RpcModel rpcModel) VpnLifecycle = rpcModel.VpnLifecycle; VpnSwitchActive = rpcModel.VpnLifecycle is VpnLifecycle.Starting or VpnLifecycle.Started; + // VpnStartupProgress is only set when the VPN is starting. + if (rpcModel.VpnLifecycle is VpnLifecycle.Starting && rpcModel.VpnStartupProgress != null) + { + // Convert 0.00-1.00 to 0-100. + var progress = (int)(rpcModel.VpnStartupProgress.Progress * 100); + VpnStartProgressValue = Math.Clamp(progress, 0, 100); + VpnStartProgressMessage = string.IsNullOrEmpty(rpcModel.VpnStartupProgress.Message) ? null : rpcModel.VpnStartupProgress.Message; + } + else + { + VpnStartProgressValue = null; + VpnStartProgressMessage = null; + } + // Add every known agent. HashSet workspacesWithAgents = []; List agents = []; diff --git a/App/Views/Pages/TrayWindowLoginRequiredPage.xaml b/App/Views/Pages/TrayWindowLoginRequiredPage.xaml index c1d69aa..171e292 100644 --- a/App/Views/Pages/TrayWindowLoginRequiredPage.xaml +++ b/App/Views/Pages/TrayWindowLoginRequiredPage.xaml @@ -36,7 +36,7 @@ diff --git a/App/Views/Pages/TrayWindowMainPage.xaml b/App/Views/Pages/TrayWindowMainPage.xaml index 283867d..f488454 100644 --- a/App/Views/Pages/TrayWindowMainPage.xaml +++ b/App/Views/Pages/TrayWindowMainPage.xaml @@ -43,6 +43,8 @@ + + + HorizontalContentAlignment="Left"> diff --git a/Tests.Vpn.Service/DownloaderTest.cs b/Tests.Vpn.Service/DownloaderTest.cs index 986ce46..a47ffbc 100644 --- a/Tests.Vpn.Service/DownloaderTest.cs +++ b/Tests.Vpn.Service/DownloaderTest.cs @@ -2,6 +2,7 @@ using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using System.Text; +using System.Threading.Channels; using Coder.Desktop.Vpn.Service; using Microsoft.Extensions.Logging.Abstractions; @@ -278,7 +279,7 @@ public async Task Download(CancellationToken ct) NullDownloadValidator.Instance, ct); await dlTask.Task; Assert.That(dlTask.TotalBytes, Is.EqualTo(4)); - Assert.That(dlTask.BytesRead, Is.EqualTo(4)); + Assert.That(dlTask.BytesWritten, Is.EqualTo(4)); Assert.That(dlTask.Progress, Is.EqualTo(1)); Assert.That(dlTask.IsCompleted, Is.True); Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test")); @@ -301,17 +302,56 @@ public async Task DownloadSameDest(CancellationToken ct) var dlTask0 = await startTask0; await dlTask0.Task; Assert.That(dlTask0.TotalBytes, Is.EqualTo(5)); - Assert.That(dlTask0.BytesRead, Is.EqualTo(5)); + Assert.That(dlTask0.BytesWritten, Is.EqualTo(5)); Assert.That(dlTask0.Progress, Is.EqualTo(1)); Assert.That(dlTask0.IsCompleted, Is.True); var dlTask1 = await startTask1; await dlTask1.Task; Assert.That(dlTask1.TotalBytes, Is.EqualTo(5)); - Assert.That(dlTask1.BytesRead, Is.EqualTo(5)); + Assert.That(dlTask1.BytesWritten, Is.EqualTo(5)); Assert.That(dlTask1.Progress, Is.EqualTo(1)); Assert.That(dlTask1.IsCompleted, Is.True); } + [Test(Description = "Download with X-Original-Content-Length")] + [CancelAfter(30_000)] + public async Task DownloadWithXOriginalContentLength(CancellationToken ct) + { + using var httpServer = new TestHttpServer(async ctx => + { + ctx.Response.StatusCode = 200; + ctx.Response.Headers.Add("X-Original-Content-Length", "6"); // wrong but should be used until complete + ctx.Response.ContentType = "text/plain"; + ctx.Response.ContentLength64 = 4; // This should be ignored. + await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), ct); + }); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + var manager = new Downloader(NullLogger.Instance); + var req = new HttpRequestMessage(HttpMethod.Get, url); + var dlTask = await manager.StartDownloadAsync(req, destPath, NullDownloadValidator.Instance, ct); + + var progressChannel = Channel.CreateUnbounded(); + dlTask.ProgressChanged += (_, args) => + Assert.That(progressChannel.Writer.TryWrite(args), Is.True); + + await dlTask.Task; + Assert.That(dlTask.TotalBytes, Is.EqualTo(4)); // should equal BytesWritten after completion + Assert.That(dlTask.BytesWritten, Is.EqualTo(4)); + progressChannel.Writer.Complete(); + + var list = progressChannel.Reader.ReadAllAsync(ct).ToBlockingEnumerable(ct).ToList(); + Assert.That(list.Count, Is.GreaterThanOrEqualTo(2)); // there may be an item in the middle + // The first item should be the initial progress with 0 bytes written. + Assert.That(list[0].BytesWritten, Is.EqualTo(0)); + Assert.That(list[0].TotalBytes, Is.EqualTo(6)); // from X-Original-Content-Length + Assert.That(list[0].Progress, Is.EqualTo(0.0d)); + // The last item should be final progress with the actual total bytes. + Assert.That(list[^1].BytesWritten, Is.EqualTo(4)); + Assert.That(list[^1].TotalBytes, Is.EqualTo(4)); // from the actual bytes written + Assert.That(list[^1].Progress, Is.EqualTo(1.0d)); + } + [Test(Description = "Download with custom headers")] [CancelAfter(30_000)] public async Task WithHeaders(CancellationToken ct) @@ -347,7 +387,7 @@ public async Task DownloadExisting(CancellationToken ct) var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, NullDownloadValidator.Instance, ct); await dlTask.Task; - Assert.That(dlTask.BytesRead, Is.Zero); + Assert.That(dlTask.BytesWritten, Is.Zero); Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test")); Assert.That(File.GetLastWriteTime(destPath), Is.LessThan(DateTime.Now - TimeSpan.FromDays(1))); } @@ -368,7 +408,7 @@ public async Task DownloadExistingDifferentContent(CancellationToken ct) var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, NullDownloadValidator.Instance, ct); await dlTask.Task; - Assert.That(dlTask.BytesRead, Is.EqualTo(4)); + Assert.That(dlTask.BytesWritten, Is.EqualTo(4)); Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test")); Assert.That(File.GetLastWriteTime(destPath), Is.GreaterThan(DateTime.Now - TimeSpan.FromDays(1))); } diff --git a/Vpn.Proto/vpn.proto b/Vpn.Proto/vpn.proto index 2561a4b..fa2f003 100644 --- a/Vpn.Proto/vpn.proto +++ b/Vpn.Proto/vpn.proto @@ -60,7 +60,8 @@ message ServiceMessage { oneof msg { StartResponse start = 2; StopResponse stop = 3; - Status status = 4; // either in reply to a StatusRequest or broadcasted + Status status = 4; // either in reply to a StatusRequest or broadcasted + StartProgress start_progress = 5; // broadcasted during startup } } @@ -218,6 +219,19 @@ message StartResponse { string error_message = 2; } +// StartProgress is sent from the manager to the client to indicate the +// download/startup progress of the tunnel. This will be sent during the +// processing of a StartRequest before the StartResponse is sent. +// +// Note: this is currently a broadcasted message to all clients due to the +// inability to easily send messages to a specific client in the Speaker +// implementation. If clients are not expecting these messages, they +// should ignore them. +message StartProgress { + double progress = 1; // 0.0 to 1.0 + string message = 2; // human-readable status message, must be set +} + // StopRequest is a request from the manager to stop the tunnel. The tunnel replies with a // StopResponse. message StopRequest {} diff --git a/Vpn.Service/Downloader.cs b/Vpn.Service/Downloader.cs index 6a3108b..856a637 100644 --- a/Vpn.Service/Downloader.cs +++ b/Vpn.Service/Downloader.cs @@ -338,32 +338,81 @@ internal static async Task TaskOrCancellation(Task task, CancellationToken cance } } +public class DownloadProgressEvent +{ + // TODO: speed calculation would be nice + public ulong BytesWritten { get; init; } + public ulong? TotalBytes { get; init; } // null if unknown + public double? Progress { get; init; } // 0.0 - 1.0, null if unknown + + public override string ToString() + { + var s = FriendlyBytes(BytesWritten); + if (TotalBytes != null) + s += $" of {FriendlyBytes(TotalBytes.Value)}"; + else + s += " of unknown"; + if (Progress != null) + s += $" ({Progress:0%})"; + return s; + } + + private static readonly string[] ByteSuffixes = ["B", "KB", "MB", "GB", "TB", "PB", "EB"]; + + // Unfortunately this is copied from FriendlyByteConverter in App. Ideally + // it should go into some shared utilities project, but it's overkill to do + // that for a single tiny function until we have more shared code. + private static string FriendlyBytes(ulong bytes) + { + if (bytes == 0) + return $"0 {ByteSuffixes[0]}"; + + var place = Convert.ToInt32(Math.Floor(Math.Log(bytes, 1024))); + var num = Math.Round(bytes / Math.Pow(1024, place), 1); + return $"{num} {ByteSuffixes[place]}"; + } +} + /// -/// Downloads an Url to a file on disk. The download will be written to a temporary file first, then moved to the final +/// Downloads a Url to a file on disk. The download will be written to a temporary file first, then moved to the final /// destination. The SHA1 of any existing file will be calculated and used as an ETag to avoid downloading the file if /// it hasn't changed. /// public class DownloadTask { private const int BufferSize = 4096; + private const int ProgressUpdateDelayMs = 50; + private const string XOriginalContentLengthHeader = "X-Original-Content-Length"; // overrides Content-Length if available - private static readonly HttpClient HttpClient = new(); + private static readonly HttpClient HttpClient = new(new HttpClientHandler + { + AutomaticDecompression = DecompressionMethods.All, + }); private readonly string _destinationDirectory; private readonly ILogger _logger; private readonly RaiiSemaphoreSlim _semaphore = new(1, 1); private readonly IDownloadValidator _validator; - public readonly string DestinationPath; + private readonly string _destinationPath; + private readonly string _tempDestinationPath; + + // ProgressChanged events are always delayed by up to 50ms to avoid + // flooding. + // + // This will be called: + // - once after the request succeeds but before the read/write routine + // begins + // - occasionally while the file is being downloaded (at least 50ms apart) + // - once when the download is complete + public EventHandler? ProgressChanged; public readonly HttpRequestMessage Request; - public readonly string TempDestinationPath; - public ulong? TotalBytes { get; private set; } - public ulong BytesRead { get; private set; } public Task Task { get; private set; } = null!; // Set in EnsureStartedAsync - - public double? Progress => TotalBytes == null ? null : (double)BytesRead / TotalBytes.Value; + public ulong BytesWritten { get; private set; } + public ulong? TotalBytes { get; private set; } + public double? Progress => TotalBytes == null ? null : (double)BytesWritten / TotalBytes.Value; public bool IsCompleted => Task.IsCompleted; internal DownloadTask(ILogger logger, HttpRequestMessage req, string destinationPath, IDownloadValidator validator) @@ -374,17 +423,17 @@ internal DownloadTask(ILogger logger, HttpRequestMessage req, string destination if (string.IsNullOrWhiteSpace(destinationPath)) throw new ArgumentException("Destination path must not be empty", nameof(destinationPath)); - DestinationPath = Path.GetFullPath(destinationPath); - if (Path.EndsInDirectorySeparator(DestinationPath)) - throw new ArgumentException($"Destination path '{DestinationPath}' must not end in a directory separator", + _destinationPath = Path.GetFullPath(destinationPath); + if (Path.EndsInDirectorySeparator(_destinationPath)) + throw new ArgumentException($"Destination path '{_destinationPath}' must not end in a directory separator", nameof(destinationPath)); - _destinationDirectory = Path.GetDirectoryName(DestinationPath) + _destinationDirectory = Path.GetDirectoryName(_destinationPath) ?? throw new ArgumentException( - $"Destination path '{DestinationPath}' must have a parent directory", + $"Destination path '{_destinationPath}' must have a parent directory", nameof(destinationPath)); - TempDestinationPath = Path.Combine(_destinationDirectory, "." + Path.GetFileName(DestinationPath) + + _tempDestinationPath = Path.Combine(_destinationDirectory, "." + Path.GetFileName(_destinationPath) + ".download-" + Path.GetRandomFileName()); } @@ -406,9 +455,9 @@ private async Task Start(CancellationToken ct = default) // If the destination path exists, generate a Coder SHA1 ETag and send // it in the If-None-Match header to the server. - if (File.Exists(DestinationPath)) + if (File.Exists(_destinationPath)) { - await using var stream = File.OpenRead(DestinationPath); + await using var stream = File.OpenRead(_destinationPath); var etag = Convert.ToHexString(await SHA1.HashDataAsync(stream, ct)).ToLower(); Request.Headers.Add("If-None-Match", "\"" + etag + "\""); } @@ -419,11 +468,11 @@ private async Task Start(CancellationToken ct = default) _logger.LogInformation("File has not been modified, skipping download"); try { - await _validator.ValidateAsync(DestinationPath, ct); + await _validator.ValidateAsync(_destinationPath, ct); } catch (Exception e) { - _logger.LogWarning(e, "Existing file '{DestinationPath}' failed custom validation", DestinationPath); + _logger.LogWarning(e, "Existing file '{DestinationPath}' failed custom validation", _destinationPath); throw new Exception("Existing file failed validation after 304 Not Modified", e); } @@ -448,6 +497,26 @@ private async Task Start(CancellationToken ct = default) if (res.Content.Headers.ContentLength >= 0) TotalBytes = (ulong)res.Content.Headers.ContentLength; + // X-Original-Content-Length overrules Content-Length if set. + if (res.Headers.TryGetValues(XOriginalContentLengthHeader, out var headerValues)) + { + // If there are multiple we only look at the first one. + var headerValue = headerValues.ToList().FirstOrDefault(); + if (!string.IsNullOrEmpty(headerValue) && ulong.TryParse(headerValue, out var originalContentLength)) + TotalBytes = originalContentLength; + else + _logger.LogWarning( + "Failed to parse {XOriginalContentLengthHeader} header value '{HeaderValue}'", + XOriginalContentLengthHeader, headerValue); + } + + SendProgressUpdate(new DownloadProgressEvent + { + BytesWritten = 0, + TotalBytes = TotalBytes, + Progress = 0.0, + }); + await Download(res, ct); } @@ -459,11 +528,11 @@ private async Task Download(HttpResponseMessage res, CancellationToken ct) FileStream tempFile; try { - tempFile = File.Create(TempDestinationPath, BufferSize, FileOptions.SequentialScan); + tempFile = File.Create(_tempDestinationPath, BufferSize, FileOptions.SequentialScan); } catch (Exception e) { - _logger.LogError(e, "Failed to create temporary file '{TempDestinationPath}'", TempDestinationPath); + _logger.LogError(e, "Failed to create temporary file '{TempDestinationPath}'", _tempDestinationPath); throw; } @@ -476,13 +545,31 @@ private async Task Download(HttpResponseMessage res, CancellationToken ct) { await tempFile.WriteAsync(buffer.AsMemory(0, n), ct); sha1?.TransformBlock(buffer, 0, n, null, 0); - BytesRead += (ulong)n; + BytesWritten += (ulong)n; + await QueueProgressUpdate(new DownloadProgressEvent + { + BytesWritten = BytesWritten, + TotalBytes = TotalBytes, + Progress = Progress, + }, ct); } } - if (TotalBytes != null && BytesRead != TotalBytes) + // Clear any pending progress updates to ensure they won't be sent + // after the final update. + await ClearQueuedProgressUpdate(ct); + // Then write the final status update. + TotalBytes = BytesWritten; + SendProgressUpdate(new DownloadProgressEvent + { + BytesWritten = BytesWritten, + TotalBytes = BytesWritten, + Progress = 1.0, + }); + + if (TotalBytes != null && BytesWritten != TotalBytes) throw new IOException( - $"Downloaded file size does not match response Content-Length: Content-Length={TotalBytes}, BytesRead={BytesRead}"); + $"Downloaded file size does not match response Content-Length: Content-Length={TotalBytes}, BytesRead={BytesWritten}"); // Verify the ETag if it was sent by the server. if (res.Headers.Contains("ETag") && sha1 != null) @@ -497,26 +584,99 @@ private async Task Download(HttpResponseMessage res, CancellationToken ct) try { - await _validator.ValidateAsync(TempDestinationPath, ct); + await _validator.ValidateAsync(_tempDestinationPath, ct); } catch (Exception e) { _logger.LogWarning(e, "Downloaded file '{TempDestinationPath}' failed custom validation", - TempDestinationPath); + _tempDestinationPath); throw new HttpRequestException("Downloaded file failed validation", e); } - File.Move(TempDestinationPath, DestinationPath, true); + File.Move(_tempDestinationPath, _destinationPath, true); } - finally + catch { #if DEBUG _logger.LogWarning("Not deleting temporary file '{TempDestinationPath}' in debug mode", - TempDestinationPath); + _tempDestinationPath); #else - if (File.Exists(TempDestinationPath)) - File.Delete(TempDestinationPath); + try + { + if (File.Exists(TempDestinationPath)) + File.Delete(TempDestinationPath); + } + catch (Exception e) + { + _logger.LogError(e, "Failed to delete temporary file '{TempDestinationPath}'", _tempDestinationPath); + } #endif + throw; } } + + // _progressEventLock protects _progressUpdateTask and _pendingProgressEvent. + private readonly RaiiSemaphoreSlim _progressEventLock = new(1, 1); + private readonly CancellationTokenSource _progressUpdateCts = new(); + private Task? _progressUpdateTask; + private DownloadProgressEvent? _pendingProgressEvent; + + // Can be called multiple times, but must not be called or in progress while + // SendQueuedProgressUpdateNow is called. + private async Task QueueProgressUpdate(DownloadProgressEvent e, CancellationToken ct) + { + using var _1 = await _progressEventLock.LockAsync(ct); + _pendingProgressEvent = e; + + if (_progressUpdateCts.IsCancellationRequested) + throw new InvalidOperationException("Progress update task was cancelled, cannot queue new progress update"); + + // Start a task with a 50ms delay unless one is already running. + var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _progressUpdateCts.Token); + cts.CancelAfter(TimeSpan.FromSeconds(5)); + _progressUpdateTask ??= Task.Delay(ProgressUpdateDelayMs, cts.Token) + .ContinueWith(t => + { + cts.Cancel(); + using var _2 = _progressEventLock.Lock(); + _progressUpdateTask = null; + if (t.IsFaulted || t.IsCanceled) return; + + var ev = _pendingProgressEvent; + if (ev != null) SendProgressUpdate(ev); + }, cts.Token); + } + + // Must only be called after all QueueProgressUpdate calls have completed. + private async Task ClearQueuedProgressUpdate(CancellationToken ct) + { + Task? t; + using (var _ = _progressEventLock.LockAsync(ct)) + { + await _progressUpdateCts.CancelAsync(); + t = _progressUpdateTask; + } + + // We can't continue to hold the lock here because the continuation + // grabs a lock. We don't need to worry about a new task spawning after + // this because the token is cancelled. + if (t == null) return; + try + { + await t.WaitAsync(ct); + } + catch (TaskCanceledException) + { + // Ignore + } + } + + private void SendProgressUpdate(DownloadProgressEvent e) + { + var handler = ProgressChanged; + if (handler == null) + return; + // Start a new task in the background to invoke the event. + _ = Task.Run(() => handler.Invoke(this, e)); + } } diff --git a/Vpn.Service/Manager.cs b/Vpn.Service/Manager.cs index fc014c0..cf2bb8a 100644 --- a/Vpn.Service/Manager.cs +++ b/Vpn.Service/Manager.cs @@ -26,6 +26,10 @@ public interface IManager : IDisposable /// public class Manager : IManager { + // We scale the download progress to 0.00-0.90, and use 0.90-1.00 for the + // remainder of startup. + private const double DownloadProgressScale = 0.90; + private readonly ManagerConfig _config; private readonly IDownloader _downloader; private readonly ILogger _logger; @@ -131,6 +135,8 @@ private async ValueTask HandleClientMessageStart(ClientMessage me { try { + await BroadcastStartProgress(0.0, "Starting Coder Connect...", ct); + var serverVersion = await CheckServerVersionAndCredentials(message.Start.CoderUrl, message.Start.ApiToken, ct); if (_status == TunnelStatus.Started && _lastStartRequest != null && @@ -151,10 +157,14 @@ private async ValueTask HandleClientMessageStart(ClientMessage me _lastServerVersion = serverVersion; // TODO: each section of this operation needs a timeout + // Stop the tunnel if it's running so we don't have to worry about // permissions issues when replacing the binary. await _tunnelSupervisor.StopAsync(ct); + await DownloadTunnelBinaryAsync(message.Start.CoderUrl, serverVersion.SemVersion, ct); + + await BroadcastStartProgress(DownloadProgressScale, "Starting Coder Connect...", ct); await _tunnelSupervisor.StartAsync(_config.TunnelBinaryPath, HandleTunnelRpcMessage, HandleTunnelRpcError, ct); @@ -237,6 +247,9 @@ private void HandleTunnelRpcMessage(ReplyableRpcMessage CurrentStatus(CancellationToken ct = default) private async Task BroadcastStatus(TunnelStatus? newStatus = null, CancellationToken ct = default) { if (newStatus != null) _status = newStatus.Value; - await _managerRpc.BroadcastAsync(new ServiceMessage + await FallibleBroadcast(new ServiceMessage { Status = await CurrentStatus(ct), }, ct); } + private async Task FallibleBroadcast(ServiceMessage message, CancellationToken ct = default) + { + // Broadcast the messages out with a low timeout. If clients don't + // receive broadcasts in time, it's not a big deal. + using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); + //cts.CancelAfter(TimeSpan.FromMilliseconds(100)); + try + { + await _managerRpc.BroadcastAsync(message, cts.Token); + } + catch (Exception ex) + { + _logger.LogWarning(ex, "Could not broadcast low priority message to all RPC clients: {Message}", message); + } + } + private void HandleTunnelRpcError(Exception e) { _logger.LogError(e, "Manager<->Tunnel RPC error"); @@ -427,10 +456,44 @@ private async Task DownloadTunnelBinaryAsync(string baseUrl, SemVersion expected var downloadTask = await _downloader.StartDownloadAsync(req, _config.TunnelBinaryPath, validators, ct); - // TODO: monitor and report progress when we have a mechanism to do so + var progressLock = new RaiiSemaphoreSlim(1, 1); + var progressBroadcastCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + downloadTask.ProgressChanged += (sender, ev) => + { + using var _ = progressLock.Lock(); + if (progressBroadcastCts.IsCancellationRequested) return; + _logger.LogInformation("Download progress: {ev}", ev); + + // Scale the progress value to be between 0.00 and 0.90. + var progress = ev.Progress * DownloadProgressScale ?? 0.0; + var message = $"Downloading Coder Connect binary...\n{ev}"; + BroadcastStartProgress(progress, message, progressBroadcastCts.Token).Wait(progressBroadcastCts.Token); + }; // Awaiting this will check the checksum (via the ETag) if the file // exists, and will also validate the signature and version. await downloadTask.Task; + + // Prevent any lagging progress events from being sent. + // ReSharper disable once PossiblyMistakenUseOfCancellationToken + using (await progressLock.LockAsync(ct)) + await progressBroadcastCts.CancelAsync(); + + // We don't send a broadcast here as we immediately send one in the + // parent routine. + _logger.LogInformation("Completed downloading VPN binary"); + } + + private async Task BroadcastStartProgress(double progress, string message, CancellationToken ct = default) + { + _logger.LogInformation("Start progress: {Progress:0%} - {Message}", progress, message); + await FallibleBroadcast(new ServiceMessage + { + StartProgress = new StartProgress + { + Progress = progress, + Message = message, + }, + }, ct); } } diff --git a/Vpn.Service/ManagerRpc.cs b/Vpn.Service/ManagerRpc.cs index c23752f..d922caf 100644 --- a/Vpn.Service/ManagerRpc.cs +++ b/Vpn.Service/ManagerRpc.cs @@ -133,7 +133,7 @@ public async Task BroadcastAsync(ServiceMessage message, CancellationToken ct) try { var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); - cts.CancelAfter(5 * 1000); + cts.CancelAfter(TimeSpan.FromSeconds(2)); await client.Speaker.SendMessage(message, cts.Token); } catch (ObjectDisposedException) diff --git a/Vpn.Service/Program.cs b/Vpn.Service/Program.cs index fc61247..094875d 100644 --- a/Vpn.Service/Program.cs +++ b/Vpn.Service/Program.cs @@ -16,10 +16,12 @@ public static class Program #if !DEBUG private const string ServiceName = "Coder Desktop"; private const string ConfigSubKey = @"SOFTWARE\Coder Desktop\VpnService"; + private const string DefaultLogLevel = "Information"; #else // This value matches Create-Service.ps1. private const string ServiceName = "Coder Desktop (Debug)"; private const string ConfigSubKey = @"SOFTWARE\Coder Desktop\DebugVpnService"; + private const string DefaultLogLevel = "Debug"; #endif private const string ManagerConfigSection = "Manager"; @@ -81,6 +83,10 @@ private static async Task BuildAndRun(string[] args) builder.Services.AddSingleton(); // Services + builder.Services.AddHostedService(); + builder.Services.AddHostedService(); + + // Either run as a Windows service or a console application if (!Environment.UserInteractive) { MainLogger.Information("Running as a windows service"); @@ -91,9 +97,6 @@ private static async Task BuildAndRun(string[] args) MainLogger.Information("Running as a console application"); } - builder.Services.AddHostedService(); - builder.Services.AddHostedService(); - var host = builder.Build(); Log.Logger = (ILogger)host.Services.GetService(typeof(ILogger))!; MainLogger.Information("Application is starting"); @@ -108,7 +111,7 @@ private static void AddDefaultConfig(IConfigurationBuilder builder) ["Serilog:Using:0"] = "Serilog.Sinks.File", ["Serilog:Using:1"] = "Serilog.Sinks.Console", - ["Serilog:MinimumLevel"] = "Information", + ["Serilog:MinimumLevel"] = DefaultLogLevel, ["Serilog:Enrich:0"] = "FromLogContext", ["Serilog:WriteTo:0:Name"] = "File", diff --git a/Vpn.Service/TunnelSupervisor.cs b/Vpn.Service/TunnelSupervisor.cs index a323cac..6ff4f3b 100644 --- a/Vpn.Service/TunnelSupervisor.cs +++ b/Vpn.Service/TunnelSupervisor.cs @@ -100,17 +100,15 @@ public async Task StartAsync(string binPath, }; // TODO: maybe we should change the log format in the inner binary // to something without a timestamp - var outLogger = Log.ForContext("SourceContext", "coder-vpn.exe[OUT]"); - var errLogger = Log.ForContext("SourceContext", "coder-vpn.exe[ERR]"); _subprocess.OutputDataReceived += (_, args) => { if (!string.IsNullOrWhiteSpace(args.Data)) - outLogger.Debug("{Data}", args.Data); + _logger.LogDebug("stdout: {Data}", args.Data); }; _subprocess.ErrorDataReceived += (_, args) => { if (!string.IsNullOrWhiteSpace(args.Data)) - errLogger.Debug("{Data}", args.Data); + _logger.LogDebug("stderr: {Data}", args.Data); }; // Pass the other end of the pipes to the subprocess and dispose diff --git a/Vpn/Speaker.cs b/Vpn/Speaker.cs index d113a50..37ec554 100644 --- a/Vpn/Speaker.cs +++ b/Vpn/Speaker.cs @@ -123,7 +123,7 @@ public async Task StartAsync(CancellationToken ct = default) // Handshakes should always finish quickly, so enforce a 5s timeout. using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token); cts.CancelAfter(TimeSpan.FromSeconds(5)); - await PerformHandshake(ct); + await PerformHandshake(cts.Token); // Start ReceiveLoop in the background. _receiveTask = ReceiveLoop(_cts.Token); From cd966bea1eb2a3b3cc07834e2384423aa597397a Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Mon, 2 Jun 2025 15:09:48 +1000 Subject: [PATCH 2/5] fixup! feat: add vpn start progress --- Vpn.Service/Downloader.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Vpn.Service/Downloader.cs b/Vpn.Service/Downloader.cs index 856a637..4e7e5b2 100644 --- a/Vpn.Service/Downloader.cs +++ b/Vpn.Service/Downloader.cs @@ -603,8 +603,8 @@ await QueueProgressUpdate(new DownloadProgressEvent #else try { - if (File.Exists(TempDestinationPath)) - File.Delete(TempDestinationPath); + if (File.Exists(_tempDestinationPath)) + File.Delete(_tempDestinationPath); } catch (Exception e) { From fb46593e59264f1ada25239d18d78c45eade04fd Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Fri, 6 Jun 2025 15:22:19 +1000 Subject: [PATCH 3/5] change to enums --- App/Models/RpcModel.cs | 135 +++++++++++++++++- .../PublishProfiles/win-arm64.pubxml | 12 ++ App/Properties/PublishProfiles/win-x64.pubxml | 12 ++ App/Properties/PublishProfiles/win-x86.pubxml | 12 ++ App/Services/RpcController.cs | 9 +- App/ViewModels/TrayWindowViewModel.cs | 5 +- Tests.Vpn.Service/DownloaderTest.cs | 4 +- Vpn.Proto/vpn.proto | 13 +- Vpn.Service/Downloader.cs | 18 ++- Vpn.Service/Manager.cs | 30 ++-- 10 files changed, 207 insertions(+), 43 deletions(-) create mode 100644 App/Properties/PublishProfiles/win-arm64.pubxml create mode 100644 App/Properties/PublishProfiles/win-x64.pubxml create mode 100644 App/Properties/PublishProfiles/win-x86.pubxml diff --git a/App/Models/RpcModel.cs b/App/Models/RpcModel.cs index 33b4647..426863b 100644 --- a/App/Models/RpcModel.cs +++ b/App/Models/RpcModel.cs @@ -1,4 +1,7 @@ +using System; using System.Collections.Generic; +using System.Diagnostics; +using Coder.Desktop.App.Converters; using Coder.Desktop.Vpn.Proto; namespace Coder.Desktop.App.Models; @@ -19,17 +22,141 @@ public enum VpnLifecycle Stopping, } +public enum VpnStartupStage +{ + Unknown, + Initializing, + Downloading, + Finalizing, +} + +public class VpnDownloadProgress +{ + public ulong BytesWritten { get; set; } = 0; + public ulong? BytesTotal { get; set; } = null; // null means unknown total size + + public double Progress + { + get + { + if (BytesTotal is > 0) + { + return (double)BytesWritten / BytesTotal.Value; + } + return 0.0; + } + } + + public override string ToString() + { + // TODO: it would be nice if the two suffixes could match + var s = FriendlyByteConverter.FriendlyBytes(BytesWritten); + if (BytesTotal != null) + s += $" of {FriendlyByteConverter.FriendlyBytes(BytesTotal.Value)}"; + else + s += " of unknown"; + if (BytesTotal != null) + s += $" ({Progress:0%})"; + return s; + } + + public VpnDownloadProgress Clone() + { + return new VpnDownloadProgress + { + BytesWritten = BytesWritten, + BytesTotal = BytesTotal, + }; + } + + public static VpnDownloadProgress FromProto(StartProgressDownloadProgress proto) + { + return new VpnDownloadProgress + { + BytesWritten = proto.BytesWritten, + BytesTotal = proto.HasBytesTotal ? proto.BytesTotal : null, + }; + } +} + public class VpnStartupProgress { - public double Progress { get; set; } = 0.0; // 0.0 to 1.0 - public string Message { get; set; } = string.Empty; + public const string DefaultStartProgressMessage = "Starting Coder Connect..."; + + // Scale the download progress to an overall progress value between these + // numbers. + private const double DownloadProgressMin = 0.05; + private const double DownloadProgressMax = 0.80; + + public VpnStartupStage Stage { get; set; } = VpnStartupStage.Unknown; + public VpnDownloadProgress? DownloadProgress { get; set; } = null; + + // 0.0 to 1.0 + public double Progress + { + get + { + switch (Stage) + { + case VpnStartupStage.Unknown: + case VpnStartupStage.Initializing: + return 0.0; + case VpnStartupStage.Downloading: + var progress = DownloadProgress?.Progress ?? 0.0; + return DownloadProgressMin + (DownloadProgressMax - DownloadProgressMin) * progress; + case VpnStartupStage.Finalizing: + return DownloadProgressMax; + default: + throw new ArgumentOutOfRangeException(); + } + } + } + + public override string ToString() + { + switch (Stage) + { + case VpnStartupStage.Unknown: + case VpnStartupStage.Initializing: + return DefaultStartProgressMessage; + case VpnStartupStage.Downloading: + var s = "Downloading Coder Connect binary..."; + if (DownloadProgress is not null) + { + s += "\n" + DownloadProgress; + } + + return s; + case VpnStartupStage.Finalizing: + return "Finalizing Coder Connect startup..."; + default: + throw new ArgumentOutOfRangeException(); + } + } public VpnStartupProgress Clone() { return new VpnStartupProgress { - Progress = Progress, - Message = Message, + Stage = Stage, + DownloadProgress = DownloadProgress?.Clone(), + }; + } + + public static VpnStartupProgress FromProto(StartProgress proto) + { + return new VpnStartupProgress + { + Stage = proto.Stage switch + { + StartProgressStage.Initializing => VpnStartupStage.Initializing, + StartProgressStage.Downloading => VpnStartupStage.Downloading, + StartProgressStage.Finalizing => VpnStartupStage.Finalizing, + _ => VpnStartupStage.Unknown, + }, + DownloadProgress = proto.Stage is StartProgressStage.Downloading ? + VpnDownloadProgress.FromProto(proto.DownloadProgress) : + null, }; } } diff --git a/App/Properties/PublishProfiles/win-arm64.pubxml b/App/Properties/PublishProfiles/win-arm64.pubxml new file mode 100644 index 0000000..ac9753e --- /dev/null +++ b/App/Properties/PublishProfiles/win-arm64.pubxml @@ -0,0 +1,12 @@ + + + + + FileSystem + ARM64 + win-arm64 + bin\$(Configuration)\$(TargetFramework)\$(RuntimeIdentifier)\publish\ + + diff --git a/App/Properties/PublishProfiles/win-x64.pubxml b/App/Properties/PublishProfiles/win-x64.pubxml new file mode 100644 index 0000000..942523b --- /dev/null +++ b/App/Properties/PublishProfiles/win-x64.pubxml @@ -0,0 +1,12 @@ + + + + + FileSystem + x64 + win-x64 + bin\$(Configuration)\$(TargetFramework)\$(RuntimeIdentifier)\publish\ + + diff --git a/App/Properties/PublishProfiles/win-x86.pubxml b/App/Properties/PublishProfiles/win-x86.pubxml new file mode 100644 index 0000000..e763481 --- /dev/null +++ b/App/Properties/PublishProfiles/win-x86.pubxml @@ -0,0 +1,12 @@ + + + + + FileSystem + x86 + win-x86 + bin\$(Configuration)\$(TargetFramework)\$(RuntimeIdentifier)\publish\ + + diff --git a/App/Services/RpcController.cs b/App/Services/RpcController.cs index b42c058..3345050 100644 --- a/App/Services/RpcController.cs +++ b/App/Services/RpcController.cs @@ -164,8 +164,7 @@ public async Task StartVpn(CancellationToken ct = default) MutateState(state => { state.VpnLifecycle = VpnLifecycle.Starting; - // Explicitly clear the startup progress. - state.VpnStartupProgress = null; + state.VpnStartupProgress = new VpnStartupProgress(); }); ServiceMessage reply; @@ -297,11 +296,7 @@ private void ApplyStartProgressUpdate(StartProgress message) { // MutateState will undo these changes if it doesn't believe we're // in the "Starting" state. - state.VpnStartupProgress = new VpnStartupProgress - { - Progress = message.Progress, - Message = message.Message, - }; + state.VpnStartupProgress = VpnStartupProgress.FromProto(message); }); } diff --git a/App/ViewModels/TrayWindowViewModel.cs b/App/ViewModels/TrayWindowViewModel.cs index cd3a641..820ff12 100644 --- a/App/ViewModels/TrayWindowViewModel.cs +++ b/App/ViewModels/TrayWindowViewModel.cs @@ -29,7 +29,6 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost { private const int MaxAgents = 5; private const string DefaultDashboardUrl = "https://coder.com"; - private const string DefaultStartProgressMessage = "Starting Coder Connect..."; private readonly IServiceProvider _services; private readonly IRpcController _rpcController; @@ -84,7 +83,7 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost public partial string? VpnStartProgressMessage { get; set; } = null; public string VpnStartProgressMessageOrDefault => - string.IsNullOrEmpty(VpnStartProgressMessage) ? DefaultStartProgressMessage : VpnStartProgressMessage; + string.IsNullOrEmpty(VpnStartProgressMessage) ? VpnStartupProgress.DefaultStartProgressMessage : VpnStartProgressMessage; public bool VpnStartProgressIsIndeterminate => VpnStartProgressValueOrDefault == 0; @@ -196,7 +195,7 @@ private void UpdateFromRpcModel(RpcModel rpcModel) // Convert 0.00-1.00 to 0-100. var progress = (int)(rpcModel.VpnStartupProgress.Progress * 100); VpnStartProgressValue = Math.Clamp(progress, 0, 100); - VpnStartProgressMessage = string.IsNullOrEmpty(rpcModel.VpnStartupProgress.Message) ? null : rpcModel.VpnStartupProgress.Message; + VpnStartProgressMessage = rpcModel.VpnStartupProgress.ToString(); } else { diff --git a/Tests.Vpn.Service/DownloaderTest.cs b/Tests.Vpn.Service/DownloaderTest.cs index a47ffbc..b33f510 100644 --- a/Tests.Vpn.Service/DownloaderTest.cs +++ b/Tests.Vpn.Service/DownloaderTest.cs @@ -344,11 +344,11 @@ public async Task DownloadWithXOriginalContentLength(CancellationToken ct) Assert.That(list.Count, Is.GreaterThanOrEqualTo(2)); // there may be an item in the middle // The first item should be the initial progress with 0 bytes written. Assert.That(list[0].BytesWritten, Is.EqualTo(0)); - Assert.That(list[0].TotalBytes, Is.EqualTo(6)); // from X-Original-Content-Length + Assert.That(list[0].BytesTotal, Is.EqualTo(6)); // from X-Original-Content-Length Assert.That(list[0].Progress, Is.EqualTo(0.0d)); // The last item should be final progress with the actual total bytes. Assert.That(list[^1].BytesWritten, Is.EqualTo(4)); - Assert.That(list[^1].TotalBytes, Is.EqualTo(4)); // from the actual bytes written + Assert.That(list[^1].BytesTotal, Is.EqualTo(4)); // from the actual bytes written Assert.That(list[^1].Progress, Is.EqualTo(1.0d)); } diff --git a/Vpn.Proto/vpn.proto b/Vpn.Proto/vpn.proto index fa2f003..bace7e0 100644 --- a/Vpn.Proto/vpn.proto +++ b/Vpn.Proto/vpn.proto @@ -227,9 +227,18 @@ message StartResponse { // inability to easily send messages to a specific client in the Speaker // implementation. If clients are not expecting these messages, they // should ignore them. +enum StartProgressStage { + Initializing = 0; + Downloading = 1; + Finalizing = 2; +} +message StartProgressDownloadProgress { + uint64 bytes_written = 1; + optional uint64 bytes_total = 2; // unknown in some situations +} message StartProgress { - double progress = 1; // 0.0 to 1.0 - string message = 2; // human-readable status message, must be set + StartProgressStage stage = 1; + optional StartProgressDownloadProgress download_progress = 2; // only set when stage == Downloading } // StopRequest is a request from the manager to stop the tunnel. The tunnel replies with a diff --git a/Vpn.Service/Downloader.cs b/Vpn.Service/Downloader.cs index 4e7e5b2..a665ec4 100644 --- a/Vpn.Service/Downloader.cs +++ b/Vpn.Service/Downloader.cs @@ -342,14 +342,15 @@ public class DownloadProgressEvent { // TODO: speed calculation would be nice public ulong BytesWritten { get; init; } - public ulong? TotalBytes { get; init; } // null if unknown - public double? Progress { get; init; } // 0.0 - 1.0, null if unknown + public ulong? BytesTotal { get; init; } // null if unknown + + public double? Progress => BytesTotal == null ? null : (double)BytesWritten / BytesTotal.Value; public override string ToString() { var s = FriendlyBytes(BytesWritten); - if (TotalBytes != null) - s += $" of {FriendlyBytes(TotalBytes.Value)}"; + if (BytesTotal != null) + s += $" of {FriendlyBytes(BytesTotal.Value)}"; else s += " of unknown"; if (Progress != null) @@ -513,8 +514,7 @@ private async Task Start(CancellationToken ct = default) SendProgressUpdate(new DownloadProgressEvent { BytesWritten = 0, - TotalBytes = TotalBytes, - Progress = 0.0, + BytesTotal = TotalBytes, }); await Download(res, ct); @@ -549,8 +549,7 @@ private async Task Download(HttpResponseMessage res, CancellationToken ct) await QueueProgressUpdate(new DownloadProgressEvent { BytesWritten = BytesWritten, - TotalBytes = TotalBytes, - Progress = Progress, + BytesTotal = TotalBytes, }, ct); } } @@ -563,8 +562,7 @@ await QueueProgressUpdate(new DownloadProgressEvent SendProgressUpdate(new DownloadProgressEvent { BytesWritten = BytesWritten, - TotalBytes = BytesWritten, - Progress = 1.0, + BytesTotal = BytesWritten, }); if (TotalBytes != null && BytesWritten != TotalBytes) diff --git a/Vpn.Service/Manager.cs b/Vpn.Service/Manager.cs index cf2bb8a..0324ebb 100644 --- a/Vpn.Service/Manager.cs +++ b/Vpn.Service/Manager.cs @@ -26,10 +26,6 @@ public interface IManager : IDisposable /// public class Manager : IManager { - // We scale the download progress to 0.00-0.90, and use 0.90-1.00 for the - // remainder of startup. - private const double DownloadProgressScale = 0.90; - private readonly ManagerConfig _config; private readonly IDownloader _downloader; private readonly ILogger _logger; @@ -135,7 +131,7 @@ private async ValueTask HandleClientMessageStart(ClientMessage me { try { - await BroadcastStartProgress(0.0, "Starting Coder Connect...", ct); + await BroadcastStartProgress(StartProgressStage.Initializing, cancellationToken: ct); var serverVersion = await CheckServerVersionAndCredentials(message.Start.CoderUrl, message.Start.ApiToken, ct); @@ -164,7 +160,7 @@ private async ValueTask HandleClientMessageStart(ClientMessage me await DownloadTunnelBinaryAsync(message.Start.CoderUrl, serverVersion.SemVersion, ct); - await BroadcastStartProgress(DownloadProgressScale, "Starting Coder Connect...", ct); + await BroadcastStartProgress(StartProgressStage.Finalizing, cancellationToken: ct); await _tunnelSupervisor.StartAsync(_config.TunnelBinaryPath, HandleTunnelRpcMessage, HandleTunnelRpcError, ct); @@ -464,10 +460,14 @@ private async Task DownloadTunnelBinaryAsync(string baseUrl, SemVersion expected if (progressBroadcastCts.IsCancellationRequested) return; _logger.LogInformation("Download progress: {ev}", ev); - // Scale the progress value to be between 0.00 and 0.90. - var progress = ev.Progress * DownloadProgressScale ?? 0.0; - var message = $"Downloading Coder Connect binary...\n{ev}"; - BroadcastStartProgress(progress, message, progressBroadcastCts.Token).Wait(progressBroadcastCts.Token); + var progress = new StartProgressDownloadProgress + { + BytesWritten = ev.BytesWritten, + }; + if (ev.BytesTotal != null) + progress.BytesTotal = ev.BytesTotal.Value; + BroadcastStartProgress(StartProgressStage.Downloading, progress, progressBroadcastCts.Token) + .Wait(progressBroadcastCts.Token); }; // Awaiting this will check the checksum (via the ETag) if the file @@ -484,16 +484,16 @@ private async Task DownloadTunnelBinaryAsync(string baseUrl, SemVersion expected _logger.LogInformation("Completed downloading VPN binary"); } - private async Task BroadcastStartProgress(double progress, string message, CancellationToken ct = default) + private async Task BroadcastStartProgress(StartProgressStage stage, StartProgressDownloadProgress? downloadProgress = null, CancellationToken cancellationToken = default) { - _logger.LogInformation("Start progress: {Progress:0%} - {Message}", progress, message); + _logger.LogInformation("Start progress: {stage}", stage); await FallibleBroadcast(new ServiceMessage { StartProgress = new StartProgress { - Progress = progress, - Message = message, + Stage = stage, + DownloadProgress = downloadProgress, }, - }, ct); + }, cancellationToken); } } From 473164dca3c7830f43473438d77d8c26aacaf476 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Fri, 6 Jun 2025 15:59:21 +1000 Subject: [PATCH 4/5] rework download progress --- Tests.Vpn.Service/DownloaderTest.cs | 50 ++++----- Vpn.Service/Downloader.cs | 151 ++-------------------------- Vpn.Service/Manager.cs | 51 ++++++---- 3 files changed, 68 insertions(+), 184 deletions(-) diff --git a/Tests.Vpn.Service/DownloaderTest.cs b/Tests.Vpn.Service/DownloaderTest.cs index b33f510..bb9b39c 100644 --- a/Tests.Vpn.Service/DownloaderTest.cs +++ b/Tests.Vpn.Service/DownloaderTest.cs @@ -2,7 +2,6 @@ using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using System.Text; -using System.Threading.Channels; using Coder.Desktop.Vpn.Service; using Microsoft.Extensions.Logging.Abstractions; @@ -278,7 +277,7 @@ public async Task Download(CancellationToken ct) var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, NullDownloadValidator.Instance, ct); await dlTask.Task; - Assert.That(dlTask.TotalBytes, Is.EqualTo(4)); + Assert.That(dlTask.BytesTotal, Is.EqualTo(4)); Assert.That(dlTask.BytesWritten, Is.EqualTo(4)); Assert.That(dlTask.Progress, Is.EqualTo(1)); Assert.That(dlTask.IsCompleted, Is.True); @@ -301,13 +300,13 @@ public async Task DownloadSameDest(CancellationToken ct) NullDownloadValidator.Instance, ct); var dlTask0 = await startTask0; await dlTask0.Task; - Assert.That(dlTask0.TotalBytes, Is.EqualTo(5)); + Assert.That(dlTask0.BytesTotal, Is.EqualTo(5)); Assert.That(dlTask0.BytesWritten, Is.EqualTo(5)); Assert.That(dlTask0.Progress, Is.EqualTo(1)); Assert.That(dlTask0.IsCompleted, Is.True); var dlTask1 = await startTask1; await dlTask1.Task; - Assert.That(dlTask1.TotalBytes, Is.EqualTo(5)); + Assert.That(dlTask1.BytesTotal, Is.EqualTo(5)); Assert.That(dlTask1.BytesWritten, Is.EqualTo(5)); Assert.That(dlTask1.Progress, Is.EqualTo(1)); Assert.That(dlTask1.IsCompleted, Is.True); @@ -320,9 +319,9 @@ public async Task DownloadWithXOriginalContentLength(CancellationToken ct) using var httpServer = new TestHttpServer(async ctx => { ctx.Response.StatusCode = 200; - ctx.Response.Headers.Add("X-Original-Content-Length", "6"); // wrong but should be used until complete + ctx.Response.Headers.Add("X-Original-Content-Length", "4"); ctx.Response.ContentType = "text/plain"; - ctx.Response.ContentLength64 = 4; // This should be ignored. + // Don't set Content-Length. await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), ct); }); var url = new Uri(httpServer.BaseUrl + "/test"); @@ -331,25 +330,30 @@ public async Task DownloadWithXOriginalContentLength(CancellationToken ct) var req = new HttpRequestMessage(HttpMethod.Get, url); var dlTask = await manager.StartDownloadAsync(req, destPath, NullDownloadValidator.Instance, ct); - var progressChannel = Channel.CreateUnbounded(); - dlTask.ProgressChanged += (_, args) => - Assert.That(progressChannel.Writer.TryWrite(args), Is.True); - await dlTask.Task; - Assert.That(dlTask.TotalBytes, Is.EqualTo(4)); // should equal BytesWritten after completion + Assert.That(dlTask.BytesTotal, Is.EqualTo(4)); Assert.That(dlTask.BytesWritten, Is.EqualTo(4)); - progressChannel.Writer.Complete(); - - var list = progressChannel.Reader.ReadAllAsync(ct).ToBlockingEnumerable(ct).ToList(); - Assert.That(list.Count, Is.GreaterThanOrEqualTo(2)); // there may be an item in the middle - // The first item should be the initial progress with 0 bytes written. - Assert.That(list[0].BytesWritten, Is.EqualTo(0)); - Assert.That(list[0].BytesTotal, Is.EqualTo(6)); // from X-Original-Content-Length - Assert.That(list[0].Progress, Is.EqualTo(0.0d)); - // The last item should be final progress with the actual total bytes. - Assert.That(list[^1].BytesWritten, Is.EqualTo(4)); - Assert.That(list[^1].BytesTotal, Is.EqualTo(4)); // from the actual bytes written - Assert.That(list[^1].Progress, Is.EqualTo(1.0d)); + } + + [Test(Description = "Download with mismatched Content-Length")] + [CancelAfter(30_000)] + public async Task DownloadWithMismatchedContentLength(CancellationToken ct) + { + using var httpServer = new TestHttpServer(async ctx => + { + ctx.Response.StatusCode = 200; + ctx.Response.Headers.Add("X-Original-Content-Length", "5"); // incorrect + ctx.Response.ContentType = "text/plain"; + await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), ct); + }); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + var manager = new Downloader(NullLogger.Instance); + var req = new HttpRequestMessage(HttpMethod.Get, url); + var dlTask = await manager.StartDownloadAsync(req, destPath, NullDownloadValidator.Instance, ct); + + var ex = Assert.ThrowsAsync(() => dlTask.Task); + Assert.That(ex.Message, Is.EqualTo("Downloaded file size does not match expected response content length: Expected=5, BytesWritten=4")); } [Test(Description = "Download with custom headers")] diff --git a/Vpn.Service/Downloader.cs b/Vpn.Service/Downloader.cs index a665ec4..c4a916f 100644 --- a/Vpn.Service/Downloader.cs +++ b/Vpn.Service/Downloader.cs @@ -338,42 +338,6 @@ internal static async Task TaskOrCancellation(Task task, CancellationToken cance } } -public class DownloadProgressEvent -{ - // TODO: speed calculation would be nice - public ulong BytesWritten { get; init; } - public ulong? BytesTotal { get; init; } // null if unknown - - public double? Progress => BytesTotal == null ? null : (double)BytesWritten / BytesTotal.Value; - - public override string ToString() - { - var s = FriendlyBytes(BytesWritten); - if (BytesTotal != null) - s += $" of {FriendlyBytes(BytesTotal.Value)}"; - else - s += " of unknown"; - if (Progress != null) - s += $" ({Progress:0%})"; - return s; - } - - private static readonly string[] ByteSuffixes = ["B", "KB", "MB", "GB", "TB", "PB", "EB"]; - - // Unfortunately this is copied from FriendlyByteConverter in App. Ideally - // it should go into some shared utilities project, but it's overkill to do - // that for a single tiny function until we have more shared code. - private static string FriendlyBytes(ulong bytes) - { - if (bytes == 0) - return $"0 {ByteSuffixes[0]}"; - - var place = Convert.ToInt32(Math.Floor(Math.Log(bytes, 1024))); - var num = Math.Round(bytes / Math.Pow(1024, place), 1); - return $"{num} {ByteSuffixes[place]}"; - } -} - /// /// Downloads a Url to a file on disk. The download will be written to a temporary file first, then moved to the final /// destination. The SHA1 of any existing file will be calculated and used as an ETag to avoid downloading the file if @@ -381,8 +345,7 @@ private static string FriendlyBytes(ulong bytes) /// public class DownloadTask { - private const int BufferSize = 4096; - private const int ProgressUpdateDelayMs = 50; + private const int BufferSize = 64 * 1024; private const string XOriginalContentLengthHeader = "X-Original-Content-Length"; // overrides Content-Length if available private static readonly HttpClient HttpClient = new(new HttpClientHandler @@ -398,22 +361,13 @@ public class DownloadTask private readonly string _destinationPath; private readonly string _tempDestinationPath; - // ProgressChanged events are always delayed by up to 50ms to avoid - // flooding. - // - // This will be called: - // - once after the request succeeds but before the read/write routine - // begins - // - occasionally while the file is being downloaded (at least 50ms apart) - // - once when the download is complete - public EventHandler? ProgressChanged; - public readonly HttpRequestMessage Request; public Task Task { get; private set; } = null!; // Set in EnsureStartedAsync + public bool DownloadStarted { get; private set; } // Whether we've received headers yet and started the actual download public ulong BytesWritten { get; private set; } - public ulong? TotalBytes { get; private set; } - public double? Progress => TotalBytes == null ? null : (double)BytesWritten / TotalBytes.Value; + public ulong? BytesTotal { get; private set; } + public double? Progress => BytesTotal == null ? null : (double)BytesWritten / BytesTotal.Value; public bool IsCompleted => Task.IsCompleted; internal DownloadTask(ILogger logger, HttpRequestMessage req, string destinationPath, IDownloadValidator validator) @@ -496,7 +450,7 @@ private async Task Start(CancellationToken ct = default) } if (res.Content.Headers.ContentLength >= 0) - TotalBytes = (ulong)res.Content.Headers.ContentLength; + BytesTotal = (ulong)res.Content.Headers.ContentLength; // X-Original-Content-Length overrules Content-Length if set. if (res.Headers.TryGetValues(XOriginalContentLengthHeader, out var headerValues)) @@ -504,24 +458,19 @@ private async Task Start(CancellationToken ct = default) // If there are multiple we only look at the first one. var headerValue = headerValues.ToList().FirstOrDefault(); if (!string.IsNullOrEmpty(headerValue) && ulong.TryParse(headerValue, out var originalContentLength)) - TotalBytes = originalContentLength; + BytesTotal = originalContentLength; else _logger.LogWarning( "Failed to parse {XOriginalContentLengthHeader} header value '{HeaderValue}'", XOriginalContentLengthHeader, headerValue); } - SendProgressUpdate(new DownloadProgressEvent - { - BytesWritten = 0, - BytesTotal = TotalBytes, - }); - await Download(res, ct); } private async Task Download(HttpResponseMessage res, CancellationToken ct) { + DownloadStarted = true; try { var sha1 = res.Headers.Contains("ETag") ? SHA1.Create() : null; @@ -546,28 +495,13 @@ private async Task Download(HttpResponseMessage res, CancellationToken ct) await tempFile.WriteAsync(buffer.AsMemory(0, n), ct); sha1?.TransformBlock(buffer, 0, n, null, 0); BytesWritten += (ulong)n; - await QueueProgressUpdate(new DownloadProgressEvent - { - BytesWritten = BytesWritten, - BytesTotal = TotalBytes, - }, ct); } } - // Clear any pending progress updates to ensure they won't be sent - // after the final update. - await ClearQueuedProgressUpdate(ct); - // Then write the final status update. - TotalBytes = BytesWritten; - SendProgressUpdate(new DownloadProgressEvent - { - BytesWritten = BytesWritten, - BytesTotal = BytesWritten, - }); - - if (TotalBytes != null && BytesWritten != TotalBytes) + BytesTotal ??= BytesWritten; + if (BytesWritten != BytesTotal) throw new IOException( - $"Downloaded file size does not match response Content-Length: Content-Length={TotalBytes}, BytesRead={BytesWritten}"); + $"Downloaded file size does not match expected response content length: Expected={BytesTotal}, BytesWritten={BytesWritten}"); // Verify the ETag if it was sent by the server. if (res.Headers.Contains("ETag") && sha1 != null) @@ -612,69 +546,4 @@ await QueueProgressUpdate(new DownloadProgressEvent throw; } } - - // _progressEventLock protects _progressUpdateTask and _pendingProgressEvent. - private readonly RaiiSemaphoreSlim _progressEventLock = new(1, 1); - private readonly CancellationTokenSource _progressUpdateCts = new(); - private Task? _progressUpdateTask; - private DownloadProgressEvent? _pendingProgressEvent; - - // Can be called multiple times, but must not be called or in progress while - // SendQueuedProgressUpdateNow is called. - private async Task QueueProgressUpdate(DownloadProgressEvent e, CancellationToken ct) - { - using var _1 = await _progressEventLock.LockAsync(ct); - _pendingProgressEvent = e; - - if (_progressUpdateCts.IsCancellationRequested) - throw new InvalidOperationException("Progress update task was cancelled, cannot queue new progress update"); - - // Start a task with a 50ms delay unless one is already running. - var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _progressUpdateCts.Token); - cts.CancelAfter(TimeSpan.FromSeconds(5)); - _progressUpdateTask ??= Task.Delay(ProgressUpdateDelayMs, cts.Token) - .ContinueWith(t => - { - cts.Cancel(); - using var _2 = _progressEventLock.Lock(); - _progressUpdateTask = null; - if (t.IsFaulted || t.IsCanceled) return; - - var ev = _pendingProgressEvent; - if (ev != null) SendProgressUpdate(ev); - }, cts.Token); - } - - // Must only be called after all QueueProgressUpdate calls have completed. - private async Task ClearQueuedProgressUpdate(CancellationToken ct) - { - Task? t; - using (var _ = _progressEventLock.LockAsync(ct)) - { - await _progressUpdateCts.CancelAsync(); - t = _progressUpdateTask; - } - - // We can't continue to hold the lock here because the continuation - // grabs a lock. We don't need to worry about a new task spawning after - // this because the token is cancelled. - if (t == null) return; - try - { - await t.WaitAsync(ct); - } - catch (TaskCanceledException) - { - // Ignore - } - } - - private void SendProgressUpdate(DownloadProgressEvent e) - { - var handler = ProgressChanged; - if (handler == null) - return; - // Start a new task in the background to invoke the event. - _ = Task.Run(() => handler.Invoke(this, e)); - } } diff --git a/Vpn.Service/Manager.cs b/Vpn.Service/Manager.cs index 0324ebb..886bb70 100644 --- a/Vpn.Service/Manager.cs +++ b/Vpn.Service/Manager.cs @@ -450,34 +450,46 @@ private async Task DownloadTunnelBinaryAsync(string baseUrl, SemVersion expected _logger.LogDebug("Skipping tunnel binary version validation"); } + // Note: all ETag, signature and version validation is performed by the + // DownloadTask. var downloadTask = await _downloader.StartDownloadAsync(req, _config.TunnelBinaryPath, validators, ct); - var progressLock = new RaiiSemaphoreSlim(1, 1); - var progressBroadcastCts = CancellationTokenSource.CreateLinkedTokenSource(ct); - downloadTask.ProgressChanged += (sender, ev) => + // Wait for the download to complete, sending progress updates every + // 50ms. + while (true) { - using var _ = progressLock.Lock(); - if (progressBroadcastCts.IsCancellationRequested) return; - _logger.LogInformation("Download progress: {ev}", ev); + // Wait for the download to complete, or for a short delay before + // we send a progress update. + var delayTask = Task.Delay(TimeSpan.FromMilliseconds(50), ct); + var winner = await Task.WhenAny([ + downloadTask.Task, + delayTask, + ]); + if (winner == downloadTask.Task) + break; + + // Task.WhenAny will not throw if the winner was cancelled, so + // check CT afterward and not beforehand. + ct.ThrowIfCancellationRequested(); + + if (!downloadTask.DownloadStarted) + // Don't send progress updates if we don't know what the + // progress is yet. + continue; var progress = new StartProgressDownloadProgress { - BytesWritten = ev.BytesWritten, + BytesWritten = downloadTask.BytesWritten, }; - if (ev.BytesTotal != null) - progress.BytesTotal = ev.BytesTotal.Value; - BroadcastStartProgress(StartProgressStage.Downloading, progress, progressBroadcastCts.Token) - .Wait(progressBroadcastCts.Token); - }; + if (downloadTask.BytesTotal != null) + progress.BytesTotal = downloadTask.BytesTotal.Value; - // Awaiting this will check the checksum (via the ETag) if the file - // exists, and will also validate the signature and version. - await downloadTask.Task; + await BroadcastStartProgress(StartProgressStage.Downloading, progress, ct); + } - // Prevent any lagging progress events from being sent. - // ReSharper disable once PossiblyMistakenUseOfCancellationToken - using (await progressLock.LockAsync(ct)) - await progressBroadcastCts.CancelAsync(); + // Await again to re-throw any exceptions that occurred during the + // download. + await downloadTask.Task; // We don't send a broadcast here as we immediately send one in the // parent routine. @@ -486,7 +498,6 @@ private async Task DownloadTunnelBinaryAsync(string baseUrl, SemVersion expected private async Task BroadcastStartProgress(StartProgressStage stage, StartProgressDownloadProgress? downloadProgress = null, CancellationToken cancellationToken = default) { - _logger.LogInformation("Start progress: {stage}", stage); await FallibleBroadcast(new ServiceMessage { StartProgress = new StartProgress From 02bc40046adcfd296ca84d05338da278f4070602 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Fri, 6 Jun 2025 16:36:23 +1000 Subject: [PATCH 5/5] invariant in startup progress model --- App/Models/RpcModel.cs | 23 +++++++++++++++++++---- App/Services/RpcController.cs | 8 ++------ Vpn.Service/Manager.cs | 2 +- Vpn.Service/ManagerRpc.cs | 15 +++++++++++---- Vpn.Service/TunnelSupervisor.cs | 6 +++--- 5 files changed, 36 insertions(+), 18 deletions(-) diff --git a/App/Models/RpcModel.cs b/App/Models/RpcModel.cs index 426863b..08d2303 100644 --- a/App/Models/RpcModel.cs +++ b/App/Models/RpcModel.cs @@ -88,8 +88,8 @@ public class VpnStartupProgress private const double DownloadProgressMin = 0.05; private const double DownloadProgressMax = 0.80; - public VpnStartupStage Stage { get; set; } = VpnStartupStage.Unknown; - public VpnDownloadProgress? DownloadProgress { get; set; } = null; + public VpnStartupStage Stage { get; init; } = VpnStartupStage.Unknown; + public VpnDownloadProgress? DownloadProgress { get; init; } = null; // 0.0 to 1.0 public double Progress @@ -165,10 +165,25 @@ public class RpcModel { public RpcLifecycle RpcLifecycle { get; set; } = RpcLifecycle.Disconnected; - public VpnLifecycle VpnLifecycle { get; set; } = VpnLifecycle.Unknown; + public VpnLifecycle VpnLifecycle + { + get; + set + { + if (VpnLifecycle != value && value == VpnLifecycle.Starting) + // Reset the startup progress when the VPN lifecycle changes to + // Starting. + VpnStartupProgress = null; + field = value; + } + } // Nullable because it is only set when the VpnLifecycle is Starting - public VpnStartupProgress? VpnStartupProgress { get; set; } + public VpnStartupProgress? VpnStartupProgress + { + get => VpnLifecycle is VpnLifecycle.Starting ? field ?? new VpnStartupProgress() : null; + set; + } public IReadOnlyList Workspaces { get; set; } = []; diff --git a/App/Services/RpcController.cs b/App/Services/RpcController.cs index 3345050..168a1be 100644 --- a/App/Services/RpcController.cs +++ b/App/Services/RpcController.cs @@ -164,7 +164,6 @@ public async Task StartVpn(CancellationToken ct = default) MutateState(state => { state.VpnLifecycle = VpnLifecycle.Starting; - state.VpnStartupProgress = new VpnStartupProgress(); }); ServiceMessage reply; @@ -255,9 +254,6 @@ private void MutateState(Action mutator) using (_stateLock.Lock()) { mutator(_state); - // Unset the startup progress if the VpnLifecycle is not Starting - if (_state.VpnLifecycle != VpnLifecycle.Starting) - _state.VpnStartupProgress = null; newState = _state.Clone(); } @@ -294,8 +290,8 @@ private void ApplyStartProgressUpdate(StartProgress message) { MutateState(state => { - // MutateState will undo these changes if it doesn't believe we're - // in the "Starting" state. + // The model itself will ignore this value if we're not in the + // starting state. state.VpnStartupProgress = VpnStartupProgress.FromProto(message); }); } diff --git a/Vpn.Service/Manager.cs b/Vpn.Service/Manager.cs index 886bb70..fdb62af 100644 --- a/Vpn.Service/Manager.cs +++ b/Vpn.Service/Manager.cs @@ -331,7 +331,7 @@ private async Task FallibleBroadcast(ServiceMessage message, CancellationToken c // Broadcast the messages out with a low timeout. If clients don't // receive broadcasts in time, it's not a big deal. using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); - //cts.CancelAfter(TimeSpan.FromMilliseconds(100)); + cts.CancelAfter(TimeSpan.FromMilliseconds(30)); try { await _managerRpc.BroadcastAsync(message, cts.Token); diff --git a/Vpn.Service/ManagerRpc.cs b/Vpn.Service/ManagerRpc.cs index d922caf..4920570 100644 --- a/Vpn.Service/ManagerRpc.cs +++ b/Vpn.Service/ManagerRpc.cs @@ -127,14 +127,20 @@ public async Task ExecuteAsync(CancellationToken stoppingToken) public async Task BroadcastAsync(ServiceMessage message, CancellationToken ct) { + // Sends messages to all clients simultaneously and waits for them all + // to send or fail/timeout. + // // Looping over a ConcurrentDictionary is exception-safe, but any items // added or removed during the loop may or may not be included. - foreach (var (clientId, client) in _activeClients) + await Task.WhenAll(_activeClients.Select(async item => + { try { - var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); + // Enforce upper bound in case a CT with a timeout wasn't + // supplied. + using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct); cts.CancelAfter(TimeSpan.FromSeconds(2)); - await client.Speaker.SendMessage(message, cts.Token); + await item.Value.Speaker.SendMessage(message, cts.Token); } catch (ObjectDisposedException) { @@ -142,11 +148,12 @@ public async Task BroadcastAsync(ServiceMessage message, CancellationToken ct) } catch (Exception e) { - _logger.LogWarning(e, "Failed to send message to client {ClientId}", clientId); + _logger.LogWarning(e, "Failed to send message to client {ClientId}", item.Key); // TODO: this should probably kill the client, but due to the // async nature of the client handling, calling Dispose // will not remove the client from the active clients list } + })); } private async Task HandleRpcClientAsync(ulong clientId, Speaker speaker, diff --git a/Vpn.Service/TunnelSupervisor.cs b/Vpn.Service/TunnelSupervisor.cs index 6ff4f3b..7dd6738 100644 --- a/Vpn.Service/TunnelSupervisor.cs +++ b/Vpn.Service/TunnelSupervisor.cs @@ -99,16 +99,16 @@ public async Task StartAsync(string binPath, }, }; // TODO: maybe we should change the log format in the inner binary - // to something without a timestamp + // to something without a timestamp _subprocess.OutputDataReceived += (_, args) => { if (!string.IsNullOrWhiteSpace(args.Data)) - _logger.LogDebug("stdout: {Data}", args.Data); + _logger.LogInformation("stdout: {Data}", args.Data); }; _subprocess.ErrorDataReceived += (_, args) => { if (!string.IsNullOrWhiteSpace(args.Data)) - _logger.LogDebug("stderr: {Data}", args.Data); + _logger.LogInformation("stderr: {Data}", args.Data); }; // Pass the other end of the pipes to the subprocess and dispose