Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit d49de5b

Browse files
authoredJun 6, 2025··
feat: add vpn start progress (#114)
1 parent 74b8658 commit d49de5b

File tree

14 files changed

+464
-68
lines changed

14 files changed

+464
-68
lines changed
 

‎.idea/.idea.Coder.Desktop/.idea/projectSettingsUpdater.xml

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎App/Models/RpcModel.cs

Lines changed: 162 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
using System;
12
using System.Collections.Generic;
3+
using System.Diagnostics;
4+
using Coder.Desktop.App.Converters;
25
using Coder.Desktop.Vpn.Proto;
36

47
namespace Coder.Desktop.App.Models;
@@ -19,11 +22,168 @@ public enum VpnLifecycle
1922
Stopping,
2023
}
2124

25+
public enum VpnStartupStage
26+
{
27+
Unknown,
28+
Initializing,
29+
Downloading,
30+
Finalizing,
31+
}
32+
33+
public class VpnDownloadProgress
34+
{
35+
public ulong BytesWritten { get; set; } = 0;
36+
public ulong? BytesTotal { get; set; } = null; // null means unknown total size
37+
38+
public double Progress
39+
{
40+
get
41+
{
42+
if (BytesTotal is > 0)
43+
{
44+
return (double)BytesWritten / BytesTotal.Value;
45+
}
46+
return 0.0;
47+
}
48+
}
49+
50+
public override string ToString()
51+
{
52+
// TODO: it would be nice if the two suffixes could match
53+
var s = FriendlyByteConverter.FriendlyBytes(BytesWritten);
54+
if (BytesTotal != null)
55+
s += $" of {FriendlyByteConverter.FriendlyBytes(BytesTotal.Value)}";
56+
else
57+
s += " of unknown";
58+
if (BytesTotal != null)
59+
s += $" ({Progress:0%})";
60+
return s;
61+
}
62+
63+
public VpnDownloadProgress Clone()
64+
{
65+
return new VpnDownloadProgress
66+
{
67+
BytesWritten = BytesWritten,
68+
BytesTotal = BytesTotal,
69+
};
70+
}
71+
72+
public static VpnDownloadProgress FromProto(StartProgressDownloadProgress proto)
73+
{
74+
return new VpnDownloadProgress
75+
{
76+
BytesWritten = proto.BytesWritten,
77+
BytesTotal = proto.HasBytesTotal ? proto.BytesTotal : null,
78+
};
79+
}
80+
}
81+
82+
public class VpnStartupProgress
83+
{
84+
public const string DefaultStartProgressMessage = "Starting Coder Connect...";
85+
86+
// Scale the download progress to an overall progress value between these
87+
// numbers.
88+
private const double DownloadProgressMin = 0.05;
89+
private const double DownloadProgressMax = 0.80;
90+
91+
public VpnStartupStage Stage { get; init; } = VpnStartupStage.Unknown;
92+
public VpnDownloadProgress? DownloadProgress { get; init; } = null;
93+
94+
// 0.0 to 1.0
95+
public double Progress
96+
{
97+
get
98+
{
99+
switch (Stage)
100+
{
101+
case VpnStartupStage.Unknown:
102+
case VpnStartupStage.Initializing:
103+
return 0.0;
104+
case VpnStartupStage.Downloading:
105+
var progress = DownloadProgress?.Progress ?? 0.0;
106+
return DownloadProgressMin + (DownloadProgressMax - DownloadProgressMin) * progress;
107+
case VpnStartupStage.Finalizing:
108+
return DownloadProgressMax;
109+
default:
110+
throw new ArgumentOutOfRangeException();
111+
}
112+
}
113+
}
114+
115+
public override string ToString()
116+
{
117+
switch (Stage)
118+
{
119+
case VpnStartupStage.Unknown:
120+
case VpnStartupStage.Initializing:
121+
return DefaultStartProgressMessage;
122+
case VpnStartupStage.Downloading:
123+
var s = "Downloading Coder Connect binary...";
124+
if (DownloadProgress is not null)
125+
{
126+
s += "\n" + DownloadProgress;
127+
}
128+
129+
return s;
130+
case VpnStartupStage.Finalizing:
131+
return "Finalizing Coder Connect startup...";
132+
default:
133+
throw new ArgumentOutOfRangeException();
134+
}
135+
}
136+
137+
public VpnStartupProgress Clone()
138+
{
139+
return new VpnStartupProgress
140+
{
141+
Stage = Stage,
142+
DownloadProgress = DownloadProgress?.Clone(),
143+
};
144+
}
145+
146+
public static VpnStartupProgress FromProto(StartProgress proto)
147+
{
148+
return new VpnStartupProgress
149+
{
150+
Stage = proto.Stage switch
151+
{
152+
StartProgressStage.Initializing => VpnStartupStage.Initializing,
153+
StartProgressStage.Downloading => VpnStartupStage.Downloading,
154+
StartProgressStage.Finalizing => VpnStartupStage.Finalizing,
155+
_ => VpnStartupStage.Unknown,
156+
},
157+
DownloadProgress = proto.Stage is StartProgressStage.Downloading ?
158+
VpnDownloadProgress.FromProto(proto.DownloadProgress) :
159+
null,
160+
};
161+
}
162+
}
163+
22164
public class RpcModel
23165
{
24166
public RpcLifecycle RpcLifecycle { get; set; } = RpcLifecycle.Disconnected;
25167

26-
public VpnLifecycle VpnLifecycle { get; set; } = VpnLifecycle.Unknown;
168+
public VpnLifecycle VpnLifecycle
169+
{
170+
get;
171+
set
172+
{
173+
if (VpnLifecycle != value && value == VpnLifecycle.Starting)
174+
// Reset the startup progress when the VPN lifecycle changes to
175+
// Starting.
176+
VpnStartupProgress = null;
177+
field = value;
178+
}
179+
}
180+
181+
// Nullable because it is only set when the VpnLifecycle is Starting
182+
public VpnStartupProgress? VpnStartupProgress
183+
{
184+
get => VpnLifecycle is VpnLifecycle.Starting ? field ?? new VpnStartupProgress() : null;
185+
set;
186+
}
27187

28188
public IReadOnlyList<Workspace> Workspaces { get; set; } = [];
29189

@@ -35,6 +195,7 @@ public RpcModel Clone()
35195
{
36196
RpcLifecycle = RpcLifecycle,
37197
VpnLifecycle = VpnLifecycle,
198+
VpnStartupProgress = VpnStartupProgress?.Clone(),
38199
Workspaces = Workspaces,
39200
Agents = Agents,
40201
};

‎App/Services/RpcController.cs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,10 @@ public async Task StartVpn(CancellationToken ct = default)
161161
throw new RpcOperationException(
162162
$"Cannot start VPN without valid credentials, current state: {credentials.State}");
163163

164-
MutateState(state => { state.VpnLifecycle = VpnLifecycle.Starting; });
164+
MutateState(state =>
165+
{
166+
state.VpnLifecycle = VpnLifecycle.Starting;
167+
});
165168

166169
ServiceMessage reply;
167170
try
@@ -283,15 +286,28 @@ private void ApplyStatusUpdate(Status status)
283286
});
284287
}
285288

