Skip to content

Commit a1c25ff

Browse files
authored
URLSessionWebSocketTask.receive() not finishing if server closes connection without a close packet (#4673)
Allow _WebSocketURLProtocol to detect a closing URL connection, and set its state to indicate this is a normal close of the connection. Kick off URLSessionWebSocketTask doPendingWork() if the connection is closed out due to an error, in order to respond to any pending Tasks parked in receive().
1 parent 9d8da69 commit a1c25ff

File tree

5 files changed

+168
-30
lines changed

5 files changed

+168
-30
lines changed

Sources/FoundationNetworking/URLSession/URLSessionTask.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,12 @@ open class URLSessionWebSocketTask : URLSessionTask {
703703
}
704704
}
705705

706+
open override var error: Error? {
707+
didSet {
708+
doPendingWork()
709+
}
710+
}
711+
706712
private var sendBuffer = [(Message, (Error?) -> Void)]()
707713
private var receiveBuffer = [Message]()
708714
private var receiveCompletionHandlers = [(Result<Message, Error>) -> Void]()

Sources/FoundationNetworking/URLSession/WebSocket/WebSocketURLProtocol.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ internal class _WebSocketURLProtocol: _HTTPURLProtocol {
8888
return .completeTask
8989
}
9090

91+
override func completeTask() {
92+
if let webSocketTask = task as? URLSessionWebSocketTask {
93+
webSocketTask.close(code: .normalClosure, reason: nil)
94+
}
95+
super.completeTask()
96+
}
97+
9198
func sendWebSocketData(_ data: Data, flags: _EasyHandle.WebSocketFlags) throws {
9299
try easyHandle.sendWebSocketsData(data, flags: flags)
93100
}

Tests/Foundation/HTTPServer.swift

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -912,8 +912,28 @@ public class TestURLSessionServer: CustomStringConvertible {
912912
"Connection: Upgrade"]
913913

