Skip to content

feat: add vpn start progress #114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 6, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

163 changes: 162 additions & 1 deletion App/Models/RpcModel.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using Coder.Desktop.App.Converters;
using Coder.Desktop.Vpn.Proto;

namespace Coder.Desktop.App.Models;
@@ -19,11 +22,168 @@ public enum VpnLifecycle
Stopping,
}

public enum VpnStartupStage
{
Unknown,
Initializing,
Downloading,
Finalizing,
}

public class VpnDownloadProgress
{
public ulong BytesWritten { get; set; } = 0;
public ulong? BytesTotal { get; set; } = null; // null means unknown total size

public double Progress
{
get
{
if (BytesTotal is > 0)
{
return (double)BytesWritten / BytesTotal.Value;
}
return 0.0;
}
}

public override string ToString()
{
// TODO: it would be nice if the two suffixes could match
var s = FriendlyByteConverter.FriendlyBytes(BytesWritten);
if (BytesTotal != null)
s += $" of {FriendlyByteConverter.FriendlyBytes(BytesTotal.Value)}";
else
s += " of unknown";
if (BytesTotal != null)
s += $" ({Progress:0%})";
return s;
}

public VpnDownloadProgress Clone()
{
return new VpnDownloadProgress
{
BytesWritten = BytesWritten,
BytesTotal = BytesTotal,
};
}

public static VpnDownloadProgress FromProto(StartProgressDownloadProgress proto)
{
return new VpnDownloadProgress
{
BytesWritten = proto.BytesWritten,
BytesTotal = proto.HasBytesTotal ? proto.BytesTotal : null,
};
}
}

public class VpnStartupProgress
{
public const string DefaultStartProgressMessage = "Starting Coder Connect...";

// Scale the download progress to an overall progress value between these
// numbers.
private const double DownloadProgressMin = 0.05;
private const double DownloadProgressMax = 0.80;

public VpnStartupStage Stage { get; init; } = VpnStartupStage.Unknown;
public VpnDownloadProgress? DownloadProgress { get; init; } = null;

// 0.0 to 1.0
public double Progress
{
get
{
switch (Stage)
{
case VpnStartupStage.Unknown:
case VpnStartupStage.Initializing:
return 0.0;
case VpnStartupStage.Downloading:
var progress = DownloadProgress?.Progress ?? 0.0;
return DownloadProgressMin + (DownloadProgressMax - DownloadProgressMin) * progress;
case VpnStartupStage.Finalizing:
return DownloadProgressMax;
default:
throw new ArgumentOutOfRangeException();
}
}
}

public override string ToString()
{
switch (Stage)
{
case VpnStartupStage.Unknown:
case VpnStartupStage.Initializing:
return DefaultStartProgressMessage;
case VpnStartupStage.Downloading:
var s = "Downloading Coder Connect binary...";
if (DownloadProgress is not null)
{
s += "\n" + DownloadProgress;
}

return s;
case VpnStartupStage.Finalizing:
return "Finalizing Coder Connect startup...";
default:
throw new ArgumentOutOfRangeException();
}
}

public VpnStartupProgress Clone()
{
return new VpnStartupProgress
{
Stage = Stage,
DownloadProgress = DownloadProgress?.Clone(),
};
}

public static VpnStartupProgress FromProto(StartProgress proto)
{
return new VpnStartupProgress
{
Stage = proto.Stage switch
{
StartProgressStage.Initializing => VpnStartupStage.Initializing,
StartProgressStage.Downloading => VpnStartupStage.Downloading,
StartProgressStage.Finalizing => VpnStartupStage.Finalizing,
_ => VpnStartupStage.Unknown,
},
DownloadProgress = proto.Stage is StartProgressStage.Downloading ?
VpnDownloadProgress.FromProto(proto.DownloadProgress) :
null,
};
}
}

public class RpcModel
{
public RpcLifecycle RpcLifecycle { get; set; } = RpcLifecycle.Disconnected;

public VpnLifecycle VpnLifecycle { get; set; } = VpnLifecycle.Unknown;
public VpnLifecycle VpnLifecycle
{
get;
set
{
if (VpnLifecycle != value && value == VpnLifecycle.Starting)
// Reset the startup progress when the VPN lifecycle changes to
// Starting.
VpnStartupProgress = null;
field = value;
}
}

// Nullable because it is only set when the VpnLifecycle is Starting
public VpnStartupProgress? VpnStartupProgress
{
get => VpnLifecycle is VpnLifecycle.Starting ? field ?? new VpnStartupProgress() : null;
set;
}

public IReadOnlyList<Workspace> Workspaces { get; set; } = [];

@@ -35,6 +195,7 @@ public RpcModel Clone()
{
RpcLifecycle = RpcLifecycle,
VpnLifecycle = VpnLifecycle,
VpnStartupProgress = VpnStartupProgress?.Clone(),
Workspaces = Workspaces,
Agents = Agents,
};
22 changes: 19 additions & 3 deletions App/Services/RpcController.cs
Original file line number Diff line number Diff line change
@@ -161,7 +161,10 @@ public async Task StartVpn(CancellationToken ct = default)
throw new RpcOperationException(
$"Cannot start VPN without valid credentials, current state: {credentials.State}");

MutateState(state => { state.VpnLifecycle = VpnLifecycle.Starting; });
MutateState(state =>
{
state.VpnLifecycle = VpnLifecycle.Starting;
});

ServiceMessage reply;
try
@@ -283,15 +286,28 @@ private void ApplyStatusUpdate(Status status)
});
}