289+
private void ApplyStartProgressUpdate(StartProgress message)
290+
{
291+
MutateState(state =>
292+
{
293+
// The model itself will ignore this value if we're not in the
294+
// starting state.
295+
state.VpnStartupProgress = VpnStartupProgress.FromProto(message);
296+
});
297+
}
298+
286299
private void SpeakerOnReceive(ReplyableRpcMessage<ClientMessage, ServiceMessage> message)
287300
{
288301
switch (message.Message.MsgCase)
289302
{
303+
case ServiceMessage.MsgOneofCase.Start:
304+
case ServiceMessage.MsgOneofCase.Stop:
290305
case ServiceMessage.MsgOneofCase.Status:
291306
ApplyStatusUpdate(message.Message.Status);
292307
break;
293-
case ServiceMessage.MsgOneofCase.Start:
294-
case ServiceMessage.MsgOneofCase.Stop:
308+
case ServiceMessage.MsgOneofCase.StartProgress:
309+
ApplyStartProgressUpdate(message.Message.StartProgress);
310+
break;
295311
case ServiceMessage.MsgOneofCase.None:
296312
default:
297313
// TODO: log unexpected message

‎App/ViewModels/TrayWindowViewModel.cs

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost
2929
{
3030
private const int MaxAgents = 5;
3131
private const string DefaultDashboardUrl = "https://coder.com";
32-
private const string DefaultHostnameSuffix = ".coder";
3332

3433
private readonly IServiceProvider _services;
3534
private readonly IRpcController _rpcController;
@@ -53,6 +52,7 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost
5352

5453
[ObservableProperty]
5554
[NotifyPropertyChangedFor(nameof(ShowEnableSection))]
55+
[NotifyPropertyChangedFor(nameof(ShowVpnStartProgressSection))]
5656
[NotifyPropertyChangedFor(nameof(ShowWorkspacesHeader))]
5757
[NotifyPropertyChangedFor(nameof(ShowNoAgentsSection))]
5858
[NotifyPropertyChangedFor(nameof(ShowAgentsSection))]
@@ -63,14 +63,33 @@ public partial class TrayWindowViewModel : ObservableObject, IAgentExpanderHost
6363

6464
[ObservableProperty]
6565
[NotifyPropertyChangedFor(nameof(ShowEnableSection))]
66+
[NotifyPropertyChangedFor(nameof(ShowVpnStartProgressSection))]
6667
[NotifyPropertyChangedFor(nameof(ShowWorkspacesHeader))]
6768
[NotifyPropertyChangedFor(nameof(ShowNoAgentsSection))]
6869
[NotifyPropertyChangedFor(nameof(ShowAgentsSection))]
6970
[NotifyPropertyChangedFor(nameof(ShowAgentOverflowButton))]
7071
[NotifyPropertyChangedFor(nameof(ShowFailedSection))]
7172
public partial string? VpnFailedMessage { get; set; } = null;
7273

73-
public bool ShowEnableSection => VpnFailedMessage is null && VpnLifecycle is not VpnLifecycle.Started;
74+
[ObservableProperty]
75+
[NotifyPropertyChangedFor(nameof(VpnStartProgressIsIndeterminate))]
76+
[NotifyPropertyChangedFor(nameof(VpnStartProgressValueOrDefault))]
77+
public partial int? VpnStartProgressValue { get; set; } = null;
78+
79+
public int VpnStartProgressValueOrDefault => VpnStartProgressValue ?? 0;
80+
81+
[ObservableProperty]
82+
[NotifyPropertyChangedFor(nameof(VpnStartProgressMessageOrDefault))]
83+
public partial string? VpnStartProgressMessage { get; set; } = null;
84+
85+
public string VpnStartProgressMessageOrDefault =>
86+
string.IsNullOrEmpty(VpnStartProgressMessage) ? VpnStartupProgress.DefaultStartProgressMessage : VpnStartProgressMessage;
87+
88+
public bool VpnStartProgressIsIndeterminate => VpnStartProgressValueOrDefault == 0;
89+
90+
public bool ShowEnableSection => VpnFailedMessage is null && VpnLifecycle is not VpnLifecycle.Starting and not VpnLifecycle.Started;
91+
92+
public bool ShowVpnStartProgressSection => VpnFailedMessage is null && VpnLifecycle is VpnLifecycle.Starting;
7493

7594
public bool ShowWorkspacesHeader => VpnFailedMessage is null && VpnLifecycle is VpnLifecycle.Started;
7695

@@ -170,6 +189,20 @@ private void UpdateFromRpcModel(RpcModel rpcModel)
170189
VpnLifecycle = rpcModel.VpnLifecycle;
171190
VpnSwitchActive = rpcModel.VpnLifecycle is VpnLifecycle.Starting or VpnLifecycle.Started;
172191

192+
// VpnStartupProgress is only set when the VPN is starting.
193+
if (rpcModel.VpnLifecycle is VpnLifecycle.Starting && rpcModel.VpnStartupProgress != null)
194+
{
195+
// Convert 0.00-1.00 to 0-100.
196+
var progress = (int)(rpcModel.VpnStartupProgress.Progress * 100);
197+
VpnStartProgressValue = Math.Clamp(progress, 0, 100);
198+
VpnStartProgressMessage = rpcModel.VpnStartupProgress.ToString();
199+
}
200+
else
201+
{
202+
VpnStartProgressValue = null;
203+
VpnStartProgressMessage = null;
204+
}
205+
173206
// Add every known agent.
174207
HashSet<ByteString> workspacesWithAgents = [];
175208
List<AgentViewModel> agents = [];

‎App/Views/Pages/TrayWindowLoginRequiredPage.xaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
</HyperlinkButton>
3737

3838
<HyperlinkButton
39-
Command="{x:Bind ViewModel.ExitCommand, Mode=OneWay}"
39+
Command="{x:Bind ViewModel.ExitCommand}"
4040
Margin="-12,-8,-12,-5"
4141
HorizontalAlignment="Stretch"
4242
HorizontalContentAlignment="Left">

‎App/Views/Pages/TrayWindowMainPage.xaml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
<ProgressRing
4444
Grid.Column="1"
4545
IsActive="{x:Bind ViewModel.VpnLifecycle, Converter={StaticResource ConnectingBoolConverter}, Mode=OneWay}"
46+
IsIndeterminate="{x:Bind ViewModel.VpnStartProgressIsIndeterminate, Mode=OneWay}"
47+
Value="{x:Bind ViewModel.VpnStartProgressValueOrDefault, Mode=OneWay}"
4648
Width="24"
4749
Height="24"
4850
Margin="10,0"
@@ -74,6 +76,13 @@
7476
Visibility="{x:Bind ViewModel.ShowEnableSection, Converter={StaticResource BoolToVisibilityConverter}, Mode=OneWay}"
7577
Foreground="{ThemeResource SystemControlForegroundBaseMediumBrush}" />
7678

79+
<TextBlock
80+
Text="{x:Bind ViewModel.VpnStartProgressMessageOrDefault, Mode=OneWay}"
81+
TextWrapping="Wrap"
82+
Margin="0,6,0,6"
83+
Visibility="{x:Bind ViewModel.ShowVpnStartProgressSection, Converter={StaticResource BoolToVisibilityConverter}, Mode=OneWay}"
84+
Foreground="{ThemeResource SystemControlForegroundBaseMediumBrush}" />
85+
7786
<TextBlock
7887
Text="Workspaces"
7988
FontWeight="semibold"
@@ -344,7 +353,7 @@
344353
Command="{x:Bind ViewModel.ExitCommand, Mode=OneWay}"
345354
Margin="-12,-8,-12,-5"
346355
HorizontalAlignment="Stretch"
347-
HorizontalContentAlignment="Left">
356+
HorizontalContentAlignment="Left">
348357

349358
<TextBlock Text="Exit" Foreground="{ThemeResource DefaultTextForegroundThemeBrush}" />
350359
</HyperlinkButton>

‎Tests.Vpn.Service/DownloaderTest.cs

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,8 @@ public async Task Download(CancellationToken ct)
277277
var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
278278
NullDownloadValidator.Instance, ct);
279279
await dlTask.Task;
280-
Assert.That(dlTask.TotalBytes, Is.EqualTo(4));
281-
Assert.That(dlTask.BytesRead, Is.EqualTo(4));
280+
Assert.That(dlTask.BytesTotal, Is.EqualTo(4));
281+
Assert.That(dlTask.BytesWritten, Is.EqualTo(4));
282282
Assert.That(dlTask.Progress, Is.EqualTo(1));
283283
Assert.That(dlTask.IsCompleted, Is.True);
284284
Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test"));
@@ -300,18 +300,62 @@ public async Task DownloadSameDest(CancellationToken ct)
300300
NullDownloadValidator.Instance, ct);
301301
var dlTask0 = await startTask0;
302302
await dlTask0.Task;
303-
Assert.That(dlTask0.TotalBytes, Is.EqualTo(5));
304-
Assert.That(dlTask0.BytesRead, Is.EqualTo(5));
303+
Assert.That(dlTask0.BytesTotal, Is.EqualTo(5));
304+
Assert.That(dlTask0.BytesWritten, Is.EqualTo(5));
305305
Assert.That(dlTask0.Progress, Is.EqualTo(1));
306306
Assert.That(dlTask0.IsCompleted, Is.True);
307307
var dlTask1 = await startTask1;
308308
await dlTask1.Task;
309-
Assert.That(dlTask1.TotalBytes, Is.EqualTo(5));
310-
Assert.That(dlTask1.BytesRead, Is.EqualTo(5));
309+
Assert.That(dlTask1.BytesTotal, Is.EqualTo(5));
310+
Assert.That(dlTask1.BytesWritten, Is.EqualTo(5));
311311
Assert.That(dlTask1.Progress, Is.EqualTo(1));
312312
Assert.That(dlTask1.IsCompleted, Is.True);
313313
}
314314

