@@ -15,6 +15,9 @@ import LSPLogging
15
15
import LanguageServerProtocol
16
16
import SKSupport
17
17
import SourceKitD
18
+ import SwiftParser
19
+ @_spi ( SourceKitLSP) import SwiftRefactor
20
+ import SwiftSyntax
18
21
19
22
/// Represents a code-completion session for a given source location that can be efficiently
20
23
/// re-filtered by calling `update()`.
@@ -88,6 +91,7 @@ class CodeCompletionSession {
88
91
static func completionList(
89
92
sourcekitd: any SourceKitD ,
90
93
snapshot: DocumentSnapshot ,
94
+ syntaxTreeParseResult: IncrementalParseResult ,
91
95
completionPosition: Position ,
92
96
completionUtf8Offset: Int ,
93
97
cursorPosition: Position ,
@@ -119,6 +123,7 @@ class CodeCompletionSession {
119
123
let session = CodeCompletionSession (
120
124
sourcekitd: sourcekitd,
121
125
snapshot: snapshot,
126
+ syntaxTreeParseResult: syntaxTreeParseResult,
122
127
utf8Offset: completionUtf8Offset,
123
128
position: completionPosition,
124
129
compileCommand: compileCommand,
@@ -135,6 +140,7 @@ class CodeCompletionSession {
135
140
136
141
private let sourcekitd : any SourceKitD
137
142
private let snapshot : DocumentSnapshot
143
+ private let syntaxTreeParseResult : IncrementalParseResult
138
144
private let utf8StartOffset : Int
139
145
private let position : Position
140
146
private let compileCommand : SwiftCompileCommand ?
@@ -152,12 +158,14 @@ class CodeCompletionSession {
152
158
private init (
153
159
sourcekitd: any SourceKitD ,
154
160
snapshot: DocumentSnapshot ,
161
+ syntaxTreeParseResult: IncrementalParseResult ,
155
162
utf8Offset: Int ,
156
163
position: Position ,
157
164
compileCommand: SwiftCompileCommand ? ,
158
165
clientSupportsSnippets: Bool
159
166
) {
160
167
self . sourcekitd = sourcekitd
168
+ self . syntaxTreeParseResult = syntaxTreeParseResult
161
169
self . snapshot = snapshot
162
170
self . utf8StartOffset = utf8Offset
163
171
self . position = position
@@ -271,6 +279,54 @@ class CodeCompletionSession {
271
279
272
280
// MARK: - Helpers
273
281
282
+ private func expandClosurePlaceholders(
283
+ insertText: String ,
284
+ utf8CodeUnitsToErase: Int ,
285
+ requestPosition: Position
286
+ ) -> String ? {
287
+ guard insertText. contains ( " <# " ) && insertText. contains ( " -> " ) else {
288
+ // Fast path: There is no closure placeholder to expand
289
+ return nil
290
+ }
291
+ guard requestPosition. line < snapshot. lineTable. count else {
292
+ logger. error ( " Request position is past the last line " )
293
+ return nil
294
+ }
295
+
296
+ let indentationOfLine = snapshot. lineTable [ requestPosition. line] . prefix ( while: { $0. isWhitespace } )
297
+
298
+ let strippedPrefix : String
299
+ let exprToExpand : String
300
+ if insertText. starts ( with: " ?. " ) {
301
+ strippedPrefix = " ?. "
302
+ exprToExpand = indentationOfLine + String( insertText. dropFirst ( 2 ) )
303
+ } else {
304
+ strippedPrefix = " "
305
+ exprToExpand = indentationOfLine + insertText
306
+ }
307
+
308
+ var parser = Parser ( exprToExpand)
309
+ let expr = ExprSyntax . parse ( from: & parser)
310
+ guard let call = OutermostFunctionCallFinder . findOutermostFunctionCall ( in: expr) ,
311
+ let expandedCall = ExpandEditorPlaceholdersToTrailingClosures . refactor ( syntax: call)
312
+ else {
313
+ return nil
314
+ }
315
+
316
+ let bytesToExpand = Array ( exprToExpand. utf8)
317
+
318
+ var expandedBytes : [ UInt8 ] = [ ]
319
+ // Add the prefix that we stripped of to allow expression parsing
320
+ expandedBytes += strippedPrefix. utf8
321
+ // Add any part of the expression that didn't end up being part of the function call
322
+ expandedBytes += bytesToExpand [ 0 ..< call. position. utf8Offset]
323
+ // Add the expanded function call excluding the added `indentationOfLine`
324
+ expandedBytes += expandedCall. syntaxTextBytes [ indentationOfLine. utf8. count... ]
325
+ // Add any trailing text that didn't end up being part of the function call
326
+ expandedBytes += bytesToExpand [ call. endPosition. utf8Offset... ]
327
+ return String ( bytes: expandedBytes, encoding: . utf8)
328
+ }
329
+
274
330
private func completionsFromSKDResponse(
275
331
_ completions: SKDResponseArray ,
276
332
in snapshot: DocumentSnapshot ,
@@ -286,9 +342,19 @@ class CodeCompletionSession {
286
342
}
287
343
288
344
var filterName : String ? = value [ keys. name]
289
- let insertText : String ? = value [ keys. sourceText]
345
+ var insertText : String ? = value [ keys. sourceText]
290
346
let typeName : String ? = value [ sourcekitd. keys. typeName]
291
347
let docBrief : String ? = value [ sourcekitd. keys. docBrief]
348
+ let utf8CodeUnitsToErase : Int = value [ sourcekitd. keys. numBytesToErase] ?? 0
349
+
350
+ if let insertTextUnwrapped = insertText {
351
+ insertText =
352
+ expandClosurePlaceholders (
353
+ insertText: insertTextUnwrapped,
354
+ utf8CodeUnitsToErase: utf8CodeUnitsToErase,
355
+ requestPosition: requestPosition
356
+ ) ?? insertText
357
+ }
292
358
293
359
let text = insertText. map {
294
360
rewriteSourceKitPlaceholders ( inString: $0, clientSupportsSnippets: clientSupportsSnippets)
@@ -297,8 +363,6 @@ class CodeCompletionSession {
297
363
298
364
let textEdit : TextEdit ?
299
365
if let text = text {
300
- let utf8CodeUnitsToErase : Int = value [ sourcekitd. keys. numBytesToErase] ?? 0
301
-
302
366
textEdit = self . computeCompletionTextEdit (
303
367
completionPos: completionPos,
304
368
requestPosition: requestPosition,
@@ -411,3 +475,39 @@ extension CodeCompletionSession: CustomStringConvertible {
411
475
" \( uri. pseudoPath) : \( position) "
412
476
}
413
477
}
478
+
479
+ fileprivate class OutermostFunctionCallFinder : SyntaxAnyVisitor {
480
+ /// Once a `FunctionCallExprSyntax` has been visited, that syntax node.
481
+ var foundCall : FunctionCallExprSyntax ?
482
+
483
+ private func shouldVisit( _ node: some SyntaxProtocol ) -> Bool {
484
+ if foundCall != nil {
485
+ return false
486
+ }
487
+ return true
488
+ }
489
+
490
+ override func visitAny( _ node: Syntax ) -> SyntaxVisitorContinueKind {
491
+ guard shouldVisit ( node) else {
492
+ return . skipChildren
493
+ }
494
+ return . visitChildren
495
+ }
496
+
497
+ override func visit( _ node: FunctionCallExprSyntax ) -> SyntaxVisitorContinueKind {
498
+ guard shouldVisit ( node) else {
499
+ return . skipChildren
500
+ }
501
+ foundCall = node
502
+ return . skipChildren
503
+ }
504
+
505
+ /// Find the innermost `FunctionCallExprSyntax` that contains `position`.
506
+ static func findOutermostFunctionCall(
507
+ in tree: some SyntaxProtocol
508
+ ) -> FunctionCallExprSyntax ? {
509
+ let finder = OutermostFunctionCallFinder ( viewMode: . sourceAccurate)
510
+ finder. walk ( tree)
511
+ return finder. foundCall
512
+ }
513
+ }
0 commit comments