Skip to content

Commit feb26b0

Browse files
nickmshelleyjberkel
authored andcommitted
Add cursor type for more safety.
1 parent 9f1735a commit feb26b0

File tree

3 files changed

+89
-40
lines changed

3 files changed

+89
-40
lines changed

Sources/SQLite/Core/Statement.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,14 @@ public final class Statement {
191191

192192
}
193193

194+
extension Statement {
195+
196+
func rowCursorNext() throws -> [Binding?]? {
197+
return try step() ? Array(row) : nil
198+
}
199+
200+
}
201+
194202
extension Statement : Sequence {
195203

196204
public func makeIterator() -> Statement {

Sources/SQLite/Typed/Query.swift

Lines changed: 71 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -894,58 +894,89 @@ public struct Delete : ExpressionType {
894894

895895
}
896896

897+
public struct RowCursor {
898+
let statement: Statement
899+
let columnNames: [String: Int]
900+
901+
public func next() throws -> Row? {
902+
return try statement.rowCursorNext().flatMap { Row(columnNames, $0) }
903+
}
904+
905+
public func map<T>(_ transform: (Row) throws -> T) throws -> [T] {
906+
var elements = [T]()
907+
while true {
908+
if let row = try next() {
909+
elements.append(try transform(row))
910+
} else {
911+
break
912+
}
913+
}
914+
915+
return elements
916+
}
917+
}
918+
897919
extension Connection {
920+
921+
public func prepareCursor(_ query: QueryType) throws -> RowCursor {
922+
let expression = query.expression
923+
let statement = try prepare(expression.template, expression.bindings)
924+
return RowCursor(statement: statement, columnNames: try columnNamesForQuery(query))
925+
}
898926

899927
public func prepare(_ query: QueryType) throws -> AnySequence<Row> {
900928
let expression = query.expression
901929
let statement = try prepare(expression.template, expression.bindings)
902930

903-
let columnNames: [String: Int] = try {
904-
var (columnNames, idx) = ([String: Int](), 0)
905-
column: for each in query.clauses.select.columns {
906-
var names = each.expression.template.characters.split { $0 == "." }.map(String.init)
907-
let column = names.removeLast()
908-
let namespace = names.joined(separator: ".")
909-
910-
func expandGlob(_ namespace: Bool) -> ((QueryType) throws -> Void) {
911-
return { (query: QueryType) throws -> (Void) in
912-
var q = type(of: query).init(query.clauses.from.name, database: query.clauses.from.database)
913-
q.clauses.select = query.clauses.select
914-
let e = q.expression
915-
var names = try self.prepare(e.template, e.bindings).columnNames.map { $0.quote() }
916-
if namespace { names = names.map { "\(query.tableName().expression.template).\($0)" } }
917-
for name in names { columnNames[name] = idx; idx += 1 }
918-
}
919-
}
931+
let columnNames = try columnNamesForQuery(query)
920932

921-
if column == "*" {
922-
var select = query
923-
select.clauses.select = (false, [Expression<Void>(literal: "*") as Expressible])
924-
let queries = [select] + query.clauses.join.map { $0.query }
925-
if !namespace.isEmpty {
926-
for q in queries {
927-
if q.tableName().expression.template == namespace {
928-
try expandGlob(true)(q)
929-
continue column
930-
}
933+
return AnySequence {
934+
AnyIterator { statement.next().map { Row(columnNames, $0) } }
935+
}
936+
}
937+
938+
private func columnNamesForQuery(_ query: QueryType) throws -> [String: Int] {
939+
var (columnNames, idx) = ([String: Int](), 0)
940+
column: for each in query.clauses.select.columns {
941+
var names = each.expression.template.characters.split { $0 == "." }.map(String.init)
942+
let column = names.removeLast()
943+
let namespace = names.joined(separator: ".")
944+
945+
func expandGlob(_ namespace: Bool) -> ((QueryType) throws -> Void) {
946+
return { (query: QueryType) throws -> (Void) in
947+
var q = type(of: query).init(query.clauses.from.name, database: query.clauses.from.database)
948+
q.clauses.select = query.clauses.select
949+
let e = q.expression
950+
var names = try self.prepare(e.template, e.bindings).columnNames.map { $0.quote() }
951+
if namespace { names = names.map { "\(query.tableName().expression.template).\($0)" } }
952+
for name in names { columnNames[name] = idx; idx += 1 }
953+
}
954+
}
955+
956+
if column == "*" {
957+
var select = query
958+
select.clauses.select = (false, [Expression<Void>(literal: "*") as Expressible])
959+
let queries = [select] + query.clauses.join.map { $0.query }
960+
if !namespace.isEmpty {
961+
for q in queries {
962+
if q.tableName().expression.template == namespace {
963+
try expandGlob(true)(q)
964+
continue column
931965
}
932966
throw QueryError.noSuchTable(name: namespace)
933967
}
934-
for q in queries {
935-
try expandGlob(query.clauses.join.count > 0)(q)
936-
}
937-
continue
968+
fatalError("no such table: \(namespace)")
938969
}
939-
940-
columnNames[each.expression.template] = idx
941-
idx += 1
970+
for q in queries {
971+
try expandGlob(query.clauses.join.count > 0)(q)
972+
}
973+
continue
942974
}
943-
return columnNames
944-
}()
945-
946-
return AnySequence {
947-
AnyIterator { statement.next().map { Row(columnNames, $0) } }
975+
976+
columnNames[each.expression.template] = idx
977+
idx += 1
948978
}
979+
return columnNames
949980
}
950981

951982
public func scalar<V : Value>(_ query: ScalarQuery<V>) throws -> V {
@@ -971,7 +1002,7 @@ extension Connection {
9711002
}
9721003

9731004
public func pluck(_ query: QueryType) throws -> Row? {
974-
return try prepare(query.limit(1, query.clauses.limit?.offset)).makeIterator().next()
1005+
return try prepareCursor(query.limit(1, query.clauses.limit?.offset)).next()
9751006
}
9761007

9771008
/// Runs an `Insert` query.

Tests/SQLiteTests/QueryTests.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,16 @@ class QueryIntegrationTests : SQLiteTestCase {
343343
_ = user[users[managerId]]
344344
}
345345
}
346+
347+
func test_prepareCursor() {
348+
let names = ["a", "b", "c"]
349+
try! InsertUsers(names)
350+
351+
let emailColumn = Expression<String>("email")
352+
let emails = try! db.prepareCursor(users).map { $0[emailColumn] }
353+
354+
XCTAssertEqual(names.map({ "\($0)@example.com" }), emails.sorted())
355+
}
346356

347357
func test_select_optional() {
348358
let managerId = Expression<Int64?>("manager_id")

0 commit comments

Comments
 (0)