315+
[Test(Description = "Download with X-Original-Content-Length")]
316+
[CancelAfter(30_000)]
317+
public async Task DownloadWithXOriginalContentLength(CancellationToken ct)
318+
{
319+
using var httpServer = new TestHttpServer(async ctx =>
320+
{
321+
ctx.Response.StatusCode = 200;
322+
ctx.Response.Headers.Add("X-Original-Content-Length", "4");
323+
ctx.Response.ContentType = "text/plain";
324+
// Don't set Content-Length.
325+
await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), ct);
326+
});
327+
var url = new Uri(httpServer.BaseUrl + "/test");
328+
var destPath = Path.Combine(_tempDir, "test");
329+
var manager = new Downloader(NullLogger<Downloader>.Instance);
330+
var req = new HttpRequestMessage(HttpMethod.Get, url);
331+
var dlTask = await manager.StartDownloadAsync(req, destPath, NullDownloadValidator.Instance, ct);
332+
333+
await dlTask.Task;
334+
Assert.That(dlTask.BytesTotal, Is.EqualTo(4));
335+
Assert.That(dlTask.BytesWritten, Is.EqualTo(4));
336+
}
337+
338+
[Test(Description = "Download with mismatched Content-Length")]
339+
[CancelAfter(30_000)]
340+
public async Task DownloadWithMismatchedContentLength(CancellationToken ct)
341+
{
342+
using var httpServer = new TestHttpServer(async ctx =>
343+
{
344+
ctx.Response.StatusCode = 200;
345+
ctx.Response.Headers.Add("X-Original-Content-Length", "5"); // incorrect
346+
ctx.Response.ContentType = "text/plain";
347+
await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), ct);
348+
});
349+
var url = new Uri(httpServer.BaseUrl + "/test");
350+
var destPath = Path.Combine(_tempDir, "test");
351+
var manager = new Downloader(NullLogger<Downloader>.Instance);
352+
var req = new HttpRequestMessage(HttpMethod.Get, url);
353+
var dlTask = await manager.StartDownloadAsync(req, destPath, NullDownloadValidator.Instance, ct);
354+
355+
var ex = Assert.ThrowsAsync<IOException>(() => dlTask.Task);
356+
Assert.That(ex.Message, Is.EqualTo("Downloaded file size does not match expected response content length: Expected=5, BytesWritten=4"));
357+
}
358+
315359
[Test(Description = "Download with custom headers")]
316360
[CancelAfter(30_000)]
317361
public async Task WithHeaders(CancellationToken ct)
@@ -347,7 +391,7 @@ public async Task DownloadExisting(CancellationToken ct)
347391
var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
348392
NullDownloadValidator.Instance, ct);
349393
await dlTask.Task;
350-
Assert.That(dlTask.BytesRead, Is.Zero);
394+
Assert.That(dlTask.BytesWritten, Is.Zero);
351395
Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test"));
352396
Assert.That(File.GetLastWriteTime(destPath), Is.LessThan(DateTime.Now - TimeSpan.FromDays(1)));
353397
}
@@ -368,7 +412,7 @@ public async Task DownloadExistingDifferentContent(CancellationToken ct)
368412
var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath,
369413
NullDownloadValidator.Instance, ct);
370414
await dlTask.Task;
371-
Assert.That(dlTask.BytesRead, Is.EqualTo(4));
415+
Assert.That(dlTask.BytesWritten, Is.EqualTo(4));
372416
Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test"));
373417
Assert.That(File.GetLastWriteTime(destPath), Is.GreaterThan(DateTime.Now - TimeSpan.FromDays(1)));
374418
}

