Skip to content

Commit d647104

Browse files
committed
Fix crash when using WebSockets with URLSession.shared (swiftlang#5128)
1 parent 22af71d commit d647104

File tree

2 files changed

+155
-3
lines changed

2 files changed

+155
-3
lines changed

Sources/FoundationNetworking/URLSession/WebSocket/WebSocketURLProtocol.swift

+10-3
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,16 @@ internal class _WebSocketURLProtocol: _HTTPURLProtocol {
129129
guard let t = self.task else {
130130
fatalError("Cannot notify")
131131
}
132-
guard case .taskDelegate = t.session.behaviour(for: self.task!),
133-
let task = self.task as? URLSessionWebSocketTask else {
134-
fatalError("WebSocket internal invariant violated")
132+
switch t.session.behaviour(for: t) {
133+
case .noDelegate:
134+
break
135+
case .taskDelegate:
136+
break
137+
default:
138+
fatalError("Unexpected behaviour for URLSessionWebSocketTask")
139+
}
140+
guard let task = t as? URLSessionWebSocketTask else {
141+
fatalError("Cast to URLSessionWebSocketTask failed")
135142
}
136143

137144
// Buffer the response message in the task

Tests/Foundation/TestURLSession.swift

+145
Original file line numberDiff line numberDiff line change
@@ -2148,6 +2148,151 @@ final class TestURLSession: LoopbackServerTest, @unchecked Sendable {
21482148
XCTAssertEqual(delegate.callbacks.count, callbacks.count)
21492149
XCTAssertEqual(delegate.callbacks, callbacks, "Callbacks for \(#function)")
21502150
}
2151+
2152+
func test_webSocketShared() async throws {
2153+
guard #available(macOS 12, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return }
2154+
guard URLSessionWebSocketTask.supportsWebSockets else {
2155+
print("libcurl lacks WebSockets support, skipping \(#function)")
2156+
return
2157+
}
2158+
2159+
let urlString = "ws://127.0.0.1:\(TestURLSession.serverPort)/web-socket"
2160+
let url = try XCTUnwrap(URL(string: urlString))
2161+
2162+
let task = URLSession.shared.webSocketTask(with: url)
2163+
task.resume()
2164+
2165+
// We interleave sending and receiving, as the test HTTPServer implementation is barebones, and can't handle receiving more than one frame at a time. So, this back-and-forth acts as a gating mechanism
2166+
try await task.send(.string("Hello"))
2167+
2168+
let stringMessage = try await task.receive()
2169+
switch stringMessage {
2170+
case .string(let str):
2171+
XCTAssert(str == "Hello")
2172+
default:
2173+
XCTFail("Unexpected String Message")
2174+
}
2175+
2176+
try await task.send(.data(Data([0x20, 0x22, 0x10, 0x03])))
2177+
2178+
let dataMessage = try await task.receive()
2179+
switch dataMessage {
2180+
case .data(let data):
2181+
XCTAssert(data == Data([0x20, 0x22, 0x10, 0x03]))
2182+
default:
2183+
XCTFail("Unexpected Data Message")
2184+
}
2185+
2186+
do {
2187+
try await task.sendPing()
2188+
// Server hasn't closed the connection yet
2189+
} catch {
2190+
// Server closed the connection before we could process the pong
2191+
let urlError = try XCTUnwrap(error as? URLError)
2192+
XCTAssertEqual(urlError._nsError.code, NSURLErrorNetworkConnectionLost)
2193+
}
2194+
}
2195+
2196+
func test_webSocketCompletions() async throws {
2197+
guard #available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return }
2198+
guard URLSessionWebSocketTask.supportsWebSockets else {
2199+
print("libcurl lacks WebSockets support, skipping \(#function)")
2200+
return
2201+
}
2202+
2203+
let urlString = "ws://127.0.0.1:\(TestURLSession.serverPort)/web-socket"
2204+
let url = try XCTUnwrap(URL(string: urlString))
2205+
let request = URLRequest(url: url)
2206+
2207+
let delegate = SessionDelegate(with: expectation(description: "\(urlString): Connect"))
2208+
let task = delegate.runWebSocketTask(with: request, timeoutInterval: 4)
2209+
2210+
// We interleave sending and receiving, as the test HTTPServer implementation is barebones, and can't handle receiving more than one frame at a time. So, this back-and-forth acts as a gating mechanism
2211+
2212+
let didCompleteSendingString = expectation(description: "Did complete sending a string")
2213+
task.send(.string("Hello")) { error in
2214+
XCTAssertNil(error)
2215+
didCompleteSendingString.fulfill()
2216+
}
2217+
await fulfillment(of: [didCompleteSendingString], timeout: 5.0)
2218+
2219+
let didCompleteReceivingString = expectation(description: "Did complete receiving a string")
2220+
task.receive { result in
2221+
switch result {
2222+
case .failure(let error):
2223+
XCTFail()
2224+
case .success(let stringMessage):
2225+
switch stringMessage {
2226+
case .string(let str):
2227+
XCTAssert(str == "Hello")
2228+
default:
2229+
XCTFail("Unexpected String Message")
2230+
}
2231+
}
2232+
didCompleteReceivingString.fulfill()
2233+
}
2234+
await fulfillment(of: [didCompleteReceivingString], timeout: 5.0)
2235+
2236+
let didCompleteSendingData = expectation(description: "Did complete sending data")
2237+
task.send(.data(Data([0x20, 0x22, 0x10, 0x03]))) { error in
2238+
XCTAssertNil(error)
2239+
didCompleteSendingData.fulfill()
2240+
}
2241+
await fulfillment(of: [didCompleteSendingData], timeout: 5.0)
2242+
2243+
let didCompleteReceivingData = expectation(description: "Did complete receiving data")
2244+
task.receive { result in
2245+
switch result {
2246+
case .failure(let error):
2247+
XCTFail()
2248+
case .success(let dataMessage):
2249+
switch dataMessage {
2250+
case .data(let data):
2251+
XCTAssert(data == Data([0x20, 0x22, 0x10, 0x03]))
2252+
default:
2253+
XCTFail("Unexpected Data Message")
2254+
}
2255+
}
2256+
didCompleteReceivingData.fulfill()
2257+
}
2258+
await fulfillment(of: [didCompleteReceivingData], timeout: 5.0)
2259+
2260+
let didCompleteSendingPing = expectation(description: "Did complete sending ping")
2261+
task.sendPing { error in
2262+
if let error {
2263+
// Server closed the connection before we could process the pong
2264+
if let urlError = error as? URLError {
2265+
XCTAssertEqual(urlError._nsError.code, NSURLErrorNetworkConnectionLost)
2266+
} else {
2267+
XCTFail("Unexpecter error type")
2268+
}
2269+
}
2270+
didCompleteSendingPing.fulfill()
2271+
}
2272+
await fulfillment(of: [delegate.expectation, didCompleteSendingPing], timeout: 50.0)
2273+
2274+
let didCompleteReceiving = expectation(description: "Did complete receiving")
2275+
task.receive { result in
2276+
switch result {
2277+
case .failure(let error):
2278+
if let urlError = error as? URLError {
2279+
XCTAssertEqual(urlError._nsError.code, NSURLErrorNetworkConnectionLost)
2280+
} else {
2281+
XCTFail("Unexpecter error type")
2282+
}
2283+
case .success:
2284+
XCTFail("Expected to throw when receiving on closed task")
2285+
}
2286+
didCompleteReceiving.fulfill()
2287+
}
2288+
await fulfillment(of: [didCompleteReceiving], timeout: 5.0)
2289+
2290+
let callbacks = [ "urlSession(_:webSocketTask:didOpenWithProtocol:)",
2291+
"urlSession(_:webSocketTask:didCloseWith:reason:)",
2292+
"urlSession(_:task:didCompleteWithError:)" ]
2293+
XCTAssertEqual(delegate.callbacks.count, callbacks.count)
2294+
XCTAssertEqual(delegate.callbacks, callbacks, "Callbacks for \(#function)")
2295+
}
21512296

21522297
func test_webSocketSpecificProtocol() async throws {
21532298
guard #available(macOS 12, iOS 13.0, watchOS 6.0, tvOS 13.0, *) else { return }

0 commit comments

Comments
 (0)