From 0fb9567a3f804b35ba3320090d0edb454da75564 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Mon, 9 Dec 2024 17:52:21 +0900 Subject: [PATCH 1/3] chore: add Vpn.Service app for Manager - Implements a basic .NET hosted service manager architecture with Microsoft.Extensions - Adds a Manager and ManagerService for handling Manager lifecycle - Adds a ManagerRpcService for managing the RPC server and passing requests to the Manager singleton - Adds a Downloader for handling singleflight for downloading files - Implements downloading with progress reporting, ETag validation, and Authenticode validation --- Coder.Desktop.sln | 20 +- Coder.Desktop.sln.DotSettings | 1 + .../RpcHeaderTest.cs | 0 .../RpcMessageTest.cs | 0 .../RpcRoleTest.cs | 0 .../RpcVersionTest.cs | 0 Tests.Vpn.Proto/Tests.Vpn.Proto.csproj | 35 ++ Tests.Vpn.Service/DownloaderTest.cs | 302 +++++++++++++++++ Tests.Vpn.Service/TestHttpServer.cs | 106 ++++++ Tests.Vpn.Service/Tests.Vpn.Service.csproj | 35 ++ {Tests/Vpn => Tests.Vpn}/SerdesTest.cs | 40 +-- {Tests/Vpn => Tests.Vpn}/SpeakerTest.cs | 92 ++--- .../Tests.Vpn.csproj | 18 +- Tests.Vpn/Utilities/TaskUtilitiesTest.cs | 141 ++++++++ Vpn.Proto/RpcRole.cs | 4 +- Vpn.Proto/Vpn.Proto.csproj | 4 +- Vpn.Service/Downloader.cs | 316 ++++++++++++++++++ Vpn.Service/Manager.cs | 103 ++++++ Vpn.Service/ManagerRpcService.cs | 131 ++++++++ Vpn.Service/ManagerService.cs | 33 ++ Vpn.Service/Program.cs | 10 + Vpn.Service/Vpn.Service.csproj | 21 ++ Vpn/Utilities/TaskUtilities.cs | 2 +- 23 files changed, 1336 insertions(+), 78 deletions(-) rename {Tests/Vpn.Proto => Tests.Vpn.Proto}/RpcHeaderTest.cs (100%) rename {Tests/Vpn.Proto => Tests.Vpn.Proto}/RpcMessageTest.cs (100%) rename {Tests/Vpn.Proto => Tests.Vpn.Proto}/RpcRoleTest.cs (100%) rename {Tests/Vpn.Proto => Tests.Vpn.Proto}/RpcVersionTest.cs (100%) create mode 100644 Tests.Vpn.Proto/Tests.Vpn.Proto.csproj create mode 100644 Tests.Vpn.Service/DownloaderTest.cs create mode 100644 Tests.Vpn.Service/TestHttpServer.cs create mode 100644 Tests.Vpn.Service/Tests.Vpn.Service.csproj rename {Tests/Vpn => Tests.Vpn}/SerdesTest.cs (75%) rename {Tests/Vpn => Tests.Vpn}/SpeakerTest.cs (89%) rename Tests/Tests.csproj => Tests.Vpn/Tests.Vpn.csproj (56%) create mode 100644 Tests.Vpn/Utilities/TaskUtilitiesTest.cs create mode 100644 Vpn.Service/Downloader.cs create mode 100644 Vpn.Service/Manager.cs create mode 100644 Vpn.Service/ManagerRpcService.cs create mode 100644 Vpn.Service/ManagerService.cs create mode 100644 Vpn.Service/Program.cs create mode 100644 Vpn.Service/Vpn.Service.csproj diff --git a/Coder.Desktop.sln b/Coder.Desktop.sln index 342963b..5c8fb15 100644 --- a/Coder.Desktop.sln +++ b/Coder.Desktop.sln @@ -5,7 +5,13 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Vpn", "Vpn\Vp EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Vpn.Proto", "Vpn.Proto\Vpn.Proto.csproj", "{318E78BB-E6AD-410F-8F3F-B680F6880293}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Tests", "Tests\Tests.csproj", "{D247B2E7-38A0-4A69-A710-7E8FAA7B807E}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tests.Vpn", "Tests.Vpn\Tests.Vpn.csproj", "{D247B2E7-38A0-4A69-A710-7E8FAA7B807E}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Vpn.Service", "Vpn.Service\Vpn.Service.csproj", "{51B91794-0A2A-4F84-9935-8E17DD2AB260}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tests.Vpn.Proto", "Tests.Vpn.Proto\Tests.Vpn.Proto.csproj", "{AA3EEFF4-414B-4A83-8ACF-188C3C61CCE1}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tests.Vpn.Service", "Tests.Vpn.Service\Tests.Vpn.Service.csproj", "{D32E5FE1-C251-4A08-8EBE-B8D4F18A36F1}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -25,5 +31,17 @@ Global {D247B2E7-38A0-4A69-A710-7E8FAA7B807E}.Debug|Any CPU.Build.0 = Debug|Any CPU {D247B2E7-38A0-4A69-A710-7E8FAA7B807E}.Release|Any CPU.ActiveCfg = Release|Any CPU {D247B2E7-38A0-4A69-A710-7E8FAA7B807E}.Release|Any CPU.Build.0 = Release|Any CPU + {51B91794-0A2A-4F84-9935-8E17DD2AB260}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {51B91794-0A2A-4F84-9935-8E17DD2AB260}.Debug|Any CPU.Build.0 = Debug|Any CPU + {51B91794-0A2A-4F84-9935-8E17DD2AB260}.Release|Any CPU.ActiveCfg = Release|Any CPU + {51B91794-0A2A-4F84-9935-8E17DD2AB260}.Release|Any CPU.Build.0 = Release|Any CPU + {AA3EEFF4-414B-4A83-8ACF-188C3C61CCE1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {AA3EEFF4-414B-4A83-8ACF-188C3C61CCE1}.Debug|Any CPU.Build.0 = Debug|Any CPU + {AA3EEFF4-414B-4A83-8ACF-188C3C61CCE1}.Release|Any CPU.ActiveCfg = Release|Any CPU + {AA3EEFF4-414B-4A83-8ACF-188C3C61CCE1}.Release|Any CPU.Build.0 = Release|Any CPU + {D32E5FE1-C251-4A08-8EBE-B8D4F18A36F1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D32E5FE1-C251-4A08-8EBE-B8D4F18A36F1}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D32E5FE1-C251-4A08-8EBE-B8D4F18A36F1}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D32E5FE1-C251-4A08-8EBE-B8D4F18A36F1}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection EndGlobal diff --git a/Coder.Desktop.sln.DotSettings b/Coder.Desktop.sln.DotSettings index 636b95d..5aa4ae7 100644 --- a/Coder.Desktop.sln.DotSettings +++ b/Coder.Desktop.sln.DotSettings @@ -253,4 +253,5 @@ </Patterns> True + True True \ No newline at end of file diff --git a/Tests/Vpn.Proto/RpcHeaderTest.cs b/Tests.Vpn.Proto/RpcHeaderTest.cs similarity index 100% rename from Tests/Vpn.Proto/RpcHeaderTest.cs rename to Tests.Vpn.Proto/RpcHeaderTest.cs diff --git a/Tests/Vpn.Proto/RpcMessageTest.cs b/Tests.Vpn.Proto/RpcMessageTest.cs similarity index 100% rename from Tests/Vpn.Proto/RpcMessageTest.cs rename to Tests.Vpn.Proto/RpcMessageTest.cs diff --git a/Tests/Vpn.Proto/RpcRoleTest.cs b/Tests.Vpn.Proto/RpcRoleTest.cs similarity index 100% rename from Tests/Vpn.Proto/RpcRoleTest.cs rename to Tests.Vpn.Proto/RpcRoleTest.cs diff --git a/Tests/Vpn.Proto/RpcVersionTest.cs b/Tests.Vpn.Proto/RpcVersionTest.cs similarity index 100% rename from Tests/Vpn.Proto/RpcVersionTest.cs rename to Tests.Vpn.Proto/RpcVersionTest.cs diff --git a/Tests.Vpn.Proto/Tests.Vpn.Proto.csproj b/Tests.Vpn.Proto/Tests.Vpn.Proto.csproj new file mode 100644 index 0000000..54b7b33 --- /dev/null +++ b/Tests.Vpn.Proto/Tests.Vpn.Proto.csproj @@ -0,0 +1,35 @@ + + + + Coder.Desktop.Tests.Vpn.Proto + net8.0 + enable + enable + + false + true + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + + + + + diff --git a/Tests.Vpn.Service/DownloaderTest.cs b/Tests.Vpn.Service/DownloaderTest.cs new file mode 100644 index 0000000..c1e0335 --- /dev/null +++ b/Tests.Vpn.Service/DownloaderTest.cs @@ -0,0 +1,302 @@ +using System.Security.Cryptography; +using System.Text; +using Coder.Desktop.Vpn.Service; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Coder.Desktop.Tests.Vpn.Service; + +public class TestDownloadValidator(Exception e) : IDownloadValidator +{ + public Task ValidateAsync(string path, CancellationToken ct = default) + { + throw e; + } +} + +[TestFixture] +public class AuthenticodeDownloadValidatorTest +{ + [Test(Description = "Test an unsigned binary")] + [CancelAfter(30_000)] + public void Unsigned(CancellationToken ct) + { + // TODO: this + } + + [Test(Description = "Test an untrusted binary")] + [CancelAfter(30_000)] + public void Untrusted(CancellationToken ct) + { + // TODO: this + } + + [Test(Description = "Test an binary with a detached signature (catalog file)")] + [CancelAfter(30_000)] + public void DifferentCertTrusted(CancellationToken ct) + { + // notepad.exe uses a catalog file for its signature. + var ex = Assert.ThrowsAsync(() => + AuthenticodeDownloadValidator.Coder.ValidateAsync(@"C:\Windows\System32\notepad.exe", ct)); + Assert.That(ex.Message, + Does.Contain("File is not signed with an embedded Authenticode signature: Kind=Catalog")); + } + + [Test(Description = "Test a binary signed by a different certificate")] + [CancelAfter(30_000)] + public void DifferentCertUntrusted(CancellationToken ct) + { + // TODO: this + } + + [Test(Description = "Test a binary signed by Coder's certificate")] + [CancelAfter(30_000)] + public async Task CoderSigned(CancellationToken ct) + { + // TODO: this + await Task.CompletedTask; + } +} + +[TestFixture] +public class DownloaderTest +{ + // FYI, SetUp and TearDown get called before and after each test. + [SetUp] + public void Setup() + { + _tempDir = Path.Combine(Path.GetTempPath(), "Coder.Desktop.Tests.Vpn.Service_" + Path.GetRandomFileName()); + Directory.CreateDirectory(_tempDir); + } + + [TearDown] + public void TearDown() + { + Directory.Delete(_tempDir, true); + } + + private string _tempDir; + + private static TestHttpServer EchoServer() + { + // Create webserver that replies to `/xyz` with a test file containing + // `xyz`. + return new TestHttpServer(async ctx => + { + // Get the path without the leading slash. + var path = ctx.Request.Url!.AbsolutePath[1..]; + var pathBytes = Encoding.UTF8.GetBytes(path); + + // If the client sends an If-None-Match header with the correct ETag, + // return 304 Not Modified. + var etag = "\"" + Convert.ToHexString(SHA1.HashData(pathBytes)).ToLower() + "\""; + if (ctx.Request.Headers["If-None-Match"] == etag) + { + ctx.Response.StatusCode = 304; + return; + } + + ctx.Response.StatusCode = 200; + ctx.Response.Headers.Add("ETag", etag); + ctx.Response.ContentType = "text/plain"; + ctx.Response.ContentLength64 = pathBytes.Length; + await ctx.Response.OutputStream.WriteAsync(pathBytes); + }); + } + + [Test(Description = "Perform a download")] + [CancelAfter(30_000)] + public async Task Download(CancellationToken ct) + { + using var httpServer = EchoServer(); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + + var manager = new Downloader(NullLogger.Instance); + var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, + NullDownloadValidator.Instance, ct); + await dlTask.Task; + Assert.That(dlTask.TotalBytes, Is.EqualTo(4)); + Assert.That(dlTask.BytesRead, Is.EqualTo(4)); + Assert.That(dlTask.Progress, Is.EqualTo(1)); + Assert.That(dlTask.IsCompleted, Is.True); + Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test")); + } + + [Test(Description = "Download with custom headers")] + [CancelAfter(30_000)] + public async Task WithHeaders(CancellationToken ct) + { + using var httpServer = new TestHttpServer(ctx => + { + Assert.That(ctx.Request.Headers["X-Custom-Header"], Is.EqualTo("custom-value")); + ctx.Response.StatusCode = 200; + }); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + + var manager = new Downloader(NullLogger.Instance); + var req = new HttpRequestMessage(HttpMethod.Get, url); + req.Headers.Add("X-Custom-Header", "custom-value"); + var dlTask = await manager.StartDownloadAsync(req, destPath, NullDownloadValidator.Instance, ct); + await dlTask.Task; + } + + [Test(Description = "Perform a download against an existing identical file")] + [CancelAfter(30_000)] + public async Task DownloadExisting(CancellationToken ct) + { + using var httpServer = EchoServer(); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + + // Create the destination file with a very old timestamp. + await File.WriteAllTextAsync(destPath, "test", ct); + File.SetLastWriteTime(destPath, DateTime.Now - TimeSpan.FromDays(365)); + + var manager = new Downloader(NullLogger.Instance); + var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, + NullDownloadValidator.Instance, ct); + await dlTask.Task; + Assert.That(dlTask.BytesRead, Is.Zero); + Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test")); + Assert.That(File.GetLastWriteTime(destPath), Is.LessThan(DateTime.Now - TimeSpan.FromDays(1))); + } + + [Test(Description = "Perform a download against an existing file with different content")] + [CancelAfter(30_000)] + public async Task DownloadExistingDifferentContent(CancellationToken ct) + { + using var httpServer = EchoServer(); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + + // Create the destination file with a very old timestamp. + await File.WriteAllTextAsync(destPath, "TEST", ct); + File.SetLastWriteTime(destPath, DateTime.Now - TimeSpan.FromDays(365)); + + var manager = new Downloader(NullLogger.Instance); + var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, + NullDownloadValidator.Instance, ct); + await dlTask.Task; + Assert.That(dlTask.BytesRead, Is.EqualTo(4)); + Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test")); + Assert.That(File.GetLastWriteTime(destPath), Is.GreaterThan(DateTime.Now - TimeSpan.FromDays(1))); + } + + [Test(Description = "Unexpected response code from server")] + [CancelAfter(30_000)] + public void UnexpectedResponseCode(CancellationToken ct) + { + using var httpServer = new TestHttpServer(ctx => { ctx.Response.StatusCode = 404; }); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + + var manager = new Downloader(NullLogger.Instance); + // The "outer" Task should fail. + var ex = Assert.ThrowsAsync(async () => + await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, + NullDownloadValidator.Instance, ct)); + Assert.That(ex.Message, Does.Contain("404")); + } + + // TODO: It would be nice to have a test that tests mismatched + // Content-Length, but it seems HttpListener doesn't allow that. + + [Test(Description = "Mismatched ETag")] + [CancelAfter(30_000)] + public async Task MismatchedETag(CancellationToken ct) + { + using var httpServer = new TestHttpServer(ctx => + { + ctx.Response.StatusCode = 200; + ctx.Response.Headers.Add("ETag", "\"beef\""); + }); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + + var manager = new Downloader(NullLogger.Instance); + // The "inner" Task should fail. + var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, + NullDownloadValidator.Instance, ct); + var ex = Assert.ThrowsAsync(async () => await dlTask.Task); + Assert.That(ex.Message, Does.Contain("ETag does not match SHA1 hash of downloaded file").And.Contains("beef")); + } + + [Test(Description = "Timeout on response headers")] + [CancelAfter(30_000)] + public void CancelledOuter(CancellationToken ct) + { + using var httpServer = new TestHttpServer(async _ => { await Task.Delay(TimeSpan.FromSeconds(5), ct); }); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + + var manager = new Downloader(NullLogger.Instance); + // The "outer" Task should fail. + var smallerCt = new CancellationTokenSource(TimeSpan.FromSeconds(1)).Token; + Assert.ThrowsAsync( + async () => await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, + NullDownloadValidator.Instance, smallerCt)); + } + + [Test(Description = "Timeout on response body")] + [CancelAfter(30_000)] + public async Task CancelledInner(CancellationToken ct) + { + using var httpServer = new TestHttpServer(async ctx => + { + ctx.Response.StatusCode = 200; + await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), ct); + await ctx.Response.OutputStream.FlushAsync(ct); + await Task.Delay(TimeSpan.FromSeconds(5), ct); + }); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + + var manager = new Downloader(NullLogger.Instance); + // The "inner" Task should fail. + var smallerCt = new CancellationTokenSource(TimeSpan.FromSeconds(1)).Token; + var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, + NullDownloadValidator.Instance, smallerCt); + var ex = Assert.ThrowsAsync(async () => await dlTask.Task); + Assert.That(ex.CancellationToken, Is.EqualTo(smallerCt)); + } + + [Test(Description = "Validation failure")] + [CancelAfter(30_000)] + public async Task ValidationFailure(CancellationToken ct) + { + using var httpServer = EchoServer(); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + + var manager = new Downloader(NullLogger.Instance); + var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, + new TestDownloadValidator(new Exception("test exception")), ct); + + var ex = Assert.ThrowsAsync(async () => await dlTask.Task); + Assert.That(ex.Message, Does.Contain("Downloaded file failed validation")); + Assert.That(ex.InnerException, Is.Not.Null); + Assert.That(ex.InnerException!.Message, Is.EqualTo("test exception")); + } + + [Test(Description = "Validation failure on existing file")] + [CancelAfter(30_000)] + public async Task ValidationFailureExistingFile(CancellationToken ct) + { + using var httpServer = EchoServer(); + var url = new Uri(httpServer.BaseUrl + "/test"); + var destPath = Path.Combine(_tempDir, "test"); + await File.WriteAllTextAsync(destPath, "test", ct); + + var manager = new Downloader(NullLogger.Instance); + // The "outer" Task should fail because the inner task never starts. + var ex = Assert.ThrowsAsync(async () => + { + await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, + new TestDownloadValidator(new Exception("test exception")), ct); + }); + Assert.That(ex.Message, Does.Contain("Existing file failed validation")); + Assert.That(ex.InnerException, Is.Not.Null); + Assert.That(ex.InnerException!.Message, Is.EqualTo("test exception")); + } +} diff --git a/Tests.Vpn.Service/TestHttpServer.cs b/Tests.Vpn.Service/TestHttpServer.cs new file mode 100644 index 0000000..d33697f --- /dev/null +++ b/Tests.Vpn.Service/TestHttpServer.cs @@ -0,0 +1,106 @@ +using System.Net; +using System.Text; + +namespace Coder.Desktop.Tests.Vpn.Service; + +public class TestHttpServer : IDisposable +{ + // IANA suggested range for dynamic or private ports + private const int MinPort = 49215; + private const int MaxPort = 65535; + private const int PortRangeSize = MaxPort - MinPort + 1; + + private readonly CancellationTokenSource _cts = new(); + private readonly Func _handler; + private readonly HttpListener _listener; + private readonly Thread _listenerThread; + + public string BaseUrl { get; private set; } + + public TestHttpServer(Action handler) : this(ctx => + { + handler(ctx); + return Task.CompletedTask; + }) + { + } + + public TestHttpServer(Func handler) + { + _handler = handler; + + // Yes, this is the best way to get an unused port using HttpListener. + // It sucks. + // + // This implementation picks a random start point between MinPort and + // MaxPort, then iterates through the entire range (wrapping around at + // the end) until it finds a free port. + var port = 0; + var random = new Random(); + var startPort = random.Next(MinPort, MaxPort + 1); + for (var i = 0; i < PortRangeSize; i++) + { + port = MinPort + (startPort - MinPort + i) % PortRangeSize; + + var attempt = new HttpListener(); + attempt.Prefixes.Add($"http://localhost:{port}/"); + try + { + attempt.Start(); + _listener = attempt; + break; + } + catch + { + // Listener disposes itself on failure + } + } + + if (_listener == null || port == 0) + throw new InvalidOperationException("Could not find a free port to listen on"); + BaseUrl = $"http://localhost:{port}"; + + _listenerThread = new Thread(() => + { + while (!_cts.Token.IsCancellationRequested) + try + { + var context = _listener.GetContext(); + Task.Run(() => HandleRequest(context)); + } + catch (HttpListenerException) when (_cts.Token.IsCancellationRequested) + { + break; + } + }); + + _listenerThread.Start(); + } + + public void Dispose() + { + _cts.Cancel(); + _listener.Stop(); + _listenerThread.Join(); + GC.SuppressFinalize(this); + } + + private async Task HandleRequest(HttpListenerContext context) + { + try + { + await _handler(context); + } + catch (Exception e) + { + await Console.Error.WriteLineAsync($"Exception while serving HTTP request: {e}"); + context.Response.StatusCode = 500; + var response = Encoding.UTF8.GetBytes($"Internal Server Error: {e.Message}"); + await context.Response.OutputStream.WriteAsync(response); + } + finally + { + context.Response.Close(); + } + } +} diff --git a/Tests.Vpn.Service/Tests.Vpn.Service.csproj b/Tests.Vpn.Service/Tests.Vpn.Service.csproj new file mode 100644 index 0000000..2fdfa76 --- /dev/null +++ b/Tests.Vpn.Service/Tests.Vpn.Service.csproj @@ -0,0 +1,35 @@ + + + + Coder.Desktop.Tests.Vpn.Service + net8.0-windows + enable + enable + + false + true + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + + + + + diff --git a/Tests/Vpn/SerdesTest.cs b/Tests.Vpn/SerdesTest.cs similarity index 75% rename from Tests/Vpn/SerdesTest.cs rename to Tests.Vpn/SerdesTest.cs index 7673d6a..cf2c480 100644 --- a/Tests/Vpn/SerdesTest.cs +++ b/Tests.Vpn/SerdesTest.cs @@ -9,8 +9,8 @@ namespace Coder.Desktop.Tests.Vpn; public class SerdesTest { [Test(Description = "Tests that writing and reading a message works")] - [Timeout(5_000)] - public async Task WriteReadMessage() + [CancelAfter(30_000)] + public async Task WriteReadMessage(CancellationToken ct) { var (stream1, stream2) = BidirectionalPipe.New(); var serdes = new Serdes(); @@ -19,14 +19,14 @@ public async Task WriteReadMessage() { Start = new StartRequest(), }; - await serdes.WriteMessage(stream1, msg); - var got = await serdes.ReadMessage(stream2); + await serdes.WriteMessage(stream1, msg, ct); + var got = await serdes.ReadMessage(stream2, ct); Assert.That(msg, Is.EqualTo(got)); } [Test(Description = "Tests that writing a message larger than 16 MiB throws an exception")] - [Timeout(5_000)] - public void WriteMessageTooLarge() + [CancelAfter(30_000)] + public void WriteMessageTooLarge(CancellationToken ct) { var (stream1, _) = BidirectionalPipe.New(); var serdes = new Serdes(); @@ -39,12 +39,12 @@ public void WriteMessageTooLarge() CoderUrl = "test", }, }; - Assert.ThrowsAsync(() => serdes.WriteMessage(stream1, msg)); + Assert.ThrowsAsync(() => serdes.WriteMessage(stream1, msg, ct)); } [Test(Description = "Tests that attempting to read a message larger than 16 MiB throws an exception")] - [Timeout(5_000)] - public async Task ReadMessageTooLarge() + [CancelAfter(30_000)] + public async Task ReadMessageTooLarge(CancellationToken ct) { var (stream1, stream2) = BidirectionalPipe.New(); var serdes = new Serdes(); @@ -53,13 +53,13 @@ public async Task ReadMessageTooLarge() // bail out immediately after reading the message length var lenBytes = new byte[4]; BinaryPrimitives.WriteUInt32BigEndian(lenBytes, 0x1000001); - await stream1.WriteAsync(lenBytes); - Assert.ThrowsAsync(() => serdes.ReadMessage(stream2)); + await stream1.WriteAsync(lenBytes, ct); + Assert.ThrowsAsync(() => serdes.ReadMessage(stream2, ct)); } [Test(Description = "Read an empty (size 0) message from the stream")] - [Timeout(5_000)] - public async Task ReadEmptyMessage() + [CancelAfter(30_000)] + public async Task ReadEmptyMessage(CancellationToken ct) { var (stream1, stream2) = BidirectionalPipe.New(); var serdes = new Serdes(); @@ -67,23 +67,23 @@ public async Task ReadEmptyMessage() // Write an empty message. var lenBytes = new byte[4]; BinaryPrimitives.WriteUInt32BigEndian(lenBytes, 0); - await stream1.WriteAsync(lenBytes); - var ex = Assert.ThrowsAsync(() => serdes.ReadMessage(stream2)); + await stream1.WriteAsync(lenBytes, ct); + var ex = Assert.ThrowsAsync(() => serdes.ReadMessage(stream2, ct)); Assert.That(ex.Message, Does.Contain("Received message size 0")); } [Test(Description = "Read an invalid/corrupt message from the stream")] - [Timeout(5_000)] - public async Task ReadInvalidMessage() + [CancelAfter(30_000)] + public async Task ReadInvalidMessage(CancellationToken ct) { var (stream1, stream2) = BidirectionalPipe.New(); var serdes = new Serdes(); var lenBytes = new byte[4]; BinaryPrimitives.WriteUInt32BigEndian(lenBytes, 1); - await stream1.WriteAsync(lenBytes); - await stream1.WriteAsync(new byte[1]); - var ex = Assert.ThrowsAsync(() => serdes.ReadMessage(stream2)); + await stream1.WriteAsync(lenBytes, ct); + await stream1.WriteAsync(new byte[1], ct); + var ex = Assert.ThrowsAsync(() => serdes.ReadMessage(stream2, ct)); Assert.That(ex.InnerException, Is.TypeOf(typeof(InvalidProtocolBufferException))); } } diff --git a/Tests/Vpn/SpeakerTest.cs b/Tests.Vpn/SpeakerTest.cs similarity index 89% rename from Tests/Vpn/SpeakerTest.cs rename to Tests.Vpn/SpeakerTest.cs index f06c62f..8966c02 100644 --- a/Tests/Vpn/SpeakerTest.cs +++ b/Tests.Vpn/SpeakerTest.cs @@ -88,13 +88,6 @@ internal class FailableStream : Stream private readonly TaskCompletionSource _writeTcs = new(); - public FailableStream(Stream inner, Exception? writeException, Exception? readException) - { - _inner = inner; - if (writeException != null) _writeTcs.SetException(writeException); - if (readException != null) _readTcs.SetException(readException); - } - public override bool CanRead => _inner.CanRead; public override bool CanSeek => _inner.CanSeek; public override bool CanWrite => _inner.CanWrite; @@ -106,6 +99,13 @@ public override long Position set => _inner.Position = value; } + public FailableStream(Stream inner, Exception? writeException, Exception? readException) + { + _inner = inner; + if (writeException != null) _writeTcs.SetException(writeException); + if (readException != null) _readTcs.SetException(readException); + } + public void SetWriteException(Exception ex) { _writeTcs.SetException(ex); @@ -172,8 +172,8 @@ public override async ValueTask WriteAsync(ReadOnlyMemory buffer, public class SpeakerTest { [Test(Description = "Send a message from speaker1 to speaker2, receive it, and send a reply back")] - [Timeout(30_000)] - public async Task SendReceiveReplyReceive() + [CancelAfter(30_000)] + public async Task SendReceiveReplyReceive(CancellationToken ct) { var (stream1, stream2) = BidirectionalPipe.New(); @@ -190,14 +190,14 @@ public async Task SendReceiveReplyReceive() speaker2.Error += ex => { Assert.Fail($"speaker2 error: {ex}"); }; // Start both speakers simultaneously - Task.WaitAll(speaker1.StartAsync(), speaker2.StartAsync()); + await Task.WhenAll(speaker1.StartAsync(ct), speaker2.StartAsync(ct)); // Send a normal message from speaker2 to speaker1 await speaker2.SendMessage(new TunnelMessage { PeerUpdate = new PeerUpdate(), - }); - var receivedMessage = await speaker1Ch.Reader.ReadAsync(); + }, ct); + var receivedMessage = await speaker1Ch.Reader.ReadAsync(ct); Assert.That(receivedMessage.RpcField, Is.Null); // not a request Assert.That(receivedMessage.Message.PeerUpdate, Is.Not.Null); @@ -209,10 +209,10 @@ await speaker2.SendMessage(new TunnelMessage ApiToken = "test", CoderUrl = "test", }, - }); + }, ct); // Receive the message in speaker2 - var message = await speaker2Ch.Reader.ReadAsync(); + var message = await speaker2Ch.Reader.ReadAsync(ct); Assert.That(message.RpcField, Is.Not.Null); Assert.That(message.RpcField!.MsgId, Is.Not.EqualTo(0)); Assert.That(message.RpcField!.ResponseTo, Is.EqualTo(0)); @@ -225,7 +225,7 @@ await message.SendReply(new TunnelMessage { Success = true, }, - }); + }, ct); // Receive the reply in speaker1 by awaiting sendTask var reply = await sendTask; @@ -236,8 +236,8 @@ await message.SendReply(new TunnelMessage } [Test(Description = "Encounter a write error during handshake")] - [Timeout(30_000)] - public async Task WriteError() + [CancelAfter(30_000)] + public async Task WriteError(CancellationToken ct) { var (stream1, _) = BidirectionalPipe.New(); var writeEx = new IOException("Test write error"); @@ -245,13 +245,13 @@ public async Task WriteError() await using var speaker = new Speaker(failStream); - var gotEx = Assert.ThrowsAsync(() => speaker.StartAsync()); + var gotEx = Assert.ThrowsAsync(() => speaker.StartAsync(ct)); Assert.That(gotEx, Is.EqualTo(writeEx)); } [Test(Description = "Encounter a read error during handshake")] - [Timeout(30_000)] - public async Task ReadError() + [CancelAfter(30_000)] + public async Task ReadError(CancellationToken ct) { var (stream1, _) = BidirectionalPipe.New(); var readEx = new IOException("Test read error"); @@ -259,28 +259,28 @@ public async Task ReadError() await using var speaker = new Speaker(failStream); - var gotEx = Assert.ThrowsAsync(() => speaker.StartAsync()); + var gotEx = Assert.ThrowsAsync(() => speaker.StartAsync(ct)); Assert.That(gotEx, Is.EqualTo(readEx)); } [Test(Description = "Receive a header that exceeds 256 bytes")] - [Timeout(30_000)] - public async Task ReadLargeHeader() + [CancelAfter(30_000)] + public async Task ReadLargeHeader(CancellationToken ct) { var (stream1, stream2) = BidirectionalPipe.New(); await using var speaker1 = new Speaker(stream1); var header = new byte[257]; for (var i = 0; i < header.Length; i++) header[i] = (byte)'a'; - await stream2.WriteAsync(header); + await stream2.WriteAsync(header, ct); - var gotEx = Assert.ThrowsAsync(() => speaker1.StartAsync()); + var gotEx = Assert.ThrowsAsync(() => speaker1.StartAsync(ct)); Assert.That(gotEx.Message, Does.Contain("Header malformed or too large")); } [Test(Description = "Receive an invalid header")] - [Timeout(30_000)] - public async Task ReceiveInvalidHeader() + [CancelAfter(30_000)] + public async Task ReceiveInvalidHeader(CancellationToken ct) { var cases = new Dictionary { @@ -302,9 +302,9 @@ public async Task ReceiveInvalidHeader() var (stream1, stream2) = BidirectionalPipe.New(); await using var speaker1 = new Speaker(stream1); - await stream2.WriteAsync(Encoding.UTF8.GetBytes(header)); + await stream2.WriteAsync(Encoding.UTF8.GetBytes(header), ct); - var gotEx = Assert.CatchAsync(() => speaker1.StartAsync(), $"header: '{header}'"); + var gotEx = Assert.CatchAsync(() => speaker1.StartAsync(ct), $"header: '{header}'"); Assert.That(gotEx.Message, Does.Contain(expectedOuter), $"header: '{header}'"); if (expectedInner is null) { @@ -318,8 +318,8 @@ public async Task ReceiveInvalidHeader() } [Test(Description = "Encounter a write error during message send")] - [Timeout(30_000)] - public async Task SendMessageWriteError() + [CancelAfter(30_000)] + public async Task SendMessageWriteError(CancellationToken ct) { var (stream1, stream2) = BidirectionalPipe.New(); var failStream = new FailableStream(stream1, null, null); @@ -330,7 +330,7 @@ public async Task SendMessageWriteError() await using var speaker2 = new Speaker(stream2); speaker2.Receive += msg => Assert.Fail($"speaker2 received message: {msg}"); speaker2.Error += ex => Assert.Fail($"speaker2 error: {ex}"); - await Task.WhenAll(speaker1.StartAsync(), speaker2.StartAsync()); + await Task.WhenAll(speaker1.StartAsync(ct), speaker2.StartAsync(ct)); var writeEx = new IOException("Test write error"); failStream.SetWriteException(writeEx); @@ -338,13 +338,13 @@ public async Task SendMessageWriteError() var gotEx = Assert.ThrowsAsync(() => speaker1.SendMessage(new ManagerMessage { Start = new StartRequest(), - })); + }, ct)); Assert.That(gotEx, Is.EqualTo(writeEx)); } [Test(Description = "Encounter a read error during message receive")] - [Timeout(30_000)] - public async Task ReceiveMessageReadError() + [CancelAfter(30_000)] + public async Task ReceiveMessageReadError(CancellationToken ct) { var (stream1, stream2) = BidirectionalPipe.New(); var failStream = new FailableStream(stream1, null, null); @@ -359,13 +359,13 @@ public async Task ReceiveMessageReadError() await using var speaker2 = new Speaker(stream2); speaker2.Receive += msg => Assert.Fail($"speaker2 received message: {msg}"); speaker2.Error += ex => Assert.Fail($"speaker2 error: {ex}"); - await Task.WhenAll(speaker1.StartAsync(), speaker2.StartAsync()); + await Task.WhenAll(speaker1.StartAsync(ct), speaker2.StartAsync(ct)); // Now the handshake is complete, cause all reads to fail var readEx = new IOException("Test write error"); failStream.SetReadException(readEx); - var gotEx = await errorCh.Reader.ReadAsync(); + var gotEx = await errorCh.Reader.ReadAsync(ct); Assert.That(gotEx, Is.EqualTo(readEx)); // The receive loop should be stopped within a timely fashion. @@ -377,24 +377,24 @@ public async Task ReceiveMessageReadError() } else { - var delayTask = Task.Delay(TimeSpan.FromSeconds(5)); + var delayTask = Task.Delay(TimeSpan.FromSeconds(5), ct); await Task.WhenAny(receiveLoopTask, delayTask); Assert.That(receiveLoopTask.IsCompleted, Is.True); } } [Test(Description = "Handle dispose while receive loop is running")] - [Timeout(30_000)] - public async Task DisposeWhileReceiveLoopRunning() + [CancelAfter(30_000)] + public async Task DisposeWhileReceiveLoopRunning(CancellationToken ct) { var (stream1, stream2) = BidirectionalPipe.New(); var speaker1 = new Speaker(stream1); await using var speaker2 = new Speaker(stream2); - await Task.WhenAll(speaker1.StartAsync(), speaker2.StartAsync()); + await Task.WhenAll(speaker1.StartAsync(ct), speaker2.StartAsync(ct)); // Dispose should happen in a timely fashion var disposeTask = speaker1.DisposeAsync(); - var delayTask = Task.Delay(TimeSpan.FromSeconds(5)); + var delayTask = Task.Delay(TimeSpan.FromSeconds(5), ct); await Task.WhenAny(disposeTask.AsTask(), delayTask); Assert.That(disposeTask.IsCompleted, Is.True); @@ -408,19 +408,19 @@ public async Task DisposeWhileReceiveLoopRunning() } [Test(Description = "Handle dispose while a message is awaiting a reply")] - [Timeout(30_000)] - public async Task DisposeWhileAwaitingReply() + [CancelAfter(30_000)] + public async Task DisposeWhileAwaitingReply(CancellationToken ct) { var (stream1, stream2) = BidirectionalPipe.New(); var speaker1 = new Speaker(stream1); await using var speaker2 = new Speaker(stream2); - await Task.WhenAll(speaker1.StartAsync(), speaker2.StartAsync()); + await Task.WhenAll(speaker1.StartAsync(ct), speaker2.StartAsync(ct)); // Send a message from speaker1 to speaker2 var sendTask = speaker1.SendRequestAwaitReply(new ManagerMessage { Start = new StartRequest(), - }); + }, ct); // Dispose speaker1 await speaker1.DisposeAsync(); diff --git a/Tests/Tests.csproj b/Tests.Vpn/Tests.Vpn.csproj similarity index 56% rename from Tests/Tests.csproj rename to Tests.Vpn/Tests.Vpn.csproj index cccd5dc..f6f2776 100644 --- a/Tests/Tests.csproj +++ b/Tests.Vpn/Tests.Vpn.csproj @@ -1,7 +1,7 @@ - Coder.Desktop.Tests + Coder.Desktop.Tests.Vpn net8.0 enable enable @@ -11,11 +11,17 @@ - - - - - + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + diff --git a/Tests.Vpn/Utilities/TaskUtilitiesTest.cs b/Tests.Vpn/Utilities/TaskUtilitiesTest.cs new file mode 100644 index 0000000..a6a4583 --- /dev/null +++ b/Tests.Vpn/Utilities/TaskUtilitiesTest.cs @@ -0,0 +1,141 @@ +using Coder.Desktop.Vpn.Utilities; + +namespace Coder.Desktop.Tests.Vpn.Utilities; + +[TestFixture] +public class TaskUtilitiesTest +{ + [Test(Description = "CancellableWhenAll with no tasks should complete immediately")] + [Timeout(30_000)] + public void CancellableWhenAll_NoTasks() + { + var task = TaskUtilities.CancellableWhenAll(new CancellationTokenSource()); + Assert.That(task.IsCompleted, Is.True); + } + + [Test(Description = "CancellableWhenAll with a single task should complete")] + [Timeout(30_000)] + public async Task CancellableWhenAll_SingleTask() + { + var innerTask = new TaskCompletionSource(); + var task = TaskUtilities.CancellableWhenAll(new CancellationTokenSource(), innerTask.Task); + Assert.That(task.IsCompleted, Is.False); + innerTask.SetResult(); + await task; + } + + [Test(Description = "CancellableWhenAll with a single task that faults should propagate the exception")] + [Timeout(30_000)] + public void CancellableWhenAll_SingleTaskFault() + { + var cts = new CancellationTokenSource(); + var innerTask = new TaskCompletionSource(); + var task = TaskUtilities.CancellableWhenAll(cts, innerTask.Task); + Assert.That(task.IsCompleted, Is.False); + innerTask.SetException(new InvalidOperationException("Test")); + Assert.ThrowsAsync(async () => await task); + Assert.That(cts.IsCancellationRequested, Is.True); + } + + [Test(Description = "CancellableWhenAll with a single task that is canceled should propagate the cancellation")] + [Timeout(30_000)] + public void CancellableWhenAll_SingleTaskCanceled() + { + var cts = new CancellationTokenSource(); + var innerTask = new TaskCompletionSource(); + var task = TaskUtilities.CancellableWhenAll(cts, innerTask.Task); + Assert.That(task.IsCompleted, Is.False); + innerTask.SetCanceled(); + Assert.ThrowsAsync(async () => await task); + Assert.That(cts.IsCancellationRequested, Is.True); + } + + [Test(Description = "CancellableWhenAll with multiple tasks should complete when all tasks are completed")] + [Timeout(30_000)] + public async Task CancellableWhenAll_MultipleTasks() + { + var cts = new CancellationTokenSource(); + var innerTask1 = new TaskCompletionSource(); + var innerTask2 = new TaskCompletionSource(); + + var task = TaskUtilities.CancellableWhenAll(cts, innerTask1.Task, innerTask2.Task); + Assert.That(task.IsCompleted, Is.False); + // This dance of awaiting a newly added continuation task before + // completing the TCS is to ensure that the original continuation task + // finished since it's inlinable. + var task1ContinueTask = innerTask1.Task.ContinueWith(_ => { }); + innerTask1.SetResult(); + await task1ContinueTask; + Assert.That(task.IsCompleted, Is.False); + var task2ContinueTask = innerTask2.Task.ContinueWith(_ => { }); + innerTask2.SetResult(); + await task2ContinueTask; + await task; + } + + [Test(Description = "CancellableWhenAll with multiple tasks that fault should propagate the first exception only")] + [Timeout(30_000)] + public async Task CancellableWhenAll_MultipleTasksFault() + { + var cts = new CancellationTokenSource(); + var innerTask1 = new TaskCompletionSource(); + var innerTask2 = new TaskCompletionSource(); + + var task = TaskUtilities.CancellableWhenAll(cts, innerTask1.Task, innerTask2.Task); + Assert.That(task.IsCompleted, Is.False); + var task1ContinueTask = innerTask1.Task.ContinueWith(_ => { }); + innerTask1.SetException(new Exception("Test1")); + await task1ContinueTask; + Assert.That(task.IsCompleted, Is.False); + var task2ContinueTask = innerTask2.Task.ContinueWith(_ => { }); + innerTask2.SetException(new Exception("Test2")); + await task2ContinueTask; + var ex = Assert.ThrowsAsync(async () => await task); + Assert.That(ex.Message, Is.EqualTo("Test1")); + } + + [Test(Description = "CancellableWhenAll with an exception and a cancellation should propagate the first thing")] + [Timeout(30_000)] + public async Task CancellableWhenAll_MultipleTasksFaultAndCanceled() + { + var cts = new CancellationTokenSource(); + var innerTask1 = new TaskCompletionSource(); + var innerTask2 = new TaskCompletionSource(); + var innertask3 = Task.CompletedTask; + + var task = TaskUtilities.CancellableWhenAll(cts, innerTask1.Task, innerTask2.Task, innertask3); + Assert.That(task.IsCompleted, Is.False); + var task1ContinueTask = innerTask1.Task.ContinueWith(_ => { }); + innerTask1.SetException(new Exception("Test1")); + await task1ContinueTask; + Assert.That(task.IsCompleted, Is.False); + Assert.That(cts.IsCancellationRequested, Is.True); + var task2ContinueTask = innerTask2.Task.ContinueWith(_ => { }); + innerTask2.SetCanceled(); + await task2ContinueTask; + var ex = Assert.ThrowsAsync(async () => await task); + Assert.That(ex.Message, Is.EqualTo("Test1")); + } + + [Test(Description = "CancellableWhenAll with a cancellation and an exception should propagate the first thing")] + [Timeout(30_000)] + public async Task CancellableWhenAll_MultipleTasksCanceledAndFault() + { + var cts = new CancellationTokenSource(); + var innerTask1 = new TaskCompletionSource(); + var innerTask2 = new TaskCompletionSource(); + var innertask3 = Task.CompletedTask; + + var task = TaskUtilities.CancellableWhenAll(cts, innerTask1.Task, innerTask2.Task, innertask3); + Assert.That(task.IsCompleted, Is.False); + var task1ContinueTask = innerTask1.Task.ContinueWith(_ => { }); + innerTask1.SetCanceled(); + await task1ContinueTask; + Assert.That(task.IsCompleted, Is.False); + Assert.That(cts.IsCancellationRequested, Is.True); + var task2ContinueTask = innerTask2.Task.ContinueWith(_ => { }); + innerTask2.SetException(new Exception("Test2")); + await task2ContinueTask; + Assert.ThrowsAsync(async () => await task); + } +} diff --git a/Vpn.Proto/RpcRole.cs b/Vpn.Proto/RpcRole.cs index 9190281..69f4b48 100644 --- a/Vpn.Proto/RpcRole.cs +++ b/Vpn.Proto/RpcRole.cs @@ -8,6 +8,8 @@ public sealed class RpcRole public const string Manager = "manager"; public const string Tunnel = "tunnel"; + private string Role { get; } + public RpcRole(string role) { if (role != Manager && role != Tunnel) throw new ArgumentException($"Unknown role '{role}'"); @@ -15,8 +17,6 @@ public RpcRole(string role) Role = role; } - private string Role { get; } - public override string ToString() { return Role; diff --git a/Vpn.Proto/Vpn.Proto.csproj b/Vpn.Proto/Vpn.Proto.csproj index 5380bd4..6acb12e 100644 --- a/Vpn.Proto/Vpn.Proto.csproj +++ b/Vpn.Proto/Vpn.Proto.csproj @@ -12,8 +12,8 @@ - - + + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/Vpn.Service/Downloader.cs b/Vpn.Service/Downloader.cs new file mode 100644 index 0000000..4a9542b --- /dev/null +++ b/Vpn.Service/Downloader.cs @@ -0,0 +1,316 @@ +using System.Collections.Concurrent; +using System.Net; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using Microsoft.Extensions.Logging; +using Microsoft.Security.Extensions; + +namespace Coder.Desktop.Vpn.Service; + +public interface IDownloader +{ + Task StartDownloadAsync(HttpRequestMessage req, string destinationPath, IDownloadValidator validator, + CancellationToken ct = default); +} + +public interface IDownloadValidator +{ + /// + /// Validates the downloaded file at the given path. This method should throw an exception if the file is invalid. + /// + /// The path of the file + /// Cancellation token + Task ValidateAsync(string path, CancellationToken ct = default); +} + +public class NullDownloadValidator : IDownloadValidator +{ + public static NullDownloadValidator Instance => new(); + + public Task ValidateAsync(string path, CancellationToken ct = default) + { + return Task.CompletedTask; + } +} + +public class AuthenticodeDownloadValidator : IDownloadValidator +{ + private readonly string _expectedName; + + public static AuthenticodeDownloadValidator Coder => new("Coder Technologies Inc."); + + public AuthenticodeDownloadValidator(string expectedName) + { + if (string.IsNullOrWhiteSpace(expectedName)) + throw new ArgumentException("Expected name must not be empty", nameof(expectedName)); + _expectedName = expectedName; + } + + public async Task ValidateAsync(string path, CancellationToken ct = default) + { + FileSignatureInfo fileSigInfo; + await using (var fileStream = File.OpenRead(path)) + { + fileSigInfo = FileSignatureInfo.GetFromFileStream(fileStream); + } + + if (fileSigInfo.State != SignatureState.SignedAndTrusted) + throw new Exception( + $"File is not signed and trusted with an Authenticode signature: State={fileSigInfo.State}"); + + // Coder will only use embedded signatures because we are downloading + // individual binaries and not installers which can ship catalog files. + if (fileSigInfo.Kind != SignatureKind.Embedded) + throw new Exception($"File is not signed with an embedded Authenticode signature: Kind={fileSigInfo.Kind}"); + + var actualName = fileSigInfo.SigningCertificate.GetNameInfo(X509NameType.SimpleName, false); + if (actualName != _expectedName) + throw new Exception( + $"File is signed by an unexpected certificate: ExpectedName='{_expectedName}', ActualName='{actualName}'"); + } +} + +/// +/// Handles downloading files from the internet. Downloads are performed asynchronously using DownloadTask. +/// Single-flight is provided to avoid performing the same download multiple times. +/// +public class Downloader : IDownloader +{ + private readonly ConcurrentDictionary _downloads = new(); + private readonly ILogger _logger; + + // ReSharper disable once ConvertToPrimaryConstructor + public Downloader(ILogger logger) + { + _logger = logger; + } + + /// + /// Starts a download with the given request. The If-None-Match header will be set to the SHA1 ETag of any existing + /// file in the destination location. + /// + /// Request message + /// Path to write file to (will be overwritten) + /// Validator for the downloaded file + /// Cancellation token + /// A DownloadTask representing the ongoing download operation after it starts + public async Task StartDownloadAsync(HttpRequestMessage req, string destinationPath, + IDownloadValidator validator, CancellationToken ct = default) + { + while (true) + { + var task = _downloads.GetOrAdd(destinationPath, + _ => new DownloadTask(_logger, req, destinationPath, validator)); + await task.EnsureStartedAsync(ct); + + // If the existing (or new) task is for the same URL, return it. + if (task.Request.RequestUri == req.RequestUri) + return task; + + // If the existing task is for a different URL, await its completion + // then retry the loop to create a new task. This could potentially + // get stuck if there are a lot of download operations for different + // URLs and the same destination path, but in our use case this + // shouldn't happen unless the user keeps changing the access URL. + _logger.LogWarning( + "Download for '{DestinationPath}' is already in progress, but is for a different Url - awaiting completion", + destinationPath); + await task.Task; + } + } +} + +/// +/// Downloads an Url to a file on disk. The download will be written to a temporary file first, then moved to the final +/// destination. The SHA1 of any existing file will be calculated and used as an ETag to avoid downloading the file if +/// it hasn't changed. +/// +public class DownloadTask +{ + private const int BufferSize = 4096; + + private static readonly HttpClient HttpClient = new(); + private readonly string _destinationDirectory; + + private readonly ILogger _logger; + + private readonly SemaphoreSlim _semaphore = new(1, 1); + private readonly IDownloadValidator _validator; + public readonly string DestinationPath; + + public readonly HttpRequestMessage Request; + public readonly string TempDestinationPath; + + public ulong? TotalBytes { get; private set; } + public ulong BytesRead { get; private set; } + public Task Task { get; private set; } = null!; // Set in EnsureStartedAsync + + public double? Progress => TotalBytes == null ? null : (double)BytesRead / TotalBytes.Value; + public bool IsCompleted => Task.IsCompleted; + + internal DownloadTask(ILogger logger, HttpRequestMessage req, string destinationPath, IDownloadValidator validator) + { + _logger = logger; + Request = req; + _validator = validator; + + if (string.IsNullOrWhiteSpace(destinationPath)) + throw new ArgumentException("Destination path must not be empty", nameof(destinationPath)); + DestinationPath = Path.GetFullPath(destinationPath); + if (Path.EndsInDirectorySeparator(DestinationPath)) + throw new ArgumentException($"Destination path '{DestinationPath}' must not end in a directory separator", + nameof(destinationPath)); + + _destinationDirectory = Path.GetDirectoryName(DestinationPath) + ?? throw new ArgumentException( + $"Destination path '{DestinationPath}' must have a parent directory", + nameof(destinationPath)); + + TempDestinationPath = Path.Combine(_destinationDirectory, "." + Path.GetFileName(DestinationPath) + + ".download-" + Path.GetRandomFileName()); + } + + internal async Task EnsureStartedAsync(CancellationToken ct = default) + { + await _semaphore.WaitAsync(ct); + try + { + if (Task == null!) + Task = await StartDownloadAsync(ct); + } + finally + { + _semaphore.Release(); + } + + return Task; + } + + /// + /// Starts downloading the file. The request will be performed in this task, but once started, the task will complete + /// and the download will continue in the background. The provided CancellationToken can be used to cancel the + /// download. + /// + private async Task StartDownloadAsync(CancellationToken ct = default) + { + Directory.CreateDirectory(_destinationDirectory); + + // If the destination path exists, generate a Coder SHA1 ETag and send + // it in the If-None-Match header to the server. + if (File.Exists(DestinationPath)) + { + await using var stream = File.OpenRead(DestinationPath); + var etag = Convert.ToHexString(await SHA1.HashDataAsync(stream, ct)).ToLower(); + Request.Headers.Add("If-None-Match", "\"" + etag + "\""); + } + + var res = await HttpClient.SendAsync(Request, HttpCompletionOption.ResponseHeadersRead, ct); + if (res.StatusCode == HttpStatusCode.NotModified) + { + _logger.LogInformation("File has not been modified, skipping download"); + try + { + await _validator.ValidateAsync(DestinationPath, ct); + } + catch (Exception e) + { + _logger.LogWarning(e, "Existing file '{DestinationPath}' failed custom validation", DestinationPath); + throw new Exception("Existing file failed validation after 304 Not Modified", e); + } + + Task = Task.CompletedTask; + return Task; + } + + if (res.StatusCode != HttpStatusCode.OK) + { + _logger.LogWarning("Failed to download file '{Request.RequestUri}': {StatusCode} {ReasonPhrase}", + Request.RequestUri, res.StatusCode, + res.ReasonPhrase); + throw new HttpRequestException( + $"Failed to download file '{Request.RequestUri}': {(int)res.StatusCode} {res.ReasonPhrase}"); + } + + if (res.Content == null) + { + _logger.LogWarning("File {Request.RequestUri} has no content", Request.RequestUri); + throw new HttpRequestException("Response has no content"); + } + + if (res.Content.Headers.ContentLength >= 0) + TotalBytes = (ulong)res.Content.Headers.ContentLength; + + FileStream tempFile; + try + { + tempFile = File.Create(TempDestinationPath, BufferSize, + FileOptions.Asynchronous | FileOptions.SequentialScan); + } + catch (Exception e) + { + _logger.LogError(e, "Failed to create temporary file '{TempDestinationPath}'", TempDestinationPath); + throw; + } + + Task = DownloadAsync(res, tempFile, ct); + return Task; + } + + private async Task DownloadAsync(HttpResponseMessage res, FileStream tempFile, CancellationToken ct) + { + try + { + var sha1 = res.Headers.Contains("ETag") ? SHA1.Create() : null; + await using (tempFile) + { + var stream = await res.Content.ReadAsStreamAsync(ct); + var buffer = new byte[BufferSize]; + int n; + while ((n = await stream.ReadAsync(buffer, ct)) > 0) + { + await tempFile.WriteAsync(buffer.AsMemory(0, n), ct); + sha1?.TransformBlock(buffer, 0, n, null, 0); + BytesRead += (ulong)n; + } + } + + if (TotalBytes != null && BytesRead != TotalBytes) + throw new IOException( + $"Downloaded file size does not match response Content-Length: Content-Length={TotalBytes}, BytesRead={BytesRead}"); + + // Verify the ETag if it was sent by the server. + if (res.Headers.Contains("ETag") && sha1 != null) + { + var etag = res.Headers.ETag!.Tag.Trim('"'); + _ = sha1.TransformFinalBlock([], 0, 0); + var hashStr = Convert.ToHexString(sha1.Hash!).ToLower(); + if (etag != hashStr) + throw new HttpRequestException( + $"ETag does not match SHA1 hash of downloaded file: ETag='{etag}', Local='{hashStr}'"); + } + + try + { + await _validator.ValidateAsync(TempDestinationPath, ct); + } + catch (Exception e) + { + _logger.LogWarning(e, "Downloaded file '{TempDestinationPath}' failed custom validation", + TempDestinationPath); + throw new HttpRequestException("Downloaded file failed validation", e); + } + + File.Move(TempDestinationPath, DestinationPath, true); + } + finally + { +#if DEBUG + _logger.LogWarning("Not deleting temporary file '{TempDestinationPath}' in debug mode", + TempDestinationPath); +#else + if (File.Exists(TempDestinationPath)) + File.Delete(TempDestinationPath); +#endif + } + } +} diff --git a/Vpn.Service/Manager.cs b/Vpn.Service/Manager.cs new file mode 100644 index 0000000..0767dc3 --- /dev/null +++ b/Vpn.Service/Manager.cs @@ -0,0 +1,103 @@ +using System.Runtime.InteropServices; +using Coder.Desktop.Vpn.Proto; +using Microsoft.Extensions.Logging; + +namespace Coder.Desktop.Vpn.Service; + +public interface IManager +{ + public Task HandleClientRpcMessage(ReplyableRpcMessage message, + CancellationToken ct = default); + + public Task StopAsync(CancellationToken ct = default); +} + +public class Manager : IManager, IAsyncDisposable +{ + private const string DestinationPath = "C:\\coder-vpn.exe"; + private readonly IDownloader _downloader; + + private readonly ILogger _logger; + + // ReSharper disable once ConvertToPrimaryConstructor + public Manager(ILogger logger, IDownloader downloader) + { + _logger = logger; + _downloader = downloader; + } + + public async ValueTask DisposeAsync() + { + await Task.CompletedTask; + GC.SuppressFinalize(this); + } + + public async Task HandleClientRpcMessage(ReplyableRpcMessage message, + CancellationToken ct = default) + { + switch (message.Message.MsgCase) + { + default: + _logger.LogWarning("Received unknown message type {MessageType}", message.Message.MsgCase); + break; + } + } + + public async Task StopAsync(CancellationToken ct = default) + { + // TODO: implement once we have process supervision + await Task.CompletedTask; + } + + /// + /// Returns the architecture of the current system. + /// + /// A golang architecture string for the binary + /// Unsupported architecture + private static string SystemArchitecture() + { + return RuntimeInformation.ProcessArchitecture switch + { + Architecture.X64 => "amd64", + Architecture.Arm64 => "arm64", + // We only support amd64 and arm64 on Windows currently. + _ => throw new PlatformNotSupportedException( + "Unsupported architecture. Coder only supports amd64 and arm64."), + }; + } + + /// + /// Fetches the "/bin/coder-windows-{architecture}.exe" binary from the given base URL and writes it to the + /// destination path after validating the signature and checksum. + /// + /// + /// + /// + private async Task DownloadVPNClientAsync(string baseUrl, CancellationToken ct = default) + { + var architecture = SystemArchitecture(); + Uri url; + try + { + url = new Uri(baseUrl, UriKind.Absolute); + if (url.PathAndQuery != "/") + throw new ArgumentException("Base URL must not contain a path", nameof(baseUrl)); + url = new Uri(url, $"/bin/coder-windows-{architecture}.exe"); + } + catch (Exception e) + { + throw new ArgumentException($"Invalid base URL '{baseUrl}'", e); + } + + _logger.LogInformation("Downloading VPN binary from '{url}' to '{DestinationPath}'", url, DestinationPath); + var downloadTask = + await _downloader.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), DestinationPath, + AuthenticodeDownloadValidator.Coder, ct); + + // TODO: monitor and report progress when we have a mechanism to do so + + // Awaiting this will check the checksum (via the ETag) if provided, + // and will also validate the signature using the validator we supplied. + await downloadTask.Task; + } +} diff --git a/Vpn.Service/ManagerRpcService.cs b/Vpn.Service/ManagerRpcService.cs new file mode 100644 index 0000000..228fc30 --- /dev/null +++ b/Vpn.Service/ManagerRpcService.cs @@ -0,0 +1,131 @@ +using System.Collections.Concurrent; +using System.IO.Pipes; +using Coder.Desktop.Vpn.Proto; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; + +namespace Coder.Desktop.Vpn.Service; + +/// +/// Provides a named pipe server for communication between multiple RpcRole.Client and RpcRole.Manager. +/// +public class ManagerRpcService : BackgroundService, IAsyncDisposable +{ + // TODO: make configurable with registry? + private const string PipeName = "Coder.Desktop.Vpn"; + private readonly ConcurrentDictionary _activeClientTasks = new(); + + private readonly CancellationTokenSource _cts = new(); + + private readonly ILogger _logger; + private readonly IManager _manager; + + // ReSharper disable once ConvertToPrimaryConstructor + public ManagerRpcService(ILogger logger, IManager manager) + { + _logger = logger; + _manager = manager; + } + + public async ValueTask DisposeAsync() + { + await _cts.CancelAsync(); + while (!_activeClientTasks.IsEmpty) await Task.WhenAny(_activeClientTasks.Values); + _cts.Dispose(); + GC.SuppressFinalize(this); + } + + public override async Task StopAsync(CancellationToken cancellationToken) + { + await _cts.CancelAsync(); + while (!_activeClientTasks.IsEmpty) await Task.WhenAny(_activeClientTasks.Values); + } + + /// + /// Starts the named pipe server, listens for incoming connections and starts handling them asynchronously. + /// + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + _logger.LogInformation(@"Starting continuous named pipe RPC server at \\.\pipe\{PipeName}", PipeName); + using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(stoppingToken, _cts.Token); + while (!linkedCts.IsCancellationRequested) + { + _logger.LogDebug($"Creating named pipe server {PipeName}"); + var pipeServer = new NamedPipeServerStream(PipeName, PipeDirection.InOut, + NamedPipeServerStream.MaxAllowedServerInstances, PipeTransmissionMode.Byte, PipeOptions.Asynchronous); + + try + { + try + { + _logger.LogDebug("Waiting for named pipe client connection"); + await pipeServer.WaitForConnectionAsync(linkedCts.Token); + } + finally + { + await pipeServer.DisposeAsync(); + } + + _logger.LogInformation("Handling named pipe client connection"); + var clientTask = HandleRpcClientAsync(pipeServer, linkedCts.Token); + _activeClientTasks.TryAdd(clientTask.Id, clientTask); + _ = clientTask.ContinueWith(RpcClientContinuation, CancellationToken.None); + } + catch (OperationCanceledException) + { + throw; + } + catch (Exception e) + { + _logger.LogWarning(e, "Failed to accept named pipe client"); + } + } + } + + private async Task HandleRpcClientAsync(NamedPipeServerStream pipeServer, CancellationToken ct) + { + var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token); + await using (pipeServer) + { + // TODO: use ClientMessage once it's ready + await using var speaker = new Speaker(pipeServer); + + var tcs = new TaskCompletionSource(); + var activeTasks = new ConcurrentDictionary(); + speaker.Receive += msg => + { + var task = HandleRpcMessageAsync(msg, linkedCts.Token); + activeTasks.TryAdd(task.Id, task); + task.ContinueWith(t => + { + if (t.IsFaulted) + _logger.LogWarning(t.Exception, "Client RPC message handler task faulted"); + activeTasks.TryRemove(t.Id, out _); + }, CancellationToken.None); + }; + speaker.Error += tcs.SetException; + await using (ct.Register(() => tcs.SetCanceled(ct))) + { + await speaker.StartAsync(ct); + await tcs.Task; + await linkedCts.CancelAsync(); + while (!activeTasks.IsEmpty) + await Task.WhenAny(activeTasks.Values); + } + } + } + + private void RpcClientContinuation(Task task) + { + if (task.IsFaulted) + _logger.LogWarning(task.Exception, "Client RPC task faulted"); + _activeClientTasks.TryRemove(task.Id, out _); + } + + private async Task HandleRpcMessageAsync(ReplyableRpcMessage message, + CancellationToken ct) + { + _logger.LogInformation("Received RPC message: {Message}", message.Message); + await _manager.HandleClientRpcMessage(message, ct); + } +} diff --git a/Vpn.Service/ManagerService.cs b/Vpn.Service/ManagerService.cs new file mode 100644 index 0000000..b7b2e34 --- /dev/null +++ b/Vpn.Service/ManagerService.cs @@ -0,0 +1,33 @@ +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; + +namespace Coder.Desktop.Vpn.Service; + +/// +/// Wraps Manager to provide a BackgroundService that informs the singleton Manager to shut down when stop is +/// requested. +/// +public class ManagerService : BackgroundService +{ + private readonly ILogger _logger; + private readonly IManager _manager; + + // ReSharper disable once ConvertToPrimaryConstructor + public ManagerService(ILogger logger, IManager manager) + { + _logger = logger; + _manager = manager; + } + + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + // Block until the service is stopped. + await Task.Delay(-1, stoppingToken); + } + + public override async Task StopAsync(CancellationToken cancellationToken) + { + _logger.LogInformation("Informing Manager to stop"); + await _manager.StopAsync(cancellationToken); + } +} diff --git a/Vpn.Service/Program.cs b/Vpn.Service/Program.cs new file mode 100644 index 0000000..24b55cf --- /dev/null +++ b/Vpn.Service/Program.cs @@ -0,0 +1,10 @@ +using Coder.Desktop.Vpn.Service; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; + +var builder = Host.CreateApplicationBuilder(args); +builder.Services.AddSingleton(); +builder.Services.AddSingleton(); +builder.Services.AddHostedService(); +builder.Services.AddHostedService(); +builder.Build().Run(); diff --git a/Vpn.Service/Vpn.Service.csproj b/Vpn.Service/Vpn.Service.csproj new file mode 100644 index 0000000..33ee897 --- /dev/null +++ b/Vpn.Service/Vpn.Service.csproj @@ -0,0 +1,21 @@ + + + + Coder.Desktop.Vpn.Service + Exe + net8.0-windows + enable + enable + + + + + + + + + + + + + diff --git a/Vpn/Utilities/TaskUtilities.cs b/Vpn/Utilities/TaskUtilities.cs index 8a2bfdb..4105c9e 100644 --- a/Vpn/Utilities/TaskUtilities.cs +++ b/Vpn/Utilities/TaskUtilities.cs @@ -1,6 +1,6 @@ namespace Coder.Desktop.Vpn.Utilities; -internal static class TaskUtilities +public static class TaskUtilities { /// /// Waits for all tasks to complete, but cancels the provided CancellationTokenSource if any task is canceled or From de6f40e7090f122f527a396d79501ca5f4d1492b Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Thu, 12 Dec 2024 19:25:02 +0900 Subject: [PATCH 2/3] Add configuration, remove RpcRole, add CoderSdk --- Coder.Desktop.sln | 22 +- Coder.Desktop.sln.DotSettings | 1 + CoderSdk/CoderApiClient.cs | 81 ++++++ CoderSdk/CoderSdk.csproj | 9 + CoderSdk/Deployment.cs | 22 ++ CoderSdk/Users.cs | 17 ++ Tests.Vpn.Proto/RpcHeaderTest.cs | 9 +- Tests.Vpn.Proto/RpcMessageTest.cs | 26 +- Tests.Vpn.Proto/RpcRoleTest.cs | 22 -- Tests.Vpn.Service/DownloaderTest.cs | 53 ++++ Tests.Vpn/SerdesTest.cs | 11 +- Tests.Vpn/SpeakerTest.cs | 95 +------- Tests.Vpn/Tests.Vpn.csproj | 1 - Vpn.Proto/RpcHeader.cs | 14 +- Vpn.Proto/RpcMessage.cs | 55 ++++- Vpn.Proto/RpcRole.cs | 56 ----- Vpn.Proto/vpn.proto | 26 +- Vpn.Service/Downloader.cs | 64 +++-- Vpn.Service/Manager.cs | 152 ++++++++++-- Vpn.Service/ManagerConfig.cs | 16 ++ Vpn.Service/ManagerRpcService.cs | 22 +- Vpn.Service/Program.cs | 20 ++ Vpn.Service/RegistryConfigurationSource.cs | 23 ++ Vpn.Service/TunnelSupervisor.cs | 271 +++++++++++++++++++++ Vpn.Service/Vpn.Service.csproj | 3 + Vpn/Serdes.cs | 24 +- Vpn/Speaker.cs | 8 +- Vpn/Utilities/BidirectionalPipe.cs | 92 +++++++ Vpn/Utilities/RaiiSemaphoreSlim.cs | 30 +++ Vpn/Vpn.csproj | 4 + 30 files changed, 960 insertions(+), 289 deletions(-) create mode 100644 CoderSdk/CoderApiClient.cs create mode 100644 CoderSdk/CoderSdk.csproj create mode 100644 CoderSdk/Deployment.cs create mode 100644 CoderSdk/Users.cs delete mode 100644 Tests.Vpn.Proto/RpcRoleTest.cs delete mode 100644 Vpn.Proto/RpcRole.cs create mode 100644 Vpn.Service/ManagerConfig.cs create mode 100644 Vpn.Service/RegistryConfigurationSource.cs create mode 100644 Vpn.Service/TunnelSupervisor.cs create mode 100644 Vpn/Utilities/BidirectionalPipe.cs create mode 100644 Vpn/Utilities/RaiiSemaphoreSlim.cs diff --git a/Coder.Desktop.sln b/Coder.Desktop.sln index 5c8fb15..c5fc598 100644 --- a/Coder.Desktop.sln +++ b/Coder.Desktop.sln @@ -5,13 +5,15 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Vpn", "Vpn\Vp EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Vpn.Proto", "Vpn.Proto\Vpn.Proto.csproj", "{318E78BB-E6AD-410F-8F3F-B680F6880293}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tests.Vpn", "Tests.Vpn\Tests.Vpn.csproj", "{D247B2E7-38A0-4A69-A710-7E8FAA7B807E}" -EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Vpn.Service", "Vpn.Service\Vpn.Service.csproj", "{51B91794-0A2A-4F84-9935-8E17DD2AB260}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tests.Vpn.Proto", "Tests.Vpn.Proto\Tests.Vpn.Proto.csproj", "{AA3EEFF4-414B-4A83-8ACF-188C3C61CCE1}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Tests.Vpn", "Tests.Vpn\Tests.Vpn.csproj", "{D247B2E7-38A0-4A69-A710-7E8FAA7B807E}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Tests.Vpn.Proto", "Tests.Vpn.Proto\Tests.Vpn.Proto.csproj", "{AA3EEFF4-414B-4A83-8ACF-188C3C61CCE1}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.Tests.Vpn.Service", "Tests.Vpn.Service\Tests.Vpn.Service.csproj", "{D32E5FE1-C251-4A08-8EBE-B8D4F18A36F1}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tests.Vpn.Service", "Tests.Vpn.Service\Tests.Vpn.Service.csproj", "{D32E5FE1-C251-4A08-8EBE-B8D4F18A36F1}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Coder.Desktop.CoderSdk", "CoderSdk\CoderSdk.csproj", "{A3D2B2B3-A051-46BD-A190-5487A9F24C28}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -27,14 +29,14 @@ Global {318E78BB-E6AD-410F-8F3F-B680F6880293}.Debug|Any CPU.Build.0 = Debug|Any CPU {318E78BB-E6AD-410F-8F3F-B680F6880293}.Release|Any CPU.ActiveCfg = Release|Any CPU {318E78BB-E6AD-410F-8F3F-B680F6880293}.Release|Any CPU.Build.0 = Release|Any CPU - {D247B2E7-38A0-4A69-A710-7E8FAA7B807E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {D247B2E7-38A0-4A69-A710-7E8FAA7B807E}.Debug|Any CPU.Build.0 = Debug|Any CPU - {D247B2E7-38A0-4A69-A710-7E8FAA7B807E}.Release|Any CPU.ActiveCfg = Release|Any CPU - {D247B2E7-38A0-4A69-A710-7E8FAA7B807E}.Release|Any CPU.Build.0 = Release|Any CPU {51B91794-0A2A-4F84-9935-8E17DD2AB260}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {51B91794-0A2A-4F84-9935-8E17DD2AB260}.Debug|Any CPU.Build.0 = Debug|Any CPU {51B91794-0A2A-4F84-9935-8E17DD2AB260}.Release|Any CPU.ActiveCfg = Release|Any CPU {51B91794-0A2A-4F84-9935-8E17DD2AB260}.Release|Any CPU.Build.0 = Release|Any CPU + {D247B2E7-38A0-4A69-A710-7E8FAA7B807E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D247B2E7-38A0-4A69-A710-7E8FAA7B807E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D247B2E7-38A0-4A69-A710-7E8FAA7B807E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D247B2E7-38A0-4A69-A710-7E8FAA7B807E}.Release|Any CPU.Build.0 = Release|Any CPU {AA3EEFF4-414B-4A83-8ACF-188C3C61CCE1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {AA3EEFF4-414B-4A83-8ACF-188C3C61CCE1}.Debug|Any CPU.Build.0 = Debug|Any CPU {AA3EEFF4-414B-4A83-8ACF-188C3C61CCE1}.Release|Any CPU.ActiveCfg = Release|Any CPU @@ -43,5 +45,9 @@ Global {D32E5FE1-C251-4A08-8EBE-B8D4F18A36F1}.Debug|Any CPU.Build.0 = Debug|Any CPU {D32E5FE1-C251-4A08-8EBE-B8D4F18A36F1}.Release|Any CPU.ActiveCfg = Release|Any CPU {D32E5FE1-C251-4A08-8EBE-B8D4F18A36F1}.Release|Any CPU.Build.0 = Release|Any CPU + {A3D2B2B3-A051-46BD-A190-5487A9F24C28}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A3D2B2B3-A051-46BD-A190-5487A9F24C28}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A3D2B2B3-A051-46BD-A190-5487A9F24C28}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A3D2B2B3-A051-46BD-A190-5487A9F24C28}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection EndGlobal diff --git a/Coder.Desktop.sln.DotSettings b/Coder.Desktop.sln.DotSettings index 5aa4ae7..70c3a3b 100644 --- a/Coder.Desktop.sln.DotSettings +++ b/Coder.Desktop.sln.DotSettings @@ -254,4 +254,5 @@ True True + True True \ No newline at end of file diff --git a/CoderSdk/CoderApiClient.cs b/CoderSdk/CoderApiClient.cs new file mode 100644 index 0000000..90343f3 --- /dev/null +++ b/CoderSdk/CoderApiClient.cs @@ -0,0 +1,81 @@ +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace CoderSdk; + +/// +/// Changes names from PascalCase to snake_case. +/// +internal class SnakeCaseNamingPolicy : JsonNamingPolicy +{ + public override string ConvertName(string name) + { + return string.Concat( + name.Select((x, i) => i > 0 && char.IsUpper(x) ? "_" + char.ToLower(x) : char.ToLower(x).ToString()) + ); + } +} + +/// +/// Provides a limited selection of API methods for a Coder instance. +/// +public partial class CoderApiClient +{ + // TODO: allow adding headers + private readonly HttpClient _httpClient = new(); + private readonly JsonSerializerOptions _jsonOptions; + + public CoderApiClient(string baseUrl) + { + var url = new Uri(baseUrl, UriKind.Absolute); + if (url.PathAndQuery != "/") + throw new ArgumentException($"Base URL '{baseUrl}' must not contain a path", nameof(baseUrl)); + _httpClient.BaseAddress = url; + _jsonOptions = new JsonSerializerOptions + { + PropertyNameCaseInsensitive = true, + PropertyNamingPolicy = new SnakeCaseNamingPolicy(), + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + }; + } + + public CoderApiClient(string baseUrl, string token) : this(baseUrl) + { + SetSessionToken(token); + } + + public void SetSessionToken(string token) + { + _httpClient.DefaultRequestHeaders.Remove("Coder-Session-Token"); + _httpClient.DefaultRequestHeaders.Add("Coder-Session-Token", token); + } + + private async Task SendRequestAsync(HttpMethod method, string path, + object? payload, CancellationToken ct = default) + { + try + { + var request = new HttpRequestMessage(method, path); + + if (payload is not null) + { + var json = JsonSerializer.Serialize(payload, _jsonOptions); + request.Content = new StringContent(json, Encoding.UTF8, "application/json"); + } + + var res = await _httpClient.SendAsync(request, ct); + // TODO: this should be improved to try and parse a codersdk.Error response + res.EnsureSuccessStatusCode(); + + var content = await res.Content.ReadAsStringAsync(ct); + var data = JsonSerializer.Deserialize(content, _jsonOptions); + if (data is null) throw new JsonException("Deserialized response is null"); + return data; + } + catch (Exception e) + { + throw new Exception($"API Request: {method} {path} (req body: {payload is not null})", e); + } + } +} diff --git a/CoderSdk/CoderSdk.csproj b/CoderSdk/CoderSdk.csproj new file mode 100644 index 0000000..3a63532 --- /dev/null +++ b/CoderSdk/CoderSdk.csproj @@ -0,0 +1,9 @@ + + + + net8.0 + enable + enable + + + diff --git a/CoderSdk/Deployment.cs b/CoderSdk/Deployment.cs new file mode 100644 index 0000000..b00d49f --- /dev/null +++ b/CoderSdk/Deployment.cs @@ -0,0 +1,22 @@ +namespace CoderSdk; + +public class BuildInfo +{ + public string ExternalUrl { get; set; } = ""; + public string Version { get; set; } = ""; + public string DashboardUrl { get; set; } = ""; + public bool Telemetry { get; set; } = false; + public bool WorkspaceProxy { get; set; } = false; + public string AgentApiVersion { get; set; } = ""; + public string ProvisionerApiVersion { get; set; } = ""; + public string UpgradeMessage { get; set; } = ""; + public string DeploymentId { get; set; } = ""; +} + +public partial class CoderApiClient +{ + public Task GetBuildInfo(CancellationToken ct = default) + { + return SendRequestAsync(HttpMethod.Get, "/api/v2/buildinfo", null, ct); + } +} diff --git a/CoderSdk/Users.cs b/CoderSdk/Users.cs new file mode 100644 index 0000000..58ff474 --- /dev/null +++ b/CoderSdk/Users.cs @@ -0,0 +1,17 @@ +namespace CoderSdk; + +public class User +{ + public const string Me = "me"; + + // TODO: fill out more fields + public string Username { get; set; } = ""; +} + +public partial class CoderApiClient +{ + public Task GetUser(string user, CancellationToken ct = default) + { + return SendRequestAsync(HttpMethod.Get, $"/api/v2/users/{user}", null, ct); + } +} diff --git a/Tests.Vpn.Proto/RpcHeaderTest.cs b/Tests.Vpn.Proto/RpcHeaderTest.cs index 8e19d0e..55edeea 100644 --- a/Tests.Vpn.Proto/RpcHeaderTest.cs +++ b/Tests.Vpn.Proto/RpcHeaderTest.cs @@ -11,14 +11,14 @@ public void Valid() { var headerStr = "codervpn manager 1.3,2.1"; var header = RpcHeader.Parse(headerStr); - Assert.That(header.Role.ToString(), Is.EqualTo(RpcRole.Manager)); + Assert.That(header.Role, Is.EqualTo("manager")); Assert.That(header.VersionList, Is.EqualTo(new RpcVersionList(new RpcVersion(1, 3), new RpcVersion(2, 1)))); Assert.That(header.ToString(), Is.EqualTo(headerStr + "\n")); Assert.That(header.ToBytes().ToArray(), Is.EqualTo(Encoding.UTF8.GetBytes(headerStr + "\n"))); headerStr = "codervpn tunnel 1.0"; header = RpcHeader.Parse(headerStr); - Assert.That(header.Role.ToString(), Is.EqualTo(RpcRole.Tunnel)); + Assert.That(header.Role, Is.EqualTo("tunnel")); Assert.That(header.VersionList, Is.EqualTo(new RpcVersionList(new RpcVersion(1, 0)))); Assert.That(header.ToString(), Is.EqualTo(headerStr + "\n")); Assert.That(header.ToBytes().ToArray(), Is.EqualTo(Encoding.UTF8.GetBytes(headerStr + "\n"))); @@ -35,7 +35,8 @@ public void ParseInvalid() Assert.That(ex.Message, Does.Contain("Wrong number of parts")); ex = Assert.Throws(() => RpcHeader.Parse("cats manager 1.0")); Assert.That(ex.Message, Does.Contain("Invalid preamble")); - ex = Assert.Throws(() => RpcHeader.Parse("codervpn cats 1.0")); - Assert.That(ex.Message, Does.Contain("Unknown role 'cats'")); + // RpcHeader doesn't care about the role string as long as it isn't empty. + ex = Assert.Throws(() => RpcHeader.Parse("codervpn 1.0")); + Assert.That(ex.Message, Does.Contain("Invalid role in header string")); } } diff --git a/Tests.Vpn.Proto/RpcMessageTest.cs b/Tests.Vpn.Proto/RpcMessageTest.cs index 36de12d..e254120 100644 --- a/Tests.Vpn.Proto/RpcMessageTest.cs +++ b/Tests.Vpn.Proto/RpcMessageTest.cs @@ -6,18 +6,16 @@ namespace Coder.Desktop.Tests.Vpn.Proto; public class RpcRoleAttributeTest { [Test] - public void Valid() + public void Ok() { - var role = new RpcRoleAttribute(RpcRole.Manager); - Assert.That(role.Role.ToString(), Is.EqualTo(RpcRole.Manager)); - role = new RpcRoleAttribute(RpcRole.Tunnel); - Assert.That(role.Role.ToString(), Is.EqualTo(RpcRole.Tunnel)); - } - - [Test] - public void Invalid() - { - Assert.Throws(() => _ = new RpcRoleAttribute("cats")); + var role = new RpcRoleAttribute("manager"); + Assert.That(role.Role, Is.EqualTo("manager")); + role = new RpcRoleAttribute("tunnel"); + Assert.That(role.Role, Is.EqualTo("tunnel")); + role = new RpcRoleAttribute("service"); + Assert.That(role.Role, Is.EqualTo("service")); + role = new RpcRoleAttribute("client"); + Assert.That(role.Role, Is.EqualTo("client")); } } @@ -33,7 +31,9 @@ public void GetRole() Assert.That(ex.Message, Does.Contain("Message type 'Coder.Desktop.Vpn.Proto.RPC' does not have a RpcRoleAttribute")); - Assert.That(ManagerMessage.GetRole().ToString(), Is.EqualTo(RpcRole.Manager)); - Assert.That(TunnelMessage.GetRole().ToString(), Is.EqualTo(RpcRole.Tunnel)); + Assert.That(ManagerMessage.GetRole(), Is.EqualTo("manager")); + Assert.That(TunnelMessage.GetRole(), Is.EqualTo("tunnel")); + Assert.That(ServiceMessage.GetRole(), Is.EqualTo("service")); + Assert.That(ClientMessage.GetRole(), Is.EqualTo("client")); } } diff --git a/Tests.Vpn.Proto/RpcRoleTest.cs b/Tests.Vpn.Proto/RpcRoleTest.cs deleted file mode 100644 index f39d5cb..0000000 --- a/Tests.Vpn.Proto/RpcRoleTest.cs +++ /dev/null @@ -1,22 +0,0 @@ -using Coder.Desktop.Vpn.Proto; - -namespace Coder.Desktop.Tests.Vpn.Proto; - -[TestFixture] -public class RpcRoleTest -{ - [Test(Description = "Instantiate a RpcRole with a valid name")] - public void ValidRole() - { - var role = new RpcRole(RpcRole.Manager); - Assert.That(role.ToString(), Is.EqualTo(RpcRole.Manager)); - role = new RpcRole(RpcRole.Tunnel); - Assert.That(role.ToString(), Is.EqualTo(RpcRole.Tunnel)); - } - - [Test(Description = "Try to instantiate a RpcRole with an invalid name")] - public void InvalidRole() - { - Assert.Throws(() => _ = new RpcRole("cats")); - } -} diff --git a/Tests.Vpn.Service/DownloaderTest.cs b/Tests.Vpn.Service/DownloaderTest.cs index c1e0335..952b80b 100644 --- a/Tests.Vpn.Service/DownloaderTest.cs +++ b/Tests.Vpn.Service/DownloaderTest.cs @@ -57,6 +57,59 @@ public async Task CoderSigned(CancellationToken ct) } } +[TestFixture] +public class AssemblyVersionDownloadValidatorTest +{ + [Test(Description = "No version on binary")] + [CancelAfter(30_000)] + public void NoVersion(CancellationToken ct) + { + // TODO: this + } + + [Test(Description = "Version mismatch")] + [CancelAfter(30_000)] + public void VersionMismatch(CancellationToken ct) + { + // TODO: this + } + + [Test(Description = "Version match")] + [CancelAfter(30_000)] + public async Task VersionMatch(CancellationToken ct) + { + // TODO: this + await Task.CompletedTask; + } +} + +[TestFixture] +public class CombinationDownloadValidatorTest +{ + [Test(Description = "All validators pass")] + [CancelAfter(30_000)] + public async Task AllPass(CancellationToken ct) + { + var validator = new CombinationDownloadValidator( + NullDownloadValidator.Instance, + NullDownloadValidator.Instance + ); + await validator.ValidateAsync("test", ct); + } + + [Test(Description = "A validator fails")] + [CancelAfter(30_000)] + public void Fail(CancellationToken ct) + { + var validator = new CombinationDownloadValidator( + NullDownloadValidator.Instance, + new TestDownloadValidator(new Exception("test exception")) + ); + var ex = Assert.ThrowsAsync(() => validator.ValidateAsync("test", ct)); + Assert.That(ex.Message, Is.EqualTo("test exception")); + } +} + [TestFixture] public class DownloaderTest { diff --git a/Tests.Vpn/SerdesTest.cs b/Tests.Vpn/SerdesTest.cs index cf2c480..3266f14 100644 --- a/Tests.Vpn/SerdesTest.cs +++ b/Tests.Vpn/SerdesTest.cs @@ -1,6 +1,7 @@ using System.Buffers.Binary; using Coder.Desktop.Vpn; using Coder.Desktop.Vpn.Proto; +using Coder.Desktop.Vpn.Utilities; using Google.Protobuf; namespace Coder.Desktop.Tests.Vpn; @@ -12,7 +13,7 @@ public class SerdesTest [CancelAfter(30_000)] public async Task WriteReadMessage(CancellationToken ct) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); var serdes = new Serdes(); var msg = new ManagerMessage @@ -28,7 +29,7 @@ public async Task WriteReadMessage(CancellationToken ct) [CancelAfter(30_000)] public void WriteMessageTooLarge(CancellationToken ct) { - var (stream1, _) = BidirectionalPipe.New(); + var (stream1, _) = BidirectionalPipe.NewInMemory(); var serdes = new Serdes(); var msg = new ManagerMessage @@ -46,7 +47,7 @@ public void WriteMessageTooLarge(CancellationToken ct) [CancelAfter(30_000)] public async Task ReadMessageTooLarge(CancellationToken ct) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); var serdes = new Serdes(); // In this test we don't actually write a message as the parser should @@ -61,7 +62,7 @@ public async Task ReadMessageTooLarge(CancellationToken ct) [CancelAfter(30_000)] public async Task ReadEmptyMessage(CancellationToken ct) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); var serdes = new Serdes(); // Write an empty message. @@ -76,7 +77,7 @@ public async Task ReadEmptyMessage(CancellationToken ct) [CancelAfter(30_000)] public async Task ReadInvalidMessage(CancellationToken ct) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); var serdes = new Serdes(); var lenBytes = new byte[4]; diff --git a/Tests.Vpn/SpeakerTest.cs b/Tests.Vpn/SpeakerTest.cs index 8966c02..51950f7 100644 --- a/Tests.Vpn/SpeakerTest.cs +++ b/Tests.Vpn/SpeakerTest.cs @@ -1,84 +1,12 @@ -using System.Buffers; -using System.IO.Pipelines; using System.Reflection; using System.Text; using System.Threading.Channels; using Coder.Desktop.Vpn; using Coder.Desktop.Vpn.Proto; +using Coder.Desktop.Vpn.Utilities; namespace Coder.Desktop.Tests.Vpn; -#region BidrectionalPipe - -internal class BidirectionalPipe(PipeReader reader, PipeWriter writer) : Stream -{ - public override bool CanRead => true; - public override bool CanSeek => false; - public override bool CanWrite => true; - public override long Length => -1; - - public override long Position - { - get => -1; - set => throw new NotImplementedException("BidirectionalPipe does not support setting position"); - } - - public static (BidirectionalPipe, BidirectionalPipe) New() - { - var pipe1 = new Pipe(); - var pipe2 = new Pipe(); - return (new BidirectionalPipe(pipe1.Reader, pipe2.Writer), new BidirectionalPipe(pipe2.Reader, pipe1.Writer)); - } - - public override void Flush() - { - } - - public override int Read(byte[] buffer, int offset, int count) - { - return ReadAsync(buffer, offset, count).GetAwaiter().GetResult(); - } - - public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken ct) - { - var result = await reader.ReadAtLeastAsync(1, ct); - var n = Math.Min((int)result.Buffer.Length, count); - // Copy result.Buffer[0:n] to buffer[offset:offset+n] - result.Buffer.Slice(0, n).CopyTo(buffer.AsMemory(offset, n).Span); - if (!result.IsCompleted) reader.AdvanceTo(result.Buffer.GetPosition(n)); - return n; - } - - public override long Seek(long offset, SeekOrigin origin) - { - throw new NotImplementedException("BidirectionalPipe does not support seeking"); - } - - public override void SetLength(long value) - { - throw new NotImplementedException("BidirectionalPipe does not support setting length"); - } - - public override void Write(byte[] buffer, int offset, int count) - { - WriteAsync(buffer, offset, count).GetAwaiter().GetResult(); - } - - public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken ct) - { - await writer.WriteAsync(buffer.AsMemory(offset, count), ct); - } - - protected override void Dispose(bool disposing) - { - base.Dispose(disposing); - writer.Complete(); - reader.Complete(); - } -} - -#endregion - #region FailableStream internal class FailableStream : Stream @@ -175,7 +103,7 @@ public class SpeakerTest [CancelAfter(30_000)] public async Task SendReceiveReplyReceive(CancellationToken ct) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); await using var speaker1 = new Speaker(stream1); var speaker1Ch = Channel @@ -239,7 +167,7 @@ await message.SendReply(new TunnelMessage [CancelAfter(30_000)] public async Task WriteError(CancellationToken ct) { - var (stream1, _) = BidirectionalPipe.New(); + var (stream1, _) = BidirectionalPipe.NewInMemory(); var writeEx = new IOException("Test write error"); var failStream = new FailableStream(stream1, writeEx, null); @@ -253,7 +181,7 @@ public async Task WriteError(CancellationToken ct) [CancelAfter(30_000)] public async Task ReadError(CancellationToken ct) { - var (stream1, _) = BidirectionalPipe.New(); + var (stream1, _) = BidirectionalPipe.NewInMemory(); var readEx = new IOException("Test read error"); var failStream = new FailableStream(stream1, null, readEx); @@ -267,7 +195,7 @@ public async Task ReadError(CancellationToken ct) [CancelAfter(30_000)] public async Task ReadLargeHeader(CancellationToken ct) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); await using var speaker1 = new Speaker(stream1); var header = new byte[257]; @@ -286,7 +214,8 @@ public async Task ReceiveInvalidHeader(CancellationToken ct) { { "invalid\n", ("Failed to parse peer header", "Wrong number of parts in header string") }, { "cats tunnel 1.0\n", ("Failed to parse peer header", "Invalid preamble in header string") }, - { "codervpn cats 1.0\n", ("Failed to parse peer header", "Unknown role 'cats'") }, + { "codervpn 1.0\n", ("Failed to parse peer header", "Invalid role in header string") }, + { "codervpn cats 1.0\n", ("Expected peer role 'tunnel' but got 'cats'", null) }, { "codervpn manager 1.0\n", ("Expected peer role 'tunnel' but got 'manager'", null) }, { "codervpn tunnel 1000.1\n", @@ -299,7 +228,7 @@ public async Task ReceiveInvalidHeader(CancellationToken ct) foreach (var (header, (expectedOuter, expectedInner)) in cases) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); await using var speaker1 = new Speaker(stream1); await stream2.WriteAsync(Encoding.UTF8.GetBytes(header), ct); @@ -321,7 +250,7 @@ public async Task ReceiveInvalidHeader(CancellationToken ct) [CancelAfter(30_000)] public async Task SendMessageWriteError(CancellationToken ct) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); var failStream = new FailableStream(stream1, null, null); await using var speaker1 = new Speaker(failStream); @@ -346,7 +275,7 @@ public async Task SendMessageWriteError(CancellationToken ct) [CancelAfter(30_000)] public async Task ReceiveMessageReadError(CancellationToken ct) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); var failStream = new FailableStream(stream1, null, null); // Speaker1 is bound to failStream and will write an error to errorCh @@ -387,7 +316,7 @@ public async Task ReceiveMessageReadError(CancellationToken ct) [CancelAfter(30_000)] public async Task DisposeWhileReceiveLoopRunning(CancellationToken ct) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); var speaker1 = new Speaker(stream1); await using var speaker2 = new Speaker(stream2); await Task.WhenAll(speaker1.StartAsync(ct), speaker2.StartAsync(ct)); @@ -411,7 +340,7 @@ public async Task DisposeWhileReceiveLoopRunning(CancellationToken ct) [CancelAfter(30_000)] public async Task DisposeWhileAwaitingReply(CancellationToken ct) { - var (stream1, stream2) = BidirectionalPipe.New(); + var (stream1, stream2) = BidirectionalPipe.NewInMemory(); var speaker1 = new Speaker(stream1); await using var speaker2 = new Speaker(stream2); await Task.WhenAll(speaker1.StartAsync(ct), speaker2.StartAsync(ct)); diff --git a/Tests.Vpn/Tests.Vpn.csproj b/Tests.Vpn/Tests.Vpn.csproj index f6f2776..df00e81 100644 --- a/Tests.Vpn/Tests.Vpn.csproj +++ b/Tests.Vpn/Tests.Vpn.csproj @@ -22,7 +22,6 @@ runtime; build; native; contentfiles; analyzers; buildtransitive - diff --git a/Vpn.Proto/RpcHeader.cs b/Vpn.Proto/RpcHeader.cs index 0b840db..2fc3fb0 100644 --- a/Vpn.Proto/RpcHeader.cs +++ b/Vpn.Proto/RpcHeader.cs @@ -3,15 +3,15 @@ namespace Coder.Desktop.Vpn.Proto; /// -/// A header to write or read from a stream to identify the speaker's role and version. +/// A header to write or read from a stream to identify the peer role and version. /// -/// Role of the speaker -/// Version of the speaker -public class RpcHeader(RpcRole role, RpcVersionList versionList) +/// Role of the peer +/// Version of the peer +public class RpcHeader(string role, RpcVersionList versionList) { private const string Preamble = "codervpn"; - public RpcRole Role { get; } = role; + public string Role { get; } = role; public RpcVersionList VersionList { get; } = versionList; /// @@ -25,10 +25,10 @@ public static RpcHeader Parse(string header) var parts = header.Split(' '); if (parts.Length != 3) throw new ArgumentException($"Wrong number of parts in header string '{header}'"); if (parts[0] != Preamble) throw new ArgumentException($"Invalid preamble in header string '{header}'"); + if (string.IsNullOrEmpty(parts[1])) throw new ArgumentException($"Invalid role in header string '{header}'"); - var role = new RpcRole(parts[1]); var versionList = RpcVersionList.Parse(parts[2]); - return new RpcHeader(role, versionList); + return new RpcHeader(parts[1], versionList); } /// diff --git a/Vpn.Proto/RpcMessage.cs b/Vpn.Proto/RpcMessage.cs index c44168c..2d1350c 100644 --- a/Vpn.Proto/RpcMessage.cs +++ b/Vpn.Proto/RpcMessage.cs @@ -6,9 +6,16 @@ namespace Coder.Desktop.Vpn.Proto; [AttributeUsage(AttributeTargets.Class, Inherited = false)] public class RpcRoleAttribute(string role) : Attribute { - public RpcRole Role { get; } = new(role); + public string Role { get; } = role; } +/// +/// IRpcMessageCompatibleWith is a marker interface that indicates that a +/// message type can be used to peer with another message type. +/// +/// +public interface IRpcMessageCompatibleWith; + /// /// Represents an actual over-the-wire message type. /// @@ -36,9 +43,9 @@ public abstract class RpcMessage where T : IMessage /// /// Gets the RpcRole of the message type from it's RpcRole attribute. /// - /// + /// The role string /// The message type does not have an RpcRoleAttribute - public static RpcRole GetRole() + public static string GetRole() { var type = typeof(T); var attr = type.GetCustomAttribute(); @@ -47,8 +54,8 @@ public static RpcRole GetRole() } } -[RpcRole(RpcRole.Manager)] -public partial class ManagerMessage : RpcMessage +[RpcRole("manager")] +public partial class ManagerMessage : RpcMessage, IRpcMessageCompatibleWith { public override RPC? RpcField { @@ -64,8 +71,8 @@ public override void Validate() } } -[RpcRole(RpcRole.Tunnel)] -public partial class TunnelMessage : RpcMessage +[RpcRole("tunnel")] +public partial class TunnelMessage : RpcMessage, IRpcMessageCompatibleWith { public override RPC? RpcField { @@ -80,3 +87,37 @@ public override void Validate() if (MsgCase == MsgOneofCase.None) throw new ArgumentException("Message does not contain inner message type"); } } + +[RpcRole("service")] +public partial class ServiceMessage : RpcMessage, IRpcMessageCompatibleWith +{ + public override RPC? RpcField + { + get => Rpc; + set => Rpc = value; + } + + public override ServiceMessage Message => this; + + public override void Validate() + { + if (MsgCase == MsgOneofCase.None) throw new ArgumentException("Message does not contain inner message type"); + } +} + +[RpcRole("client")] +public partial class ClientMessage : RpcMessage, IRpcMessageCompatibleWith +{ + public override RPC? RpcField + { + get => Rpc; + set => Rpc = value; + } + + public override ClientMessage Message => this; + + public override void Validate() + { + if (MsgCase == MsgOneofCase.None) throw new ArgumentException("Message does not contain inner message type"); + } +} diff --git a/Vpn.Proto/RpcRole.cs b/Vpn.Proto/RpcRole.cs deleted file mode 100644 index 69f4b48..0000000 --- a/Vpn.Proto/RpcRole.cs +++ /dev/null @@ -1,56 +0,0 @@ -namespace Coder.Desktop.Vpn.Proto; - -/// -/// Represents a role that either side of the connection can fulfil. -/// -public sealed class RpcRole -{ - public const string Manager = "manager"; - public const string Tunnel = "tunnel"; - - private string Role { get; } - - public RpcRole(string role) - { - if (role != Manager && role != Tunnel) throw new ArgumentException($"Unknown role '{role}'"); - - Role = role; - } - - public override string ToString() - { - return Role; - } - - #region SpeakerRole equality - - public static bool operator ==(RpcRole a, RpcRole b) - { - return a.Equals(b); - } - - public static bool operator !=(RpcRole a, RpcRole b) - { - return !a.Equals(b); - } - - private bool Equals(RpcRole other) - { - return Role == other.Role; - } - - public override bool Equals(object? obj) - { - if (obj is null) return false; - if (ReferenceEquals(this, obj)) return true; - if (obj.GetType() != GetType()) return false; - return Equals((RpcRole)obj); - } - - public override int GetHashCode() - { - return Role.GetHashCode(); - } - - #endregion -} diff --git a/Vpn.Proto/vpn.proto b/Vpn.Proto/vpn.proto index 33a3ff4..a03978a 100644 --- a/Vpn.Proto/vpn.proto +++ b/Vpn.Proto/vpn.proto @@ -44,6 +44,24 @@ message TunnelMessage { } } +// ClientMessage is a message from the client (to the service). +message ClientMessage { + RPC rpc = 1; + oneof msg { + StartRequest start = 2; + StopRequest stop = 3; + } +} + +// ServiceMessage is a message from the service (to the client). +message ServiceMessage { + RPC rpc = 1; + oneof msg { + StartResponse start = 2; + StopResponse stop = 3; + } +} + // Log is a log message generated by the tunnel. The manager should log it to the system log. It is // one-way tunnel -> manager with no response. message Log { @@ -105,7 +123,7 @@ message Agent { bytes id = 1; // UUID string name = 2; bytes workspace_id = 3; // UUID - string fqdn = 4; + repeated string fqdn = 4; repeated string ip_addrs = 5; // last_handshake is the primary indicator of whether we are connected to a peer. Zero value or // anything longer than 5 minutes ago means there is a problem. @@ -179,6 +197,12 @@ message StartRequest { int32 tunnel_file_descriptor = 1; string coder_url = 2; string api_token = 3; + // Additional HTTP headers added to all requests + message Header { + string name = 1; + string value = 2; + } + repeated Header headers = 4; } message StartResponse { diff --git a/Vpn.Service/Downloader.cs b/Vpn.Service/Downloader.cs index 4a9542b..f55cdf7 100644 --- a/Vpn.Service/Downloader.cs +++ b/Vpn.Service/Downloader.cs @@ -1,7 +1,9 @@ using System.Collections.Concurrent; +using System.Diagnostics; using System.Net; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; +using Coder.Desktop.Vpn.Utilities; using Microsoft.Extensions.Logging; using Microsoft.Security.Extensions; @@ -33,19 +35,13 @@ public Task ValidateAsync(string path, CancellationToken ct = default) } } -public class AuthenticodeDownloadValidator : IDownloadValidator +/// +/// Ensures the downloaded binary is signed by the expected authenticode organization. +/// +public class AuthenticodeDownloadValidator(string expectedName) : IDownloadValidator { - private readonly string _expectedName; - public static AuthenticodeDownloadValidator Coder => new("Coder Technologies Inc."); - public AuthenticodeDownloadValidator(string expectedName) - { - if (string.IsNullOrWhiteSpace(expectedName)) - throw new ArgumentException("Expected name must not be empty", nameof(expectedName)); - _expectedName = expectedName; - } - public async Task ValidateAsync(string path, CancellationToken ct = default) { FileSignatureInfo fileSigInfo; @@ -63,10 +59,39 @@ public async Task ValidateAsync(string path, CancellationToken ct = default) if (fileSigInfo.Kind != SignatureKind.Embedded) throw new Exception($"File is not signed with an embedded Authenticode signature: Kind={fileSigInfo.Kind}"); + // TODO: check that it's an extended validation certificate + var actualName = fileSigInfo.SigningCertificate.GetNameInfo(X509NameType.SimpleName, false); - if (actualName != _expectedName) + if (actualName != expectedName) throw new Exception( - $"File is signed by an unexpected certificate: ExpectedName='{_expectedName}', ActualName='{actualName}'"); + $"File is signed by an unexpected certificate: ExpectedName='{expectedName}', ActualName='{actualName}'"); + } +} + +public class AssemblyVersionDownloadValidator(string expectedAssemblyVersion) : IDownloadValidator +{ + public Task ValidateAsync(string path, CancellationToken ct = default) + { + var info = FileVersionInfo.GetVersionInfo(path); + if (string.IsNullOrEmpty(info.ProductVersion)) + throw new Exception("File ProductVersion is empty or null, was the binary compiled correctly?"); + if (info.ProductVersion != expectedAssemblyVersion) + throw new Exception( + $"File ProductVersion is '{info.ProductVersion}', but expected '{expectedAssemblyVersion}'"); + return Task.CompletedTask; + } +} + +/// +/// Combines multiple download validators into a single validator. All validators will be run in order. +/// +/// Validators to run +public class CombinationDownloadValidator(params IDownloadValidator[] validators) : IDownloadValidator +{ + public async Task ValidateAsync(string path, CancellationToken ct = default) + { + foreach (var validator in validators) + await validator.ValidateAsync(path, ct); } } @@ -134,7 +159,7 @@ public class DownloadTask private readonly ILogger _logger; - private readonly SemaphoreSlim _semaphore = new(1, 1); + private readonly RaiiSemaphoreSlim _semaphore = new(1, 1); private readonly IDownloadValidator _validator; public readonly string DestinationPath; @@ -172,16 +197,9 @@ internal DownloadTask(ILogger logger, HttpRequestMessage req, string destination internal async Task EnsureStartedAsync(CancellationToken ct = default) { - await _semaphore.WaitAsync(ct); - try - { - if (Task == null!) - Task = await StartDownloadAsync(ct); - } - finally - { - _semaphore.Release(); - } + using var _ = await _semaphore.LockAsync(ct); + if (Task == null!) + Task = await StartDownloadAsync(ct); return Task; } diff --git a/Vpn.Service/Manager.cs b/Vpn.Service/Manager.cs index 0767dc3..0f11f34 100644 --- a/Vpn.Service/Manager.cs +++ b/Vpn.Service/Manager.cs @@ -1,42 +1,68 @@ using System.Runtime.InteropServices; using Coder.Desktop.Vpn.Proto; +using CoderSdk; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Semver; namespace Coder.Desktop.Vpn.Service; -public interface IManager +public interface IManager : IDisposable { - public Task HandleClientRpcMessage(ReplyableRpcMessage message, + public Task HandleClientRpcMessage(ReplyableRpcMessage message, CancellationToken ct = default); public Task StopAsync(CancellationToken ct = default); } -public class Manager : IManager, IAsyncDisposable +/// +/// Manager provides handling for RPC requests from the client and from the tunnel. +/// +public class Manager : IManager { - private const string DestinationPath = "C:\\coder-vpn.exe"; - private readonly IDownloader _downloader; + // TODO: determine a suitable value for this + private const string ServerVersionRange = ">=0.0.0"; + private readonly ManagerConfig _config; + private readonly IDownloader _downloader; private readonly ILogger _logger; + private readonly ITunnelSupervisor _tunnelSupervisor; // ReSharper disable once ConvertToPrimaryConstructor - public Manager(ILogger logger, IDownloader downloader) + public Manager(IOptions config, ILogger logger, IDownloader downloader, + ITunnelSupervisor tunnelSupervisor) { + _config = config.Value; _logger = logger; _downloader = downloader; + _tunnelSupervisor = tunnelSupervisor; } - public async ValueTask DisposeAsync() + public void Dispose() { - await Task.CompletedTask; GC.SuppressFinalize(this); } - public async Task HandleClientRpcMessage(ReplyableRpcMessage message, + /// + /// Processes a message sent from a Client to the ManagerRpcService over the codervpn RPC protocol. + /// + /// Client message + /// Cancellation token + public async Task HandleClientRpcMessage(ReplyableRpcMessage message, CancellationToken ct = default) { + _logger.LogInformation("ClientMessage: {MessageType}", message.Message.MsgCase); + // TODO: break out each into it's own method? switch (message.Message.MsgCase) { + case ClientMessage.MsgOneofCase.Start: + // TODO: these sub-methods should be managed by some Task list and cancelled/awaited on stop + await HandleClientMessageStart(message, ct); + break; + case ClientMessage.MsgOneofCase.Stop: + await HandleClientMessageStop(message, ct); + break; + case ClientMessage.MsgOneofCase.None: default: _logger.LogWarning("Received unknown message type {MessageType}", message.Message.MsgCase); break; @@ -45,8 +71,86 @@ public async Task HandleClientRpcMessage(ReplyableRpcMessage message, + CancellationToken ct) + { + try + { + // TODO: if the credentials and URL are identical and the server + // version hasn't changed we should not do anything + // TODO: this should be broken out into it's own method + _logger.LogInformation("ClientMessage.Start: testing server '{ServerUrl}'", message.Message.Start.CoderUrl); + var client = new CoderApiClient(message.Message.Start.CoderUrl, message.Message.Start.ApiToken); + var buildInfo = await client.GetBuildInfo(ct); + _logger.LogInformation("ClientMessage.Start: server version '{ServerVersion}'", buildInfo.Version); + var serverVersion = SemVersion.Parse(buildInfo.Version); + if (!serverVersion.Satisfies(ServerVersionRange)) + throw new InvalidOperationException( + $"Server version '{serverVersion}' is not within required server version range '{ServerVersionRange}'"); + var user = await client.GetUser(User.Me, ct); + _logger.LogInformation("ClientMessage.Start: authenticated as '{Username}'", user.Username); + + await DownloadTunnelBinaryAsync(message.Message.Start.CoderUrl, serverVersion, ct); + await _tunnelSupervisor.StartAsync(_config.TunnelBinaryPath, HandleTunnelRpcMessage, + HandleTunnelRpcError, + ct); + } + catch (Exception e) + { + _logger.LogWarning(e, "ClientMessage.Start: Failed to start VPN client"); + await message.SendReply(new ServiceMessage + { + Start = new StartResponse + { + Success = false, + ErrorMessage = e.Message, + }, + }, ct); + } + } + + private async Task HandleClientMessageStop(ReplyableRpcMessage message, + CancellationToken ct) + { + try + { + // This will handle sending the Stop message for us. + await _tunnelSupervisor.StopAsync(ct); + } + catch (Exception e) + { + _logger.LogWarning(e, "ClientMessage.Stop: Failed to stop VPN client"); + await message.SendReply(new ServiceMessage + { + Stop = new StopResponse + { + Success = false, + ErrorMessage = e.Message, + }, + }, ct); + } + } + + private void HandleTunnelRpcMessage(ReplyableRpcMessage message) + { + // TODO: this + } + + private void HandleTunnelRpcError(Exception e) + { + // TODO: this probably happens during an ongoing start or stop operation, and we should definitely ignore those + _logger.LogError(e, "Manager<->Tunnel RPC error"); + try + { + _tunnelSupervisor.StopAsync(); + } + catch (Exception e2) + { + _logger.LogError(e2, "Failed to stop tunnel supervisor after RPC error"); + } } /// @@ -56,13 +160,14 @@ public async Task StopAsync(CancellationToken ct = default) /// Unsupported architecture private static string SystemArchitecture() { + // ReSharper disable once SwitchExpressionHandlesSomeKnownEnumValuesWithExceptionInDefault return RuntimeInformation.ProcessArchitecture switch { Architecture.X64 => "amd64", Architecture.Arm64 => "arm64", // We only support amd64 and arm64 on Windows currently. _ => throw new PlatformNotSupportedException( - "Unsupported architecture. Coder only supports amd64 and arm64."), + $"Unsupported architecture '{RuntimeInformation.ProcessArchitecture}'. Coder only supports amd64 and arm64."), }; } @@ -70,10 +175,12 @@ private static string SystemArchitecture() /// Fetches the "/bin/coder-windows-{architecture}.exe" binary from the given base URL and writes it to the /// destination path after validating the signature and checksum. /// - /// - /// - /// - private async Task DownloadVPNClientAsync(string baseUrl, CancellationToken ct = default) + /// Server base URL to download the binary from + /// The version of the server to expect in the binary + /// Cancellation token + /// If the base URL is invalid + private async Task DownloadTunnelBinaryAsync(string baseUrl, SemVersion expectedVersion, + CancellationToken ct = default) { var architecture = SystemArchitecture(); Uri url; @@ -89,10 +196,15 @@ private async Task DownloadVPNClientAsync(string baseUrl, CancellationToken ct = throw new ArgumentException($"Invalid base URL '{baseUrl}'", e); } - _logger.LogInformation("Downloading VPN binary from '{url}' to '{DestinationPath}'", url, DestinationPath); - var downloadTask = - await _downloader.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), DestinationPath, - AuthenticodeDownloadValidator.Coder, ct); + _logger.LogInformation("Downloading VPN binary from '{url}' to '{DestinationPath}'", url, + _config.TunnelBinaryPath); + var req = new HttpRequestMessage(HttpMethod.Get, url); + var validators = new CombinationDownloadValidator( + AuthenticodeDownloadValidator.Coder, + new AssemblyVersionDownloadValidator( + $"{expectedVersion.Major}.{expectedVersion.Minor}.{expectedVersion.Patch}.0") + ); + var downloadTask = await _downloader.StartDownloadAsync(req, _config.TunnelBinaryPath, validators, ct); // TODO: monitor and report progress when we have a mechanism to do so diff --git a/Vpn.Service/ManagerConfig.cs b/Vpn.Service/ManagerConfig.cs new file mode 100644 index 0000000..906a0b8 --- /dev/null +++ b/Vpn.Service/ManagerConfig.cs @@ -0,0 +1,16 @@ +using System.ComponentModel.DataAnnotations; +using System.Diagnostics.CodeAnalysis; + +namespace Coder.Desktop.Vpn.Service; + +[SuppressMessage("ReSharper", "AutoPropertyCanBeMadeGetOnly.Global")] +public class ManagerConfig +{ + [Required] + [RegularExpression(@"^([a-zA-Z0-9_-]+\.)*[a-zA-Z0-9_-]+$")] + public string ServiceRpcPipeName { get; set; } = "Coder.Desktop.Vpn"; + + // TODO: pick a better default path + [Required] + public string TunnelBinaryPath { get; set; } = @"C:\coder-vpn.exe"; +} diff --git a/Vpn.Service/ManagerRpcService.cs b/Vpn.Service/ManagerRpcService.cs index 228fc30..0cfaed1 100644 --- a/Vpn.Service/ManagerRpcService.cs +++ b/Vpn.Service/ManagerRpcService.cs @@ -3,6 +3,7 @@ using Coder.Desktop.Vpn.Proto; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; namespace Coder.Desktop.Vpn.Service; @@ -11,18 +12,16 @@ namespace Coder.Desktop.Vpn.Service; /// public class ManagerRpcService : BackgroundService, IAsyncDisposable { - // TODO: make configurable with registry? - private const string PipeName = "Coder.Desktop.Vpn"; private readonly ConcurrentDictionary _activeClientTasks = new(); - + private readonly ManagerConfig _config; private readonly CancellationTokenSource _cts = new(); - private readonly ILogger _logger; private readonly IManager _manager; // ReSharper disable once ConvertToPrimaryConstructor - public ManagerRpcService(ILogger logger, IManager manager) + public ManagerRpcService(IOptions config, ILogger logger, IManager manager) { + _config = config.Value; _logger = logger; _manager = manager; } @@ -46,19 +45,19 @@ public override async Task StopAsync(CancellationToken cancellationToken) /// protected override async Task ExecuteAsync(CancellationToken stoppingToken) { - _logger.LogInformation(@"Starting continuous named pipe RPC server at \\.\pipe\{PipeName}", PipeName); + _logger.LogInformation(@"Starting continuous named pipe RPC server at \\.\pipe\{PipeName}", + _config.ServiceRpcPipeName); using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(stoppingToken, _cts.Token); while (!linkedCts.IsCancellationRequested) { - _logger.LogDebug($"Creating named pipe server {PipeName}"); - var pipeServer = new NamedPipeServerStream(PipeName, PipeDirection.InOut, + var pipeServer = new NamedPipeServerStream(_config.ServiceRpcPipeName, PipeDirection.InOut, NamedPipeServerStream.MaxAllowedServerInstances, PipeTransmissionMode.Byte, PipeOptions.Asynchronous); try { try { - _logger.LogDebug("Waiting for named pipe client connection"); + _logger.LogDebug("Waiting for new named pipe client connection"); await pipeServer.WaitForConnectionAsync(linkedCts.Token); } finally @@ -87,8 +86,7 @@ private async Task HandleRpcClientAsync(NamedPipeServerStream pipeServer, Cancel var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(ct, _cts.Token); await using (pipeServer) { - // TODO: use ClientMessage once it's ready - await using var speaker = new Speaker(pipeServer); + await using var speaker = new Speaker(pipeServer); var tcs = new TaskCompletionSource(); var activeTasks = new ConcurrentDictionary(); @@ -122,7 +120,7 @@ private void RpcClientContinuation(Task task) _activeClientTasks.TryRemove(task.Id, out _); } - private async Task HandleRpcMessageAsync(ReplyableRpcMessage message, + private async Task HandleRpcMessageAsync(ReplyableRpcMessage message, CancellationToken ct) { _logger.LogInformation("Received RPC message: {Message}", message.Message); diff --git a/Vpn.Service/Program.cs b/Vpn.Service/Program.cs index 24b55cf..78fbff2 100644 --- a/Vpn.Service/Program.cs +++ b/Vpn.Service/Program.cs @@ -1,10 +1,30 @@ using Coder.Desktop.Vpn.Service; +using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; +using Microsoft.Win32; var builder = Host.CreateApplicationBuilder(args); + +// Configuration sources +builder.Configuration.Sources.Clear(); +(builder.Configuration as IConfigurationBuilder).Add( + new RegistryConfigurationSource(Registry.LocalMachine, @"SOFTWARE\Coder\Coder VPN")); +builder.Configuration.AddEnvironmentVariables("CODER_MANAGER_"); +builder.Configuration.AddCommandLine(args); + +// Options types (these get registered as IOptions singletons) +builder.Services.AddOptions() + .Bind(builder.Configuration.GetSection("Manager")) + .ValidateDataAnnotations(); + +// Singletons builder.Services.AddSingleton(); +builder.Services.AddSingleton(); builder.Services.AddSingleton(); + +// Services builder.Services.AddHostedService(); builder.Services.AddHostedService(); + builder.Build().Run(); diff --git a/Vpn.Service/RegistryConfigurationSource.cs b/Vpn.Service/RegistryConfigurationSource.cs new file mode 100644 index 0000000..3e0ff5f --- /dev/null +++ b/Vpn.Service/RegistryConfigurationSource.cs @@ -0,0 +1,23 @@ +using Microsoft.Extensions.Configuration; +using Microsoft.Win32; + +namespace Coder.Desktop.Vpn.Service; + +public class RegistryConfigurationSource(RegistryKey root, string subKeyName) : IConfigurationSource +{ + public IConfigurationProvider Build(IConfigurationBuilder builder) + { + return new RegistryConfigurationProvider(root, subKeyName); + } +} + +public class RegistryConfigurationProvider(RegistryKey root, string subKeyName) : ConfigurationProvider +{ + public override void Load() + { + using var key = root.OpenSubKey(subKeyName); + if (key == null) return; + + foreach (var valueName in key.GetValueNames()) Data[valueName] = key.GetValue(valueName)?.ToString(); + } +} diff --git a/Vpn.Service/TunnelSupervisor.cs b/Vpn.Service/TunnelSupervisor.cs new file mode 100644 index 0000000..9ea5b05 --- /dev/null +++ b/Vpn.Service/TunnelSupervisor.cs @@ -0,0 +1,271 @@ +using System.Diagnostics; +using System.IO.Pipes; +using Coder.Desktop.Vpn.Proto; +using Coder.Desktop.Vpn.Utilities; +using Microsoft.Extensions.Logging; + +namespace Coder.Desktop.Vpn.Service; + +public interface ITunnelSupervisor : IAsyncDisposable +{ + /// + /// Starts the tunnel subprocess with the given executable path. If the subprocess is already running, this method will + /// kill it first. + /// + /// Path to the executable + /// Handler to call with each RPC message + /// + /// Handler for permanent errors from the RPC Speaker. The recipient should call StopAsync after + /// receiving this. + /// + /// Cancellation token + public Task StartAsync(string binPath, + Speaker.OnReceiveDelegate messageHandler, + Speaker.OnErrorDelegate errorHandler, + CancellationToken ct = default); + + /// + /// Stops the tunnel subprocess. If the subprocess is not running, this method does nothing. + /// + /// + /// + public Task StopAsync(CancellationToken ct = default); +} + +/// +/// Launches and supervises the tunnel subprocess. Provides RPC communication with the subprocess. +/// +public class TunnelSupervisor : ITunnelSupervisor +{ + private readonly CancellationTokenSource _cts = new(); + private readonly ILogger _logger; + private readonly SemaphoreSlim _operationLock = new(1, 1); + private AnonymousPipeServerStream? _inPipe; + private AnonymousPipeServerStream? _outPipe; + private Speaker? _speaker; + + private Process? _subprocess; + + // ReSharper disable once ConvertToPrimaryConstructor + public TunnelSupervisor(ILogger logger) + { + _logger = logger; + } + + public async Task StartAsync(string binPath, + Speaker.OnReceiveDelegate messageHandler, + Speaker.OnErrorDelegate errorHandler, + CancellationToken ct = default) + { + _logger.LogInformation("StartAsync(\"{binPath}\")", binPath); + if (!await _operationLock.WaitAsync(0, ct)) + throw new InvalidOperationException( + "Another TunnelSupervisor Start or Stop operation is already in progress"); + + try + { + await CleanupAsync(ct); + + _outPipe = new AnonymousPipeServerStream(PipeDirection.Out, HandleInheritability.Inheritable); + _inPipe = new AnonymousPipeServerStream(PipeDirection.In, HandleInheritability.Inheritable); + _subprocess = new Process + { + StartInfo = new ProcessStartInfo + { + FileName = binPath, + ArgumentList = { "vpn-daemon", "run" }, + UseShellExecute = false, + CreateNoWindow = true, + }, + }; + + // Pass the other end of the pipes to the subprocess and dispose + // the local copies. + _subprocess.StartInfo.Environment.Add("CODER_VPN_DAEMON_RPC_READ_HANDLE", + _outPipe.GetClientHandleAsString()); + _subprocess.StartInfo.Environment.Add("CODER_VPN_DAEMON_RPC_WRITE_HANDLE", + _inPipe.GetClientHandleAsString()); + _outPipe.DisposeLocalCopyOfClientHandle(); + _inPipe.DisposeLocalCopyOfClientHandle(); + + _logger.LogInformation("StartAsync: starting subprocess"); + _subprocess.Start(); + _logger.LogInformation("StartAsync: subprocess started"); + + // We don't use the supplied CancellationToken here because we want it to only apply to the startup + // procedure. + _ = _subprocess.WaitForExitAsync(_cts.Token).ContinueWith(OnProcessExited, CancellationToken.None); + + // Start the RPC Speaker. + try + { + var stream = new BidirectionalPipe(_inPipe, _outPipe); + _speaker = new Speaker(stream); + _speaker.Receive += messageHandler; + _speaker.Error += errorHandler; + // Handshakes already have a 5-second timeout. + await _speaker.StartAsync(ct); + } + catch (Exception e) + { + throw new Exception("Failed to start RPC Speaker on pipes to subprocess", e); + } + } + catch (Exception e) + { + _logger.LogError(e, "StartAsync: failed to start or connect to subprocess"); + await CleanupAsync(ct); + throw; + } + finally + { + _operationLock.Release(); + } + } + + public async Task StopAsync(CancellationToken ct = default) + { + _logger.LogInformation("StopAsync()"); + if (!await _operationLock.WaitAsync(0, ct)) + throw new InvalidOperationException( + "Another TunnelSupervisor Start or Stop operation is already in progress"); + + try + { + await CleanupAsync(ct); + } + finally + { + _operationLock.Release(); + } + } + + public async ValueTask DisposeAsync() + { + _cts.Dispose(); + await CleanupAsync(); + GC.SuppressFinalize(this); + } + + private async Task OnProcessExited(Task task) + { + if (task.IsFaulted) + { + _logger.LogError(task.Exception, "OnProcessExited: subprocess exited with an exception"); + return; + } + + if (!await _operationLock.WaitAsync(0)) _logger.LogInformation("OnProcessExited: subprocess exited"); + + try + { + await CleanupAsync(); + _logger.LogInformation("OnProcessExited: subprocess exited with code {ExitCode}", + _subprocess?.ExitCode ?? -1); + } + finally + { + _operationLock.Release(); + } + } + + /// + /// Cleans up the pipes and the subprocess if it's still running. This method should not be called without holding the + /// semaphore. + /// + private async Task CleanupAsync(CancellationToken ct = default) + { + if (_speaker != null) + { + try + { + _logger.LogInformation("CleanupAsync: Sending stop message to subprocess"); + var stopCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + stopCts.CancelAfter(5000); + await _speaker.SendRequestAwaitReply(new ManagerMessage + { + Stop = new StopRequest(), + }, stopCts.Token); + } + catch (Exception e) + { + _logger.LogError(e, "CleanupAsync: Failed to send stop message to subprocess"); + } + + try + { + _logger.LogInformation("CleanupAsync: Disposing _speaker"); + await _speaker.DisposeAsync(); + } + catch (Exception e) + { + _logger.LogError(e, "CleanupAsync: Failed to stop/dispose _speaker"); + } + finally + { + _speaker = null; + } + } + + if (_outPipe != null) + { + _logger.LogInformation("CleanupAsync: Disposing _outPipe"); + try + { + await _outPipe.DisposeAsync(); + } + catch (Exception e) + { + _logger.LogError(e, "CleanupAsync: Failed to dispose _outPipe"); + } + finally + { + _outPipe = null; + } + } + + if (_inPipe != null) + { + _logger.LogInformation("CleanupAsync: Disposing _inPipe"); + try + { + await _inPipe.DisposeAsync(); + } + catch (Exception e) + { + _logger.LogError(e, "CleanupAsync: Failed to dispose _inPipe"); + } + finally + { + _inPipe = null; + } + } + + if (_subprocess != null) + try + { + if (!_subprocess.HasExited) + { + // TODO: is there a nicer way we can do this? + _logger.LogInformation("CleanupAsync: Killing un-exited _subprocess"); + _subprocess.Kill(); + // Since we just killed the process ideally it should exit + // immediately. + var exitCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + exitCts.CancelAfter(5000); + await _subprocess.WaitForExitAsync(exitCts.Token); + } + + _logger.LogInformation("CleanupAsync: Disposing _subprocess"); + _subprocess.Dispose(); + } + catch (Exception e) + { + _logger.LogError(e, "CleanupAsync: Failed to kill/dispose _subprocess"); + } + finally + { + _subprocess = null; + } + } +} diff --git a/Vpn.Service/Vpn.Service.csproj b/Vpn.Service/Vpn.Service.csproj index 33ee897..e6da70d 100644 --- a/Vpn.Service/Vpn.Service.csproj +++ b/Vpn.Service/Vpn.Service.csproj @@ -11,10 +11,13 @@ + + + diff --git a/Vpn/Serdes.cs b/Vpn/Serdes.cs index 317417b..00837b7 100644 --- a/Vpn/Serdes.cs +++ b/Vpn/Serdes.cs @@ -1,32 +1,10 @@ using System.Buffers.Binary; using Coder.Desktop.Vpn.Proto; +using Coder.Desktop.Vpn.Utilities; using Google.Protobuf; namespace Coder.Desktop.Vpn; -/// -/// RaiiSemaphoreSlim is a wrapper around SemaphoreSlim that provides RAII-style locking. -/// -internal class RaiiSemaphoreSlim(int initialCount, int maxCount) -{ - private readonly SemaphoreSlim _semaphore = new(initialCount, maxCount); - - public async ValueTask LockAsync(CancellationToken ct = default) - { - await _semaphore.WaitAsync(ct); - return new Lock(_semaphore); - } - - private class Lock(SemaphoreSlim semaphore) : IDisposable - { - public void Dispose() - { - semaphore.Release(); - GC.SuppressFinalize(this); - } - } -} - /// /// Serdes provides serialization and deserialization of messages read from a Stream. /// diff --git a/Vpn/Speaker.cs b/Vpn/Speaker.cs index 5bccbe4..8d06e0b 100644 --- a/Vpn/Speaker.cs +++ b/Vpn/Speaker.cs @@ -18,8 +18,8 @@ public class RpcVersionCompatibilityException(RpcVersionList localVersion, RpcVe /// Speaker to use for sending reply /// Original received message public class ReplyableRpcMessage(Speaker speaker, TR message) : RpcMessage - where TS : RpcMessage, IMessage - where TR : RpcMessage, IMessage, new() + where TS : RpcMessage, IRpcMessageCompatibleWith, IMessage + where TR : RpcMessage, IRpcMessageCompatibleWith, IMessage, new() { public override RPC? RpcField { @@ -51,8 +51,8 @@ public async Task SendReply(TS reply, CancellationToken ct = default) /// The message type for sent messages /// The message type for received messages public class Speaker : IAsyncDisposable - where TS : RpcMessage, IMessage - where TR : RpcMessage, IMessage, new() + where TS : RpcMessage, IRpcMessageCompatibleWith, IMessage + where TR : RpcMessage, IRpcMessageCompatibleWith, IMessage, new() { public delegate void OnErrorDelegate(Exception e); diff --git a/Vpn/Utilities/BidirectionalPipe.cs b/Vpn/Utilities/BidirectionalPipe.cs new file mode 100644 index 0000000..7bc4b4d --- /dev/null +++ b/Vpn/Utilities/BidirectionalPipe.cs @@ -0,0 +1,92 @@ +using System.IO.Pipelines; + +namespace Coder.Desktop.Vpn.Utilities; + +/// +/// BidirectionalPipe implements Stream using a read-only Stream and a write-only Stream. +/// +/// The stream to perform reads from +/// The stream to write data to +public class BidirectionalPipe(Stream reader, Stream writer) : Stream +{ + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => true; + public override long Length => -1; + + public override long Position + { + get => -1; + set => throw new NotImplementedException("BidirectionalPipe does not support setting position"); + } + + /// + /// Creates a new pair of BidirectionalPipes that are connected to each other using buffered in-memory pipes. + /// + /// Two pipes connected to each other + public static (BidirectionalPipe, BidirectionalPipe) NewInMemory() + { + var pipe1 = new Pipe(); + var pipe2 = new Pipe(); + return ( + new BidirectionalPipe(pipe1.Reader.AsStream(), pipe2.Writer.AsStream()), + new BidirectionalPipe(pipe2.Reader.AsStream(), pipe1.Writer.AsStream()) + ); + } + + public override void Flush() + { + writer.Flush(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + return reader.Read(buffer, offset, count); + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken ct) + { +#pragma warning disable CA1835 + return await reader.ReadAsync(buffer, offset, count, ct); +#pragma warning restore CA1835 + } + + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + return reader.ReadAsync(buffer, cancellationToken); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotImplementedException("BidirectionalPipe does not support seeking"); + } + + public override void SetLength(long value) + { + throw new NotImplementedException("BidirectionalPipe does not support setting length"); + } + + public override void Write(byte[] buffer, int offset, int count) + { + writer.Write(buffer, offset, count); + } + + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken ct) + { +#pragma warning disable CA1835 + await writer.WriteAsync(buffer, offset, count, ct); +#pragma warning restore CA1835 + } + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + return writer.WriteAsync(buffer, cancellationToken); + } + + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + writer.Dispose(); + reader.Dispose(); + } +} diff --git a/Vpn/Utilities/RaiiSemaphoreSlim.cs b/Vpn/Utilities/RaiiSemaphoreSlim.cs new file mode 100644 index 0000000..4e94ef2 --- /dev/null +++ b/Vpn/Utilities/RaiiSemaphoreSlim.cs @@ -0,0 +1,30 @@ +namespace Coder.Desktop.Vpn.Utilities; + +/// +/// RaiiSemaphoreSlim is a wrapper around SemaphoreSlim that provides RAII-style locking. +/// +public class RaiiSemaphoreSlim(int initialCount, int maxCount) : IDisposable +{ + private readonly SemaphoreSlim _semaphore = new(initialCount, maxCount); + + public void Dispose() + { + _semaphore.Dispose(); + GC.SuppressFinalize(this); + } + + public async ValueTask LockAsync(CancellationToken ct = default) + { + await _semaphore.WaitAsync(ct); + return new Lock(_semaphore); + } + + private class Lock(SemaphoreSlim semaphore) : IDisposable + { + public void Dispose() + { + semaphore.Release(); + GC.SuppressFinalize(this); + } + } +} diff --git a/Vpn/Vpn.csproj b/Vpn/Vpn.csproj index bcef1b5..22b585f 100644 --- a/Vpn/Vpn.csproj +++ b/Vpn/Vpn.csproj @@ -11,4 +11,8 @@ + + + + From 05278200cde3be39ae1c4587d6691980cc9f52a9 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Mon, 16 Dec 2024 21:57:53 +0900 Subject: [PATCH 3/3] PR comments --- Coder.Desktop.sln.DotSettings | 1 + Tests.Vpn.Service/DownloaderTest.cs | 11 ++++-- Vpn.Proto/RpcHeader.cs | 16 ++++++--- Vpn.Proto/RpcMessage.cs | 9 +++-- Vpn.Proto/RpcVersion.cs | 16 ++++++--- Vpn.Service/Downloader.cs | 39 +++++++++++++++++----- Vpn.Service/ManagerRpcService.cs | 3 +- Vpn.Service/RegistryConfigurationSource.cs | 26 ++++++++++++--- Vpn/Speaker.cs | 34 +++++++++++++------ Vpn/Utilities/BidirectionalPipe.cs | 33 +++++++++++------- Vpn/Utilities/RaiiSemaphoreSlim.cs | 20 ++++++++--- 11 files changed, 153 insertions(+), 55 deletions(-) diff --git a/Coder.Desktop.sln.DotSettings b/Coder.Desktop.sln.DotSettings index 70c3a3b..176e490 100644 --- a/Coder.Desktop.sln.DotSettings +++ b/Coder.Desktop.sln.DotSettings @@ -255,4 +255,5 @@ True True True + True True \ No newline at end of file diff --git a/Tests.Vpn.Service/DownloaderTest.cs b/Tests.Vpn.Service/DownloaderTest.cs index 952b80b..ae3a0a0 100644 --- a/Tests.Vpn.Service/DownloaderTest.cs +++ b/Tests.Vpn.Service/DownloaderTest.cs @@ -5,11 +5,18 @@ namespace Coder.Desktop.Tests.Vpn.Service; -public class TestDownloadValidator(Exception e) : IDownloadValidator +public class TestDownloadValidator : IDownloadValidator { + private readonly Exception _e; + + public TestDownloadValidator(Exception e) + { + _e = e; + } + public Task ValidateAsync(string path, CancellationToken ct = default) { - throw e; + throw _e; } } diff --git a/Vpn.Proto/RpcHeader.cs b/Vpn.Proto/RpcHeader.cs index 2fc3fb0..cf7ffcc 100644 --- a/Vpn.Proto/RpcHeader.cs +++ b/Vpn.Proto/RpcHeader.cs @@ -5,14 +5,20 @@ namespace Coder.Desktop.Vpn.Proto; /// /// A header to write or read from a stream to identify the peer role and version. /// -/// Role of the peer -/// Version of the peer -public class RpcHeader(string role, RpcVersionList versionList) +public class RpcHeader { private const string Preamble = "codervpn"; - public string Role { get; } = role; - public RpcVersionList VersionList { get; } = versionList; + public string Role { get; } + public RpcVersionList VersionList { get; } + + /// Role of the peer + /// Version of the peer + public RpcHeader(string role, RpcVersionList versionList) + { + Role = role; + VersionList = versionList; + } /// /// Parse a header string into a SpeakerHeader. diff --git a/Vpn.Proto/RpcMessage.cs b/Vpn.Proto/RpcMessage.cs index 2d1350c..bfe4d82 100644 --- a/Vpn.Proto/RpcMessage.cs +++ b/Vpn.Proto/RpcMessage.cs @@ -4,9 +4,14 @@ namespace Coder.Desktop.Vpn.Proto; [AttributeUsage(AttributeTargets.Class, Inherited = false)] -public class RpcRoleAttribute(string role) : Attribute +public class RpcRoleAttribute : Attribute { - public string Role { get; } = role; + public string Role { get; } + + public RpcRoleAttribute(string role) + { + Role = role; + } } /// diff --git a/Vpn.Proto/RpcVersion.cs b/Vpn.Proto/RpcVersion.cs index a9b1914..574768d 100644 --- a/Vpn.Proto/RpcVersion.cs +++ b/Vpn.Proto/RpcVersion.cs @@ -3,14 +3,20 @@ /// /// A version of the RPC API. Can be compared other versions to determine compatibility between two peers. /// -/// The major version of the peer -/// The minor version of the peer -public class RpcVersion(ulong major, ulong minor) +public class RpcVersion { public static readonly RpcVersion Current = new(1, 0); - public ulong Major { get; } = major; - public ulong Minor { get; } = minor; + public ulong Major { get; } + public ulong Minor { get; } + + /// The major version of the peer + /// The minor version of the peer + public RpcVersion(ulong major, ulong minor) + { + Major = major; + Minor = minor; + } /// /// Parse a string in the format "major.minor" into an ApiVersion. diff --git a/Vpn.Service/Downloader.cs b/Vpn.Service/Downloader.cs index f55cdf7..83eda24 100644 --- a/Vpn.Service/Downloader.cs +++ b/Vpn.Service/Downloader.cs @@ -38,10 +38,17 @@ public Task ValidateAsync(string path, CancellationToken ct = default) /// /// Ensures the downloaded binary is signed by the expected authenticode organization. /// -public class AuthenticodeDownloadValidator(string expectedName) : IDownloadValidator +public class AuthenticodeDownloadValidator : IDownloadValidator { + private readonly string _expectedName; + public static AuthenticodeDownloadValidator Coder => new("Coder Technologies Inc."); + public AuthenticodeDownloadValidator(string expectedName) + { + _expectedName = expectedName; + } + public async Task ValidateAsync(string path, CancellationToken ct = default) { FileSignatureInfo fileSigInfo; @@ -62,22 +69,29 @@ public async Task ValidateAsync(string path, CancellationToken ct = default) // TODO: check that it's an extended validation certificate var actualName = fileSigInfo.SigningCertificate.GetNameInfo(X509NameType.SimpleName, false); - if (actualName != expectedName) + if (actualName != _expectedName) throw new Exception( - $"File is signed by an unexpected certificate: ExpectedName='{expectedName}', ActualName='{actualName}'"); + $"File is signed by an unexpected certificate: ExpectedName='{_expectedName}', ActualName='{actualName}'"); } } -public class AssemblyVersionDownloadValidator(string expectedAssemblyVersion) : IDownloadValidator +public class AssemblyVersionDownloadValidator : IDownloadValidator { + private readonly string _expectedAssemblyVersion; + + public AssemblyVersionDownloadValidator(string expectedAssemblyVersion) + { + _expectedAssemblyVersion = expectedAssemblyVersion; + } + public Task ValidateAsync(string path, CancellationToken ct = default) { var info = FileVersionInfo.GetVersionInfo(path); if (string.IsNullOrEmpty(info.ProductVersion)) throw new Exception("File ProductVersion is empty or null, was the binary compiled correctly?"); - if (info.ProductVersion != expectedAssemblyVersion) + if (info.ProductVersion != _expectedAssemblyVersion) throw new Exception( - $"File ProductVersion is '{info.ProductVersion}', but expected '{expectedAssemblyVersion}'"); + $"File ProductVersion is '{info.ProductVersion}', but expected '{_expectedAssemblyVersion}'"); return Task.CompletedTask; } } @@ -85,12 +99,19 @@ public Task ValidateAsync(string path, CancellationToken ct = default) /// /// Combines multiple download validators into a single validator. All validators will be run in order. /// -/// Validators to run -public class CombinationDownloadValidator(params IDownloadValidator[] validators) : IDownloadValidator +public class CombinationDownloadValidator : IDownloadValidator { + private readonly IDownloadValidator[] _validators; + + /// Validators to run + public CombinationDownloadValidator(params IDownloadValidator[] validators) + { + _validators = validators; + } + public async Task ValidateAsync(string path, CancellationToken ct = default) { - foreach (var validator in validators) + foreach (var validator in _validators) await validator.ValidateAsync(path, ct); } } diff --git a/Vpn.Service/ManagerRpcService.cs b/Vpn.Service/ManagerRpcService.cs index 0cfaed1..ce2b17e 100644 --- a/Vpn.Service/ManagerRpcService.cs +++ b/Vpn.Service/ManagerRpcService.cs @@ -18,12 +18,11 @@ public class ManagerRpcService : BackgroundService, IAsyncDisposable private readonly ILogger _logger; private readonly IManager _manager; - // ReSharper disable once ConvertToPrimaryConstructor public ManagerRpcService(IOptions config, ILogger logger, IManager manager) { - _config = config.Value; _logger = logger; _manager = manager; + _config = config.Value; } public async ValueTask DisposeAsync() diff --git a/Vpn.Service/RegistryConfigurationSource.cs b/Vpn.Service/RegistryConfigurationSource.cs index 3e0ff5f..7ac2764 100644 --- a/Vpn.Service/RegistryConfigurationSource.cs +++ b/Vpn.Service/RegistryConfigurationSource.cs @@ -3,19 +3,37 @@ namespace Coder.Desktop.Vpn.Service; -public class RegistryConfigurationSource(RegistryKey root, string subKeyName) : IConfigurationSource +public class RegistryConfigurationSource : IConfigurationSource { + private readonly RegistryKey _root; + private readonly string _subKeyName; + + public RegistryConfigurationSource(RegistryKey root, string subKeyName) + { + _root = root; + _subKeyName = subKeyName; + } + public IConfigurationProvider Build(IConfigurationBuilder builder) { - return new RegistryConfigurationProvider(root, subKeyName); + return new RegistryConfigurationProvider(_root, _subKeyName); } } -public class RegistryConfigurationProvider(RegistryKey root, string subKeyName) : ConfigurationProvider +public class RegistryConfigurationProvider : ConfigurationProvider { + private readonly RegistryKey _root; + private readonly string _subKeyName; + + public RegistryConfigurationProvider(RegistryKey root, string subKeyName) + { + _root = root; + _subKeyName = subKeyName; + } + public override void Load() { - using var key = root.OpenSubKey(subKeyName); + using var key = _root.OpenSubKey(_subKeyName); if (key == null) return; foreach (var valueName in key.GetValueNames()) Data[valueName] = key.GetValue(valueName)?.ToString(); diff --git a/Vpn/Speaker.cs b/Vpn/Speaker.cs index 8d06e0b..4c6ef3c 100644 --- a/Vpn/Speaker.cs +++ b/Vpn/Speaker.cs @@ -9,29 +9,43 @@ namespace Coder.Desktop.Vpn; /// /// Thrown when the two peers are incompatible with each other. /// -public class RpcVersionCompatibilityException(RpcVersionList localVersion, RpcVersionList remoteVersion) - : Exception($"No RPC versions are compatible: local={localVersion}, remote={remoteVersion}"); +public class RpcVersionCompatibilityException : Exception +{ + public RpcVersionCompatibilityException(RpcVersionList localVersion, RpcVersionList remoteVersion) : base( + $"No RPC versions are compatible: local={localVersion}, remote={remoteVersion}") + { + } +} /// /// Wraps a RpcMessage to allow easily sending a reply via the Speaker. /// -/// Speaker to use for sending reply -/// Original received message -public class ReplyableRpcMessage(Speaker speaker, TR message) : RpcMessage +public class ReplyableRpcMessage : RpcMessage where TS : RpcMessage, IRpcMessageCompatibleWith, IMessage where TR : RpcMessage, IRpcMessageCompatibleWith, IMessage, new() { + private readonly TR _message; + private readonly Speaker _speaker; + public override RPC? RpcField { - get => message.RpcField; - set => message.RpcField = value; + get => _message.RpcField; + set => _message.RpcField = value; } - public override TR Message => message; + public override TR Message => _message; + + /// Speaker to use for sending reply + /// Original received message + public ReplyableRpcMessage(Speaker speaker, TR message) + { + _speaker = speaker; + _message = message; + } public override void Validate() { - message.Validate(); + _message.Validate(); } /// @@ -41,7 +55,7 @@ public override void Validate() /// Optional cancellation token public async Task SendReply(TS reply, CancellationToken ct = default) { - await speaker.SendReply(message, reply, ct); + await _speaker.SendReply(_message, reply, ct); } } diff --git a/Vpn/Utilities/BidirectionalPipe.cs b/Vpn/Utilities/BidirectionalPipe.cs index 7bc4b4d..72e633b 100644 --- a/Vpn/Utilities/BidirectionalPipe.cs +++ b/Vpn/Utilities/BidirectionalPipe.cs @@ -5,10 +5,11 @@ namespace Coder.Desktop.Vpn.Utilities; /// /// BidirectionalPipe implements Stream using a read-only Stream and a write-only Stream. /// -/// The stream to perform reads from -/// The stream to write data to -public class BidirectionalPipe(Stream reader, Stream writer) : Stream +public class BidirectionalPipe : Stream { + private readonly Stream _reader; + private readonly Stream _writer; + public override bool CanRead => true; public override bool CanSeek => false; public override bool CanWrite => true; @@ -20,6 +21,14 @@ public override long Position set => throw new NotImplementedException("BidirectionalPipe does not support setting position"); } + /// The stream to perform reads from + /// The stream to write data to + public BidirectionalPipe(Stream reader, Stream writer) + { + _reader = reader; + _writer = writer; + } + /// /// Creates a new pair of BidirectionalPipes that are connected to each other using buffered in-memory pipes. /// @@ -36,24 +45,24 @@ public static (BidirectionalPipe, BidirectionalPipe) NewInMemory() public override void Flush() { - writer.Flush(); + _writer.Flush(); } public override int Read(byte[] buffer, int offset, int count) { - return reader.Read(buffer, offset, count); + return _reader.Read(buffer, offset, count); } public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken ct) { #pragma warning disable CA1835 - return await reader.ReadAsync(buffer, offset, count, ct); + return await _reader.ReadAsync(buffer, offset, count, ct); #pragma warning restore CA1835 } public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) { - return reader.ReadAsync(buffer, cancellationToken); + return _reader.ReadAsync(buffer, cancellationToken); } public override long Seek(long offset, SeekOrigin origin) @@ -68,25 +77,25 @@ public override void SetLength(long value) public override void Write(byte[] buffer, int offset, int count) { - writer.Write(buffer, offset, count); + _writer.Write(buffer, offset, count); } public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken ct) { #pragma warning disable CA1835 - await writer.WriteAsync(buffer, offset, count, ct); + await _writer.WriteAsync(buffer, offset, count, ct); #pragma warning restore CA1835 } public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) { - return writer.WriteAsync(buffer, cancellationToken); + return _writer.WriteAsync(buffer, cancellationToken); } protected override void Dispose(bool disposing) { base.Dispose(disposing); - writer.Dispose(); - reader.Dispose(); + _writer.Dispose(); + _reader.Dispose(); } } diff --git a/Vpn/Utilities/RaiiSemaphoreSlim.cs b/Vpn/Utilities/RaiiSemaphoreSlim.cs index 4e94ef2..f4ecee6 100644 --- a/Vpn/Utilities/RaiiSemaphoreSlim.cs +++ b/Vpn/Utilities/RaiiSemaphoreSlim.cs @@ -3,9 +3,14 @@ namespace Coder.Desktop.Vpn.Utilities; /// /// RaiiSemaphoreSlim is a wrapper around SemaphoreSlim that provides RAII-style locking. /// -public class RaiiSemaphoreSlim(int initialCount, int maxCount) : IDisposable +public class RaiiSemaphoreSlim : IDisposable { - private readonly SemaphoreSlim _semaphore = new(initialCount, maxCount); + private readonly SemaphoreSlim _semaphore; + + public RaiiSemaphoreSlim(int initialCount, int maxCount) + { + _semaphore = new SemaphoreSlim(initialCount, maxCount); + } public void Dispose() { @@ -19,11 +24,18 @@ public async ValueTask LockAsync(CancellationToken ct = default) return new Lock(_semaphore); } - private class Lock(SemaphoreSlim semaphore) : IDisposable + private class Lock : IDisposable { + private readonly SemaphoreSlim _semaphore1; + + public Lock(SemaphoreSlim semaphore) + { + _semaphore1 = semaphore; + } + public void Dispose() { - semaphore.Release(); + _semaphore1.Release(); GC.SuppressFinalize(this); } }