Skip to content

Commit 4c00694

Browse files
committed
add support for variables
see #49 #69
1 parent 9870408 commit 4c00694

File tree

4 files changed

+144
-48
lines changed

4 files changed

+144
-48
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
# CHANGELOG.md
22

3+
## 0.12.0 (unreleased)
4+
5+
- **variables** . SQLPage now support setting and reusing variables between statements. This allows you to write more complex SQL queries, and to reuse the result of a query in multiple places.
6+
```sql
7+
-- Set a variable
8+
SET $person = 'Alice';
9+
-- Use it in a query
10+
SELECT 'text' AS component, 'Hello ' || $person AS contents;
11+
```
12+
313
## 0.11.0 (2023-09-17)
414
- Support for **environment variables** ! You can now read environment variables from sql code using `sqlpage.environment_variable('VAR_NAME')`.
515
- Better support for connection options in mssql.

src/webserver/database/mod.rs

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@ use futures_util::stream::Stream;
77
use futures_util::StreamExt;
88
use serde_json::Value;
99
use std::borrow::Cow;
10+
use std::collections::HashMap;
1011
use std::fmt::{Display, Formatter};
1112
use std::time::Duration;
1213

1314
use crate::app_config::AppConfig;
1415
pub use crate::file_cache::FileCache;
1516

