Skip to content

Commit 2a85578

Browse files
kardianosbradfitz
authored andcommitted
database/sql: support returning query database types
Creates a ColumnType structure that can be extended in to future. Allow drivers to implement what makes sense for the database. Fixes #16652 Change-Id: Ieb1fd64eac1460107b1d3474eba5201fa300a4ec Reviewed-on: https://go-review.googlesource.com/29961 Run-TryBot: Brad Fitzpatrick <[email protected]> TryBot-Result: Gobot Gobot <[email protected]> Reviewed-by: Brad Fitzpatrick <[email protected]>
1 parent 2ecaaf1 commit 2a85578

File tree

4 files changed

+259
-11
lines changed

4 files changed

+259
-11
lines changed

src/database/sql/driver/driver.go

+55
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ package driver
1111
import (
1212
"context"
1313
"errors"
14+
"reflect"
1415
)
1516

1617
// Value is a value that drivers must be able to handle.
@@ -239,6 +240,60 @@ type RowsNextResultSet interface {
239240
NextResultSet() error
240241
}
241242

243+
// RowsColumnTypeScanType may be implemented by Rows. It should return
244+
// the value type that can be used to scan types into. For example, the database
245+
// column type "bigint" this should return "reflect.TypeOf(int64(0))".
246+
type RowsColumnTypeScanType interface {
247+
Rows
248+
ColumnTypeScanType(index int) reflect.Type
249+
}
250+
251+
// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
252+
// database system type name without the length. Type names should be uppercase.
253+
// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
254+
// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
255+
// "TIMESTAMP".
256+
type RowsColumnTypeDatabaseTypeName interface {
257+
Rows
258+
ColumnTypeDatabaseTypeName(index int) string
259+
}
260+
261+
// RowsColumnTypeLength may be implemented by Rows. It should return the length
262+
// of the column type if the column is a variable length type. If the column is
263+
// not a variable length type ok should return false.
264+
// If length is not limited other than system limits, it should return math.MaxInt64.
265+
// The following are examples of returned values for various types:
266+
// TEXT (math.MaxInt64, true)
267+
// varchar(10) (10, true)
268+
// nvarchar(10) (10, true)
269+
// decimal (0, false)
270+
// int (0, false)
271+
// bytea(30) (30, true)
272+
type RowsColumnTypeLength interface {
273+
Rows
274+
ColumnTypeLength(index int) (length int64, ok bool)
275+
}
276+
277+
// RowsColumnTypeNullable may be implemented by Rows. The nullable value should
278+
// be true if it is known the column may be null, or false if the column is known
279+
// to be not nullable.
280+
// If the column nullability is unknown, ok should be false.
281+
type RowsColumnTypeNullable interface {
282+
Rows
283+
ColumnTypeNullable(index int) (nullable, ok bool)
284+
}
285+
286+
// RowsColumnTypePrecisionScale may be implemented by Rows. It should return
287+
// the precision and scale for decimal types. If not applicable, ok should be false.
288+
// The following are examples of returned values for various types:
289+
// decimal(38, 4) (38, 4, true)
290+
// int (0, 0, false)
291+
// decimal (math.MaxInt64, math.MaxInt64, true)
292+
type RowsColumnTypePrecisionScale interface {
293+
Rows
294+
ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool)
295+
}
296+
242297
// Tx is a transaction.
243298
type Tx interface {
244299
Commit() error

src/database/sql/fakedb_test.go

+50-9
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"fmt"
1212
"io"
1313
"log"
14+
"reflect"
1415
"sort"
1516
"strconv"
1617
"strings"
@@ -405,6 +406,7 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, err
405406
return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
406407
}
407408
stmt.table = parts[0]
409+
408410
stmt.colName = strings.Split(parts[1], ",")
409411
for n, colspec := range strings.Split(parts[2], ",") {
410412
if colspec == "" {
@@ -725,6 +727,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
725727

726728
setMRows := make([][]*row, 0, 1)
727729
setColumns := make([][]string, 0, 1)
730+
setColType := make([][]string, 0, 1)
728731

729732
for {
730733
db.mu.Lock()
@@ -794,10 +797,16 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
794797
mrows = append(mrows, mrow)
795798
}
796799

800+
var colType []string
801+
for _, column := range s.colName {
802+
colType = append(colType, t.coltype[t.columnIndex(column)])
803+
}
804+
797805
t.mu.Unlock()
798806

799807
setMRows = append(setMRows, mrows)
800808
setColumns = append(setColumns, s.colName)
809+
setColType = append(setColType, colType)
801810

802811
if s.next == nil {
803812
break
@@ -806,10 +815,11 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
806815
}
807816

808817
cursor := &rowsCursor{
809-
posRow: -1,
810-
rows: setMRows,
811-
cols: setColumns,
812-
errPos: -1,
818+
posRow: -1,
819+
rows: setMRows,
820+
cols: setColumns,
821+
colType: setColType,
822+
errPos: -1,
813823
}
814824
return cursor, nil
815825
}
@@ -844,11 +854,12 @@ func (tx *fakeTx) Rollback() error {
844854
}
845855

846856
type rowsCursor struct {
847-
cols [][]string
848-
posSet int
849-
posRow int
850-
rows [][]*row
851-
closed bool
857+
cols [][]string
858+
colType [][]string
859+
posSet int
860+
posRow int
861+
rows [][]*row
862+
closed bool
852863

853864
// errPos and err are for making Next return early with error.
854865
errPos int
@@ -874,6 +885,10 @@ func (rc *rowsCursor) Columns() []string {
874885
return rc.cols[rc.posSet]
875886
}
876887

888+
func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
889+
return colTypeToReflectType(rc.colType[rc.posSet][index])
890+
}
891+
877892
var rowsCursorNextHook func(dest []driver.Value) error
878893

879894
func (rc *rowsCursor) Next(dest []driver.Value) error {
@@ -980,3 +995,29 @@ func converterForType(typ string) driver.ValueConverter {
980995
}
981996
panic("invalid fakedb column type of " + typ)
982997
}
998+
999+
func colTypeToReflectType(typ string) reflect.Type {
1000+
switch typ {
1001+
case "bool":
1002+
return reflect.TypeOf(false)
1003+
case "nullbool":
1004+
return reflect.TypeOf(NullBool{})
1005+
case "int32":
1006+
return reflect.TypeOf(int32(0))
1007+
case "string":
1008+
return reflect.TypeOf("")
1009+
case "nullstring":
1010+
return reflect.TypeOf(NullString{})
1011+
case "int64":
1012+
return reflect.TypeOf(int64(0))
1013+
case "nullint64":
1014+
return reflect.TypeOf(NullInt64{})
1015+
case "float64":
1016+
return reflect.TypeOf(float64(0))
1017+
case "nullfloat64":
1018+
return reflect.TypeOf(NullFloat64{})
1019+
case "datetime":
1020+
return reflect.TypeOf(time.Time{})
1021+
}
1022+
panic("invalid fakedb column type of " + typ)
1023+
}

src/database/sql/sql.go

+104-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"errors"
1919
"fmt"
2020
"io"
21+
"reflect"
2122
"runtime"
2223
"sort"
2324
"sync"
@@ -996,8 +997,8 @@ const maxBadConnRetries = 2
996997
// The caller must call the statement's Close method
997998
// when the statement is no longer needed.
998999
//
999-
// The provided context is for the preparation of the statment, not for the execution of
1000-
// the statement.
1000+
// The provided context is used for the preparation of the statement, not for the
1001+
// execution of the statement.
10011002
func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
10021003
var stmt *Stmt
10031004
var err error
@@ -2033,6 +2034,107 @@ func (rs *Rows) Columns() ([]string, error) {
20332034
return rs.rowsi.Columns(), nil
20342035
}
20352036

2037+
// ColumnTypes returns column information such as column type, length,
2038+
// and nullable. Some information may not be available from some drivers.
2039+
func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
2040+
if rs.isClosed() {
2041+
return nil, errors.New("sql: Rows are closed")
2042+
}
2043+
if rs.rowsi == nil {
2044+
return nil, errors.New("sql: no Rows available")
2045+
}
2046+
return rowsColumnInfoSetup(rs.rowsi), nil
2047+
}
2048+
2049+
// ColumnType contains the name and type of a column.
2050+
type ColumnType struct {
2051+
name string
2052+
2053+
hasNullable bool
2054+
hasLength bool
2055+
hasPrecisionScale bool
2056+
2057+
nullable bool
2058+
length int64
2059+
databaseType string
2060+
precision int64
2061+
scale int64
2062+
scanType reflect.Type
2063+
}
2064+
2065+
// Name returns the name or alias of the column.
2066+
func (ci *ColumnType) Name() string {
2067+
return ci.name
2068+
}
2069+
2070+
// Length returns the column type length for variable length column types such
2071+
// as text and binary field types. If the type length is unbounded the value will
2072+
// be math.MaxInt64 (any database limits will still apply).
2073+
// If the column type is not variable length, such as an int, or if not supported
2074+
// by the driver ok is false.
2075+
func (ci *ColumnType) Length() (length int64, ok bool) {
2076+
return ci.length, ci.hasLength
2077+
}
2078+
2079+
// DecimalSize returns the scale and precision of a decimal type.
2080+
// If not applicable or if not supported ok is false.
2081+
func (ci *ColumnType) DecimalSize() (precision, scale int64, ok bool) {
2082+
return ci.precision, ci.scale, ci.hasPrecisionScale
2083+
}
2084+
2085+
// ScanType returns a Go type suitable for scanning into using Rows.Scan.
2086+
// If a driver does not support this property ScanType will return
2087+
// the type of an empty interface.
2088+
func (ci *ColumnType) ScanType() reflect.Type {
2089+
return ci.scanType
2090+
}
2091+
2092+
// Nullable returns whether the column may be null.
2093+
// If a driver does not support this property ok will be false.
2094+
func (ci *ColumnType) Nullable() (nullable, ok bool) {
2095+
return ci.nullable, ci.hasNullable
2096+
}
2097+
2098+
// DatabaseTypeName returns the database system name of the column type. If an empty
2099+
// string is returned the driver type name is not supported.
2100+
// Consult your driver documentation for a list of driver data types. Length specifiers
2101+
// are not included.
2102+
// Common type include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL", "INT", "BIGINT".
2103+
func (ci *ColumnType) DatabaseTypeName() string {
2104+
return ci.databaseType
2105+
}
2106+
2107+
func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType {
2108+
names := rowsi.Columns()
2109+
2110+
list := make([]*ColumnType, len(names))
2111+
for i := range list {
2112+
ci := &ColumnType{
2113+
name: names[i],
2114+
}
2115+
list[i] = ci
2116+
2117+
if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok {
2118+
ci.scanType = prop.ColumnTypeScanType(i)
2119+
} else {
2120+
ci.scanType = reflect.TypeOf(new(interface{})).Elem()
2121+
}
2122+
if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok {
2123+
ci.databaseType = prop.ColumnTypeDatabaseTypeName(i)
2124+
}
2125+
if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok {
2126+
ci.length, ci.hasLength = prop.ColumnTypeLength(i)
2127+
}
2128+
if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok {
2129+
ci.nullable, ci.hasNullable = prop.ColumnTypeNullable(i)
2130+
}
2131+
if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok {
2132+
ci.precision, ci.scale, ci.hasPrecisionScale = prop.ColumnTypePrecisionScale(i)
2133+
}
2134+
}
2135+
return list
2136+
}
2137+
20362138
// Scan copies the columns in the current row into the values pointed
20372139
// at by dest. The number of values in dest must be the same as the
20382140
// number of columns in Rows.