‎Vpn.Proto/vpn.proto

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ message ServiceMessage {
6060
oneof msg {
6161
StartResponse start = 2;
6262
StopResponse stop = 3;
63-
Status status = 4; // either in reply to a StatusRequest or broadcasted
63+
Status status = 4; // either in reply to a StatusRequest or broadcasted
64+
StartProgress start_progress = 5; // broadcasted during startup
6465
}
6566
}
6667

@@ -218,6 +219,28 @@ message StartResponse {
218219
string error_message = 2;
219220
}
220221

222+
// StartProgress is sent from the manager to the client to indicate the
223+
// download/startup progress of the tunnel. This will be sent during the
224+
// processing of a StartRequest before the StartResponse is sent.
225+
//
226+
// Note: this is currently a broadcasted message to all clients due to the
227+
// inability to easily send messages to a specific client in the Speaker
228+
// implementation. If clients are not expecting these messages, they
229+
// should ignore them.
230+
enum StartProgressStage {
231+
Initializing = 0;
232+
Downloading = 1;
233+
Finalizing = 2;
234+
}
235+
message StartProgressDownloadProgress {
236+
uint64 bytes_written = 1;
237+
optional uint64 bytes_total = 2; // unknown in some situations
238+
}
239+
message StartProgress {
240+
StartProgressStage stage = 1;
241+
optional StartProgressDownloadProgress download_progress = 2; // only set when stage == Downloading
242+
}
243+
221244
// StopRequest is a request from the manager to stop the tunnel. The tunnel replies with a
222245
// StopResponse.
223246
message StopRequest {}

