Skip to content

Commit eaafb67

Browse files
authored
Merge pull request #1602 from spevans/pr_urlsession_fixes
2 parents b3cf307 + 079fbf4 commit eaafb67

File tree

2 files changed

+138
-72
lines changed

2 files changed

+138
-72
lines changed

TestFoundation/HTTPServer.swift

Lines changed: 116 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ extension UInt16 {
4040
}
4141

4242
class _TCPSocket {
43-
43+
44+
private let sendFlags: CInt
4445
private var listenSocket: Int32!
4546
private var socketAddress = UnsafeMutablePointer<sockaddr_in>.allocate(capacity: 1)
4647
private var connectionSocket: Int32!
@@ -55,27 +56,37 @@ class _TCPSocket {
5556

5657
private func attempt(_ name: String, file: String = #file, line: UInt = #line, valid: (CInt) -> Bool, _ b: @autoclosure () -> CInt) throws -> CInt {
5758
let r = b()
58-
guard valid(r) else { throw ServerError(operation: name, errno: r, file: file, line: line) }
59+
guard valid(r) else {
60+
throw ServerError(operation: name, errno: errno, file: file, line: line)
61+
}
5962
return r
6063
}
6164

6265
public private(set) var port: UInt16
6366

6467
init(port: UInt16?) throws {
65-
#if os(Linux) && !os(Android)
68+
#if canImport(Darwin)
69+
sendFlags = 0
70+
#else
71+
sendFlags = CInt(MSG_NOSIGNAL)
72+
#endif
73+
74+
#if os(Linux) && !os(Android)
6675
let SOCKSTREAM = Int32(SOCK_STREAM.rawValue)
67-
#else
76+
#else
6877
let SOCKSTREAM = SOCK_STREAM
69-
#endif
78+
#endif
7079
self.port = port ?? 0
7180
listenSocket = try attempt("socket", valid: isNotNegative, socket(AF_INET, SOCKSTREAM, Int32(IPPROTO_TCP)))
72-
var on: Int = 1
73-
_ = try attempt("setsockopt", valid: isZero, setsockopt(listenSocket, SOL_SOCKET, SO_REUSEADDR, &on, socklen_t(MemoryLayout<Int>.size)))
81+
var on: CInt = 1
82+
_ = try attempt("setsockopt", valid: isZero, setsockopt(listenSocket, SOL_SOCKET, SO_REUSEADDR, &on, socklen_t(MemoryLayout<CInt>.size)))
83+
7484
let sa = createSockaddr(port)
7585
socketAddress.initialize(to: sa)
7686
try socketAddress.withMemoryRebound(to: sockaddr.self, capacity: MemoryLayout<sockaddr>.size, {
7787
let addr = UnsafePointer<sockaddr>($0)
7888
_ = try attempt("bind", valid: isZero, bind(listenSocket, addr, socklen_t(MemoryLayout<sockaddr>.size)))
89+
_ = try attempt("listen", valid: isZero, listen(listenSocket, SOMAXCONN))
7990
})
8091

8192
var actualSA = sockaddr_in()
@@ -104,18 +115,21 @@ class _TCPSocket {
104115
}
105116

106117
func acceptConnection(notify: ServerSemaphore) throws {
107-
_ = try attempt("listen", valid: isZero, listen(listenSocket, SOMAXCONN))
108118
try socketAddress.withMemoryRebound(to: sockaddr.self, capacity: MemoryLayout<sockaddr>.size, {
109119
let addr = UnsafeMutablePointer<sockaddr>($0)
110120
var sockLen = socklen_t(MemoryLayout<sockaddr>.size)
111-
notify.signal()
112121
connectionSocket = try attempt("accept", valid: isNotNegative, accept(listenSocket, addr, &sockLen))
122+
#if canImport(Dawin)
123+
// Disable SIGPIPEs when writing to closed sockets
124+
var on: CInt = 1
125+
_ = try attempt("setsockopt", valid: isZero, setsockopt(connectionSocket, SOL_SOCKET, SO_NOSIGPIPE, &on, socklen_t(MemoryLayout<CInt>.size)))
126+
#endif
113127
})
114128
}
115129

116130
func readData() throws -> String {
117131
var buffer = [UInt8](repeating: 0, count: 4096)
118-
_ = try attempt("read", valid: isNotNegative, CInt(read(connectionSocket, &buffer, 4096)))
132+
_ = try attempt("read", valid: isNotNegative, CInt(read(connectionSocket, &buffer, buffer.count)))
119133
return String(cString: &buffer)
120134
}
121135

@@ -129,33 +143,39 @@ class _TCPSocket {
129143

130144
func writeRawData(_ data: Data) throws {
131145
_ = try data.withUnsafeBytes { ptr in
132-
try attempt("write", valid: isNotNegative, CInt(write(connectionSocket, ptr, data.count)))
146+
try attempt("send", valid: isNotNegative, CInt(send(connectionSocket, ptr, data.count, sendFlags)))
133147
}
134148
}
135149

136150
func writeData(header: String, body: String, sendDelay: TimeInterval? = nil, bodyChunks: Int? = nil) throws {
137-
var header = Array(header.utf8)
138-
_ = try attempt("write", valid: isNotNegative, CInt(write(connectionSocket, &header, header.count)))
139-
151+
var _header = Array(header.utf8)
152+
_ = try attempt("send", valid: isNotNegative, CInt(send(connectionSocket, &_header, _header.count, sendFlags)))
153+
140154
if let sendDelay = sendDelay, let bodyChunks = bodyChunks {
141155
let count = max(1, Int(Double(body.utf8.count) / Double(bodyChunks)))
142156
let texts = split(body, count)
143157

144158
for item in texts {
145159
sleep(UInt32(sendDelay))
146160
var bytes = Array(item.utf8)
147-
_ = try attempt("write", valid: isNotNegative, CInt(write(connectionSocket, &bytes, bytes.count)))
161+
_ = try attempt("send", valid: isNotNegative, CInt(send(connectionSocket, &bytes, bytes.count, sendFlags)))
148162
}
149163
} else {
150164
var bytes = Array(body.utf8)
151-
_ = try attempt("write", valid: isNotNegative, CInt(write(connectionSocket, &bytes, bytes.count)))
165+
_ = try attempt("send", valid: isNotNegative, CInt(send(connectionSocket, &bytes, bytes.count, sendFlags)))
152166
}
153167
}
154168

155-
func shutdown() {
169+
func closeClient() {
156170
if let connectionSocket = self.connectionSocket {
157171
close(connectionSocket)
172+
self.connectionSocket = nil
158173
}
174+
}
175+
176+
func shutdownListener() {
177+
closeClient()
178+
shutdown(listenSocket, CInt(SHUT_RDWR))
159179
close(listenSocket)
160180
}
161181
}
@@ -182,32 +202,23 @@ class _HTTPServer {
182202
}
183203

184204
public func stop() {
185-
socket.shutdown()
205+
socket.closeClient()
206+
socket.shutdownListener()
186207
}
187208

188209
public func request() throws -> _HTTPRequest {
189210
return try _HTTPRequest(request: socket.readData())
190211
}
191212

192213
public func respond(with response: _HTTPResponse, startDelay: TimeInterval? = nil, sendDelay: TimeInterval? = nil, bodyChunks: Int? = nil) throws {
193-
let semaphore = DispatchSemaphore(value: 0)
194-
let deadlineTime: DispatchTime
195-
196-
if let startDelay = startDelay {
197-
deadlineTime = .now() + .seconds(Int(startDelay))
198-
} else {
199-
deadlineTime = .now()
214+
if let delay = startDelay {
215+
Thread.sleep(forTimeInterval: delay)
200216
}
201-
202-
DispatchQueue.global().asyncAfter(deadline: deadlineTime) {
203-
do {
204-
try self.socket.writeData(header: response.header, body: response.body, sendDelay: sendDelay, bodyChunks: bodyChunks)
205-
semaphore.signal()
206-
} catch { }
217+
do {
218+
try self.socket.writeData(header: response.header, body: response.body, sendDelay: sendDelay, bodyChunks: bodyChunks)
219+
} catch {
207220
}
208-
semaphore.wait()
209-
210-
}
221+
}
211222

212223
func respondWithBrokenResponses(uri: String) throws {
213224
let responseData: Data
@@ -285,11 +296,13 @@ struct _HTTPRequest {
285296
let headers: [String]
286297

287298
public init(request: String) {
288-
let lines = request.components(separatedBy: _HTTPUtils.CRLF2)[0].components(separatedBy: _HTTPUtils.CRLF)
289-
headers = Array(lines[0...lines.count-2])
290-
method = Method(rawValue: headers[0].components(separatedBy: " ")[0])!
291-
uri = headers[0].components(separatedBy: " ")[1]
292-
body = lines.last!
299+
let headerEnd = (request as NSString).range(of: _HTTPUtils.CRLF2)
300+
let header = (request as NSString).substring(to: headerEnd.location)
301+
headers = header.components(separatedBy: _HTTPUtils.CRLF)
302+
let action = headers[0]
303+
method = Method(rawValue: action.components(separatedBy: " ")[0])!
304+
uri = action.components(separatedBy: " ")[1]
305+
body = (request as NSString).substring(from: headerEnd.location + headerEnd.length)
293306
}
294307

295308
public func getCommaSeparatedHeaders() -> String {
@@ -300,6 +313,16 @@ struct _HTTPRequest {
300313
return allHeaders
301314
}
302315

316+
public func getHeader(for key: String) -> String? {
317+
let lookup = key.lowercased()
318+
for header in headers {
319+
let parts = header.components(separatedBy: ":")
320+
if parts[0].lowercased() == lookup {
321+
return parts[1].trimmingCharacters(in: CharacterSet(charactersIn: " "))
322+
}
323+
}
324+
return nil
325+
}
303326
}
304327

305328
struct _HTTPResponse {
@@ -349,13 +372,16 @@ public class TestURLSessionServer {
349372
self.sendDelay = sendDelay
350373
self.bodyChunks = bodyChunks
351374
}
352-
public func start(started: ServerSemaphore) throws {
353-
started.signal()
354-
try httpServer.listen(notify: started)
355-
}
356-
375+
357376
public func readAndRespond() throws {
358377
let req = try httpServer.request()
378+
379+
if let value = req.getHeader(for: "x-pause") {
380+
if let wait = Double(value), wait > 0 {
381+
Thread.sleep(forTimeInterval: wait)
382+
}
383+
}
384+
359385
if req.uri.hasPrefix("/LandOfTheLostCities/") {
360386
/* these are all misbehaving servers */
361387
try httpServer.respondWithBrokenResponses(uri: req.uri)
@@ -402,7 +428,7 @@ public class TestURLSessionServer {
402428

403429
if uri == "/requestCookies" {
404430
let text = request.getCommaSeparatedHeaders()
405-
return _HTTPResponse(response: .OK, headers: "Content-Length: \(text.data(using: .utf8)!.count)\r\nSet-Cookie: fr=anjd&232; Max-Age=7776000; path=/; domain=127.0.0.1; secure; httponly\r\nSet-Cookie: nm=sddf&232; Max-Age=7776000; path=/; domain=.swift.org; secure; httponly\r\n", body: text)
431+
return _HTTPResponse(response: .OK, headers: "Content-Length: \(text.data(using: .utf8)!.count)\r\nSet-Cookie: fr=anjd&232; Max-Age=7776000; path=/\r\nSet-Cookie: nm=sddf&232; Max-Age=7776000; path=/; domain=.swift.org; secure; httponly\r\n", body: text)
406432
}
407433

408434
if uri == "/setCookies" {
@@ -482,8 +508,8 @@ extension ServerError : CustomStringConvertible {
482508
public class ServerSemaphore {
483509
let dispatchSemaphore = DispatchSemaphore(value: 0)
484510

485-
public func wait() {
486-
dispatchSemaphore.wait()
511+
public func wait(timeout: DispatchTime) -> DispatchTimeoutResult {
512+
return dispatchSemaphore.wait(timeout: timeout)
487513
}
488514

489515
public func signal() {
@@ -495,6 +521,11 @@ class LoopbackServerTest : XCTestCase {
495521
private static let staticSyncQ = DispatchQueue(label: "org.swift.TestFoundation.HTTPServer.StaticSyncQ")
496522

497523
private static var _serverPort: Int = -1
524+
private static let serverReady = ServerSemaphore()
525+
private static var _serverActive = false
526+
private static var testServer: TestURLSessionServer? = nil
527+
528+
498529
static var serverPort: Int {
499530
get {
500531
return staticSyncQ.sync { _serverPort }
@@ -504,30 +535,56 @@ class LoopbackServerTest : XCTestCase {
504535
}
505536
}
506537

538+
static var serverActive: Bool {
539+
get { return staticSyncQ.sync { _serverActive } }
540+
set { staticSyncQ.sync { _serverActive = newValue }}
541+
}
542+
543+
static func terminateServer() {
544+
serverActive = false
545+
testServer?.stop()
546+
testServer = nil
547+
}
548+
507549
override class func setUp() {
508550
super.setUp()
509551
func runServer(with condition: ServerSemaphore, startDelay: TimeInterval? = nil, sendDelay: TimeInterval? = nil, bodyChunks: Int? = nil) throws {
510-
while true {
511-
let test = try TestURLSessionServer(port: nil, startDelay: startDelay, sendDelay: sendDelay, bodyChunks: bodyChunks)
512-
serverPort = Int(test.port)
513-
try test.start(started: condition)
514-
try test.readAndRespond()
515-
serverPort = -2
516-
test.stop()
552+
let server = try TestURLSessionServer(port: nil, startDelay: startDelay, sendDelay: sendDelay, bodyChunks: bodyChunks)
553+
testServer = server
554+
serverPort = Int(server.port)
555+
serverReady.signal()
556+
serverActive = true
557+
558+
while serverActive {
559+
do {
560+
try server.httpServer.listen(notify: condition)
561+
try server.readAndRespond()
562+
server.httpServer.socket.closeClient()
563+
} catch {
564+
}
517565
}
566+
serverPort = -2
567+
518568
}
519569

520-
let serverReady = ServerSemaphore()
521570
globalDispatchQueue.async {
522571
do {
523572
try runServer(with: serverReady)
524-
525573
} catch {
526-
XCTAssertTrue(true)
527-
return
528574
}
529575
}
530576

531-
serverReady.wait()
577+
let timeout = DispatchTime(uptimeNanoseconds: DispatchTime.now().uptimeNanoseconds + 2_000_000_000)
578+
579+
while serverPort == -2 {
580+
guard serverReady.wait(timeout: timeout) == .success else {
581+
fatalError("Timedout waiting for server to be ready")
582+
}
583+
}
584+
}
585+
586+
override class func tearDown() {
587+
super.tearDown()
588+
terminateServer()
532589
}
533590
}

0 commit comments

Comments
 (0)