private void ApplyStartProgressUpdate(StartProgress message)
{
MutateState(state =>
{
// The model itself will ignore this value if we're not in the
// starting state.
state.VpnStartupProgress = VpnStartupProgress.FromProto(message);
});
}

private void SpeakerOnReceive(ReplyableRpcMessage<ClientMessage, ServiceMessage> message)
{
switch (message.Message.MsgCase)
{
case ServiceMessage.MsgOneofCase.Start:
case ServiceMessage.MsgOneofCase.Stop:
case ServiceMessage.MsgOneofCase.Status:
ApplyStatusUpdate(message.Message.Status);
break;
case ServiceMessage.MsgOneofCase.Start:
case ServiceMessage.MsgOneofCase.Stop:
case ServiceMessage.MsgOneofCase.StartProgress:
ApplyStartProgressUpdate(message.Message.StartProgress);
break;
case ServiceMessage.MsgOneofCase.None:
default:
// TODO: log unexpected message
37 changes: 35 additions & 2 deletions App/ViewModels/TrayWindowViewModel.cs
Original file line number Diff line number Diff line change
@@ -29,7 +29,6 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost
{
private const int MaxAgents = 5;
private const string DefaultDashboardUrl = "https://coder.com";
private const string DefaultHostnameSuffix = ".coder";

private readonly IServiceProvider _services;
private readonly IRpcController _rpcController;
@@ -53,6 +52,7 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost

[ObservableProperty]
[NotifyPropertyChangedFor(nameof(ShowEnableSection))]
[NotifyPropertyChangedFor(nameof(ShowVpnStartProgressSection))]
[NotifyPropertyChangedFor(nameof(ShowWorkspacesHeader))]
[NotifyPropertyChangedFor(nameof(ShowNoAgentsSection))]
[NotifyPropertyChangedFor(nameof(ShowAgentsSection))]
@@ -63,14 +63,33 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost

[ObservableProperty]
[NotifyPropertyChangedFor(nameof(ShowEnableSection))]
[NotifyPropertyChangedFor(nameof(ShowVpnStartProgressSection))]
[NotifyPropertyChangedFor(nameof(ShowWorkspacesHeader))]
[NotifyPropertyChangedFor(nameof(ShowNoAgentsSection))]
[NotifyPropertyChangedFor(nameof(ShowAgentsSection))]
[NotifyPropertyChangedFor(nameof(ShowAgentOverflowButton))]
[NotifyPropertyChangedFor(nameof(ShowFailedSection))]
public partial string? VpnFailedMessage { get; set; } = null;

public bool ShowEnableSection => VpnFailedMessage is null && VpnLifecycle is not VpnLifecycle.Started;
[ObservableProperty]
[NotifyPropertyChangedFor(nameof(VpnStartProgressIsIndeterminate))]
[NotifyPropertyChangedFor(nameof(VpnStartProgressValueOrDefault))]
public partial int? VpnStartProgressValue { get; set; } = null;

public int VpnStartProgressValueOrDefault => VpnStartProgressValue ?? 0;

[ObservableProperty]
[NotifyPropertyChangedFor(nameof(VpnStartProgressMessageOrDefault))]
public partial string? VpnStartProgressMessage { get; set; } = null;

public string VpnStartProgressMessageOrDefault =>
string.IsNullOrEmpty(VpnStartProgressMessage) ? VpnStartupProgress.DefaultStartProgressMessage : VpnStartProgressMessage;

public bool VpnStartProgressIsIndeterminate => VpnStartProgressValueOrDefault == 0;

public bool ShowEnableSection => VpnFailedMessage is null && VpnLifecycle is not VpnLifecycle.Starting and not VpnLifecycle.Started;

public bool ShowVpnStartProgressSection => VpnFailedMessage is null && VpnLifecycle is VpnLifecycle.Starting;

public bool ShowWorkspacesHeader => VpnFailedMessage is null && VpnLifecycle is VpnLifecycle.Started;

@@ -170,6 +189,20 @@ private void UpdateFromRpcModel(RpcModel rpcModel)
VpnLifecycle = rpcModel.VpnLifecycle;
VpnSwitchActive = rpcModel.VpnLifecycle is VpnLifecycle.Starting or VpnLifecycle.Started;

// VpnStartupProgress is only set when the VPN is starting.
if (rpcModel.VpnLifecycle is VpnLifecycle.Starting && rpcModel.VpnStartupProgress != null)
{
// Convert 0.00-1.00 to 0-100.
var progress = (int)(rpcModel.VpnStartupProgress.Progress * 100);
VpnStartProgressValue = Math.Clamp(progress, 0, 100);
VpnStartProgressMessage = rpcModel.VpnStartupProgress.ToString();
}
else
{
VpnStartProgressValue = null;
VpnStartProgressMessage = null;
}

// Add every known agent.
HashSet<ByteString> workspacesWithAgents = [];
List<AgentViewModel> agents = [];
2 changes: 1 addition & 1 deletion App/Views/Pages/TrayWindowLoginRequiredPage.xaml
Original file line number Diff line number Diff line change
@@ -36,7 +36,7 @@
</HyperlinkButton>

<HyperlinkButton
Command="{x:Bind ViewModel.ExitCommand, Mode=OneWay}"
Command="{x:Bind ViewModel.ExitCommand}"
Margin="-12,-8,-12,-5"
HorizontalAlignment="Stretch"
HorizontalContentAlignment="Left">
11 changes: 10 additions & 1 deletion App/Views/Pages/TrayWindowMainPage.xaml
Original file line number Diff line number Diff line change
@@ -43,6 +43,8 @@
<ProgressRing
Grid.Column="1"
IsActive="{x:Bind ViewModel.VpnLifecycle, Converter={StaticResource ConnectingBoolConverter}, Mode=OneWay}"
IsIndeterminate="{x:Bind ViewModel.VpnStartProgressIsIndeterminate, Mode=OneWay}"
Value="{x:Bind ViewModel.VpnStartProgressValueOrDefault, Mode=OneWay}"
Width="24"
Height="24"
Margin="10,0"
@@ -74,6 +76,13 @@
Visibility="{x:Bind ViewModel.ShowEnableSection, Converter={StaticResource BoolToVisibilityConverter}, Mode=OneWay}"
Foreground="{ThemeResource SystemControlForegroundBaseMediumBrush}" />

<TextBlock
Text="{x:Bind ViewModel.VpnStartProgressMessageOrDefault, Mode=OneWay}"
TextWrapping="Wrap"
Margin="0,6,0,6"
Visibility="{x:Bind ViewModel.ShowVpnStartProgressSection, Converter={StaticResource BoolToVisibilityConverter}, Mode=OneWay}"
Foreground="{ThemeResource SystemControlForegroundBaseMediumBrush}" />

<TextBlock
Text="Workspaces"
FontWeight="semibold"
@@ -344,7 +353,7 @@
Command="{x:Bind ViewModel.ExitCommand, Mode=OneWay}"
Margin="-12,-8,-12,-5"
HorizontalAlignment="Stretch"
HorizontalContentAlignment="Left">
HorizontalContentAlignment="Left">

<TextBlock Text="Exit" Foreground="{ThemeResource DefaultTextForegroundThemeBrush}" />
</HyperlinkButton>
60 changes: 52 additions & 8 deletions Tests.Vpn.Service/DownloaderTest.cs
Original file line number Diff line number Diff line change
@@ -277,8 +277,8 @@ public async Task Download(CancellationToken ct)
var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
NullDownloadValidator.Instance, ct);
await dlTask.Task;
Assert.That(dlTask.TotalBytes, Is.EqualTo(4));
Assert.That(dlTask.BytesRead, Is.EqualTo(4));
Assert.That(dlTask.BytesTotal, Is.EqualTo(4));
Assert.That(dlTask.BytesWritten, Is.EqualTo(4));
Assert.That(dlTask.Progress, Is.EqualTo(1));
Assert.That(dlTask.IsCompleted, Is.True);
Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test"));
@@ -300,18 +300,62 @@ public async Task DownloadSameDest(CancellationToken ct)
NullDownloadValidator.Instance, ct);
var dlTask0 = await startTask0;
await dlTask0.Task;
Assert.That(dlTask0.TotalBytes, Is.EqualTo(5));
Assert.That(dlTask0.BytesRead, Is.EqualTo(5));
Assert.That(dlTask0.BytesTotal, Is.EqualTo(5));
Assert.That(dlTask0.BytesWritten, Is.EqualTo(5));
Assert.That(dlTask0.Progress, Is.EqualTo(1));
Assert.That(dlTask0.IsCompleted, Is.True);
var dlTask1 = await startTask1;
await dlTask1.Task;
Assert.That(dlTask1.TotalBytes, Is.EqualTo(5));
Assert.That(dlTask1.BytesRead, Is.EqualTo(5));
Assert.That(dlTask1.BytesTotal, Is.EqualTo(5));
Assert.That(dlTask1.BytesWritten, Is.EqualTo(5));
Assert.That(dlTask1.Progress, Is.EqualTo(1));
Assert.That(dlTask1.IsCompleted, Is.True);
}

