diff --git a/Tests.Vpn.Service/DownloaderTest.cs b/Tests.Vpn.Service/DownloaderTest.cs index b30e3e4..986ce46 100644 --- a/Tests.Vpn.Service/DownloaderTest.cs +++ b/Tests.Vpn.Service/DownloaderTest.cs @@ -442,23 +442,28 @@ public async Task CancelledWaitingForOther(CancellationToken ct) [CancelAfter(30_000)] public async Task CancelledInner(CancellationToken ct) { + var httpCts = CancellationTokenSource.CreateLinkedTokenSource(ct); + var taskCts = CancellationTokenSource.CreateLinkedTokenSource(ct); using var httpServer = new TestHttpServer(async ctx => { ctx.Response.StatusCode = 200; - await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), ct); - await ctx.Response.OutputStream.FlushAsync(ct); - await Task.Delay(TimeSpan.FromSeconds(5), ct); + await ctx.Response.OutputStream.WriteAsync("test"u8.ToArray(), httpCts.Token); + await ctx.Response.OutputStream.FlushAsync(httpCts.Token); + // wait up to 5 seconds. + await Task.Delay(TimeSpan.FromSeconds(5), httpCts.Token); }); var url = new Uri(httpServer.BaseUrl + "/test"); var destPath = Path.Combine(_tempDir, "test"); var manager = new Downloader(NullLogger.Instance); // The "inner" Task should fail. - var smallerCt = new CancellationTokenSource(TimeSpan.FromSeconds(1)).Token; + var taskCt = taskCts.Token; var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, - NullDownloadValidator.Instance, smallerCt); + NullDownloadValidator.Instance, taskCt); + await taskCts.CancelAsync(); var ex = Assert.ThrowsAsync(async () => await dlTask.Task); - Assert.That(ex.CancellationToken, Is.EqualTo(smallerCt)); + Assert.That(ex.CancellationToken, Is.EqualTo(taskCt)); + await httpCts.CancelAsync(); } [Test(Description = "Validation failure")] diff --git a/Vpn.Service/Downloader.cs b/Vpn.Service/Downloader.cs index 467c9af..c7b94c6 100644 --- a/Vpn.Service/Downloader.cs +++ b/Vpn.Service/Downloader.cs @@ -453,27 +453,25 @@ private async Task Start(CancellationToken ct = default) if (res.Content.Headers.ContentLength >= 0) TotalBytes = (ulong)res.Content.Headers.ContentLength; - FileStream tempFile; - try - { - tempFile = File.Create(TempDestinationPath, BufferSize, - FileOptions.Asynchronous | FileOptions.SequentialScan); - } - catch (Exception e) - { - _logger.LogError(e, "Failed to create temporary file '{TempDestinationPath}'", TempDestinationPath); - throw; - } - - await Download(res, tempFile, ct); + await Download(res, ct); return; } - private async Task Download(HttpResponseMessage res, FileStream tempFile, CancellationToken ct) + private async Task Download(HttpResponseMessage res, CancellationToken ct) { try { var sha1 = res.Headers.Contains("ETag") ? SHA1.Create() : null; + FileStream tempFile; + try + { + tempFile = File.Create(TempDestinationPath, BufferSize, FileOptions.SequentialScan); + } + catch (Exception e) + { + _logger.LogError(e, "Failed to create temporary file '{TempDestinationPath}'", TempDestinationPath); + throw; + } await using (tempFile) { var stream = await res.Content.ReadAsStreamAsync(ct);