src/database/sql/sql_test.go

+50
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,56 @@ func TestRowsColumns(t *testing.T) {
499499
}
500500
}
501501

502+
func TestRowsColumnTypes(t *testing.T) {
503+
db := newTestDB(t, "people")
504+
defer closeDB(t, db)
505+
rows, err := db.Query("SELECT|people|age,name|")
506+
if err != nil {
507+
t.Fatalf("Query: %v", err)
508+
}
509+
tt, err := rows.ColumnTypes()
510+
if err != nil {
511+
t.Fatalf("ColumnTypes: %v", err)
512+
}
513+
514+
types := make([]reflect.Type, len(tt))
515+
for i, tp := range tt {
516+
st := tp.ScanType()
517+
if st == nil {
518+
t.Errorf("scantype is null for column %q", tp.Name())
519+
continue
520+
}
521+
types[i] = st
522+
}
523+
values := make([]interface{}, len(tt))
524+
for i := range values {
525+
values[i] = reflect.New(types[i]).Interface()
526+
}
527+
ct := 0
528+
for rows.Next() {
529+
err = rows.Scan(values...)
530+
if err != nil {
531+
t.Fatalf("failed to scan values in %v", err)
532+
}
533+
ct++
534+
if ct == 0 {
535+
if values[0].(string) != "Bob" {
536+
t.Errorf("Expected Bob, got %v", values[0])
537+
}
538+
if values[1].(int) != 2 {
539+
t.Errorf("Expected 2, got %v", values[1])
540+
}
541+
}
542+
}
543+
if ct != 3 {
544+
t.Errorf("expected 3 rows, got %d", ct)
545+
}
546+
547+
if err := rows.Close(); err != nil {
548+
t.Errorf("error closing rows: %s", err)
549+
}
550+
}
551+
502552
func TestQueryRow(t *testing.T) {
503553
db := newTestDB(t, "people")
504554
defer closeDB(t, db)

0 commit comments

Comments
 (0)