[Test(Description = "Download with X-Original-Content-Length")]
[CancelAfter(30_000)]
public async Task DownloadWithXOriginalContentLength(CancellationToken ct)
{
using var httpServer = new TestHttpServer(async ctx =>
{
ctx.Response.StatusCode = 200;
ctx.Response.Headers.Add("X-Original-Content-Length", "4");
ctx.Response.ContentType = "text/plain";
// Don't set Content-Length.
await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), ct);
});
var url = new Uri(httpServer.BaseUrl + "/test");
var destPath = Path.Combine(_tempDir, "test");
var manager = new Downloader(NullLogger<Downloader>.Instance);
var req = new HttpRequestMessage(HttpMethod.Get, url);
var dlTask = await manager.StartDownloadAsync(req, destPath, NullDownloadValidator.Instance, ct);

await dlTask.Task;
Assert.That(dlTask.BytesTotal, Is.EqualTo(4));
Assert.That(dlTask.BytesWritten, Is.EqualTo(4));
}

[Test(Description = "Download with mismatched Content-Length")]
[CancelAfter(30_000)]
public async Task DownloadWithMismatchedContentLength(CancellationToken ct)
{
using var httpServer = new TestHttpServer(async ctx =>
{
ctx.Response.StatusCode = 200;
ctx.Response.Headers.Add("X-Original-Content-Length", "5"); // incorrect
ctx.Response.ContentType = "text/plain";
await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), ct);
});
var url = new Uri(httpServer.BaseUrl + "/test");
var destPath = Path.Combine(_tempDir, "test");
var manager = new Downloader(NullLogger<Downloader>.Instance);
var req = new HttpRequestMessage(HttpMethod.Get, url);
var dlTask = await manager.StartDownloadAsync(req, destPath, NullDownloadValidator.Instance, ct);

