Skip to content

Commit 71dcf5a

Browse files
authored
improve IndexStore::listTests to include inherited tests (#335)
motivation: test discivery on linux depends on index store test listing changes: * refactor code to also scan for class inheritance * rollup test across inheritance in different files * deprecate single file API as it cannot effectively take inheritance into account rdar://59655518
1 parent 5176169 commit 71dcf5a

File tree

1 file changed

+155
-40
lines changed

1 file changed

+155
-40
lines changed

Sources/TSCUtility/IndexStore.swift

Lines changed: 155 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ public final class IndexStore {
4141
return IndexStore(impl)
4242
}
4343

44+
public func listTests(in objectFiles: [AbsolutePath]) throws -> [TestCaseClass] {
45+
return try impl.listTests(in: objectFiles)
46+
}
47+
48+
@available(*, deprecated, message: "use listTests(in:) instead")
4449
public func listTests(inObjectFile object: AbsolutePath) throws -> [TestCaseClass] {
4550
return try impl.listTests(inObjectFile: object)
4651
}
@@ -58,13 +63,10 @@ public final class IndexStoreAPI {
5863
}
5964

6065
private final class IndexStoreImpl {
61-
6266
typealias TestCaseClass = IndexStore.TestCaseClass
6367

6468
let api: IndexStoreAPIImpl
6569

66-
var fn: indexstore_functions_t { api.fn }
67-
6870
let store: indexstore_t
6971

7072
private init(store: indexstore_t, api: IndexStoreAPIImpl) {
@@ -79,47 +81,156 @@ private final class IndexStoreImpl {
7981
throw StringError("Unable to open store at \(path)")
8082
}
8183

84+
public func listTests(in objectFiles: [AbsolutePath]) throws -> [TestCaseClass] {
85+
var inheritance = [String: [String: String]]()
86+
var testMethods = [String: [String: [(name: String, async: Bool)]]]()
87+
88+
for objectFile in objectFiles {
89+
// Get the records of this object file.
90+
let unitReader = try self.api.call{ self.api.fn.unit_reader_create(store, unitName(object: objectFile), &$0) }
91+
let records = try getRecords(unitReader: unitReader)
92+
let moduleName = self.api.fn.unit_reader_get_module_name(unitReader).str
93+
for record in records {
94+
// get tests info
95+
let testsInfo = try self.getTestsInfo(record: record)
96+
// merge results across module
97+
for (className, parentClassName) in testsInfo.inheritance {
98+
inheritance[moduleName, default: [:]][className] = parentClassName
99+
}
100+
for (className, classTestMethods) in testsInfo.testMethods {
101+
testMethods[moduleName, default: [:]][className, default: []].append(contentsOf: classTestMethods)
102+
}
103+
}
104+
}
105+
106+
// merge across inheritance in module boundries
107+
func flatten(moduleName: String, className: String) -> [String: (name: String, async: Bool)] {
108+
var allMethods = [String: (name: String, async: Bool)]()
109+
110+
if let parentClassName = inheritance[moduleName]?[className] {
111+
let parentMethods = flatten(moduleName: moduleName, className: parentClassName)
112+
allMethods.merge(parentMethods, uniquingKeysWith: { (lhs, _) in lhs })
113+
}
114+
115+
for method in testMethods[moduleName]?[className] ?? [] {
116+
allMethods[method.name] = (name: method.name, async: method.async)
117+
}
118+
119+
return allMethods
120+
}
121+
122+
var testCaseClasses = [TestCaseClass]()
123+
for (moduleName, classMethods) in testMethods {
124+
for className in classMethods.keys {
125+
let methods = flatten(moduleName: moduleName, className: className)
126+
.map { (name, info) in TestCaseClass.TestMethod(name: name, isAsync: info.async) }
127+
.sorted()
128+
testCaseClasses.append(TestCaseClass(name: className, module: moduleName, testMethods: methods, methods: methods.map(\.name)))
129+
}
130+
}
131+
132+
return testCaseClasses
133+
}
134+
135+
136+
@available(*, deprecated, message: "use listTests(in:) instead")
82137
public func listTests(inObjectFile object: AbsolutePath) throws -> [TestCaseClass] {
83138
// Get the records of this object file.
84-
let unitReader = try api.call{ fn.unit_reader_create(store, unitName(object: object), &$0) }
139+
let unitReader = try api.call{ self.api.fn.unit_reader_create(store, unitName(object: object), &$0) }
85140
let records = try getRecords(unitReader: unitReader)
86141

87142
// Get the test classes.
88-
let testCaseClasses = try records.flatMap{ try self.getTestCaseClasses(forRecord: $0) }
89-
90-
// Fill the module name and return.
91-
let module = fn.unit_reader_get_module_name(unitReader).str
92-
return testCaseClasses.map {
93-
var c = $0
94-
c.module = module
95-
return c
143+
var inheritance = [String: String]()
144+
var testMethods = [String: [(name: String, async: Bool)]]()
145+
146+
for record in records {
147+
let testsInfo = try self.getTestsInfo(record: record)
148+
inheritance.merge(testsInfo.inheritance, uniquingKeysWith: { (lhs, _) in lhs })
149+
testMethods.merge(testsInfo.testMethods, uniquingKeysWith: { (lhs, _) in lhs })
150+
}
151+
152+
func flatten(className: String) -> [(method: String, async: Bool)] {
153+
var results = [(String, Bool)]()
154+
if let parentClassName = inheritance[className] {
155+
let parentMethods = flatten(className: parentClassName)
156+
results.append(contentsOf: parentMethods)
157+
}
158+
if let methods = testMethods[className] {
159+
results.append(contentsOf: methods)
160+
}
161+
return results
96162
}
163+
164+
let moduleName = self.api.fn.unit_reader_get_module_name(unitReader).str
165+
166+
var testCaseClasses = [TestCaseClass]()
167+
for className in testMethods.keys {
168+
let methods = flatten(className: className)
169+
.map { TestCaseClass.TestMethod(name: $0.method, isAsync: $0.async) }
170+
.sorted()
171+
testCaseClasses.append(TestCaseClass(name: className, module: moduleName, testMethods: methods, methods: methods.map(\.name)))
172+
}
173+
174+
return testCaseClasses
97175
}
98176

99-
private func getTestCaseClasses(forRecord record: String) throws -> [TestCaseClass] {
100-
let recordReader = try api.call{ fn.record_reader_create(store, record, &$0) }
177+
private func getTestsInfo(record: String) throws -> (inheritance: [String: String], testMethods: [String: [(name: String, async: Bool)]] ) {
178+
let recordReader = try api.call{ self.api.fn.record_reader_create(store, record, &$0) }
101179

102-
class TestCaseBuilder {
103-
var classToMethods: [String: Set<TestCaseClass.TestMethod>] = [:]
180+
// scan for inheritance
104181

105-
func add(className: String, method: TestCaseClass.TestMethod) {
106-
classToMethods[className, default: []].insert(method)
182+
let inheritanceRef = Ref([String: String](), api: self.api)
183+
let inheritancePointer = unsafeBitCast(Unmanaged.passUnretained(inheritanceRef), to: UnsafeMutableRawPointer.self)
184+
185+
_ = self.api.fn.record_reader_occurrences_apply_f(recordReader, inheritancePointer) { inheritancePointer , occ -> Bool in
186+
let inheritanceRef = Unmanaged<Ref<[String: String?]>>.fromOpaque(inheritancePointer!).takeUnretainedValue()
187+
let fn = inheritanceRef.api.fn
188+
189+
// Get the symbol.
190+
let sym = fn.occurrence_get_symbol(occ)
191+
let symbolProperties = fn.symbol_get_properties(sym)
192+
// We only care about symbols that are marked unit tests and are instance methods.
193+
if symbolProperties & UInt64(INDEXSTORE_SYMBOL_PROPERTY_UNITTEST.rawValue) == 0 {
194+
return true
195+
}
196+
if fn.symbol_get_kind(sym) != INDEXSTORE_SYMBOL_KIND_CLASS{
197+
return true
107198
}
108199

109-
func build() -> [TestCaseClass] {
110-
return classToMethods.map {
111-
let testMethods = Array($0.value).sorted()
112-
return TestCaseClass(name: $0.key, module: "", testMethods: testMethods, methods: testMethods.map(\.name))
200+
let parentClassName = fn.symbol_get_name(sym).str
201+
202+
let childClassNameRef = Ref("", api: inheritanceRef.api)
203+
let childClassNamePointer = unsafeBitCast(Unmanaged.passUnretained(childClassNameRef), to: UnsafeMutableRawPointer.self)
204+
_ = fn.occurrence_relations_apply_f(occ!, childClassNamePointer) { childClassNamePointer, relation in
205+
guard let relation = relation else { return true }
206+
let childClassNameRef = Unmanaged<Ref<String>>.fromOpaque(childClassNamePointer!).takeUnretainedValue()
207+
let fn = childClassNameRef.api.fn
208+
209+
// Look for the base class.
210+
if fn.symbol_relation_get_roles(relation) != UInt64(INDEXSTORE_SYMBOL_ROLE_REL_BASEOF.rawValue) {
211+
return true
113212
}
213+
214+
let childClassNameSym = fn.symbol_relation_get_symbol(relation)
215+
childClassNameRef.instance = fn.symbol_get_name(childClassNameSym).str
216+
return true
114217
}
218+
219+
if !childClassNameRef.instance.isEmpty {
220+
inheritanceRef.instance[childClassNameRef.instance] = parentClassName
221+
}
222+
223+
return true
115224
}
116225

117-
let builder = Ref(TestCaseBuilder(), api: api)
226+
// scan for methods
118227

119-
let ctx = unsafeBitCast(Unmanaged.passUnretained(builder), to: UnsafeMutableRawPointer.self)
120-
_ = fn.record_reader_occurrences_apply_f(recordReader, ctx) { ctx , occ -> Bool in
121-
let builder = Unmanaged<Ref<TestCaseBuilder>>.fromOpaque(ctx!).takeUnretainedValue()
122-
let fn = builder.api.fn
228+
let testMethodsRef = Ref([String: [(name: String, async: Bool)]](), api: api)
229+
let testMethodsPointer = unsafeBitCast(Unmanaged.passUnretained(testMethodsRef), to: UnsafeMutableRawPointer.self)
230+
231+
_ = self.api.fn.record_reader_occurrences_apply_f(recordReader, testMethodsPointer) { testMethodsPointer , occ -> Bool in
232+
let testMethodsRef = Unmanaged<Ref<[String: [(name: String, async: Bool)]]>>.fromOpaque(testMethodsPointer!).takeUnretainedValue()
233+
let fn = testMethodsRef.api.fn
123234

124235
// Get the symbol.
125236
let sym = fn.occurrence_get_symbol(occ)
@@ -132,41 +243,45 @@ private final class IndexStoreImpl {
132243
return true
133244
}
134245

135-
let className = Ref("", api: builder.api)
136-
let ctx = unsafeBitCast(Unmanaged.passUnretained(className), to: UnsafeMutableRawPointer.self)
246+
let classNameRef = Ref("", api: testMethodsRef.api)
247+
let classNamePointer = unsafeBitCast(Unmanaged.passUnretained(classNameRef), to: UnsafeMutableRawPointer.self)
137248

138-
_ = fn.occurrence_relations_apply_f(occ!, ctx) { ctx, relation in
249+
_ = fn.occurrence_relations_apply_f(occ!, classNamePointer) { classNamePointer, relation in
139250
guard let relation = relation else { return true }
140-
let className = Unmanaged<Ref<String>>.fromOpaque(ctx!).takeUnretainedValue()
141-
let fn = className.api.fn
251+
let classNameRef = Unmanaged<Ref<String>>.fromOpaque(classNamePointer!).takeUnretainedValue()
252+
let fn = classNameRef.api.fn
142253

143254
// Look for the class.
144255
if fn.symbol_relation_get_roles(relation) != UInt64(INDEXSTORE_SYMBOL_ROLE_REL_CHILDOF.rawValue) {
145256
return true
146257
}
147258

148-
let sym = fn.symbol_relation_get_symbol(relation)
149-
className.instance = fn.symbol_get_name(sym).str
259+
let classNameSym = fn.symbol_relation_get_symbol(relation)
260+
classNameRef.instance = fn.symbol_get_name(classNameSym).str
150261
return true
151262
}
152263

153-
if !className.instance.isEmpty {
264+
if !classNameRef.instance.isEmpty {
154265
let methodName = fn.symbol_get_name(sym).str
155266
let isAsync = symbolProperties & UInt64(INDEXSTORE_SYMBOL_PROPERTY_SWIFT_ASYNC.rawValue) != 0
156-
builder.instance.add(className: className.instance, method: TestCaseClass.TestMethod(name: methodName, isAsync: isAsync))
267+
testMethodsRef.instance[classNameRef.instance, default: []].append((name: methodName, async: isAsync))
157268
}
158269

159270
return true
160271
}
161272

162-
return builder.instance.build()
273+
return (
274+
inheritance: inheritanceRef.instance,
275+
testMethods: testMethodsRef.instance
276+
)
277+
163278
}
164279

165280
private func getRecords(unitReader: indexstore_unit_reader_t?) throws -> [String] {
166281
let builder = Ref([String](), api: api)
167282

168283
let ctx = unsafeBitCast(Unmanaged.passUnretained(builder), to: UnsafeMutableRawPointer.self)
169-
_ = fn.unit_reader_dependencies_apply_f(unitReader, ctx) { ctx , unit -> Bool in
284+
_ = self.api.fn.unit_reader_dependencies_apply_f(unitReader, ctx) { ctx , unit -> Bool in
170285
let store = Unmanaged<Ref<[String]>>.fromOpaque(ctx!).takeUnretainedValue()
171286
let fn = store.api.fn
172287
if fn.unit_dependency_get_kind(unit) == INDEXSTORE_UNIT_DEPENDENCY_RECORD {
@@ -181,12 +296,12 @@ private final class IndexStoreImpl {
181296
private func unitName(object: AbsolutePath) -> String {
182297
let initialSize = 64
183298
var buf = UnsafeMutablePointer<CChar>.allocate(capacity: initialSize)
184-
let len = fn.store_get_unit_name_from_output_path(store, object.pathString, buf, initialSize)
299+
let len = self.api.fn.store_get_unit_name_from_output_path(store, object.pathString, buf, initialSize)
185300

186301
if len + 1 > initialSize {
187302
buf.deallocate()
188303
buf = UnsafeMutablePointer<CChar>.allocate(capacity: len + 1)
189-
_ = fn.store_get_unit_name_from_output_path(store, object.pathString, buf, len + 1)
304+
_ = self.api.fn.store_get_unit_name_from_output_path(store, object.pathString, buf, len + 1)
190305
}
191306

192307
defer {

0 commit comments

Comments
 (0)