‎Vpn.Service/Downloader.cs

Lines changed: 59 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -339,31 +339,35 @@ internal static async Task TaskOrCancellation(Task task, CancellationToken cance
339339
}
340340

341341
/// <summary>
342-
/// Downloads an Url to a file on disk. The download will be written to a temporary file first, then moved to the final
342+
/// Downloads a Url to a file on disk. The download will be written to a temporary file first, then moved to the final
343343
/// destination. The SHA1 of any existing file will be calculated and used as an ETag to avoid downloading the file if
344344
/// it hasn't changed.
345345
/// </summary>
346346
public class DownloadTask
347347
{
348-
private const int BufferSize = 4096;
348+
private const int BufferSize = 64 * 1024;
349+
private const string XOriginalContentLengthHeader = "X-Original-Content-Length"; // overrides Content-Length if available
349350

350-
private static readonly HttpClient HttpClient = new();
351+
private static readonly HttpClient HttpClient = new(new HttpClientHandler
352+
{
353+
AutomaticDecompression = DecompressionMethods.All,
354+
});
351355
private readonly string _destinationDirectory;
352356

353357
private readonly ILogger _logger;
354358

355359
private readonly RaiiSemaphoreSlim _semaphore = new(1, 1);
356360
private readonly IDownloadValidator _validator;
357-
public readonly string DestinationPath;
361+
private readonly string _destinationPath;
362+
private readonly string _tempDestinationPath;
358363

359364
public readonly HttpRequestMessage Request;
360-
public readonly string TempDestinationPath;
361365

362-
public ulong? TotalBytes { get; private set; }
363-
public ulong BytesRead { get; private set; }
364366
public Task Task { get; private set; } = null!; // Set in EnsureStartedAsync
365-
366-
public double? Progress => TotalBytes == null ? null : (double)BytesRead / TotalBytes.Value;
367+
public bool DownloadStarted { get; private set; } // Whether we've received headers yet and started the actual download
368+
public ulong BytesWritten { get; private set; }
369+
public ulong? BytesTotal { get; private set; }
370+
public double? Progress => BytesTotal == null ? null : (double)BytesWritten / BytesTotal.Value;
367371
public bool IsCompleted => Task.IsCompleted;
368372

369373
internal DownloadTask(ILogger logger, HttpRequestMessage req, string destinationPath, IDownloadValidator validator)
@@ -374,17 +378,17 @@ internal DownloadTask(ILogger logger, HttpRequestMessage req, string destination
374378

375379
if (string.IsNullOrWhiteSpace(destinationPath))
376380
throw new ArgumentException("Destination path must not be empty", nameof(destinationPath));
377-
DestinationPath = Path.GetFullPath(destinationPath);
378-
if (Path.EndsInDirectorySeparator(DestinationPath))
379-
throw new ArgumentException($"Destination path '{DestinationPath}' must not end in a directory separator",
381+
_destinationPath = Path.GetFullPath(destinationPath);
382+
if (Path.EndsInDirectorySeparator(_destinationPath))
383+
throw new ArgumentException($"Destination path '{_destinationPath}' must not end in a directory separator",
380384
nameof(destinationPath));
381385

382-
_destinationDirectory = Path.GetDirectoryName(DestinationPath)
386+
_destinationDirectory = Path.GetDirectoryName(_destinationPath)
383387
?? throw new ArgumentException(
384-
$"Destination path '{DestinationPath}' must have a parent directory",
388+
$"Destination path '{_destinationPath}' must have a parent directory",
385389
nameof(destinationPath));
386390

387-
TempDestinationPath = Path.Combine(_destinationDirectory, "." + Path.GetFileName(DestinationPath) +
391+
_tempDestinationPath = Path.Combine(_destinationDirectory, "." + Path.GetFileName(_destinationPath) +
388392
".download-" + Path.GetRandomFileName());
389393
}
390394

@@ -406,9 +410,9 @@ private async Task Start(CancellationToken ct = default)
406410

407411
// If the destination path exists, generate a Coder SHA1 ETag and send
408412
// it in the If-None-Match header to the server.
409-
if (File.Exists(DestinationPath))
413+
if (File.Exists(_destinationPath))
410414
{
411-
await using var stream = File.OpenRead(DestinationPath);
415+
await using var stream = File.OpenRead(_destinationPath);
412416
var etag = Convert.ToHexString(await SHA1.HashDataAsync(stream, ct)).ToLower();
413417
Request.Headers.Add("If-None-Match", "\"" + etag + "\"");
414418
}
@@ -419,11 +423,11 @@ private async Task Start(CancellationToken ct = default)
419423
_logger.LogInformation("File has not been modified, skipping download");
420424
try
421425
{
422-
await _validator.ValidateAsync(DestinationPath, ct);
426+
await _validator.ValidateAsync(_destinationPath, ct);
423427
}
424428
catch (Exception e)
425429
{
426-
_logger.LogWarning(e, "Existing file '{DestinationPath}' failed custom validation", DestinationPath);
430+
_logger.LogWarning(e, "Existing file '{DestinationPath}' failed custom validation", _destinationPath);
427431
throw new Exception("Existing file failed validation after 304 Not Modified", e);
428432
}
429433

@@ -446,24 +450,38 @@ private async Task Start(CancellationToken ct = default)
446450
}
447451

448452
if (res.Content.Headers.ContentLength >= 0)
449-
TotalBytes = (ulong)res.Content.Headers.ContentLength;
453+
BytesTotal = (ulong)res.Content.Headers.ContentLength;
454+
455+
// X-Original-Content-Length overrules Content-Length if set.
456+
if (res.Headers.TryGetValues(XOriginalContentLengthHeader, out var headerValues))
457+
{
458+
// If there are multiple we only look at the first one.
459+
var headerValue = headerValues.ToList().FirstOrDefault();
460+
if (!string.IsNullOrEmpty(headerValue) && ulong.TryParse(headerValue, out var originalContentLength))
461+
BytesTotal = originalContentLength;
462+
else
463+
_logger.LogWarning(
464+
"Failed to parse {XOriginalContentLengthHeader} header value '{HeaderValue}'",
465+
XOriginalContentLengthHeader, headerValue);
466+
}
450467

451468
await Download(res, ct);
452469
}
453470

454471
private async Task Download(HttpResponseMessage res, CancellationToken ct)
455472
{
473+
DownloadStarted = true;
456474
try
457475
{
458476
var sha1 = res.Headers.Contains("ETag") ? SHA1.Create() : null;
459477
FileStream tempFile;
460478
try
461479
{
462-
tempFile = File.Create(TempDestinationPath, BufferSize, FileOptions.SequentialScan);
480+
tempFile = File.Create(_tempDestinationPath, BufferSize, FileOptions.SequentialScan);
463481
}
464482
catch (Exception e)
465483
{
466-
_logger.LogError(e, "Failed to create temporary file '{TempDestinationPath}'", TempDestinationPath);
484+
_logger.LogError(e, "Failed to create temporary file '{TempDestinationPath}'", _tempDestinationPath);
467485
throw;
468486
}
469487

@@ -476,13 +494,14 @@ private async Task Download(HttpResponseMessage res, CancellationToken ct)
476494
{
477495
await tempFile.WriteAsync(buffer.AsMemory(0, n), ct);
478496
sha1?.TransformBlock(buffer, 0, n, null, 0);
479-
BytesRead += (ulong)n;
497+
BytesWritten += (ulong)n;
480498
}
481499
}
482500

483-
if (TotalBytes != null && BytesRead != TotalBytes)
501+
BytesTotal ??= BytesWritten;
502+
if (BytesWritten != BytesTotal)
484503
throw new IOException(
485-
$"Downloaded file size does not match response Content-Length: Content-Length={TotalBytes}, BytesRead={BytesRead}");
504+
$"Downloaded file size does not match expected response content length: Expected={BytesTotal}, BytesWritten={BytesWritten}");
486505

487506
// Verify the ETag if it was sent by the server.
488507
if (res.Headers.Contains("ETag") && sha1 != null)
@@ -497,26 +516,34 @@ private async Task Download(HttpResponseMessage res, CancellationToken ct)
497516

498517
try
499518
{
500-
await _validator.ValidateAsync(TempDestinationPath, ct);
519+
await _validator.ValidateAsync(_tempDestinationPath, ct);
501520
}
502521
catch (Exception e)
503522
{
504523
_logger.LogWarning(e, "Downloaded file '{TempDestinationPath}' failed custom validation",
505-
TempDestinationPath);
524+
_tempDestinationPath);
506525
throw new HttpRequestException("Downloaded file failed validation", e);
507526
}
508527