var ex = Assert.ThrowsAsync<IOException>(() => dlTask.Task);
Assert.That(ex.Message, Is.EqualTo("Downloaded file size does not match expected response content length: Expected=5, BytesWritten=4"));
}

[Test(Description = "Download with custom headers")]
[CancelAfter(30_000)]
public async Task WithHeaders(CancellationToken ct)
@@ -347,7 +391,7 @@ public async Task DownloadExisting(CancellationToken ct)
var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
NullDownloadValidator.Instance, ct);
await dlTask.Task;
Assert.That(dlTask.BytesRead, Is.Zero);
Assert.That(dlTask.BytesWritten, Is.Zero);
Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test"));
Assert.That(File.GetLastWriteTime(destPath), Is.LessThan(DateTime.Now - TimeSpan.FromDays(1)));
}
@@ -368,7 +412,7 @@ public async Task DownloadExistingDifferentContent(CancellationToken ct)
var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
NullDownloadValidator.Instance, ct);
await dlTask.Task;
Assert.That(dlTask.BytesRead, Is.EqualTo(4));
Assert.That(dlTask.BytesWritten, Is.EqualTo(4));
Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test"));
Assert.That(File.GetLastWriteTime(destPath), Is.GreaterThan(DateTime.Now - TimeSpan.FromDays(1)));
}
25 changes: 24 additions & 1 deletion Vpn.Proto/vpn.proto
Original file line number Diff line number Diff line change
@@ -60,7 +60,8 @@ message ServiceMessage {
oneof msg {
StartResponse start = 2;
StopResponse stop = 3;
Status status = 4; // either in reply to a StatusRequest or broadcasted
Status status = 4; // either in reply to a StatusRequest or broadcasted
StartProgress start_progress = 5; // broadcasted during startup
}
}

@@ -218,6 +219,28 @@ message StartResponse {
string error_message = 2;
}