914914
let expectFullRequestResponseTests: Bool
915+
let sendClosePacket: Bool
916+
let completeUpgrade: Bool
917+
915918
let uri = request.uri
916-
if uri.count > "/web-socket/".count {
919+
switch uri {
920+
case "/web-socket":
921+
expectFullRequestResponseTests = true
922+
completeUpgrade = true
923+
sendClosePacket = true
924+
case "/web-socket/semi-abrupt-close":
925+
expectFullRequestResponseTests = false
926+
completeUpgrade = true
927+
sendClosePacket = false
928+
case "/web-socket/abrupt-close":
929+
expectFullRequestResponseTests = false
930+
completeUpgrade = false
931+
sendClosePacket = false
932+
default:
933+
guard uri.count > "/web-socket/".count else {
934+
NSLog("Expected Sec-WebSocket-Protocol")
935+
throw InternalServerError.badHeaders
936+
}
917937
let expectedProtocol = String(uri.suffix(from: uri.index(uri.startIndex, offsetBy: "/web-socket/".count)))
918938
guard let receivedProtocolStr = request.getHeader(for: "Sec-WebSocket-Protocol"),
919939
expectedProtocol == receivedProtocolStr.components(separatedBy: ", ")[0] else {
@@ -922,10 +942,12 @@ public class TestURLSessionServer: CustomStringConvertible {
922942
}
923943
responseHeaders.append("Sec-WebSocket-Protocol: \(expectedProtocol)")
924944
expectFullRequestResponseTests = false
925-
} else {
926-
expectFullRequestResponseTests = true
945+
completeUpgrade = true
946+
sendClosePacket = true
927947
}
928948

949+
guard completeUpgrade else { return }
950+
929951
var upgradeResponse = _HTTPResponse(response: .SWITCHING_PROTOCOLS, headers: responseHeaders)
930952
// Lacking an available SHA1 implementation, we'll only include this response for a well-known key
931953
if "dGhlIHNhbXBsZSBub25jZQ==" == request.getHeader(for: "sec-websocket-key") {
@@ -940,10 +962,11 @@ public class TestURLSessionServer: CustomStringConvertible {
940962
let closePayload = Data([UInt8(closeCode >> 8),
941963
UInt8(closeCode & 0xFF)]) + closeReason
942964

965+
let pingPayload = "Hi".data(using: .utf8)!
966+
943967
if expectFullRequestResponseTests {
944968
let stringPayload = "Hello".data(using: .utf8)!
945969
let dataPayload = Data([0x20, 0x22, 0x10, 0x03])
946-
let pingPayload = "Hi".data(using: .utf8)!
947970

948971
// Receive a string message
949972
guard let stringFrame = try httpServer.tcpSocket.readData(),
@@ -981,32 +1004,36 @@ public class TestURLSessionServer: CustomStringConvertible {
9811004
}
9821005
// ... and pong it
9831006
try httpServer.tcpSocket.writeRawData(Data([0x8a, 0x00]))
984-
985-
// Send a ping
986-
let sendPingFrame = Data([0x89, UInt8(pingPayload.count)]) + pingPayload
987-
try httpServer.tcpSocket.writeRawData(sendPingFrame)
988-
// ... and receive its pong
989-
guard let pongFrame = try httpServer.tcpSocket.readData(),
990-
pongFrame.count == (2 + 4 + pingPayload.count),
991-
Data(pongFrame.prefix(2)) == Data([0x8a, (0x80 | UInt8(pingPayload.count))]),
992-
try unmaskedPayload(from: pongFrame) == pingPayload else {
993-
NSLog("Invalid pong frame")
994-
throw InternalServerError.badBody
995-
}
996-
997-
// Send a close
998-
let sendCloseFrame = Data([0x88, UInt8(closePayload.count)]) + closePayload
999-
try httpServer.tcpSocket.writeRawData(sendCloseFrame)
10001007
}
10011008

1002-
// Receive a close message
1003-
guard let closeFrame = try httpServer.tcpSocket.readData(),
1004-
closeFrame.count == (2 + 4 + closePayload.count),
1005-
Data(closeFrame.prefix(2)) == Data([0x88, (0x80 | UInt8(closePayload.count))]),
1006-
try unmaskedPayload(from: closeFrame) == closePayload else {
1007-
NSLog("Invalid close payload")
1009+
// Send a ping
1010+
let sendPingFrame = Data([0x89, UInt8(pingPayload.count)]) + pingPayload
1011+
try httpServer.tcpSocket.writeRawData(sendPingFrame)
1012+
// ... and receive its pong
1013+
guard let pongFrame = try httpServer.tcpSocket.readData(),
1014+
pongFrame.count == (2 + 4 + pingPayload.count),
1015+
Data(pongFrame.prefix(2)) == Data([0x8a, (0x80 | UInt8(pingPayload.count))]),
1016+
try unmaskedPayload(from: pongFrame) == pingPayload else {
1017+
NSLog("Invalid pong frame")
10081018
throw InternalServerError.badBody
10091019
}
1020+
1021+
if sendClosePacket {
1022+
if expectFullRequestResponseTests {
1023+
// Send a close
1024+
let sendCloseFrame = Data([0x88, UInt8(closePayload.count)]) + closePayload
1025+
try httpServer.tcpSocket.writeRawData(sendCloseFrame)
1026+
}
1027+
1028+
// Receive a close message
1029+
guard let closeFrame = try httpServer.tcpSocket.readData(),
1030+
closeFrame.count == (2 + 4 + closePayload.count),
1031+
Data(closeFrame.prefix(2)) == Data([0x88, (0x80 | UInt8(closePayload.count))]),
1032+
try unmaskedPayload(from: closeFrame) == closePayload else {
1033+
NSLog("Invalid close payload")
1034+
throw InternalServerError.badBody
1035+
}
1036+
}
10101037

10111038
} catch {
10121039
let badBodyCloseFrame = Data([0x88, 0x08, 0x03, 0xEA, 0x42, 0x75, 0x68, 0x42, 0x79, 0x65])

Tests/Foundation/Tests/TestDecimal.swift

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,13 @@
77
// See http://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
88
//
99

10-
import Foundation
11-
import XCTest
10+
#if NS_FOUNDATION_ALLOWS_TESTABLE_IMPORT
11+
#if canImport(SwiftFoundation) && !DEPLOYMENT_RUNTIME_OBJC
12+
@testable import SwiftFoundation
13+
#else
14+
@testable import Foundation
15+
#endif
16+
#endif
1217

1318
class TestDecimal: XCTestCase {
1419

Tests/Foundation/Tests/TestURLSession.swift

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1894,9 +1894,17 @@ class TestURLSession: LoopbackServerTest {
18941894
}
18951895

18961896
try await task.sendPing()
1897-
1897+
18981898
wait(for: [delegate.expectation], timeout: 50)
18991899

1900+
do {
1901+
_ = try await task.receive()
1902+
XCTFail("Expected to throw when receiving on closed task")
1903+
} catch {
1904+
let urlError = try XCTUnwrap(error as? URLError)
1905+
XCTAssertEqual(urlError._nsError.code, NSURLErrorNetworkConnectionLost)
1906+
}
1907+
19001908
let callbacks = [ "urlSession(_:webSocketTask:didOpenWithProtocol:)",
19011909
"urlSession(_:webSocketTask:didCloseWith:reason:)",
19021910
"urlSession(_:task:didCompleteWithError:)" ]
@@ -1925,15 +1933,98 @@ class TestURLSession: LoopbackServerTest {
19251933
wait(for: [delegate.expectation], timeout: 50)
19261934

19271935
let callbacks = [ "urlSession(_:webSocketTask:didOpenWithProtocol:)",
1936+
"urlSession(_:webSocketTask:didCloseWith:reason:)",
19281937
"urlSession(_:task:didCompleteWithError:)" ]
19291938
XCTAssertEqual(delegate.callbacks.count, callbacks.count)
19301939
XCTAssertEqual(delegate.callbacks, callbacks, "Callbacks for \(#function)")
19311940

19321941
XCTAssertEqual(task.closeCode, .normalClosure)
19331942
XCTAssertEqual(task.closeReason, "BuhBye".data(using: .utf8))
19341943
}
1935-
#endif
19361944

1945+
func test_webSocketAbruptClose() async throws {
1946+
guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return }
1947+
guard URLSessionWebSocketTask.supportsWebSockets else {
1948+
print("libcurl lacks WebSockets support, skipping \(#function)")
1949+
return
1950+
}
1951+
1952+
let urlString = "ws://127.0.0.1:\(TestURLSession.serverPort)/web-socket/abrupt-close"
1953+
let url = try XCTUnwrap(URL(string: urlString))
1954+
let request = URLRequest(url: url)
1955+
1956+
let delegate = SessionDelegate(with: expectation(description: "\(urlString): Connect"))
1957+
let task = delegate.runWebSocketTask(with: request, timeoutInterval: 4)
1958+
1959+
do {
1960+
_ = try await task.receive()
1961+
XCTFail("Expected to throw when server closes connection")
1962+
} catch {
1963+
let urlError = try XCTUnwrap(error as? URLError)
1964+
XCTAssertEqual(urlError._nsError.code, NSURLErrorBadServerResponse)
1965+
}
1966+
1967+
wait(for: [delegate.expectation], timeout: 50)
1968+
1969+
do {
1970+
_ = try await task.receive()
1971+
XCTFail("Expected to throw when receiving on closed connection")
1972+
} catch {
1973+
let urlError = try XCTUnwrap(error as? URLError)
1974+
XCTAssertEqual(urlError._nsError.code, NSURLErrorBadServerResponse)
1975+
}
1976+
1977+
let callbacks = [ "urlSession(_:task:didCompleteWithError:)" ]
1978+
XCTAssertEqual(delegate.callbacks.count, callbacks.count)
1979+
XCTAssertEqual(delegate.callbacks, callbacks, "Callbacks for \(#function)")
1980+
1981+
XCTAssertEqual(task.closeCode, .invalid)
1982+
XCTAssertEqual(task.closeReason, nil)
1983+
}
1984+
1985+
func test_webSocketSemiAbruptClose() async throws {
1986+
guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return }
1987+
guard URLSessionWebSocketTask.supportsWebSockets else {
1988+
print("libcurl lacks WebSockets support, skipping \(#function)")
1989+
return
1990+
}
1991+
1992+
let urlString = "ws://127.0.0.1:\(TestURLSession.serverPort)/web-socket/semi-abrupt-close"
1993+
let url = try XCTUnwrap(URL(string: urlString))
1994+
let request = URLRequest(url: url)
1995+
1996+
let delegate = SessionDelegate(with: expectation(description: "\(urlString): Connect"))
1997+
let task = delegate.runWebSocketTask(with: request, timeoutInterval: 4)
1998+
1999+
do {
2000+
_ = try await task.receive()
2001+
XCTFail("Expected to throw when server closes connection")
2002+
} catch {
2003+
let urlError = try XCTUnwrap(error as? URLError)
2004+
XCTAssertEqual(urlError._nsError.code, NSURLErrorNetworkConnectionLost)
2005+
}
2006+
2007+
wait(for: [delegate.expectation], timeout: 50)
2008+
2009+
do {
2010+
_ = try await task.receive()
2011+
XCTFail("Expected to throw when receiving on closed connection")
2012+
} catch {
2013+
let urlError = try XCTUnwrap(error as? URLError)
2014+
XCTAssertEqual(urlError._nsError.code, NSURLErrorNetworkConnectionLost)
2015+
}
2016+
2017+
let callbacks = [ "urlSession(_:webSocketTask:didOpenWithProtocol:)",
2018+
"urlSession(_:webSocketTask:didCloseWith:reason:)",
2019+
"urlSession(_:task:didCompleteWithError:)" ]
2020+
XCTAssertEqual(delegate.callbacks.count, callbacks.count)
2021+
XCTAssertEqual(delegate.callbacks, callbacks, "Callbacks for \(#function)")
2022+
2023+
XCTAssertEqual(task.closeCode, .normalClosure)
2024+
XCTAssertEqual(task.closeReason, nil)
2025+
}
2026+
#endif
2027+
19372028
static var allTests: [(String, (TestURLSession) -> () throws -> Void)] {
19382029
var retVal = [
19392030
("test_dataTaskWithURL", test_dataTaskWithURL),
@@ -2011,6 +2102,8 @@ class TestURLSession: LoopbackServerTest {
20112102
retVal.append(contentsOf: [
20122103
("test_webSocket", asyncTest(test_webSocket)),
20132104
("test_webSocketSpecificProtocol", asyncTest(test_webSocketSpecificProtocol)),
2105+
("test_webSocketAbruptClose", asyncTest(test_webSocketAbruptClose)),
2106+
("test_webSocketSemiAbruptClose", asyncTest(test_webSocketSemiAbruptClose)),
20142107
])
20152108
}
20162109
return retVal

0 commit comments

Comments
 (0)