509-
File.Move(TempDestinationPath, DestinationPath, true);
528+
File.Move(_tempDestinationPath, _destinationPath, true);
510529
}
511-
finally
530+
catch
512531
{
513532
#if DEBUG
514533
_logger.LogWarning("Not deleting temporary file '{TempDestinationPath}' in debug mode",
515-
TempDestinationPath);
534+
_tempDestinationPath);
516535
#else
517-
if (File.Exists(TempDestinationPath))
518-
File.Delete(TempDestinationPath);
536+
try
537+
{
538+
if (File.Exists(_tempDestinationPath))
539+
File.Delete(_tempDestinationPath);
540+
}
541+
catch (Exception e)
542+
{
543+
_logger.LogError(e, "Failed to delete temporary file '{TempDestinationPath}'", _tempDestinationPath);
544+
}
519545
#endif
546+
throw;
520547
}
521548
}
522549
}

‎Vpn.Service/Manager.cs

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ private async ValueTask<StartResponse> HandleClientMessageStart(ClientMessage me
131131
{
132132
try
133133
{
134+
await BroadcastStartProgress(StartProgressStage.Initializing, cancellationToken: ct);
135+
134136
var serverVersion =
135137
await CheckServerVersionAndCredentials(message.Start.CoderUrl, message.Start.ApiToken, ct);
136138
if (_status == TunnelStatus.Started && _lastStartRequest != null &&
@@ -151,10 +153,14 @@ private async ValueTask<StartResponse> HandleClientMessageStart(ClientMessage me
151153
_lastServerVersion = serverVersion;
152154

153155
// TODO: each section of this operation needs a timeout
156+
154157
// Stop the tunnel if it's running so we don't have to worry about
155158
// permissions issues when replacing the binary.
156159
await _tunnelSupervisor.StopAsync(ct);
160+
157161
await DownloadTunnelBinaryAsync(message.Start.CoderUrl, serverVersion.SemVersion, ct);
162+
163+
await BroadcastStartProgress(StartProgressStage.Finalizing, cancellationToken: ct);
158164
await _tunnelSupervisor.StartAsync(_config.TunnelBinaryPath, HandleTunnelRpcMessage,
159165
HandleTunnelRpcError,
160166
ct);
@@ -237,6 +243,9 @@ private void HandleTunnelRpcMessage(ReplyableRpcMessage<ManagerMessage, TunnelMe
237243
_logger.LogWarning("Received unexpected message reply type {MessageType}", message.Message.MsgCase);
238244
break;
239245
case TunnelMessage.MsgOneofCase.Log:
246+
// Ignored. We already log stdout/stderr from the tunnel
247+
// binary.
248+
break;
240249
case TunnelMessage.MsgOneofCase.NetworkSettings:
241250
_logger.LogWarning("Received message type {MessageType} that is not expected on Windows",
242251
message.Message.MsgCase);
@@ -311,12 +320,28 @@ private async ValueTask<Status> CurrentStatus(CancellationToken ct = default)
311320
private async Task BroadcastStatus(TunnelStatus? newStatus = null, CancellationToken ct = default)
312321
{
313322
if (newStatus != null) _status = newStatus.Value;
314-
await _managerRpc.BroadcastAsync(new ServiceMessage
323+
await FallibleBroadcast(new ServiceMessage
315324
{
316325
Status = await CurrentStatus(ct),
317326
}, ct);
318327
}
319328

329+
private async Task FallibleBroadcast(ServiceMessage message, CancellationToken ct = default)
330+
{
331+
// Broadcast the messages out with a low timeout. If clients don't
332+
// receive broadcasts in time, it's not a big deal.
333+
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
334+
cts.CancelAfter(TimeSpan.FromMilliseconds(30));
335+
try
336+
{
337+
await _managerRpc.BroadcastAsync(message, cts.Token);
338+
}
339+
catch (Exception ex)
340+
{
341+
_logger.LogWarning(ex, "Could not broadcast low priority message to all RPC clients: {Message}", message);
342+
}
343+
}
344+
320345
private void HandleTunnelRpcError(Exception e)
321346
{
322347
_logger.LogError(e, "Manager<->Tunnel RPC error");
@@ -425,12 +450,61 @@ private async Task DownloadTunnelBinaryAsync(string baseUrl, SemVersion expected
425450
_logger.LogDebug("Skipping tunnel binary version validation");
426451
}
427452

453+
// Note: all ETag, signature and version validation is performed by the
454+
// DownloadTask.
428455
var downloadTask = await _downloader.StartDownloadAsync(req, _config.TunnelBinaryPath, validators, ct);
429456

430-
// TODO: monitor and report progress when we have a mechanism to do so
457+
// Wait for the download to complete, sending progress updates every
458+
// 50ms.
459+
while (true)
460+
{
461+
// Wait for the download to complete, or for a short delay before
462+
// we send a progress update.
463+
var delayTask = Task.Delay(TimeSpan.FromMilliseconds(50), ct);
464+
var winner = await Task.WhenAny([
465+
downloadTask.Task,
466+
delayTask,
467+
]);
468+
if (winner == downloadTask.Task)
469+
break;
470+
471+
// Task.WhenAny will not throw if the winner was cancelled, so
472+
// check CT afterward and not beforehand.
473+
ct.ThrowIfCancellationRequested();
474+
475+
if (!downloadTask.DownloadStarted)
476+
// Don't send progress updates if we don't know what the
477+
// progress is yet.
478+
continue;
479+
480+
var progress = new StartProgressDownloadProgress
481+
{
482+
BytesWritten = downloadTask.BytesWritten,
483+
};
484+
if (downloadTask.BytesTotal != null)
485+
progress.BytesTotal = downloadTask.BytesTotal.Value;
431486

432-
// Awaiting this will check the checksum (via the ETag) if the file
433-
// exists, and will also validate the signature and version.
487+
await BroadcastStartProgress(StartProgressStage.Downloading, progress, ct);
488+
}
489+
490+
// Await again to re-throw any exceptions that occurred during the
491+
// download.
434492
await downloadTask.Task;
493+
494+
// We don't send a broadcast here as we immediately send one in the
495+
// parent routine.
496+
_logger.LogInformation("Completed downloading VPN binary");
497+
}
498+
499+
private async Task BroadcastStartProgress(StartProgressStage stage, StartProgressDownloadProgress? downloadProgress = null, CancellationToken cancellationToken = default)
500+
{
501+
await FallibleBroadcast(new ServiceMessage
502+
{
503+
StartProgress = new StartProgress
504+
{
505+
Stage = stage,
506+
DownloadProgress = downloadProgress,
507+
},
508+
}, cancellationToken);
435509
}
436510
}

‎Vpn.Service/ManagerRpc.cs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,26 +127,33 @@ public async Task ExecuteAsync(CancellationToken stoppingToken)
127127

128128
public async Task BroadcastAsync(ServiceMessage message, CancellationToken ct)
129129
{
130+
// Sends messages to all clients simultaneously and waits for them all
131+
// to send or fail/timeout.
132+
//
130133
// Looping over a ConcurrentDictionary is exception-safe, but any items
131134
// added or removed during the loop may or may not be included.
132-
foreach (var (clientId, client) in _activeClients)
135+
await Task.WhenAll(_activeClients.Select(async item =>
136+
{
133137
try
134138
{
135-
var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
136-
cts.CancelAfter(5 * 1000);
137-
await client.Speaker.SendMessage(message, cts.Token);
139+
// Enforce upper bound in case a CT with a timeout wasn't
140+
// supplied.
141+
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct);
142+
cts.CancelAfter(TimeSpan.FromSeconds(2));
143+
await item.Value.Speaker.SendMessage(message, cts.Token);
138144
}
139145
catch (ObjectDisposedException)
140146
{
141147
// The speaker was likely closed while we were iterating.
142148
}
143149
catch (Exception e)
144150
{
145-
_logger.LogWarning(e, "Failed to send message to client {ClientId}", clientId);
151+
_logger.LogWarning(e, "Failed to send message to client {ClientId}", item.Key);
146152
// TODO: this should probably kill the client, but due to the
147153
// async nature of the client handling, calling Dispose
148154
// will not remove the client from the active clients list
149155
}
156+
}));
150157
}
151158

152159
private async Task HandleRpcClientAsync(ulong clientId, Speaker<ServiceMessage, ClientMessage> speaker,

‎Vpn.Service/Program.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ public static class Program
1616
#if !DEBUG
1717
private const string ServiceName = "Coder Desktop";
1818
private const string ConfigSubKey = @"SOFTWARE\Coder Desktop\VpnService";
19+
private const string DefaultLogLevel = "Information";
1920
#else
2021
// This value matches Create-Service.ps1.
2122
private const string ServiceName = "Coder Desktop (Debug)";
2223
private const string ConfigSubKey = @"SOFTWARE\Coder Desktop\DebugVpnService";
24+
private const string DefaultLogLevel = "Debug";
2325
#endif
2426

2527
private const string ManagerConfigSection = "Manager";
@@ -81,6 +83,10 @@ private static async Task BuildAndRun(string[] args)
8183
builder.Services.AddSingleton<ITelemetryEnricher, TelemetryEnricher>();
8284

8385
// Services
86+
builder.Services.AddHostedService<ManagerService>();
87+
builder.Services.AddHostedService<ManagerRpcService>();
88+
89+
// Either run as a Windows service or a console application
8490
if (!Environment.UserInteractive)
8591
{
8692
MainLogger.Information("Running as a windows service");
@@ -91,9 +97,6 @@ private static async Task BuildAndRun(string[] args)
9197
MainLogger.Information("Running as a console application");
9298
}
9399

94-
builder.Services.AddHostedService<ManagerService>();
95-
builder.Services.AddHostedService<ManagerRpcService>();
96-
97100
var host = builder.Build();
98101
Log.Logger = (ILogger)host.Services.GetService(typeof(ILogger))!;
99102
MainLogger.Information("Application is starting");
@@ -108,7 +111,7 @@ private static void AddDefaultConfig(IConfigurationBuilder builder)
108111
["Serilog:Using:0"] = "Serilog.Sinks.File",
109112
["Serilog:Using:1"] = "Serilog.Sinks.Console",
110113

111-
["Serilog:MinimumLevel"] = "Information",
114+
["Serilog:MinimumLevel"] = DefaultLogLevel,
112115
["Serilog:Enrich:0"] = "FromLogContext",
113116

114117
["Serilog:WriteTo:0:Name"] = "File",

‎Vpn.Service/TunnelSupervisor.cs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,16 @@ public async Task StartAsync(string binPath,
9999
},
100100
};
101101
// TODO: maybe we should change the log format in the inner binary
102-
// to something without a timestamp
103-
var outLogger = Log.ForContext("SourceContext", "coder-vpn.exe[OUT]");
104-
var errLogger = Log.ForContext("SourceContext", "coder-vpn.exe[ERR]");
102+
// to something without a timestamp
105103
_subprocess.OutputDataReceived += (_, args) =>
106104
{
107105
if (!string.IsNullOrWhiteSpace(args.Data))
108-
outLogger.Debug("{Data}", args.Data);
106+
_logger.LogInformation("stdout: {Data}", args.Data);
109107
};
110108
_subprocess.ErrorDataReceived += (_, args) =>
111109
{
112110
if (!string.IsNullOrWhiteSpace(args.Data))
113-
errLogger.Debug("{Data}", args.Data);
111+
_logger.LogInformation("stderr: {Data}", args.Data);
114112
};
115113

116114
// Pass the other end of the pipes to the subprocess and dispose

‎Vpn/Speaker.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ public async Task StartAsync(CancellationToken ct = default)
123123
// Handshakes should always finish quickly, so enforce a 5s timeout.
124124
using var cts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token);
125125
cts.CancelAfter(TimeSpan.FromSeconds(5));
126-
await PerformHandshake(ct);
126+
await PerformHandshake(cts.Token);
127127

128128
// Start ReceiveLoop in the background.
129129
_receiveTask = ReceiveLoop(_cts.Token);

0 commit comments

Comments
 (0)
Please sign in to comment.