// StartProgress is sent from the manager to the client to indicate the
// download/startup progress of the tunnel. This will be sent during the
// processing of a StartRequest before the StartResponse is sent.
//
// Note: this is currently a broadcasted message to all clients due to the
// inability to easily send messages to a specific client in the Speaker
// implementation. If clients are not expecting these messages, they
// should ignore them.
enum StartProgressStage {
Initializing = 0;
Downloading = 1;
Finalizing = 2;
}
message StartProgressDownloadProgress {
uint64 bytes_written = 1;
optional uint64 bytes_total = 2; // unknown in some situations
}
message StartProgress {
StartProgressStage stage = 1;
optional StartProgressDownloadProgress download_progress = 2; // only set when stage == Downloading
}

// StopRequest is a request from the manager to stop the tunnel. The tunnel replies with a
// StopResponse.
message StopRequest {}
91 changes: 59 additions & 32 deletions Vpn.Service/Downloader.cs
Original file line number Diff line number Diff line change
@@ -339,31 +339,35 @@ internal static async Task TaskOrCancellation(Task task, CancellationToken cance
}

/// <summary>
/// Downloads an Url to a file on disk. The download will be written to a temporary file first, then moved to the final
/// Downloads a Url to a file on disk. The download will be written to a temporary file first, then moved to the final
/// destination. The SHA1 of any existing file will be calculated and used as an ETag to avoid downloading the file if
/// it hasn't changed.
/// </summary>
public class DownloadTask
{
private const int BufferSize = 4096;
private const int BufferSize = 64 * 1024;
private const string XOriginalContentLengthHeader = "X-Original-Content-Length"; // overrides Content-Length if available

private static readonly HttpClient HttpClient = new();
private static readonly HttpClient HttpClient = new(new HttpClientHandler
{
AutomaticDecompression = DecompressionMethods.All,
});
private readonly string _destinationDirectory;

private readonly ILogger _logger;

private readonly RaiiSemaphoreSlim _semaphore = new(1, 1);
private readonly IDownloadValidator _validator;
public readonly string DestinationPath;
private readonly string _destinationPath;
private readonly string _tempDestinationPath;

public readonly HttpRequestMessage Request;
public readonly string TempDestinationPath;

public ulong? TotalBytes { get; private set; }
public ulong BytesRead { get; private set; }
public Task Task { get; private set; } = null!; // Set in EnsureStartedAsync

public double? Progress => TotalBytes == null ? null : (double)BytesRead / TotalBytes.Value;
public bool DownloadStarted { get; private set; } // Whether we've received headers yet and started the actual download
public ulong BytesWritten { get; private set; }
public ulong? BytesTotal { get; private set; }
public double? Progress => BytesTotal == null ? null : (double)BytesWritten / BytesTotal.Value;
public bool IsCompleted => Task.IsCompleted;

internal DownloadTask(ILogger logger, HttpRequestMessage req, string destinationPath, IDownloadValidator validator)
@@ -374,17 +378,17 @@ internal DownloadTask(ILogger logger, HttpRequestMessage req, string destination

if (string.IsNullOrWhiteSpace(destinationPath))
throw new ArgumentException("Destination path must not be empty", nameof(destinationPath));
DestinationPath = Path.GetFullPath(destinationPath);
if (Path.EndsInDirectorySeparator(DestinationPath))
throw new ArgumentException($"Destination path '{DestinationPath}' must not end in a directory separator",
_destinationPath = Path.GetFullPath(destinationPath);
if (Path.EndsInDirectorySeparator(_destinationPath))
throw new ArgumentException($"Destination path '{_destinationPath}' must not end in a directory separator",
nameof(destinationPath));

_destinationDirectory = Path.GetDirectoryName(DestinationPath)
_destinationDirectory = Path.GetDirectoryName(_destinationPath)
?? throw new ArgumentException(
$"Destination path '{DestinationPath}' must have a parent directory",
$"Destination path '{_destinationPath}' must have a parent directory",
nameof(destinationPath));

TempDestinationPath = Path.Combine(_destinationDirectory, "." + Path.GetFileName(DestinationPath) +
_tempDestinationPath = Path.Combine(_destinationDirectory, "." + Path.GetFileName(_destinationPath) +
".download-" + Path.GetRandomFileName());
}

@@ -406,9 +410,9 @@ private async Task Start(CancellationToken ct = default)

// If the destination path exists, generate a Coder SHA1 ETag and send
// it in the If-None-Match header to the server.
if (File.Exists(DestinationPath))
if (File.Exists(_destinationPath))
{
await using var stream = File.OpenRead(DestinationPath);
await using var stream = File.OpenRead(_destinationPath);
var etag = Convert.ToHexString(await SHA1.HashDataAsync(stream, ct)).ToLower();
Request.Headers.Add("If-None-Match", "\"" + etag + "\"");
}
@@ -419,11 +423,11 @@ private async Task Start(CancellationToken ct = default)
_logger.LogInformation("File has not been modified, skipping download");
try
{
await _validator.ValidateAsync(DestinationPath, ct);
await _validator.ValidateAsync(_destinationPath, ct);
}
catch (Exception e)
{
_logger.LogWarning(e, "Existing file '{DestinationPath}' failed custom validation", DestinationPath);
_logger.LogWarning(e, "Existing file '{DestinationPath}' failed custom validation", _destinationPath);
throw new Exception("Existing file failed validation after 304 Not Modified", e);
}

@@ -446,24 +450,38 @@ private async Task Start(CancellationToken ct = default)
}

if (res.Content.Headers.ContentLength >= 0)
TotalBytes = (ulong)res.Content.Headers.ContentLength;
BytesTotal = (ulong)res.Content.Headers.ContentLength;

// X-Original-Content-Length overrules Content-Length if set.
if (res.Headers.TryGetValues(XOriginalContentLengthHeader, out var headerValues))
{
// If there are multiple we only look at the first one.
var headerValue = headerValues.ToList().FirstOrDefault();
if (!string.IsNullOrEmpty(headerValue) && ulong.TryParse(headerValue, out var originalContentLength))
BytesTotal = originalContentLength;
else
_logger.LogWarning(
"Failed to parse {XOriginalContentLengthHeader} header value '{HeaderValue}'",
XOriginalContentLengthHeader, headerValue);
}

await Download(res, ct);
}

private async Task Download(HttpResponseMessage res, CancellationToken ct)
{
DownloadStarted = true;
try
{
var sha1 = res.Headers.Contains("ETag") ? SHA1.Create() : null;
FileStream tempFile;
try
{
tempFile = File.Create(TempDestinationPath, BufferSize, FileOptions.SequentialScan);
tempFile = File.Create(_tempDestinationPath, BufferSize, FileOptions.SequentialScan);
}
catch (Exception e)
{
_logger.LogError(e, "Failed to create temporary file '{TempDestinationPath}'", TempDestinationPath);
_logger.LogError(e, "Failed to create temporary file '{TempDestinationPath}'", _tempDestinationPath);
throw;
}

@@ -476,13 +494,14 @@ private async Task Download(HttpResponseMessage res, CancellationToken ct)
{
await tempFile.WriteAsync(buffer.AsMemory(0, n), ct);
sha1?.TransformBlock(buffer, 0, n, null, 0);
BytesRead += (ulong)n;
BytesWritten += (ulong)n;
}
}

if (TotalBytes != null && BytesRead != TotalBytes)
BytesTotal ??= BytesWritten;
if (BytesWritten != BytesTotal)
throw new IOException(
$"Downloaded file size does not match response Content-Length: Content-Length={TotalBytes}, BytesRead={BytesRead}");
$"Downloaded file size does not match expected response content length: Expected={BytesTotal}, BytesWritten={BytesWritten}");

// Verify the ETag if it was sent by the server.
if (res.Headers.Contains("ETag") && sha1 != null)
@@ -497,26 +516,34 @@ private async Task Download(HttpResponseMessage res, CancellationToken ct)

try
{
await _validator.ValidateAsync(TempDestinationPath, ct);
await _validator.ValidateAsync(_tempDestinationPath, ct);
}
catch (Exception e)
{
_logger.LogWarning(e, "Downloaded file '{TempDestinationPath}' failed custom validation",
TempDestinationPath);
_tempDestinationPath);
throw new HttpRequestException("Downloaded file failed validation", e);
}

