diff --git a/.eslintrc b/.eslintrc index 7c3d515cd..1df3684d1 100644 --- a/.eslintrc +++ b/.eslintrc @@ -1,5 +1,6 @@ { "env": { + "es6": true, "node": true }, "rules": { diff --git a/index.js b/index.js index 72624076b..86017f010 100644 --- a/index.js +++ b/index.js @@ -1,3 +1,5 @@ +var ttCommon = require('template-tag-common'); + var Classes = Object.create(null); /** @@ -46,10 +48,18 @@ exports.createPoolCluster = function createPoolCluster(config) { * @return {Query} New query object * @public */ -exports.createQuery = function createQuery(sql, values, callback) { +exports.createQuery = function createQuery(...args) { var Connection = loadClass('Connection'); - - return Connection.createQuery(sql, values, callback); + if (ttCommon.calledAsTemplateTagQuick(args[0], args.length)) { + var Template = loadClass('Template'); + const sqlFragment = Template(...args); + return function (callback) { + return Connection.createQuery(sqlFragment.content, [], callback); + }; + } else { + const [ sql, values, callback ] = args; + return Connection.createQuery(sql, values, callback); + } }; /** @@ -106,12 +116,21 @@ exports.raw = function raw(sql) { return SqlString.raw(sql); }; -/** - * The type constants. - * @public - */ -Object.defineProperty(exports, 'Types', { - get: loadClass.bind(null, 'Types') +Object.defineProperties(exports, { + /** + * The type constants. + * @public + */ + 'Types': { + get: loadClass.bind(null, 'Types') + }, + /** + * The SQL template tag. + * @public + */ + 'sql': { + get: loadClass.bind(null, 'Template') + } }); /** @@ -147,6 +166,9 @@ function loadClass(className) { case 'SqlString': Class = require('./lib/protocol/SqlString'); break; + case 'Template': + Class = require('./lib/Template'); + break; case 'Types': Class = require('./lib/protocol/constants/types'); break; diff --git a/lib/Connection.js b/lib/Connection.js index ea452757e..e11cd0981 100644 --- a/lib/Connection.js +++ b/lib/Connection.js @@ -2,10 +2,12 @@ var Crypto = require('crypto'); var Events = require('events'); var Net = require('net'); var tls = require('tls'); +var ttCommon = require('template-tag-common'); var ConnectionConfig = require('./ConnectionConfig'); var Protocol = require('./protocol/Protocol'); var SqlString = require('./protocol/SqlString'); var Query = require('./protocol/sequences/Query'); +var Template = require('./Template'); var Util = require('util'); module.exports = Connection; @@ -191,8 +193,16 @@ Connection.prototype.rollback = function rollback(options, callback) { return this.query(options, callback); }; -Connection.prototype.query = function query(sql, values, cb) { - var query = Connection.createQuery(sql, values, cb); +Connection.prototype.query = function query(...args) { + if (ttCommon.calledAsTemplateTagQuick(args[0], args.length)) { + const sqlFragment = Template(...args); + return function (callback) { + return this.query(sqlFragment.content, [], callback); + }.bind(this); + } + + const [ sql, values, callback ] = args; + var query = Connection.createQuery(sql, values, callback); query._connection = this; if (!(typeof sql === 'object' && 'typeCast' in sql)) { diff --git a/lib/Template.js b/lib/Template.js new file mode 100644 index 000000000..d19406895 --- /dev/null +++ b/lib/Template.js @@ -0,0 +1,247 @@ +const Mysql = require('../index'); +const { + memoizedTagFunction, + trimCommonWhitespaceFromLines, + TypedString +} = require('template-tag-common'); + +// A simple lexer for SQL. +// SQL has many divergent dialects with subtly different +// conventions for string escaping and comments. +// This just attempts to roughly tokenize MySQL's specific variant. +// See also +// https://www.w3.org/2005/05/22-SPARQL-MySQL/sql_yacc +// https://github.com/twitter/mysql/blob/master/sql/sql_lex.cc +// https://dev.mysql.com/doc/refman/5.7/en/string-literals.html + +// "--" followed by whitespace starts a line comment +// "#" +// "/*" starts an inline comment ended at first "*/" +// \N means null +// Prefixed strings x'...' is a hex string, b'...' is a binary string, .... +// '...', "..." are strings. `...` escapes identifiers. +// doubled delimiters and backslash both escape +// doubled delimiters work in `...` identifiers + +const PREFIX_BEFORE_DELIMITER = new RegExp( + '^(?:' + + ( + // Comment + '--(?=[\\t\\r\\n ])[^\\r\\n]*' + + '|#[^\\r\\n]*' + + '|/[*][\\s\\S]*?[*]/' + ) + + '|' + + ( + // Run of non-comment non-string starts + '(?:[^\'"`\\-/#]|-(?!-)|/(?![*]))' + ) + + ')*'); +const DELIMITED_BODIES = { + '\'' : /^(?:[^'\\]|\\[\s\S]|'')*/, + '"' : /^(?:[^"\\]|\\[\s\S]|"")*/, + '`' : /^(?:[^`\\]|\\[\s\S]|``)*/ +}; + +/** + * Template tag that creates a new Error with a message. + * @param {!Array.} strs a valid TemplateObject. + * @return {string} A message suitable for the Error constructor. + */ +function msg (strs, ...dyn) { + let message = String(strs[0]); + for (let i = 0; i < dyn.length; ++i) { + message += JSON.stringify(dyn[i]) + strs[i + 1]; + } + return message; +} + +/** + * Returns a function that can be fed chunks of input and which + * returns a delimiter context. + * + * @return {!function (string) : string} + * a stateful function that takes a string of SQL text and + * returns the context after it. Subsequent calls will assume + * that context. + */ +function makeLexer () { + let errorMessage = null; + let delimiter = null; + return (text) => { + if (errorMessage) { + // Replay the error message if we've already failed. + throw new Error(errorMessage); + } + text = String(text); + while (text) { + const pattern = delimiter + ? DELIMITED_BODIES[delimiter] + : PREFIX_BEFORE_DELIMITER; + const match = pattern.exec(text); + if (!match) { + throw new Error( + errorMessage = msg`Failed to lex starting at ${text}`); + } + let nConsumed = match[0].length; + if (text.length > nConsumed) { + const chr = text.charAt(nConsumed); + if (delimiter) { + if (chr === delimiter) { + delimiter = null; + ++nConsumed; + } else { + throw new Error( + errorMessage = msg`Expected ${chr} at ${text}`); + } + } else if (Object.hasOwnProperty.call(DELIMITED_BODIES, chr)) { + delimiter = chr; + ++nConsumed; + } else { + throw new Error( + errorMessage = msg`Expected delimiter at ${text}`); + } + } + text = text.substring(nConsumed); + } + return delimiter; + }; +} + +/** A string wrapper that marks its content as a SQL identifier. */ +class Identifier extends TypedString {} + +/** + * A string wrapper that marks its content as a series of + * well-formed SQL tokens. + */ +class SqlFragment extends TypedString {} + +/** + * Analyzes the static parts of the tag content. + * + * @param {!Array.} strings a valid TemplateObject. + * @return { !{ + * raw: !Array., + * delimiters : !Array., + * chunks: !Array. + * } } + * A record like { raw, delimiters, chunks } + * where delimiter is a contextual cue and chunk is + * the adjusted raw text. + */ +function computeStatic (strings) { + const { raw } = trimCommonWhitespaceFromLines(strings); + + const delimiters = []; + const chunks = []; + + const lexer = makeLexer(); + + let delimiter = null; + for (let i = 0, len = raw.length; i < len; ++i) { + let chunk = String(raw[i]); + if (delimiter === '`') { + // Treat raw \` in an identifier literal as an ending delimiter. + chunk = chunk.replace(/^([^\\`]|\\[\s\S])*\\`/, '$1`'); + } + const newDelimiter = lexer(chunk); + if (newDelimiter === '`' && !delimiter) { + // Treat literal \` outside a string context as starting an + // identifier literal + chunk = chunk.replace( + /((?:^|[^\\])(?:\\\\)*)\\(`(?:[^`\\]|\\[\s\S])*)$/, '$1$2'); + } + + chunks.push(chunk); + delimiters.push(newDelimiter); + delimiter = newDelimiter; + } + + if (delimiter) { + throw new Error(`Unclosed quoted string: ${delimiter}`); + } + + return { raw, delimiters, chunks }; +} + +function interpolateSqlIntoFragment ( + { raw, delimiters, chunks }, strings, values) { + // A buffer to accumulate output. + let [ result ] = chunks; + for (let i = 1, len = raw.length; i < len; ++i) { + const chunk = chunks[i]; + // The count of values must be 1 less than the surrounding + // chunks of literal text. + if (i !== 0) { + const delimiter = delimiters[i - 1]; + const value = values[i - 1]; + if (delimiter) { + result += escapeDelimitedValue(value, delimiter); + } else { + result = appendValue(result, value, chunk); + } + } + + result += chunk; + } + + return new SqlFragment(result); +} + +function escapeDelimitedValue (value, delimiter) { + if (delimiter === '`') { + return Mysql.escapeId(String(value)).replace(/^`|`$/g, ''); + } + const escaped = Mysql.escape(String(value)); + return escaped.substring(1, escaped.length - 1); +} + +function appendValue (resultBefore, value, chunk) { + let needsSpace = false; + let result = resultBefore; + const valueArray = Array.isArray(value) ? value : [ value ]; + for (let i = 0, nValues = valueArray.length; i < nValues; ++i) { + if (i) { + result += ', '; + } + + const one = valueArray[i]; + let valueStr = null; + if (one instanceof SqlFragment) { + if (!/(?:^|[\n\r\t ,\x28])$/.test(result)) { + result += ' '; + } + valueStr = one.toString(); + needsSpace = i + 1 === nValues; + } else if (one instanceof Identifier) { + valueStr = Mysql.escapeId(one.toString()); + } else { + // If we need to handle nested arrays, we would recurse here. + valueStr = Mysql.format('?', one); + } + result += valueStr; + } + + if (needsSpace && chunk && !/^[\n\r\t ,\x29]/.test(chunk)) { + result += ' '; + } + + return result; +} + +/** + * Template tag function that contextually autoescapes values + * producing a SqlFragment. + */ +const sql = memoizedTagFunction(computeStatic, interpolateSqlIntoFragment); +sql.Identifier = Identifier; +sql.Fragment = SqlFragment; + +if (require('process').env.npm_lifecycle_event === 'test') { + // Expose for testing. + // Harmless if this leaks + sql.makeLexer = makeLexer; +} + +module.exports = sql; diff --git a/package.json b/package.json index b66820932..c99fac997 100644 --- a/package.json +++ b/package.json @@ -16,7 +16,8 @@ "bignumber.js": "4.0.4", "readable-stream": "2.3.3", "safe-buffer": "5.1.1", - "sqlstring": "2.3.0" + "sqlstring": "2.3.0", + "template-tag-common": "1.0.8" }, "devDependencies": { "after": "0.8.2", @@ -39,9 +40,9 @@ }, "scripts": { "lint": "eslint .", - "test": "node test/run.js", - "test-ci": "nyc --reporter=text npm test", - "test-cov": "nyc --reporter=html --reporter=text npm test", + "test": "TZ=GMT node test/run.js", + "test-ci": "TZ=GMT nyc --reporter=text npm test", + "test-cov": "TZ=GMT nyc --reporter=html --reporter=text npm test", "version": "node tool/version-changes.js && git add Changes.md" } } diff --git a/test/common.js b/test/common.js index db502b60f..b9de44109 100644 --- a/test/common.js +++ b/test/common.js @@ -29,6 +29,7 @@ common.Parser = require(common.lib + '/protocol/Parser'); common.PoolConfig = require(common.lib + '/PoolConfig'); common.PoolConnection = require(common.lib + '/PoolConnection'); common.SqlString = require(common.lib + '/protocol/SqlString'); +common.Template = require(common.lib + '/Template'); common.Types = require(common.lib + '/protocol/constants/types'); var Mysql = require(path.resolve(common.lib, '../index')); diff --git a/test/integration/connection/test-query.js b/test/integration/connection/test-query.js index a11b04dc4..3ea077276 100644 --- a/test/integration/connection/test-query.js +++ b/test/integration/connection/test-query.js @@ -4,17 +4,17 @@ var common = require('../../common'); common.getTestConnection(function (err, connection) { assert.ifError(err); - connection.query('SELECT 1', function (err, rows, fields) { + function callback (err, rows, fields) { assert.ifError(err); assert.deepEqual(rows, [{1: 1}]); assert.equal(fields[0].name, '1'); - }); + } - connection.query({ sql: 'SELECT ?' }, [ 1 ], function (err, rows, fields) { - assert.ifError(err); - assert.deepEqual(rows, [{1: 1}]); - assert.equal(fields[0].name, '1'); - }); + connection.query('SELECT 1', callback); + + connection.query({ sql: 'SELECT ?' }, [ 1 ], callback); + + connection.query`SELECT ${ 1 }`(callback); connection.end(assert.ifError); }); diff --git a/test/unit/template/test-template.js b/test/unit/template/test-template.js new file mode 100644 index 000000000..9ba843a6d --- /dev/null +++ b/test/unit/template/test-template.js @@ -0,0 +1,157 @@ +var assert = require('assert'); +var common = require('../../common'); +var test = require('utest'); +var Template = common.Template; + +function tokens (...chunks) { + const lexer = Template.makeLexer(); + const out = []; + for (let i = 0, len = chunks.length; i < len; ++i) { + out.push(lexer(chunks[i]) || '_'); + } + return out.join(','); +} + +test('template lexer', { + 'empty string': function () { + assert.equal(tokens(''), '_'); + }, + 'hash comments': function () { + assert.equal(tokens(' # "foo\n', ''), '_,_'); + }, + 'dash comments': function () { + assert.equal(tokens(' -- \'foo\n', ''), '_,_'); + }, + 'block comments': function () { + assert.equal(tokens(' /* `foo */', ''), '_,_'); + }, + 'dq': function () { + assert.equal(tokens('SELECT "foo"'), '_'); + assert.equal(tokens('SELECT `foo`, "foo"'), '_'); + assert.equal(tokens('SELECT "', '"'), '",_'); + assert.equal(tokens('SELECT "x', '"'), '",_'); + assert.equal(tokens('SELECT "\'', '"'), '",_'); + assert.equal(tokens('SELECT "`', '"'), '",_'); + assert.equal(tokens('SELECT """', '"'), '",_'); + assert.equal(tokens('SELECT "\\"', '"'), '",_'); + }, + 'sq': function () { + assert.equal(tokens('SELECT \'foo\''), '_'); + assert.equal(tokens('SELECT `foo`, \'foo\''), '_'); + assert.equal(tokens('SELECT \'', '\''), '\',_'); + assert.equal(tokens('SELECT \'x', '\''), '\',_'); + assert.equal(tokens('SELECT \'"', '\''), '\',_'); + assert.equal(tokens('SELECT \'`', '\''), '\',_'); + assert.equal(tokens('SELECT \'\'\'', '\''), '\',_'); + assert.equal(tokens('SELECT \'\\\'', '\''), '\',_'); + }, + 'bq': function () { + assert.equal(tokens('SELECT `foo`'), '_'); + assert.equal(tokens('SELECT "foo", `foo`'), '_'); + assert.equal(tokens('SELECT `', '`'), '`,_'); + assert.equal(tokens('SELECT `x', '`'), '`,_'); + assert.equal(tokens('SELECT `\'', '`'), '`,_'); + assert.equal(tokens('SELECT `"', '`'), '`,_'); + assert.equal(tokens('SELECT ```', '`'), '`,_'); + assert.equal(tokens('SELECT `\\`', '`'), '`,_'); + } +}); + +function runTagTest (golden, test) { + // Run multiply to test memoization bugs. + for (let i = 3; --i >= 0;) { + let result = test(); + if (result instanceof Template.Fragment) { + result = result.toString(); + } else { + throw new Error(`Expected SqlFragment not ${result}`); + } + assert.equal(result, golden); + } +} + +test('template tag', { + 'numbers': function () { + runTagTest( + 'SELECT 2', + () => Template`SELECT ${1 + 1}`); + }, + 'date': function () { + runTagTest( + `SELECT '2000-01-01 00:00:00.000'`, + () => Template`SELECT ${new Date(Date.UTC(2000, 0, 1, 0, 0, 0))}`); + }, + 'string': function () { + runTagTest( + `SELECT 'Hello, World!\\n'`, + () => Template`SELECT ${'Hello, World!\n'}`); + }, + 'identifier': function () { + runTagTest( + 'SELECT `foo`', + () => Template`SELECT ${new Template.Identifier('foo')}`); + }, + 'fragment': function () { + const fragment = new Template.Fragment('1 + 1'); + runTagTest( + `SELECT 1 + 1`, + () => Template`SELECT ${fragment}`); + }, + 'fragment no token merging': function () { + const fragment = new Template.Fragment('1 + 1'); + runTagTest( + `SELECT 1 + 1 FROM T`, + () => Template`SELECT${fragment}FROM T`); + }, + 'string in dq string': function () { + runTagTest( + `SELECT "Hello, World!\\n"`, + () => Template`SELECT "Hello, ${'World!'}\n"`); + }, + 'string in sq string': function () { + runTagTest( + `SELECT 'Hello, World!\\n'`, + () => Template`SELECT 'Hello, ${'World!'}\n'`); + }, + 'string after string in string': function () { + // The following tests check obliquely that '?' is not + // interpreted as a prepared statement meta-character + // internally. + runTagTest( + `SELECT 'Hello', "World?"`, + () => Template`SELECT '${'Hello'}', "World?"`); + }, + 'string before string in string': function () { + runTagTest( + `SELECT 'Hello?', 'World?'`, + () => Template`SELECT 'Hello?', '${'World?'}'`); + }, + 'number after string in string': function () { + runTagTest( + `SELECT 'Hello?', 123`, + () => Template`SELECT '${'Hello?'}', ${123}`); + }, + 'number before string in string': function () { + runTagTest( + `SELECT 123, 'World?'`, + () => Template`SELECT ${123}, '${'World?'}'`); + }, + 'string in identifier': function () { + runTagTest( + 'SELECT `foo`', + () => Template`SELECT \`${'foo'}\``); + }, + 'number in identifier': function () { + runTagTest( + 'SELECT `foo_123`', + () => Template`SELECT \`foo_${123}\``); + }, + 'array': function () { + const id = new Template.Identifier('foo'); + const frag = new Template.Fragment('1 + 1'); + const values = [ 123, 'foo', id, frag ]; + runTagTest( + "SELECT X FROM T WHERE X IN (123, 'foo', `foo`, 1 + 1)", + () => Template`SELECT X FROM T WHERE X IN (${values})`); + } +});