1617
use crate::webserver::database::sql_pseudofunctions::extract_req_param;
17-
use crate::webserver::http::RequestInfo;
18+
use crate::webserver::http::{RequestInfo, SingleOrVec};
1819
use crate::MIGRATIONS_DIR;
1920
pub use sql::make_placeholder;
2021
pub use sql::ParsedSqlFile;
@@ -24,9 +25,12 @@ use sqlx::any::{
2425
use sqlx::migrate::Migrator;
2526
use sqlx::pool::{PoolConnection, PoolOptions};
2627
use sqlx::query::Query;
27-
use sqlx::{Any, AnyConnection, AnyPool, Arguments, ConnectOptions, Either, Executor, Statement};
28+
use sqlx::{
29+
Any, AnyConnection, AnyPool, Arguments, ConnectOptions, Either, Executor, Row, Statement,
30+
};
2831

2932
use self::sql::ParsedSQLStatement;
33+
use sql_pseudofunctions::StmtParam;
3034

3135
pub struct Database {
3236
pub(crate) connection: AnyPool,
@@ -97,27 +101,15 @@ fn migration_err(operation: &'static str) -> String {
97101
pub fn stream_query_results<'a>(
98102
db: &'a Database,
99103
sql_file: &'a ParsedSqlFile,
100-
request: &'a RequestInfo,
104+
request: &'a mut RequestInfo,
101105
) -> impl Stream<Item = DbItem> + 'a {
102-
async_stream::stream! {
106+
async_stream::try_stream! {
103107
let mut connection_opt = None;
104108
for res in &sql_file.statements {
105109
match res {
106-
ParsedSQLStatement::Statement(stmt)=>{
107-
let query = match bind_parameters(stmt, request) {
108-
Ok(q) => q,
109-
Err(e) => {
110-
yield DbItem::Error(e);
111-
continue;
112-
}
113-
};
114-
let connection = match take_connection(db, &mut connection_opt).await {
115-
Ok(c) => c,
116-
Err(e) => {
117-
yield DbItem::Error(e);
118-
return;
119-
}
120-
};
110+
ParsedSQLStatement::Statement(stmt) => {
111+
let query = bind_parameters(stmt, request)?;
112+
let connection = take_connection(db, &mut connection_opt).await?;
121113
let mut stream = query.fetch_many(connection);
122114
while let Some(elem) = stream.next().await {
123115
let is_err = elem.is_err();
@@ -127,13 +119,49 @@ pub fn stream_query_results<'a>(
127119
}
128120
}
129121
},
122+
ParsedSQLStatement::SetVariable { variable, value} => {
123+
let query = bind_parameters(value, request)?;
124+
let connection = take_connection(db, &mut connection_opt).await?;
125+
let row = query.fetch_optional(connection).await?;
126+
let (vars, name) = vars_and_name(request, variable)?;
127+
if let Some(row) = row {
128+
vars.insert(name.clone(), row_to_varvalue(&row));
129+
} else {
130+
vars.remove(&name);
131+
}
132+
},
130133
ParsedSQLStatement::StaticSimpleSelect(value) => {
131134
yield DbItem::Row(value.clone().into())
132135
}
133136
ParsedSQLStatement::Error(e) => yield DbItem::Error(clone_anyhow_err(e)),
134137
}
135138
}
136139
}
140+
.map(|res| res.unwrap_or_else(DbItem::Error))
141+
}
142+
143+
fn vars_and_name<'a>(
144+
request: &'a mut RequestInfo,
145+
variable: &StmtParam,
146+
) -> anyhow::Result<(&'a mut HashMap<String, SingleOrVec>, String)> {
147+
match variable {
148+
StmtParam::Get(name) => {
149+
let vars = &mut request.get_variables;
150+
Ok((vars, name.clone()))
151+
}
152+
StmtParam::Post(name) => {
153+
let vars = &mut request.post_variables;
154+
Ok((vars, name.clone()))
155+
}
156+
_ => Err(anyhow!(
157+
"Only GET and POST variables can be set, not {variable:?}"
158+
)),
159+
}
160+
}
161+
162+
fn row_to_varvalue(row: &AnyRow) -> SingleOrVec {
163+
row.try_get::<String, usize>(0)
164+
.map_or_else(|_| SingleOrVec::Vec(vec![]), SingleOrVec::Single)
137165
}
138166

139167
async fn take_connection<'a, 'b>(
@@ -293,7 +321,7 @@ fn set_custom_connect_options(options: &mut AnyConnectOptions, config: &AppConfi
293321
}
294322
struct PreparedStatement {
295323
statement: AnyStatement<'static>,
296-
parameters: Vec<sql_pseudofunctions::StmtParam>,
324+
parameters: Vec<StmtParam>,
297325
}
298326

299327
impl Display for PreparedStatement {

src/webserver/database/sql.rs

Lines changed: 84 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ pub(super) enum ParsedSQLStatement {
2626
Statement(PreparedStatement),
2727
StaticSimpleSelect(serde_json::Map<String, serde_json::Value>),
2828
Error(anyhow::Error),
29+
SetVariable {
30+
variable: StmtParam,
31+
value: PreparedStatement,
32+
},
2933
}
3034

3135
impl ParsedSqlFile {
@@ -40,20 +44,15 @@ impl ParsedSqlFile {
4044
statements.push(match parsed {
4145
ParsedStatement::StaticSimpleSelect(s) => ParsedSQLStatement::StaticSimpleSelect(s),
4246
ParsedStatement::Error(e) => ParsedSQLStatement::Error(e),
43-
ParsedStatement::StmtWithParams { query, params } => {
44-
let param_types = get_param_types(&params);
45-
match db.prepare_with(&query, &param_types).await {
46-
Ok(statement) => {
47-
log::debug!("Successfully prepared SQL statement '{query}'");
48-
ParsedSQLStatement::Statement(PreparedStatement {
49-
statement,
50-
parameters: params,
51-
})
52-
}
53-
Err(err) => {
54-
log::warn!("Failed to prepare {query:?}: {err:#}");
55-
ParsedSQLStatement::Error(err)
47+
ParsedStatement::StmtWithParams(stmt_with_params) => {
48+
prepare_query_with_params(db, stmt_with_params).await
49+
}
50+
ParsedStatement::SetVariable { variable, value } => {
51+
match prepare_query_with_params(db, value).await {
52+
ParsedSQLStatement::Statement(value) => {
53+
ParsedSQLStatement::SetVariable { variable, value }
5654
}
55+
err => err,
5756
}
5857
}
5958
});
@@ -71,19 +70,45 @@ impl ParsedSqlFile {
7170
}
7271
}
7372

73+
async fn prepare_query_with_params(
74+
db: &Database,
75+
StmtWithParams { query, params }: StmtWithParams,
76+
) -> ParsedSQLStatement {
77+
let param_types = get_param_types(&params);
78+
match db.prepare_with(&query, &param_types).await {
79+
Ok(statement) => {
80+
log::debug!("Successfully prepared SQL statement '{query}'");
81+
ParsedSQLStatement::Statement(PreparedStatement {
82+
statement,
83+
parameters: params,
84+
})
85+
}
86+
Err(err) => {
87+
log::warn!("Failed to prepare {query:?}: {err:#}");
88+
ParsedSQLStatement::Error(err)
89+
}
90+
}
91+
}
92+
7493
#[async_trait(? Send)]
7594
impl AsyncFromStrWithState for ParsedSqlFile {
7695
async fn from_str_with_state(app_state: &AppState, source: &str) -> anyhow::Result<Self> {
7796
Ok(ParsedSqlFile::new(&app_state.db, source).await)
7897
}
7998
}
8099

100+
struct StmtWithParams {
101+
query: String,
102+
params: Vec<StmtParam>,
103+
}
104+
81105
enum ParsedStatement {
82-
StmtWithParams {
83-
query: String,
84-
params: Vec<StmtParam>,
85-
},
106+
StmtWithParams(StmtWithParams),
86107
StaticSimpleSelect(serde_json::Map<String, serde_json::Value>),
108+
SetVariable {
109+
variable: StmtParam,
110+
value: StmtWithParams,
111+
},
87112
Error(anyhow::Error),
88113
}
89114

@@ -113,8 +138,16 @@ fn parse_single_statement(parser: &mut Parser<'_>, db_kind: AnyKind) -> Option<P
113138
return Some(ParsedStatement::StaticSimpleSelect(static_statement));
114139
}
115140
let params = ParameterExtractor::extract_parameters(&mut stmt, db_kind);
116-
let query = stmt.to_string();
117-
Some(ParsedStatement::StmtWithParams { query, params })
141+
if let Some((variable, query)) = extract_set_variable(&mut stmt) {
142+
return Some(ParsedStatement::SetVariable {
143+
variable,
144+
value: StmtWithParams { query, params },
145+
});
146+
}
147+
Some(ParsedStatement::StmtWithParams(StmtWithParams {
148+
query: stmt.to_string(),
149+
params,
150+
}))
118151
}
119152

120153
fn syntax_error(err: ParserError, parser: &mut Parser) -> ParsedStatement {
@@ -226,6 +259,24 @@ fn extract_static_simple_select(
226259
Some(map)
227260
}
228261

262+
fn extract_set_variable(stmt: &mut Statement) -> Option<(StmtParam, String)> {
263+
if let Statement::SetVariable {
264+
variable: ObjectName(name),
265+
value,
266+
local: false,
267+
hivevar: false,
268+
} = stmt
269+
{
270+
if let ([ident], [value]) = (name.as_mut_slice(), value.as_mut_slice()) {
271+
if let Some(variable) = extract_ident_param(ident) {
272+
let query = format!("SELECT {value}");
273+
return Some((variable, query));
274+
}
275+
}
276+
}
277+
None
278+
}
279+
229280
struct ParameterExtractor {
230281
db_kind: AnyKind,
231282
parameters: Vec<StmtParam>,
@@ -378,17 +429,24 @@ pub fn make_placeholder(db_kind: AnyKind, arg_number: usize) -> String {
378429
DEFAULT_PLACEHOLDER.to_string()
379430
}
380431

432+
fn extract_ident_param(Ident { value, quote_style }: &mut Ident) -> Option<StmtParam> {
433+
if quote_style.is_none() && value.starts_with('$') || value.starts_with(':') {
434+
let name = std::mem::take(value);
435+
Some(map_param(name))
436+
} else {
437+
None
438+
}
439+
}
440+
381441
impl VisitorMut for ParameterExtractor {
382442
type Break = ();
383443
fn pre_visit_expr(&mut self, value: &mut Expr) -> ControlFlow<Self::Break> {
384444
match value {
385-
Expr::Identifier(Ident {
386-
value: var_name,
387-
quote_style: None,
388-
}) if var_name.starts_with('$') || var_name.starts_with(':') => {
389-
let name = std::mem::take(var_name);
390-
*value = self.make_placeholder();
391-
self.parameters.push(map_param(name));
445+
Expr::Identifier(ident) => {
446+
if let Some(param) = extract_ident_param(ident) {
447+
*value = self.make_placeholder();
448+
self.parameters.push(param);
449+
}
392450
}
393451
Expr::Value(Value::Placeholder(param)) if !self.is_own_placeholder(param) =>
394452
// this check is to avoid recursively replacing placeholders in the form of '?', or '$1', '$2', which we emit ourselves

src/webserver/http.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ async fn render_sql(
210210
srv_req: &mut ServiceRequest,
211211
sql_file: Arc<ParsedSqlFile>,
212212
) -> actix_web::Result<HttpResponse> {
213-
let req_param = extract_request_info(srv_req).await;
213+
let mut req_param = extract_request_info(srv_req).await;
214214
log::debug!("Received a request with the following parameters: {req_param:?}");
215215
let app_state = srv_req
216216
.app_data::<web::Data<AppState>>()
@@ -220,7 +220,7 @@ async fn render_sql(
220220

221221
let (resp_send, resp_recv) = tokio::sync::oneshot::channel::<HttpResponse>();
222222
actix_web::rt::spawn(async move {
223-
let database_entries_stream = stream_query_results(&app_state.db, &sql_file, &req_param);
223+
let database_entries_stream = stream_query_results(&app_state.db, &sql_file, &mut req_param);
224224
let response_with_writer =
225225
build_response_header_and_stream(Arc::clone(&app_state), database_entries_stream).await;
226226
match response_with_writer {

0 commit comments

Comments
 (0)