File.Move(TempDestinationPath, DestinationPath, true);
File.Move(_tempDestinationPath, _destinationPath, true);
}
finally
catch
{
#if DEBUG
_logger.LogWarning("Not deleting temporary file '{TempDestinationPath}' in debug mode",
TempDestinationPath);
_tempDestinationPath);
#else
if (File.Exists(TempDestinationPath))
File.Delete(TempDestinationPath);
try
{
if (File.Exists(_tempDestinationPath))
File.Delete(_tempDestinationPath);
}
catch (Exception e)
{
_logger.LogError(e, "Failed to delete temporary file '{TempDestinationPath}'", _tempDestinationPath);
}
#endif
throw;
}
}
}
82 changes: 78 additions & 4 deletions Vpn.Service/Manager.cs
Original file line number Diff line number Diff line change
@@ -131,6 +131,8 @@ private async ValueTask<StartResponse> HandleClientMessageStart(ClientMessage me
{
try
{
await BroadcastStartProgress(StartProgressStage.Initializing, cancellationToken: ct);

var serverVersion =
await CheckServerVersionAndCredentials(message.Start.CoderUrl, message.Start.ApiToken, ct);
if (_status == TunnelStatus.Started && _lastStartRequest != null &&
@@ -151,10 +153,14 @@ private async ValueTask<StartResponse> HandleClientMessageStart(ClientMessage me
_lastServerVersion = serverVersion;

// TODO: each section of this operation needs a timeout

// Stop the tunnel if it's running so we don't have to worry about
// permissions issues when replacing the binary.
await _tunnelSupervisor.StopAsync(ct);

await DownloadTunnelBinaryAsync(message.Start.CoderUrl, serverVersion.SemVersion, ct);

await BroadcastStartProgress(StartProgressStage.Finalizing, cancellationToken: ct);
await _tunnelSupervisor.StartAsync(_config.TunnelBinaryPath, HandleTunnelRpcMessage,
HandleTunnelRpcError,
ct);
@@ -237,6 +243,9 @@ private void HandleTunnelRpcMessage(ReplyableRpcMessage<ManagerMessage, TunnelMe
_logger.LogWarning("Received unexpected message reply type {MessageType}", message.Message.MsgCase);
break;
case TunnelMessage.MsgOneofCase.Log:
// Ignored. We already log stdout/stderr from the tunnel
// binary.
break;
case TunnelMessage.MsgOneofCase.NetworkSettings:
_logger.LogWarning("Received message type {MessageType} that is not expected on Windows",
message.Message.MsgCase);
@@ -311,12 +320,28 @@ private async ValueTask<Status> CurrentStatus(CancellationToken ct = default)
private async Task BroadcastStatus(TunnelStatus? newStatus = null, CancellationToken ct = default)
{
if (newStatus != null) _status = newStatus.Value;
await _managerRpc.BroadcastAsync(new ServiceMessage
await FallibleBroadcast(new ServiceMessage
{
Status = await CurrentStatus(ct),
}, ct);
}

private async Task FallibleBroadcast(ServiceMessage message, CancellationToken ct = default)
{
// Broadcast the messages out with a low timeout. If clients don't
// receive broadcasts in time, it's not a big deal.
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
cts.CancelAfter(TimeSpan.FromMilliseconds(30));
try
{
await _managerRpc.BroadcastAsync(message, cts.Token);
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Could not broadcast low priority message to all RPC clients: {Message}", message);
}
}

private void HandleTunnelRpcError(Exception e)
{
_logger.LogError(e, "Manager<->Tunnel RPC error");
@@ -425,12 +450,61 @@ private async Task DownloadTunnelBinaryAsync(string baseUrl, SemVersion expected
_logger.LogDebug("Skipping tunnel binary version validation");
}

// Note: all ETag, signature and version validation is performed by the
// DownloadTask.
var downloadTask = await _downloader.StartDownloadAsync(req, _config.TunnelBinaryPath, validators, ct);

// TODO: monitor and report progress when we have a mechanism to do so
// Wait for the download to complete, sending progress updates every
// 50ms.
while (true)
{
// Wait for the download to complete, or for a short delay before
// we send a progress update.
var delayTask = Task.Delay(TimeSpan.FromMilliseconds(50), ct);
var winner = await Task.WhenAny([
downloadTask.Task,
delayTask,
]);
if (winner == downloadTask.Task)
break;

// Task.WhenAny will not throw if the winner was cancelled, so
// check CT afterward and not beforehand.
ct.ThrowIfCancellationRequested();

if (!downloadTask.DownloadStarted)
// Don't send progress updates if we don't know what the
// progress is yet.
continue;

var progress = new StartProgressDownloadProgress
{
BytesWritten = downloadTask.BytesWritten,
};
if (downloadTask.BytesTotal != null)
progress.BytesTotal = downloadTask.BytesTotal.Value;

// Awaiting this will check the checksum (via the ETag) if the file
// exists, and will also validate the signature and version.
await BroadcastStartProgress(StartProgressStage.Downloading, progress, ct);
}

// Await again to re-throw any exceptions that occurred during the
// download.
await downloadTask.Task;

// We don't send a broadcast here as we immediately send one in the
// parent routine.
_logger.LogInformation("Completed downloading VPN binary");
}

private async Task BroadcastStartProgress(StartProgressStage stage, StartProgressDownloadProgress? downloadProgress = null, CancellationToken cancellationToken = default)
{
await FallibleBroadcast(new ServiceMessage
{
StartProgress = new StartProgress
{
Stage = stage,
DownloadProgress = downloadProgress,
},
}, cancellationToken);
}
}
17 changes: 12 additions & 5 deletions Vpn.Service/ManagerRpc.cs
Original file line number Diff line number Diff line change
@@ -127,26 +127,33 @@ public async Task ExecuteAsync(CancellationToken stoppingToken)

public async Task BroadcastAsync(ServiceMessage message, CancellationToken ct)
{
// Sends messages to all clients simultaneously and waits for them all
// to send or fail/timeout.
//
// Looping over a ConcurrentDictionary is exception-safe, but any items
// added or removed during the loop may or may not be included.
foreach (var (clientId, client) in _activeClients)
await Task.WhenAll(_activeClients.Select(async item =>
{
try
{
var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
cts.CancelAfter(5 * 1000);
await client.Speaker.SendMessage(message, cts.Token);
// Enforce upper bound in case a CT with a timeout wasn't
// supplied.
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
cts.CancelAfter(TimeSpan.FromSeconds(2));
await item.Value.Speaker.SendMessage(message, cts.Token);
}
catch (ObjectDisposedException)
{
// The speaker was likely closed while we were iterating.
}
catch (Exception e)
{
_logger.LogWarning(e, "Failed to send message to client {ClientId}", clientId);
_logger.LogWarning(e, "Failed to send message to client {ClientId}", item.Key);
// TODO: this should probably kill the client, but due to the
// async nature of the client handling, calling Dispose
// will not remove the client from the active clients list
}
}));
}

private async Task HandleRpcClientAsync(ulong clientId, Speaker<ServiceMessage, ClientMessage> speaker,
11 changes: 7 additions & 4 deletions Vpn.Service/Program.cs
Original file line number Diff line number Diff line change
@@ -16,10 +16,12 @@ public static class Program
#if !DEBUG
private const string ServiceName = "Coder Desktop";
private const string ConfigSubKey = @"SOFTWARE\Coder Desktop\VpnService";
private const string DefaultLogLevel = "Information";
#else
// This value matches Create-Service.ps1.
private const string ServiceName = "Coder Desktop (Debug)";
private const string ConfigSubKey = @"SOFTWARE\Coder Desktop\DebugVpnService";
private const string DefaultLogLevel = "Debug";
#endif

private const string ManagerConfigSection = "Manager";
@@ -81,6 +83,10 @@ private static async Task BuildAndRun(string[] args)
builder.Services.AddSingleton<ITelemetryEnricher, TelemetryEnricher>();

// Services
builder.Services.AddHostedService<ManagerService>();
builder.Services.AddHostedService<ManagerRpcService>();

// Either run as a Windows service or a console application
if (!Environment.UserInteractive)
{
MainLogger.Information("Running as a windows service");
@@ -91,9 +97,6 @@ private static async Task BuildAndRun(string[] args)
MainLogger.Information("Running as a console application");
}

builder.Services.AddHostedService<ManagerService>();
builder.Services.AddHostedService<ManagerRpcService>();

var host = builder.Build();
Log.Logger = (ILogger)host.Services.GetService(typeof(ILogger))!;
MainLogger.Information("Application is starting");
@@ -108,7 +111,7 @@ private static void AddDefaultConfig(IConfigurationBuilder builder)
["Serilog:Using:0"] = "Serilog.Sinks.File",
["Serilog:Using:1"] = "Serilog.Sinks.Console",

["Serilog:MinimumLevel"] = "Information",
["Serilog:MinimumLevel"] = DefaultLogLevel,
["Serilog:Enrich:0"] = "FromLogContext",

["Serilog:WriteTo:0:Name"] = "File",
8 changes: 3 additions & 5 deletions Vpn.Service/TunnelSupervisor.cs
Original file line number Diff line number Diff line change
@@ -99,18 +99,16 @@ public async Task StartAsync(string binPath,
},
};
// TODO: maybe we should change the log format in the inner binary
// to something without a timestamp
var outLogger = Log.ForContext("SourceContext", "coder-vpn.exe[OUT]");
var errLogger = Log.ForContext("SourceContext", "coder-vpn.exe[ERR]");
// to something without a timestamp
_subprocess.OutputDataReceived += (_, args) =>
{
if (!string.IsNullOrWhiteSpace(args.Data))
outLogger.Debug("{Data}", args.Data);
_logger.LogInformation("stdout: {Data}", args.Data);
};
_subprocess.ErrorDataReceived += (_, args) =>
{
if (!string.IsNullOrWhiteSpace(args.Data))
errLogger.Debug("{Data}", args.Data);
_logger.LogInformation("stderr: {Data}", args.Data);
};

// Pass the other end of the pipes to the subprocess and dispose
2 changes: 1 addition & 1 deletion Vpn/Speaker.cs
Original file line number Diff line number Diff line change
@@ -123,7 +123,7 @@ public async Task StartAsync(CancellationToken ct = default)
// Handshakes should always finish quickly, so enforce a 5s timeout.
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token);
cts.CancelAfter(TimeSpan.FromSeconds(5));
await PerformHandshake(ct);
await PerformHandshake(cts.Token);

// Start ReceiveLoop in the background.
_receiveTask = ReceiveLoop(_cts.Token);