diff --git a/CHANGELOG.md b/CHANGELOG.md index 4459f978a..ad693e0f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. - Public API with request object types (#126) - Support decimal type in msgpack (#96) - Support datetime type in msgpack (#118) +- Prepared SQL statements (#117) ### Changed @@ -30,6 +31,7 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. ### Fixed - Build with OpenSSL < 1.1.1 (#194) +- Add `ExecuteAsync` and `ExecuteTyped` to common connector interface (#62) ## [1.6.0] - 2022-06-01 diff --git a/connection.go b/connection.go index 871890673..6de1e9d01 100644 --- a/connection.go +++ b/connection.go @@ -993,6 +993,13 @@ func (conn *Connection) nextRequestId() (requestId uint32) { // An error is returned if the request was formed incorrectly, or failed to // create the future. func (conn *Connection) Do(req Request) *Future { + if connectedReq, ok := req.(ConnectedRequest); ok { + if connectedReq.Conn() != conn { + fut := NewFuture() + fut.SetError(fmt.Errorf("the passed connected request doesn't belong to the current connection or connection pool")) + return fut + } + } return conn.send(req) } @@ -1009,3 +1016,13 @@ func (conn *Connection) OverrideSchema(s *Schema) { conn.Schema = s } } + +// NewPrepared passes a sql statement to Tarantool for preparation synchronously. +func (conn *Connection) NewPrepared(expr string) (*Prepared, error) { + req := NewPrepareRequest(expr) + resp, err := conn.Do(req).Get() + if err != nil { + return nil, err + } + return NewPreparedFromResponse(conn, resp) +} diff --git a/connection_pool/config.lua b/connection_pool/config.lua index b1492dd13..fb3859297 100644 --- a/connection_pool/config.lua +++ b/connection_pool/config.lua @@ -21,6 +21,21 @@ box.once("init", function() parts = {{ field = 1, type = 'string' }}, if_not_exists = true }) + + local sp = box.schema.space.create('SQL_TEST', { + id = 521, + if_not_exists = true, + format = { + {name = "NAME0", type = "unsigned"}, + {name = "NAME1", type = "string"}, + {name = "NAME2", type = "string"}, + } + }) + sp:create_index('primary', {type = 'tree', parts = {1, 'uint'}, if_not_exists = true}) + sp:insert{1, "test", "test"} + -- grants for sql tests + box.schema.user.grant('test', 'create,read,write,drop,alter', 'space') + box.schema.user.grant('test', 'create', 'sequence') end) local function simple_incr(a) diff --git a/connection_pool/connection_pool.go b/connection_pool/connection_pool.go index 0752038f0..ad2e936cc 100644 --- a/connection_pool/connection_pool.go +++ b/connection_pool/connection_pool.go @@ -12,6 +12,7 @@ package connection_pool import ( "errors" + "fmt" "log" "sync/atomic" "time" @@ -525,7 +526,16 @@ func (connPool *ConnectionPool) EvalAsync(expr string, args interface{}, userMod } // Do sends the request and returns a future. +// For requests that belong to an only one connection (e.g. Unprepare or ExecutePrepared) +// the argument of type Mode is unused. func (connPool *ConnectionPool) Do(req tarantool.Request, userMode Mode) *tarantool.Future { + if connectedReq, ok := req.(tarantool.ConnectedRequest); ok { + conn, _ := connPool.getConnectionFromPool(connectedReq.Conn().Addr()) + if conn == nil { + return newErrorFuture(fmt.Errorf("the passed connected request doesn't belong to the current connection or connection pool")) + } + return connectedReq.Conn().Do(req) + } conn, err := connPool.getNextConnection(userMode) if err != nil { return newErrorFuture(err) @@ -788,3 +798,12 @@ func newErrorFuture(err error) *tarantool.Future { fut.SetError(err) return fut } + +// NewPrepared passes a sql statement to Tarantool for preparation synchronously. +func (connPool *ConnectionPool) NewPrepared(expr string, userMode Mode) (*tarantool.Prepared, error) { + conn, err := connPool.getNextConnection(userMode) + if err != nil { + return nil, err + } + return conn.NewPrepared(expr) +} diff --git a/connection_pool/connection_pool_test.go b/connection_pool/connection_pool_test.go index 00337005d..2e462e4eb 100644 --- a/connection_pool/connection_pool_test.go +++ b/connection_pool/connection_pool_test.go @@ -1,8 +1,10 @@ package connection_pool_test import ( + "fmt" "log" "os" + "reflect" "strings" "testing" "time" @@ -1276,6 +1278,97 @@ func TestDo(t *testing.T) { require.NotNilf(t, resp, "response is nil after Ping") } +func TestNewPrepared(t *testing.T) { + test_helpers.SkipIfSQLUnsupported(t) + + roles := []bool{true, true, false, true, false} + + err := test_helpers.SetClusterRO(servers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := connection_pool.Connect(servers, connOpts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + stmt, err := connPool.NewPrepared("SELECT NAME0, NAME1 FROM SQL_TEST WHERE NAME0=:id AND NAME1=:name;", connection_pool.RO) + require.Nilf(t, err, "fail to prepare statement: %v", err) + + if connPool.GetPoolInfo()[stmt.Conn.Addr()].ConnRole != connection_pool.RO { + t.Errorf("wrong role for the statement's connection") + } + + executeReq := tarantool.NewExecutePreparedRequest(stmt) + unprepareReq := tarantool.NewUnprepareRequest(stmt) + + resp, err := connPool.Do(executeReq.Args([]interface{}{1, "test"}), connection_pool.ANY).Get() + if err != nil { + t.Fatalf("failed to execute prepared: %v", err) + } + if resp == nil { + t.Fatalf("nil response") + } + if resp.Code != tarantool.OkCode { + t.Fatalf("failed to execute prepared: code %d", resp.Code) + } + if reflect.DeepEqual(resp.Data[0], []interface{}{1, "test"}) { + t.Error("Select with named arguments failed") + } + if resp.MetaData[0].FieldType != "unsigned" || + resp.MetaData[0].FieldName != "NAME0" || + resp.MetaData[1].FieldType != "string" || + resp.MetaData[1].FieldName != "NAME1" { + t.Error("Wrong metadata") + } + + // the second argument for unprepare request is unused - it already belongs to some connection + resp, err = connPool.Do(unprepareReq, connection_pool.ANY).Get() + if err != nil { + t.Errorf("failed to unprepare prepared statement: %v", err) + } + if resp.Code != tarantool.OkCode { + t.Errorf("failed to unprepare prepared statement: code %d", resp.Code) + } + + _, err = connPool.Do(unprepareReq, connection_pool.ANY).Get() + if err == nil { + t.Errorf("the statement must be already unprepared") + } + require.Contains(t, err.Error(), "Prepared statement with id") + + _, err = connPool.Do(executeReq, connection_pool.ANY).Get() + if err == nil { + t.Errorf("the statement must be already unprepared") + } + require.Contains(t, err.Error(), "Prepared statement with id") +} + +func TestDoWithStrangerConn(t *testing.T) { + expectedErr := fmt.Errorf("the passed connected request doesn't belong to the current connection or connection pool") + + roles := []bool{true, true, false, true, false} + + err := test_helpers.SetClusterRO(servers, connOpts, roles) + require.Nilf(t, err, "fail to set roles for cluster") + + connPool, err := connection_pool.Connect(servers, connOpts) + require.Nilf(t, err, "failed to connect") + require.NotNilf(t, connPool, "conn is nil after Connect") + + defer connPool.Close() + + req := test_helpers.NewStrangerRequest() + + _, err = connPool.Do(req, connection_pool.ANY).Get() + if err == nil { + t.Fatalf("nil error catched") + } + if err.Error() != expectedErr.Error() { + t.Fatalf("Unexpected error catched") + } +} + // runTestMain is a body of TestMain function // (see https://pkg.go.dev/testing#hdr-Main). // Using defer + os.Exit is not works so TestMain body diff --git a/connection_pool/example_test.go b/connection_pool/example_test.go index faf97bc66..08995d03e 100644 --- a/connection_pool/example_test.go +++ b/connection_pool/example_test.go @@ -548,3 +548,28 @@ func ExampleConnectionPool_Do() { // Ping Data [] // Ping Error } + +func ExampleConnectionPool_NewPrepared() { + pool, err := examplePool(testRoles) + if err != nil { + fmt.Println(err) + } + defer pool.Close() + + stmt, err := pool.NewPrepared("SELECT 1", connection_pool.ANY) + if err != nil { + fmt.Println(err) + } + + executeReq := tarantool.NewExecutePreparedRequest(stmt) + unprepareReq := tarantool.NewUnprepareRequest(stmt) + + _, err = pool.Do(executeReq, connection_pool.ANY).Get() + if err != nil { + fmt.Printf("Failed to execute prepared stmt") + } + _, err = pool.Do(unprepareReq, connection_pool.ANY).Get() + if err != nil { + fmt.Printf("Failed to prepare") + } +} diff --git a/connector.go b/connector.go index cd77d2c4c..3084b9124 100644 --- a/connector.go +++ b/connector.go @@ -30,6 +30,7 @@ type Connector interface { Call16Typed(functionName string, args interface{}, result interface{}) (err error) Call17Typed(functionName string, args interface{}, result interface{}) (err error) EvalTyped(expr string, args interface{}, result interface{}) (err error) + ExecuteTyped(expr string, args interface{}, result interface{}) (SQLInfo, []ColumnMetaData, error) SelectAsync(space, index interface{}, offset, limit, iterator uint32, key interface{}) *Future InsertAsync(space interface{}, tuple interface{}) *Future @@ -41,6 +42,9 @@ type Connector interface { Call16Async(functionName string, args interface{}) *Future Call17Async(functionName string, args interface{}) *Future EvalAsync(expr string, args interface{}) *Future + ExecuteAsync(expr string, args interface{}) *Future + + NewPrepared(expr string) (*Prepared, error) Do(req Request) (fut *Future) } diff --git a/const.go b/const.go index 2ad7f3b76..3d0d7424f 100644 --- a/const.go +++ b/const.go @@ -12,6 +12,7 @@ const ( UpsertRequestCode = 9 Call17RequestCode = 10 /* call in >= 1.7 format */ ExecuteRequestCode = 11 + PrepareRequestCode = 13 PingRequestCode = 64 SubscribeRequestCode = 66 @@ -31,9 +32,11 @@ const ( KeyData = 0x30 KeyError = 0x31 KeyMetaData = 0x32 + KeyBindCount = 0x34 KeySQLText = 0x40 KeySQLBind = 0x41 KeySQLInfo = 0x42 + KeyStmtID = 0x43 KeyFieldName = 0x00 KeyFieldType = 0x01 diff --git a/errors.go b/errors.go index a6895ccbc..760a59c5e 100644 --- a/errors.go +++ b/errors.go @@ -1,8 +1,6 @@ package tarantool -import ( - "fmt" -) +import "fmt" // Error is wrapper around error returned by Tarantool. type Error struct { diff --git a/example_test.go b/example_test.go index 0e5487a53..65dc971a0 100644 --- a/example_test.go +++ b/example_test.go @@ -651,3 +651,43 @@ func ExampleConnection_Execute() { fmt.Println("MetaData", resp.MetaData) fmt.Println("SQL Info", resp.SQLInfo) } + +// To use prepared statements to query a tarantool instance, call NewPrepared. +func ExampleConnection_NewPrepared() { + // Tarantool supports SQL since version 2.0.0 + isLess, err := test_helpers.IsTarantoolVersionLess(2, 0, 0) + if err != nil || isLess { + return + } + + server := "127.0.0.1:3013" + opts := tarantool.Opts{ + Timeout: 500 * time.Millisecond, + Reconnect: 1 * time.Second, + MaxReconnects: 3, + User: "test", + Pass: "test", + } + conn, err := tarantool.Connect(server, opts) + if err != nil { + fmt.Printf("Failed to connect: %s", err.Error()) + } + + stmt, err := conn.NewPrepared("SELECT 1") + if err != nil { + fmt.Printf("Failed to connect: %s", err.Error()) + } + + executeReq := tarantool.NewExecutePreparedRequest(stmt) + unprepareReq := tarantool.NewUnprepareRequest(stmt) + + _, err = conn.Do(executeReq).Get() + if err != nil { + fmt.Printf("Failed to execute prepared stmt") + } + + _, err = conn.Do(unprepareReq).Get() + if err != nil { + fmt.Printf("Failed to prepare") + } +} diff --git a/export_test.go b/export_test.go index 8bdfb9812..315f444de 100644 --- a/export_test.go +++ b/export_test.go @@ -75,3 +75,21 @@ func RefImplEvalBody(enc *msgpack.Encoder, expr string, args interface{}) error func RefImplExecuteBody(enc *msgpack.Encoder, expr string, args interface{}) error { return fillExecute(enc, expr, args) } + +// RefImplPrepareBody is reference implementation for filling of an prepare +// request's body. +func RefImplPrepareBody(enc *msgpack.Encoder, expr string) error { + return fillPrepare(enc, expr) +} + +// RefImplUnprepareBody is reference implementation for filling of an execute prepared +// request's body. +func RefImplExecutePreparedBody(enc *msgpack.Encoder, stmt Prepared, args interface{}) error { + return fillExecutePrepared(enc, stmt, args) +} + +// RefImplUnprepareBody is reference implementation for filling of an unprepare +// request's body. +func RefImplUnprepareBody(enc *msgpack.Encoder, stmt Prepared) error { + return fillUnprepare(enc, stmt) +} diff --git a/multi/config.lua b/multi/config.lua index 2b745185a..5d75da513 100644 --- a/multi/config.lua +++ b/multi/config.lua @@ -13,6 +13,21 @@ rawset(_G, 'get_cluster_nodes', get_cluster_nodes) box.once("init", function() box.schema.user.create('test', { password = 'test' }) box.schema.user.grant('test', 'read,write,execute', 'universe') + + local sp = box.schema.space.create('SQL_TEST', { + id = 521, + if_not_exists = true, + format = { + {name = "NAME0", type = "unsigned"}, + {name = "NAME1", type = "string"}, + {name = "NAME2", type = "string"}, + } + }) + sp:create_index('primary', {type = 'tree', parts = {1, 'uint'}, if_not_exists = true}) + sp:insert{1, "test", "test"} + -- grants for sql tests + box.schema.user.grant('test', 'create,read,write,drop,alter', 'space') + box.schema.user.grant('test', 'create', 'sequence') end) local function simple_incr(a) diff --git a/multi/multi.go b/multi/multi.go index 3a98342a7..03531a817 100644 --- a/multi/multi.go +++ b/multi/multi.go @@ -13,6 +13,7 @@ package multi import ( "errors" + "fmt" "sync" "sync/atomic" "time" @@ -419,6 +420,11 @@ func (connMulti *ConnectionMulti) EvalTyped(expr string, args interface{}, resul return connMulti.getCurrentConnection().EvalTyped(expr, args, result) } +// ExecuteTyped passes sql expression to Tarantool for execution. +func (connMulti *ConnectionMulti) ExecuteTyped(expr string, args interface{}, result interface{}) (tarantool.SQLInfo, []tarantool.ColumnMetaData, error) { + return connMulti.getCurrentConnection().ExecuteTyped(expr, args, result) +} + // SelectAsync sends select request to Tarantool and returns Future. func (connMulti *ConnectionMulti) SelectAsync(space, index interface{}, offset, limit, iterator uint32, key interface{}) *tarantool.Future { return connMulti.getCurrentConnection().SelectAsync(space, index, offset, limit, iterator, key) @@ -482,7 +488,26 @@ func (connMulti *ConnectionMulti) EvalAsync(expr string, args interface{}) *tara return connMulti.getCurrentConnection().EvalAsync(expr, args) } +// ExecuteAsync passes sql expression to Tarantool for execution. +func (connMulti *ConnectionMulti) ExecuteAsync(expr string, args interface{}) *tarantool.Future { + return connMulti.getCurrentConnection().ExecuteAsync(expr, args) +} + +// NewPrepared passes a sql statement to Tarantool for preparation synchronously. +func (connMulti *ConnectionMulti) NewPrepared(expr string) (*tarantool.Prepared, error) { + return connMulti.getCurrentConnection().NewPrepared(expr) +} + // Do sends the request and returns a future. func (connMulti *ConnectionMulti) Do(req tarantool.Request) *tarantool.Future { + if connectedReq, ok := req.(tarantool.ConnectedRequest); ok { + _, belongs := connMulti.getConnectionFromPool(connectedReq.Conn().Addr()) + if !belongs { + fut := tarantool.NewFuture() + fut.SetError(fmt.Errorf("the passed connected request doesn't belong to the current connection or connection pool")) + return fut + } + return connectedReq.Conn().Do(req) + } return connMulti.getCurrentConnection().Do(req) } diff --git a/multi/multi_test.go b/multi/multi_test.go index 0f84cdb4d..628a2ab28 100644 --- a/multi/multi_test.go +++ b/multi/multi_test.go @@ -1,11 +1,15 @@ package multi import ( + "fmt" "log" "os" + "reflect" "testing" "time" + "github.com/stretchr/testify/require" + "github.com/tarantool/go-tarantool" "github.com/tarantool/go-tarantool/test_helpers" ) @@ -231,6 +235,89 @@ func TestCall17(t *testing.T) { } } +func TestNewPrepared(t *testing.T) { + test_helpers.SkipIfSQLUnsupported(t) + + multiConn, err := Connect([]string{server1, server2}, connOpts) + if err != nil { + t.Fatalf("Failed to connect: %s", err.Error()) + } + if multiConn == nil { + t.Fatalf("conn is nil after Connect") + } + defer multiConn.Close() + + stmt, err := multiConn.NewPrepared("SELECT NAME0, NAME1 FROM SQL_TEST WHERE NAME0=:id AND NAME1=:name;") + require.Nilf(t, err, "fail to prepare statement: %v", err) + + executeReq := tarantool.NewExecutePreparedRequest(stmt) + unprepareReq := tarantool.NewUnprepareRequest(stmt) + + resp, err := multiConn.Do(executeReq.Args([]interface{}{1, "test"})).Get() + if err != nil { + t.Fatalf("failed to execute prepared: %v", err) + } + if resp == nil { + t.Fatalf("nil response") + } + if resp.Code != tarantool.OkCode { + t.Fatalf("failed to execute prepared: code %d", resp.Code) + } + if reflect.DeepEqual(resp.Data[0], []interface{}{1, "test"}) { + t.Error("Select with named arguments failed") + } + if resp.MetaData[0].FieldType != "unsigned" || + resp.MetaData[0].FieldName != "NAME0" || + resp.MetaData[1].FieldType != "string" || + resp.MetaData[1].FieldName != "NAME1" { + t.Error("Wrong metadata") + } + + // the second argument for unprepare request is unused - it already belongs to some connection + resp, err = multiConn.Do(unprepareReq).Get() + if err != nil { + t.Errorf("failed to unprepare prepared statement: %v", err) + } + if resp.Code != tarantool.OkCode { + t.Errorf("failed to unprepare prepared statement: code %d", resp.Code) + } + + _, err = multiConn.Do(unprepareReq).Get() + if err == nil { + t.Errorf("the statement must be already unprepared") + } + require.Contains(t, err.Error(), "Prepared statement with id") + + _, err = multiConn.Do(executeReq).Get() + if err == nil { + t.Errorf("the statement must be already unprepared") + } + require.Contains(t, err.Error(), "Prepared statement with id") +} + +func TestDoWithStrangerConn(t *testing.T) { + expectedErr := fmt.Errorf("the passed connected request doesn't belong to the current connection or connection pool") + + multiConn, err := Connect([]string{server1, server2}, connOpts) + if err != nil { + t.Fatalf("Failed to connect: %s", err.Error()) + } + if multiConn == nil { + t.Fatalf("conn is nil after Connect") + } + defer multiConn.Close() + + req := test_helpers.NewStrangerRequest() + + _, err = multiConn.Do(req).Get() + if err == nil { + t.Fatalf("nil error catched") + } + if err.Error() != expectedErr.Error() { + t.Fatalf("Unexpected error catched") + } +} + // runTestMain is a body of TestMain function // (see https://pkg.go.dev/testing#hdr-Main). // Using defer + os.Exit is not works so TestMain body diff --git a/prepared.go b/prepared.go new file mode 100644 index 000000000..9508f0546 --- /dev/null +++ b/prepared.go @@ -0,0 +1,138 @@ +package tarantool + +import ( + "fmt" + + "gopkg.in/vmihailenco/msgpack.v2" +) + +// PreparedID is a type for Prepared Statement ID +type PreparedID uint64 + +// Prepared is a type for handling prepared statements +// +// Since 1.7.0 +type Prepared struct { + StatementID PreparedID + MetaData []ColumnMetaData + ParamCount uint64 + Conn *Connection +} + +// NewPreparedFromResponse constructs a Prepared object. +func NewPreparedFromResponse(conn *Connection, resp *Response) (*Prepared, error) { + if resp == nil { + return nil, fmt.Errorf("pased nil response") + } + if resp.Data == nil { + return nil, fmt.Errorf("response Data is nil") + } + if len(resp.Data) == 0 { + return nil, fmt.Errorf("response Data format is wrong") + } + stmt, ok := resp.Data[0].(*Prepared) + if !ok { + return nil, fmt.Errorf("response Data format is wrong") + } + stmt.Conn = conn + return stmt, nil +} + +// PrepareRequest helps you to create a prepare request object for execution +// by a Connection. +type PrepareRequest struct { + baseRequest + expr string +} + +// NewPrepareRequest returns a new empty PrepareRequest. +func NewPrepareRequest(expr string) *PrepareRequest { + req := new(PrepareRequest) + req.requestCode = PrepareRequestCode + req.expr = expr + return req +} + +// Body fills an encoder with the execute request body. +func (req *PrepareRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { + return fillPrepare(enc, req.expr) +} + +// UnprepareRequest helps you to create an unprepare request object for +// execution by a Connection. +type UnprepareRequest struct { + baseRequest + stmt *Prepared +} + +// NewUnprepareRequest returns a new empty UnprepareRequest. +func NewUnprepareRequest(stmt *Prepared) *UnprepareRequest { + req := new(UnprepareRequest) + req.requestCode = PrepareRequestCode + req.stmt = stmt + return req +} + +// Conn returns the Connection object the request belongs to +func (req *UnprepareRequest) Conn() *Connection { + return req.stmt.Conn +} + +// Body fills an encoder with the execute request body. +func (req *UnprepareRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { + return fillUnprepare(enc, *req.stmt) +} + +// ExecutePreparedRequest helps you to create an execute prepared request +// object for execution by a Connection. +type ExecutePreparedRequest struct { + baseRequest + stmt *Prepared + args interface{} +} + +// NewExecutePreparedRequest returns a new empty preparedExecuteRequest. +func NewExecutePreparedRequest(stmt *Prepared) *ExecutePreparedRequest { + req := new(ExecutePreparedRequest) + req.requestCode = ExecuteRequestCode + req.stmt = stmt + req.args = []interface{}{} + return req +} + +// Conn returns the Connection object the request belongs to +func (req *ExecutePreparedRequest) Conn() *Connection { + return req.stmt.Conn +} + +// Args sets the args for execute the prepared request. +// Note: default value is empty. +func (req *ExecutePreparedRequest) Args(args interface{}) *ExecutePreparedRequest { + req.args = args + return req +} + +// Body fills an encoder with the execute request body. +func (req *ExecutePreparedRequest) Body(res SchemaResolver, enc *msgpack.Encoder) error { + return fillExecutePrepared(enc, *req.stmt, req.args) +} + +func fillPrepare(enc *msgpack.Encoder, expr string) error { + enc.EncodeMapLen(1) + enc.EncodeUint64(KeySQLText) + return enc.EncodeString(expr) +} + +func fillUnprepare(enc *msgpack.Encoder, stmt Prepared) error { + enc.EncodeMapLen(1) + enc.EncodeUint64(KeyStmtID) + return enc.EncodeUint64(uint64(stmt.StatementID)) +} + +func fillExecutePrepared(enc *msgpack.Encoder, stmt Prepared, args interface{}) error { + enc.EncodeMapLen(2) + enc.EncodeUint64(KeyStmtID) + enc.EncodeUint64(uint64(stmt.StatementID)) + enc.EncodeUint64(KeySQLBind) + return encodeSQLBind(enc, args) +} diff --git a/request.go b/request.go index 3b6b33f07..a83094145 100644 --- a/request.go +++ b/request.go @@ -539,6 +539,14 @@ type Request interface { Body(resolver SchemaResolver, enc *msgpack.Encoder) error } +// ConnectedRequest is an interface that provides the info about a Connection +// the request belongs to. +type ConnectedRequest interface { + Request + // Conn returns a Connection the request belongs to. + Conn() *Connection +} + type baseRequest struct { requestCode int32 } diff --git a/request_test.go b/request_test.go index f0da3f865..7c1805155 100644 --- a/request_test.go +++ b/request_test.go @@ -5,6 +5,8 @@ import ( "errors" "testing" + "github.com/stretchr/testify/assert" + . "github.com/tarantool/go-tarantool" "gopkg.in/vmihailenco/msgpack.v2" ) @@ -20,6 +22,8 @@ const validExpr = "any string" // We don't check the value here. const defaultSpace = 0 // And valid too. const defaultIndex = 0 // And valid too. +var validStmt *Prepared = &Prepared{StatementID: 1, Conn: &Connection{}} + type ValidSchemeResolver struct { } @@ -168,6 +172,9 @@ func TestRequestsCodes(t *testing.T) { {req: NewEvalRequest(validExpr), code: EvalRequestCode}, {req: NewExecuteRequest(validExpr), code: ExecuteRequestCode}, {req: NewPingRequest(), code: PingRequestCode}, + {req: NewPrepareRequest(validExpr), code: PrepareRequestCode}, + {req: NewUnprepareRequest(validStmt), code: PrepareRequestCode}, + {req: NewExecutePreparedRequest(validStmt), code: ExecuteRequestCode}, } for _, test := range tests { @@ -517,3 +524,64 @@ func TestExecuteRequestSetters(t *testing.T) { Args(args) assertBodyEqual(t, refBuf.Bytes(), req) } + +func TestPrepareRequestDefaultValues(t *testing.T) { + var refBuf bytes.Buffer + + refEnc := msgpack.NewEncoder(&refBuf) + err := RefImplPrepareBody(refEnc, validExpr) + if err != nil { + t.Errorf("An unexpected RefImplPrepareBody() error: %q", err.Error()) + return + } + + req := NewPrepareRequest(validExpr) + assertBodyEqual(t, refBuf.Bytes(), req) +} + +func TestUnprepareRequestDefaultValues(t *testing.T) { + var refBuf bytes.Buffer + + refEnc := msgpack.NewEncoder(&refBuf) + err := RefImplUnprepareBody(refEnc, *validStmt) + if err != nil { + t.Errorf("An unexpected RefImplUnprepareBody() error: %q", err.Error()) + return + } + + req := NewUnprepareRequest(validStmt) + assert.Equal(t, req.Conn(), validStmt.Conn) + assertBodyEqual(t, refBuf.Bytes(), req) +} + +func TestExecutePreparedRequestSetters(t *testing.T) { + args := []interface{}{uint(11)} + var refBuf bytes.Buffer + + refEnc := msgpack.NewEncoder(&refBuf) + err := RefImplExecutePreparedBody(refEnc, *validStmt, args) + if err != nil { + t.Errorf("An unexpected RefImplExecutePreparedBody() error: %q", err.Error()) + return + } + + req := NewExecutePreparedRequest(validStmt). + Args(args) + assert.Equal(t, req.Conn(), validStmt.Conn) + assertBodyEqual(t, refBuf.Bytes(), req) +} + +func TestExecutePreparedRequestDefaultValues(t *testing.T) { + var refBuf bytes.Buffer + + refEnc := msgpack.NewEncoder(&refBuf) + err := RefImplExecutePreparedBody(refEnc, *validStmt, []interface{}{}) + if err != nil { + t.Errorf("An unexpected RefImplExecutePreparedBody() error: %q", err.Error()) + return + } + + req := NewExecutePreparedRequest(validStmt) + assert.Equal(t, req.Conn(), validStmt.Conn) + assertBodyEqual(t, refBuf.Bytes(), req) +} diff --git a/response.go b/response.go index 80b38849b..3fd7322b0 100644 --- a/response.go +++ b/response.go @@ -147,6 +147,7 @@ func (resp *Response) decodeHeader(d *msgpack.Decoder) (err error) { func (resp *Response) decodeBody() (err error) { if resp.buf.Len() > 2 { var l int + var stmtID, bindCount uint64 d := msgpack.NewDecoder(&resp.buf) if l, err = d.DecodeMapLen(); err != nil { return err @@ -178,12 +179,28 @@ func (resp *Response) decodeBody() (err error) { if err = d.Decode(&resp.MetaData); err != nil { return err } + case KeyStmtID: + if stmtID, err = d.DecodeUint64(); err != nil { + return err + } + case KeyBindCount: + if bindCount, err = d.DecodeUint64(); err != nil { + return err + } default: if err = d.Skip(); err != nil { return err } } } + if stmtID != 0 { + stmt := &Prepared{ + StatementID: PreparedID(stmtID), + ParamCount: bindCount, + MetaData: resp.MetaData, + } + resp.Data = []interface{}{stmt} + } if resp.Code != OkCode && resp.Code != PushCode { resp.Code &^= ErrorCodeBit err = Error{resp.Code, resp.Error} diff --git a/tarantool_test.go b/tarantool_test.go index 64e9c7942..06771338c 100644 --- a/tarantool_test.go +++ b/tarantool_test.go @@ -10,6 +10,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" . "github.com/tarantool/go-tarantool" "github.com/tarantool/go-tarantool/test_helpers" @@ -136,6 +138,61 @@ func BenchmarkClientSerialTyped(b *testing.B) { } } +func BenchmarkClientSerialSQL(b *testing.B) { + test_helpers.SkipIfSQLUnsupported(b) + + conn := test_helpers.ConnectWithValidation(b, server, opts) + defer conn.Close() + + spaceNo := 519 + _, err := conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) + if err != nil { + b.Errorf("Failed to replace: %s", err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := conn.Execute("SELECT NAME0,NAME1,NAME2 FROM SQL_TEST WHERE NAME0=?", []interface{}{uint(1111)}) + if err != nil { + b.Errorf("Select failed: %s", err.Error()) + break + } + } +} + +func BenchmarkClientSerialSQLPrepared(b *testing.B) { + test_helpers.SkipIfSQLUnsupported(b) + + conn := test_helpers.ConnectWithValidation(b, server, opts) + defer conn.Close() + + spaceNo := 519 + _, err := conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) + if err != nil { + b.Errorf("Failed to replace: %s", err) + } + + stmt, err := conn.NewPrepared("SELECT NAME0,NAME1,NAME2 FROM SQL_TEST WHERE NAME0=?") + if err != nil { + b.Fatalf("failed to prepare a SQL statement") + } + executeReq := NewExecutePreparedRequest(stmt) + unprepareReq := NewUnprepareRequest(stmt) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := conn.Do(executeReq.Args([]interface{}{uint(1111)})).Get() + if err != nil { + b.Errorf("Select failed: %s", err.Error()) + break + } + } + _, err = conn.Do(unprepareReq).Get() + if err != nil { + b.Fatalf("failed to unprepare a SQL statement") + } +} + func BenchmarkClientFuture(b *testing.B) { var err error @@ -398,21 +455,14 @@ func BenchmarkClientLargeSelectParallel(b *testing.B) { }) } -func BenchmarkSQLParallel(b *testing.B) { - // Tarantool supports SQL since version 2.0.0 - isLess, err := test_helpers.IsTarantoolVersionLess(2, 0, 0) - if err != nil { - b.Fatal("Could not check the Tarantool version") - } - if isLess { - b.Skip() - } +func BenchmarkClientParallelSQL(b *testing.B) { + test_helpers.SkipIfSQLUnsupported(b) conn := test_helpers.ConnectWithValidation(b, server, opts) defer conn.Close() spaceNo := 519 - _, err = conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) + _, err := conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) if err != nil { b.Errorf("No connection available") } @@ -429,21 +479,49 @@ func BenchmarkSQLParallel(b *testing.B) { }) } -func BenchmarkSQLSerial(b *testing.B) { - // Tarantool supports SQL since version 2.0.0 - isLess, err := test_helpers.IsTarantoolVersionLess(2, 0, 0) +func BenchmarkClientParallelSQLPrepared(b *testing.B) { + test_helpers.SkipIfSQLUnsupported(b) + + conn := test_helpers.ConnectWithValidation(b, server, opts) + defer conn.Close() + + spaceNo := 519 + _, err := conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) + if err != nil { + b.Errorf("No connection available") + } + + stmt, err := conn.NewPrepared("SELECT NAME0,NAME1,NAME2 FROM SQL_TEST WHERE NAME0=?") if err != nil { - b.Fatal("Could not check the Tarantool version") + b.Fatalf("failed to prepare a SQL statement") } - if isLess { - b.Skip() + executeReq := NewExecutePreparedRequest(stmt) + unprepareReq := NewUnprepareRequest(stmt) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := conn.Do(executeReq.Args([]interface{}{uint(1111)})).Get() + if err != nil { + b.Errorf("Select failed: %s", err.Error()) + break + } + } + }) + _, err = conn.Do(unprepareReq).Get() + if err != nil { + b.Fatalf("failed to unprepare a SQL statement") } +} + +func BenchmarkSQLSerial(b *testing.B) { + test_helpers.SkipIfSQLUnsupported(b) conn := test_helpers.ConnectWithValidation(b, server, opts) defer conn.Close() spaceNo := 519 - _, err = conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) + _, err := conn.Replace(spaceNo, []interface{}{uint(1111), "hello", "world"}) if err != nil { b.Errorf("Failed to replace: %s", err) } @@ -882,14 +960,7 @@ const ( ) func TestSQL(t *testing.T) { - // Tarantool supports SQL since version 2.0.0 - isLess, err := test_helpers.IsTarantoolVersionLess(2, 0, 0) - if err != nil { - t.Fatalf("Could not check the Tarantool version") - } - if isLess { - t.Skip() - } + test_helpers.SkipIfSQLUnsupported(t) type testCase struct { Query string @@ -1061,14 +1132,7 @@ func TestSQL(t *testing.T) { } func TestSQLTyped(t *testing.T) { - // Tarantool supports SQL since version 2.0.0 - isLess, err := test_helpers.IsTarantoolVersionLess(2, 0, 0) - if err != nil { - t.Fatal("Could not check the Tarantool version") - } - if isLess { - t.Skip() - } + test_helpers.SkipIfSQLUnsupported(t) conn := test_helpers.ConnectWithValidation(t, server, opts) defer conn.Close() @@ -1090,20 +1154,13 @@ func TestSQLTyped(t *testing.T) { } func TestSQLBindings(t *testing.T) { + test_helpers.SkipIfSQLUnsupported(t) + // Data for test table testData := map[int]string{ 1: "test", } - // Tarantool supports SQL since version 2.0.0 - isLess, err := test_helpers.IsTarantoolVersionLess(2, 0, 0) - if err != nil { - t.Fatal("Could not check the Tarantool version") - } - if isLess { - t.Skip() - } - var resp *Response conn := test_helpers.ConnectWithValidation(t, server, opts) @@ -1150,7 +1207,7 @@ func TestSQLBindings(t *testing.T) { } for _, bind := range namedSQLBinds { - resp, err = conn.Execute(selectNamedQuery2, bind) + resp, err := conn.Execute(selectNamedQuery2, bind) if err != nil { t.Fatalf("Failed to Execute: %s", err.Error()) } @@ -1168,7 +1225,7 @@ func TestSQLBindings(t *testing.T) { } } - resp, err = conn.Execute(selectPosQuery2, sqlBind5) + resp, err := conn.Execute(selectPosQuery2, sqlBind5) if err != nil { t.Fatalf("Failed to Execute: %s", err.Error()) } @@ -1204,21 +1261,14 @@ func TestSQLBindings(t *testing.T) { } func TestStressSQL(t *testing.T) { - // Tarantool supports SQL since version 2.0.0 - isLess, err := test_helpers.IsTarantoolVersionLess(2, 0, 0) - if err != nil { - t.Fatalf("Could not check the Tarantool version") - } - if isLess { - t.Skip() - } + test_helpers.SkipIfSQLUnsupported(t) var resp *Response conn := test_helpers.ConnectWithValidation(t, server, opts) defer conn.Close() - resp, err = conn.Execute(createTableQuery, []interface{}{}) + resp, err := conn.Execute(createTableQuery, []interface{}{}) if err != nil { t.Fatalf("Failed to Execute: %s", err.Error()) } @@ -1308,6 +1358,122 @@ func TestStressSQL(t *testing.T) { } } +func TestNewPrepared(t *testing.T) { + test_helpers.SkipIfSQLUnsupported(t) + + conn := test_helpers.ConnectWithValidation(t, server, opts) + defer conn.Close() + + stmt, err := conn.NewPrepared(selectNamedQuery2) + if err != nil { + t.Errorf("failed to prepare: %v", err) + } + + executeReq := NewExecutePreparedRequest(stmt) + unprepareReq := NewUnprepareRequest(stmt) + + resp, err := conn.Do(executeReq.Args([]interface{}{1, "test"})).Get() + if err != nil { + t.Errorf("failed to execute prepared: %v", err) + } + if resp.Code != OkCode { + t.Errorf("failed to execute prepared: code %d", resp.Code) + } + if reflect.DeepEqual(resp.Data[0], []interface{}{1, "test"}) { + t.Error("Select with named arguments failed") + } + if resp.MetaData[0].FieldType != "unsigned" || + resp.MetaData[0].FieldName != "NAME0" || + resp.MetaData[1].FieldType != "string" || + resp.MetaData[1].FieldName != "NAME1" { + t.Error("Wrong metadata") + } + + resp, err = conn.Do(unprepareReq).Get() + if err != nil { + t.Errorf("failed to unprepare prepared statement: %v", err) + } + if resp.Code != OkCode { + t.Errorf("failed to unprepare prepared statement: code %d", resp.Code) + } + + _, err = conn.Do(unprepareReq).Get() + if err == nil { + t.Errorf("the statement must be already unprepared") + } + require.Contains(t, err.Error(), "Prepared statement with id") + + _, err = conn.Do(executeReq).Get() + if err == nil { + t.Errorf("the statement must be already unprepared") + } + require.Contains(t, err.Error(), "Prepared statement with id") + + prepareReq := NewPrepareRequest(selectNamedQuery2) + resp, err = conn.Do(prepareReq).Get() + if err != nil { + t.Errorf("failed to prepare: %v", err) + } + if resp.Data == nil { + t.Errorf("failed to prepare: Data is nil") + } + if resp.Code != OkCode { + t.Errorf("failed to unprepare prepared statement: code %d", resp.Code) + } + + if len(resp.Data) == 0 { + t.Errorf("failed to prepare: response Data has no elements") + } + stmt, ok := resp.Data[0].(*Prepared) + if !ok { + t.Errorf("failed to prepare: failed to cast the response Data to Prepared object") + } + if stmt.StatementID == 0 { + t.Errorf("failed to prepare: statement id is 0") + } +} + +func TestConnection_DoWithStrangerConn(t *testing.T) { + expectedErr := fmt.Errorf("the passed connected request doesn't belong to the current connection or connection pool") + + conn1 := &Connection{} + req := test_helpers.NewStrangerRequest() + + _, err := conn1.Do(req).Get() + if err == nil { + t.Fatalf("nil error catched") + } + if err.Error() != expectedErr.Error() { + t.Fatalf("Unexpected error catched") + } +} + +func TestNewPreparedFromResponse(t *testing.T) { + var ( + ErrNilResponsePassed = fmt.Errorf("pased nil response") + ErrNilResponseData = fmt.Errorf("response Data is nil") + ErrWrongDataFormat = fmt.Errorf("response Data format is wrong") + ) + testConn := &Connection{} + testCases := []struct { + name string + resp *Response + expectedError error + }{ + {"ErrNilResponsePassed", nil, ErrNilResponsePassed}, + {"ErrNilResponseData", &Response{Data: nil}, ErrNilResponseData}, + {"ErrWrongDataFormat", &Response{Data: []interface{}{}}, ErrWrongDataFormat}, + {"ErrWrongDataFormat", &Response{Data: []interface{}{"test"}}, ErrWrongDataFormat}, + {"nil", &Response{Data: []interface{}{&Prepared{}}}, nil}, + } + for _, testCase := range testCases { + t.Run("Expecting error "+testCase.name, func(t *testing.T) { + _, err := NewPreparedFromResponse(testConn, testCase.resp) + assert.Equal(t, err, testCase.expectedError) + }) + } +} + func TestSchema(t *testing.T) { var err error diff --git a/test_helpers/request_mock.go b/test_helpers/request_mock.go new file mode 100644 index 000000000..00674a3a7 --- /dev/null +++ b/test_helpers/request_mock.go @@ -0,0 +1,25 @@ +package test_helpers + +import ( + "github.com/tarantool/go-tarantool" + "gopkg.in/vmihailenco/msgpack.v2" +) + +type StrangerRequest struct { +} + +func NewStrangerRequest() *StrangerRequest { + return &StrangerRequest{} +} + +func (sr *StrangerRequest) Code() int32 { + return 0 +} + +func (sr *StrangerRequest) Body(resolver tarantool.SchemaResolver, enc *msgpack.Encoder) error { + return nil +} + +func (sr *StrangerRequest) Conn() *tarantool.Connection { + return &tarantool.Connection{} +} diff --git a/test_helpers/utils.go b/test_helpers/utils.go index b7c8fdc96..e07f34bf8 100644 --- a/test_helpers/utils.go +++ b/test_helpers/utils.go @@ -23,3 +23,16 @@ func ConnectWithValidation(t testing.TB, } return conn } + +func SkipIfSQLUnsupported(t testing.TB) { + t.Helper() + + // Tarantool supports SQL since version 2.0.0 + isLess, err := IsTarantoolVersionLess(2, 0, 0) + if err != nil { + t.Fatalf("Could not check the Tarantool version") + } + if isLess { + t.Skip() + } +}