diff --git a/CHANGELOG.md b/CHANGELOG.md index 64ee2343..3872e78a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - new `tooltip` property in the button component. - New `search_value` property in the shell component. - Fixed a display issue in the hero component when the button text is long and the viewport is narrow. + - reuse the existing opened database connection for the current query in `sqlpage.run_sql` instead of opening a new one. This makes it possible to create a temporary table in a file, and reuse it in an included script, create a SQL transaction that spans over multiple run_sql calls, and should generally make run_sql more performant. - Fixed a bug in the cookie component where removing a cookie from a subdirectory would not work. ## 0.22.0 (2024-05-29) diff --git a/src/webserver/database/execute_queries.rs b/src/webserver/database/execute_queries.rs index 078f112e..1cbc15ea 100644 --- a/src/webserver/database/execute_queries.rs +++ b/src/webserver/database/execute_queries.rs @@ -17,7 +17,9 @@ use super::syntax_tree::{extract_req_param, StmtParam}; use super::{highlight_sql_error, Database, DbItem}; use sqlx::any::{AnyArguments, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo}; use sqlx::pool::PoolConnection; -use sqlx::{Any, AnyConnection, Arguments, Either, Executor, Statement}; +use sqlx::{Any, Arguments, Either, Executor, Statement}; + +pub type DbConn = Option>; impl Database { pub(crate) async fn prepare_with( @@ -32,23 +34,23 @@ impl Database { .map_err(|e| highlight_sql_error("Failed to prepare SQL statement", query, e)) } } -pub fn stream_query_results<'a>( - db: &'a Database, + +pub fn stream_query_results_with_conn<'a>( sql_file: &'a ParsedSqlFile, request: &'a mut RequestInfo, + db_connection: &'a mut DbConn, ) -> impl Stream + 'a { async_stream::try_stream! { - let mut connection_opt = None; for res in &sql_file.statements { match res { ParsedStatement::CsvImport(csv_import) => { - let connection = take_connection(db, &mut connection_opt).await?; + let connection = take_connection(&request.app_state.db, db_connection).await?; log::debug!("Executing CSV import: {:?}", csv_import); run_csv_import(connection, csv_import, request).await?; }, ParsedStatement::StmtWithParams(stmt) => { - let query = bind_parameters(stmt, request).await?; - let connection = take_connection(db, &mut connection_opt).await?; + let query = bind_parameters(stmt, request, db_connection).await?; + let connection = take_connection(&request.app_state.db, db_connection).await?; log::trace!("Executing query {:?}", query.sql); let mut stream = connection.fetch_many(query); while let Some(elem) = stream.next().await { @@ -62,13 +64,13 @@ pub fn stream_query_results<'a>( } }, ParsedStatement::SetVariable { variable, value} => { - execute_set_variable_query(db, &mut connection_opt, request, variable, value).await + execute_set_variable_query(db_connection, request, variable, value).await .with_context(|| format!("Failed to set the {variable} variable to {value:?}") )?; }, ParsedStatement::StaticSimpleSelect(value) => { - for i in parse_dynamic_rows(DbItem::Row(exec_static_simple_select(value, request).await?)) { + for i in parse_dynamic_rows(DbItem::Row(exec_static_simple_select(value, request, db_connection).await?)) { yield i; } } @@ -83,12 +85,15 @@ pub fn stream_query_results<'a>( async fn exec_static_simple_select( columns: &[(String, SimpleSelectValue)], req: &RequestInfo, + db_connection: &mut DbConn, ) -> anyhow::Result { let mut map = serde_json::Map::with_capacity(columns.len()); for (name, value) in columns { let value = match value { SimpleSelectValue::Static(s) => s.clone(), - SimpleSelectValue::Dynamic(p) => extract_req_param_as_json(p, req).await?, + SimpleSelectValue::Dynamic(p) => { + extract_req_param_as_json(p, req, db_connection).await? + } }; map = add_value_to_map(map, (name.clone(), value)); } @@ -100,8 +105,9 @@ async fn exec_static_simple_select( async fn extract_req_param_as_json( param: &StmtParam, request: &RequestInfo, + db_connection: &mut DbConn, ) -> anyhow::Result { - if let Some(val) = extract_req_param(param, request).await? { + if let Some(val) = extract_req_param(param, request, db_connection).await? { Ok(serde_json::Value::String(val.into_owned())) } else { Ok(serde_json::Value::Null) @@ -111,22 +117,25 @@ async fn extract_req_param_as_json( /// This function is used to create a pinned boxed stream of query results. /// This allows recursive calls. pub fn stream_query_results_boxed<'a>( - db: &'a Database, sql_file: &'a ParsedSqlFile, request: &'a mut RequestInfo, + db_connection: &'a mut DbConn, ) -> Pin + 'a>> { - Box::pin(stream_query_results(db, sql_file, request)) + Box::pin(stream_query_results_with_conn( + sql_file, + request, + db_connection, + )) } async fn execute_set_variable_query<'a>( - db: &'a Database, - connection_opt: &mut Option>, + db_connection: &'a mut DbConn, request: &'a mut RequestInfo, variable: &StmtParam, statement: &StmtWithParams, ) -> anyhow::Result<()> { - let query = bind_parameters(statement, request).await?; - let connection = take_connection(db, connection_opt).await?; + let query = bind_parameters(statement, request, db_connection).await?; + let connection = take_connection(&request.app_state.db, db_connection).await?; log::debug!( "Executing query to set the {variable:?} variable: {:?}", query.sql @@ -169,21 +178,21 @@ fn vars_and_name<'a, 'b>( async fn take_connection<'a, 'b>( db: &'a Database, - conn: &'b mut Option>, -) -> anyhow::Result<&'b mut AnyConnection> { - match conn { - Some(c) => Ok(c), - None => match db.connection.acquire().await { - Ok(c) => { - log::debug!("Acquired a database connection"); - *conn = Some(c); - Ok(conn.as_mut().unwrap()) - } - Err(e) => { - let err_msg = format!("Unable to acquire a database connection to execute the SQL file. All of the {} {:?} connections are busy.", db.connection.size(), db.connection.any_kind()); - Err(anyhow::Error::new(e).context(err_msg)) - } - }, + conn: &'b mut DbConn, +) -> anyhow::Result<&'b mut PoolConnection> { + if let Some(c) = conn { + return Ok(c); + } + match db.connection.acquire().await { + Ok(c) => { + log::debug!("Acquired a database connection"); + *conn = Some(c); + Ok(conn.as_mut().unwrap()) + } + Err(e) => { + let err_msg = format!("Unable to acquire a database connection to execute the SQL file. All of the {} {:?} connections are busy.", db.connection.size(), db.connection.any_kind()); + Err(anyhow::Error::new(e).context(err_msg)) + } } } @@ -211,16 +220,17 @@ fn clone_anyhow_err(err: &anyhow::Error) -> anyhow::Error { e } -async fn bind_parameters<'a>( +async fn bind_parameters<'a, 'b>( stmt: &'a StmtWithParams, request: &'a RequestInfo, + db_connection: &'b mut DbConn, ) -> anyhow::Result> { let sql = stmt.query.as_str(); log::debug!("Preparing statement: {}", sql); let mut arguments = AnyArguments::default(); for (param_idx, param) in stmt.params.iter().enumerate() { log::trace!("\tevaluating parameter {}: {}", param_idx + 1, param); - let argument = extract_req_param(param, request).await?; + let argument = extract_req_param(param, request, db_connection).await?; log::debug!( "\tparameter {}: {}", param_idx + 1, diff --git a/src/webserver/database/mod.rs b/src/webserver/database/mod.rs index 68aa3fdc..dda62cee 100644 --- a/src/webserver/database/mod.rs +++ b/src/webserver/database/mod.rs @@ -50,3 +50,9 @@ pub fn highlight_sql_error( } anyhow::Error::new(db_err).context(msg) } + +impl std::fmt::Display for Database { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.connection.any_kind()) + } +} diff --git a/src/webserver/database/sqlpage_functions/function_definition_macro.rs b/src/webserver/database/sqlpage_functions/function_definition_macro.rs index e704afc9..3ff3f3db 100644 --- a/src/webserver/database/sqlpage_functions/function_definition_macro.rs +++ b/src/webserver/database/sqlpage_functions/function_definition_macro.rs @@ -1,7 +1,12 @@ /// Defines all sqlpage functions #[macro_export] macro_rules! sqlpage_functions { - ($($func_name:ident($(($request:ty)$(,)?)? $($param_name:ident : $param_type:ty),*);)*) => { + ($($func_name:ident( + $(($request:ty $(, $db_conn:ty)?))? + $(,)? + $($param_name:ident : $param_type:ty),* + ); + )*) => { #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum SqlPageFunctionName { $( #[allow(non_camel_case_types)] $func_name ),* @@ -47,10 +52,11 @@ macro_rules! sqlpage_functions { } } impl SqlPageFunctionName { - pub(crate) async fn evaluate<'a>( + pub(crate) async fn evaluate<'a, 'b>( &self, #[allow(unused_variables)] request: &'a RequestInfo, + db_connection: &'b mut Option>, params: Vec>> ) -> anyhow::Result>> { use $crate::webserver::database::sqlpage_functions::function_traits::*; @@ -66,7 +72,10 @@ macro_rules! sqlpage_functions { anyhow::bail!("Too many arguments. Remove extra argument {}", as_sql(extraneous_param)); } let result = $func_name( - $(<$request>::from(request),)* + $( + <$request>::from(request), + $(<$db_conn>::from(db_connection),)* + )* $($param_name.into()),* ).await; result.into_cow_result() diff --git a/src/webserver/database/sqlpage_functions/functions.rs b/src/webserver/database/sqlpage_functions/functions.rs index 1cbb716c..8fba61c0 100644 --- a/src/webserver/database/sqlpage_functions/functions.rs +++ b/src/webserver/database/sqlpage_functions/functions.rs @@ -1,5 +1,5 @@ use super::RequestInfo; -use crate::webserver::{http::SingleOrVec, ErrorWithStatus}; +use crate::webserver::{database::execute_queries::DbConn, http::SingleOrVec, ErrorWithStatus}; use anyhow::{anyhow, Context}; use futures_util::StreamExt; use std::{borrow::Cow, ffi::OsStr, str::FromStr}; @@ -27,7 +27,7 @@ super::function_definition_macro::sqlpage_functions! { read_file_as_data_url((&RequestInfo), file_path: Option>); read_file_as_text((&RequestInfo), file_path: Option>); request_method((&RequestInfo)); - run_sql((&RequestInfo), sql_file_path: Option>); + run_sql((&RequestInfo, &mut DbConn), sql_file_path: Option>); uploaded_file_mime_type((&RequestInfo), upload_name: Cow); uploaded_file_path((&RequestInfo), upload_name: Cow); @@ -347,6 +347,7 @@ async fn request_method(request: &RequestInfo) -> String { async fn run_sql<'a>( request: &'a RequestInfo, + db_connection: &mut DbConn, sql_file_path: Option>, ) -> anyhow::Result>> { use serde::ser::{SerializeSeq, Serializer}; @@ -373,9 +374,9 @@ async fn run_sql<'a>( } let mut results_stream = crate::webserver::database::execute_queries::stream_query_results_boxed( - &request.app_state.db, &sql_file, &mut tmp_req, + db_connection, ); let mut json_results_bytes = Vec::new(); let mut json_encoder = serde_json::Serializer::new(&mut json_results_bytes); diff --git a/src/webserver/database/syntax_tree.rs b/src/webserver/database/syntax_tree.rs index 49b196bd..a00e1471 100644 --- a/src/webserver/database/syntax_tree.rs +++ b/src/webserver/database/syntax_tree.rs @@ -20,7 +20,7 @@ use crate::webserver::database::sql::function_arg_to_stmt_param; use crate::webserver::http::SingleOrVec; use crate::webserver::http_request_info::RequestInfo; -use super::sqlpage_functions::functions::SqlPageFunctionName; +use super::{execute_queries::DbConn, sqlpage_functions::functions::SqlPageFunctionName}; use anyhow::{anyhow, Context as _}; /// Represents a parameter to a SQL statement. @@ -100,13 +100,16 @@ impl SqlPageFunctionCall { }) } - pub async fn evaluate<'a>( + pub async fn evaluate<'a, 'b>( &self, request: &'a RequestInfo, + db_connection: &'b mut DbConn, ) -> anyhow::Result>> { - let evaluated_args = self.arguments.iter().map(|x| extract_req_param(x, request)); - let evaluated_args = futures_util::future::try_join_all(evaluated_args).await?; - self.function.evaluate(request, evaluated_args).await + let mut params = Vec::with_capacity(self.arguments.len()); + for param in &self.arguments { + params.push(Box::pin(extract_req_param(param, request, db_connection)).await?); + } + self.function.evaluate(request, db_connection, params).await } } @@ -127,9 +130,10 @@ impl std::fmt::Display for SqlPageFunctionCall { /// Extracts the value of a parameter from the request. /// Returns `Ok(None)` when NULL should be used as the parameter value. -pub(super) async fn extract_req_param<'a>( +pub(super) async fn extract_req_param<'a, 'b>( param: &StmtParam, request: &'a RequestInfo, + db_connection: &'b mut DbConn, ) -> anyhow::Result>> { Ok(match param { // sync functions @@ -145,8 +149,8 @@ pub(super) async fn extract_req_param<'a>( StmtParam::Error(x) => anyhow::bail!("{}", x), StmtParam::Literal(x) => Some(Cow::Owned(x.to_string())), StmtParam::Null => None, - StmtParam::Concat(args) => concat_params(&args[..], request).await?, - StmtParam::FunctionCall(func) => func.evaluate(request).await.with_context(|| { + StmtParam::Concat(args) => concat_params(&args[..], request, db_connection).await?, + StmtParam::FunctionCall(func) => func.evaluate(request, db_connection).await.with_context(|| { format!( "Error in function call {func}.\nExpected {:#}", func.function @@ -155,13 +159,14 @@ pub(super) async fn extract_req_param<'a>( }) } -async fn concat_params<'a>( +async fn concat_params<'a, 'b>( args: &[StmtParam], request: &'a RequestInfo, + db_connection: &'b mut DbConn, ) -> anyhow::Result>> { let mut result = String::new(); for arg in args { - let Some(arg) = Box::pin(extract_req_param(arg, request)).await? else { + let Some(arg) = Box::pin(extract_req_param(arg, request, db_connection)).await? else { return Ok(None); }; result.push_str(&arg); diff --git a/src/webserver/http.rs b/src/webserver/http.rs index eb827462..2c3e22e2 100644 --- a/src/webserver/http.rs +++ b/src/webserver/http.rs @@ -1,5 +1,5 @@ use crate::render::{HeaderContext, PageContext, RenderContext}; -use crate::webserver::database::{execute_queries::stream_query_results, DbItem}; +use crate::webserver::database::{execute_queries::stream_query_results_with_conn, DbItem}; use crate::webserver::http_request_info::extract_request_info; use crate::webserver::ErrorWithStatus; use crate::{app_config, AppConfig, AppState, ParsedSqlFile}; @@ -229,8 +229,9 @@ async fn render_sql( let layout_context = &LayoutContext { is_embedded: req_param.get_variables.contains_key("_sqlpage_embed"), }; + let mut conn = None; let database_entries_stream = - stream_query_results(&app_state.db, &sql_file, &mut req_param); + stream_query_results_with_conn(&sql_file, &mut req_param, &mut conn); let response_with_writer = build_response_header_and_stream( Arc::clone(&app_state), database_entries_stream, diff --git a/tests/index.rs b/tests/index.rs index 009bca5c..6c20cb82 100644 --- a/tests/index.rs +++ b/tests/index.rs @@ -133,6 +133,11 @@ async fn test_files() { if test_file_path.extension().unwrap_or_default() != "sql" { continue; } + if test_file_path_string.contains(&format!("no{}", app_data.db.to_string().to_lowercase())) + { + // skipping because the test does not support the database + continue; + } let req_str = format!("/{}?x=1", test_file_path_string); let resp = req_path_with_app_data(&req_str, app_data.clone()) .await diff --git a/tests/select_temp_t.sql b/tests/select_temp_t.sql new file mode 100644 index 00000000..ec74c193 --- /dev/null +++ b/tests/select_temp_t.sql @@ -0,0 +1,2 @@ +-- see tests/sql_test_files/it_works_temp_table_accessible_in_run_sql.sql +select 'text' as component, x as contents from temp_t; \ No newline at end of file diff --git a/tests/sql_test_files/README.md b/tests/sql_test_files/README.md index 36d15a74..a3447f4e 100644 --- a/tests/sql_test_files/README.md +++ b/tests/sql_test_files/README.md @@ -6,6 +6,10 @@ Files with names starting with `it_works` should all return a page that contains the text "It works !" and does not contain the text "error" (case insensitive) when executed. +If a file name contains `nosqlite`, `nomssql`, `nopostgres` or `nomysql`, then +the test will be ignored when running against the corresponding database. +This allows using syntax that is not supported on all databases in some tests. + ## `error_` files Files with names starting with `error` should all return a page that contains diff --git a/tests/sql_test_files/it_works_temp_table_accessible_in_run_sql_nomssql.sql b/tests/sql_test_files/it_works_temp_table_accessible_in_run_sql_nomssql.sql new file mode 100644 index 00000000..7da8faf3 --- /dev/null +++ b/tests/sql_test_files/it_works_temp_table_accessible_in_run_sql_nomssql.sql @@ -0,0 +1,4 @@ +-- Doesnt work on mssql because it does not support "create temporary table" +create temporary table temp_t(x text); +insert into temp_t(x) values ('It works !'); +select 'dynamic' as component, sqlpage.run_sql('tests/select_temp_t.sql') AS properties; \ No newline at end of file