diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 505386fbf..e480b1026 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -3329,6 +3329,22 @@ pub enum Statement { channel: Ident, payload: Option, }, + /// ```sql + /// LOAD DATA [LOCAL] INPATH 'filepath' [OVERWRITE] INTO TABLE tablename + /// [PARTITION (partcol1=val1, partcol2=val2 ...)] + /// [INPUTFORMAT 'inputformat' SERDE 'serde'] + /// ``` + /// Loading files into tables + /// + /// See Hive + LoadData { + local: bool, + inpath: String, + overwrite: bool, + table_name: ObjectName, + partitioned: Option>, + table_format: Option, + }, } impl fmt::Display for Statement { @@ -3931,6 +3947,36 @@ impl fmt::Display for Statement { Ok(()) } Statement::CreateTable(create_table) => create_table.fmt(f), + Statement::LoadData { + local, + inpath, + overwrite, + table_name, + partitioned, + table_format, + } => { + write!( + f, + "LOAD DATA {local}INPATH '{inpath}' {overwrite}INTO TABLE {table_name}", + local = if *local { "LOCAL " } else { "" }, + inpath = inpath, + overwrite = if *overwrite { "OVERWRITE " } else { "" }, + table_name = table_name, + )?; + if let Some(ref parts) = &partitioned { + if !parts.is_empty() { + write!(f, " PARTITION ({})", display_comma_separated(parts))?; + } + } + if let Some(HiveLoadDataFormat { + serde, + input_format, + }) = &table_format + { + write!(f, " INPUTFORMAT {input_format} SERDE {serde}")?; + } + Ok(()) + } Statement::CreateVirtualTable { name, if_not_exists, @@ -5816,6 +5862,14 @@ pub enum HiveRowFormat { DELIMITED { delimiters: Vec }, } +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct HiveLoadDataFormat { + pub serde: Expr, + pub input_format: Expr, +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] diff --git a/src/dialect/duckdb.rs b/src/dialect/duckdb.rs index e1b8db118..905b04e36 100644 --- a/src/dialect/duckdb.rs +++ b/src/dialect/duckdb.rs @@ -66,4 +66,9 @@ impl Dialect for DuckDbDialect { fn supports_explain_with_utility_options(&self) -> bool { true } + + /// See DuckDB + fn supports_load_extension(&self) -> bool { + true + } } diff --git a/src/dialect/generic.rs b/src/dialect/generic.rs index 8cfac217b..4998e0f4b 100644 --- a/src/dialect/generic.rs +++ b/src/dialect/generic.rs @@ -115,4 +115,8 @@ impl Dialect for GenericDialect { fn supports_comment_on(&self) -> bool { true } + + fn supports_load_extension(&self) -> bool { + true + } } diff --git a/src/dialect/hive.rs b/src/dialect/hive.rs index b97bf69be..571f9b9ba 100644 --- a/src/dialect/hive.rs +++ b/src/dialect/hive.rs @@ -56,4 +56,9 @@ impl Dialect for HiveDialect { fn supports_bang_not_operator(&self) -> bool { true } + + /// See Hive + fn supports_load_data(&self) -> bool { + true + } } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index d95d7c70a..c47079903 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -611,6 +611,16 @@ pub trait Dialect: Debug + Any { false } + /// Returns true if the dialect supports the `LOAD DATA` statement + fn supports_load_data(&self) -> bool { + false + } + + /// Returns true if the dialect supports the `LOAD extension` statement + fn supports_load_extension(&self) -> bool { + false + } + /// Returns true if this dialect expects the `TOP` option /// before the `ALL`/`DISTINCT` options in a `SELECT` statement. fn supports_top_before_distinct(&self) -> bool { diff --git a/src/keywords.rs b/src/keywords.rs index 9cdc90ce2..790268219 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -389,6 +389,7 @@ define_keywords!( INITIALLY, INNER, INOUT, + INPATH, INPUT, INPUTFORMAT, INSENSITIVE, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 756f4d68b..d246d0eb5 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -543,10 +543,7 @@ impl<'a> Parser<'a> { Keyword::INSTALL if dialect_of!(self is DuckDbDialect | GenericDialect) => { self.parse_install() } - // `LOAD` is duckdb specific https://duckdb.org/docs/extensions/overview - Keyword::LOAD if dialect_of!(self is DuckDbDialect | GenericDialect) => { - self.parse_load() - } + Keyword::LOAD => self.parse_load(), // `OPTIMIZE` is clickhouse specific https://clickhouse.tech/docs/en/sql-reference/statements/optimize/ Keyword::OPTIMIZE if dialect_of!(self is ClickHouseDialect | GenericDialect) => { self.parse_optimize_table() @@ -11178,6 +11175,22 @@ impl<'a> Parser<'a> { } } + pub fn parse_load_data_table_format( + &mut self, + ) -> Result, ParserError> { + if self.parse_keyword(Keyword::INPUTFORMAT) { + let input_format = self.parse_expr()?; + self.expect_keyword(Keyword::SERDE)?; + let serde = self.parse_expr()?; + Ok(Some(HiveLoadDataFormat { + input_format, + serde, + })) + } else { + Ok(None) + } + } + /// Parse an UPDATE statement, returning a `Box`ed SetExpr /// /// This is used to reduce the size of the stack frames in debug builds @@ -12180,10 +12193,35 @@ impl<'a> Parser<'a> { Ok(Statement::Install { extension_name }) } - /// `LOAD [extension_name]` + /// Parse a SQL LOAD statement pub fn parse_load(&mut self) -> Result { - let extension_name = self.parse_identifier(false)?; - Ok(Statement::Load { extension_name }) + if self.dialect.supports_load_extension() { + let extension_name = self.parse_identifier(false)?; + Ok(Statement::Load { extension_name }) + } else if self.parse_keyword(Keyword::DATA) && self.dialect.supports_load_data() { + let local = self.parse_one_of_keywords(&[Keyword::LOCAL]).is_some(); + self.expect_keyword(Keyword::INPATH)?; + let inpath = self.parse_literal_string()?; + let overwrite = self.parse_one_of_keywords(&[Keyword::OVERWRITE]).is_some(); + self.expect_keyword(Keyword::INTO)?; + self.expect_keyword(Keyword::TABLE)?; + let table_name = self.parse_object_name(false)?; + let partitioned = self.parse_insert_partition()?; + let table_format = self.parse_load_data_table_format()?; + Ok(Statement::LoadData { + local, + inpath, + overwrite, + table_name, + partitioned, + table_format, + }) + } else { + self.expected( + "`DATA` or an extension name after `LOAD`", + self.peek_token(), + ) + } } /// ```sql diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 25bf306ad..68da8e601 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -11503,13 +11503,208 @@ fn parse_notify_channel() { dialects.parse_sql_statements(sql).unwrap_err(), ParserError::ParserError("Expected: an SQL statement, found: NOTIFY".to_string()) ); - assert_eq!( - dialects.parse_sql_statements(sql).unwrap_err(), - ParserError::ParserError("Expected: an SQL statement, found: NOTIFY".to_string()) - ); } } +#[test] +fn parse_load_data() { + let dialects = all_dialects_where(|d| d.supports_load_data()); + let only_supports_load_extension_dialects = + all_dialects_where(|d| !d.supports_load_data() && d.supports_load_extension()); + let not_supports_load_dialects = + all_dialects_where(|d| !d.supports_load_data() && !d.supports_load_extension()); + + let sql = "LOAD DATA INPATH '/local/path/to/data.txt' INTO TABLE test.my_table"; + match dialects.verified_stmt(sql) { + Statement::LoadData { + local, + inpath, + overwrite, + table_name, + partitioned, + table_format, + } => { + assert_eq!(false, local); + assert_eq!("/local/path/to/data.txt", inpath); + assert_eq!(false, overwrite); + assert_eq!( + ObjectName(vec![Ident::new("test"), Ident::new("my_table")]), + table_name + ); + assert_eq!(None, partitioned); + assert_eq!(None, table_format); + } + _ => unreachable!(), + }; + + // with OVERWRITE keyword + let sql = "LOAD DATA INPATH '/local/path/to/data.txt' OVERWRITE INTO TABLE my_table"; + match dialects.verified_stmt(sql) { + Statement::LoadData { + local, + inpath, + overwrite, + table_name, + partitioned, + table_format, + } => { + assert_eq!(false, local); + assert_eq!("/local/path/to/data.txt", inpath); + assert_eq!(true, overwrite); + assert_eq!(ObjectName(vec![Ident::new("my_table")]), table_name); + assert_eq!(None, partitioned); + assert_eq!(None, table_format); + } + _ => unreachable!(), + }; + + assert_eq!( + only_supports_load_extension_dialects + .parse_sql_statements(sql) + .unwrap_err(), + ParserError::ParserError("Expected: end of statement, found: INPATH".to_string()) + ); + assert_eq!( + not_supports_load_dialects + .parse_sql_statements(sql) + .unwrap_err(), + ParserError::ParserError( + "Expected: `DATA` or an extension name after `LOAD`, found: INPATH".to_string() + ) + ); + + // with LOCAL keyword + let sql = "LOAD DATA LOCAL INPATH '/local/path/to/data.txt' INTO TABLE test.my_table"; + match dialects.verified_stmt(sql) { + Statement::LoadData { + local, + inpath, + overwrite, + table_name, + partitioned, + table_format, + } => { + assert_eq!(true, local); + assert_eq!("/local/path/to/data.txt", inpath); + assert_eq!(false, overwrite); + assert_eq!( + ObjectName(vec![Ident::new("test"), Ident::new("my_table")]), + table_name + ); + assert_eq!(None, partitioned); + assert_eq!(None, table_format); + } + _ => unreachable!(), + }; + + assert_eq!( + only_supports_load_extension_dialects + .parse_sql_statements(sql) + .unwrap_err(), + ParserError::ParserError("Expected: end of statement, found: LOCAL".to_string()) + ); + assert_eq!( + not_supports_load_dialects + .parse_sql_statements(sql) + .unwrap_err(), + ParserError::ParserError( + "Expected: `DATA` or an extension name after `LOAD`, found: LOCAL".to_string() + ) + ); + + // with PARTITION clause + let sql = "LOAD DATA LOCAL INPATH '/local/path/to/data.txt' INTO TABLE my_table PARTITION (year = 2024, month = 11)"; + match dialects.verified_stmt(sql) { + Statement::LoadData { + local, + inpath, + overwrite, + table_name, + partitioned, + table_format, + } => { + assert_eq!(true, local); + assert_eq!("/local/path/to/data.txt", inpath); + assert_eq!(false, overwrite); + assert_eq!(ObjectName(vec![Ident::new("my_table")]), table_name); + assert_eq!( + Some(vec![ + Expr::BinaryOp { + left: Box::new(Expr::Identifier(Ident::new("year"))), + op: BinaryOperator::Eq, + right: Box::new(Expr::Value(Value::Number("2024".parse().unwrap(), false))), + }, + Expr::BinaryOp { + left: Box::new(Expr::Identifier(Ident::new("month"))), + op: BinaryOperator::Eq, + right: Box::new(Expr::Value(Value::Number("11".parse().unwrap(), false))), + } + ]), + partitioned + ); + assert_eq!(None, table_format); + } + _ => unreachable!(), + }; + + // with PARTITION clause + let sql = "LOAD DATA LOCAL INPATH '/local/path/to/data.txt' OVERWRITE INTO TABLE good.my_table PARTITION (year = 2024, month = 11) INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat' SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde'"; + match dialects.verified_stmt(sql) { + Statement::LoadData { + local, + inpath, + overwrite, + table_name, + partitioned, + table_format, + } => { + assert_eq!(true, local); + assert_eq!("/local/path/to/data.txt", inpath); + assert_eq!(true, overwrite); + assert_eq!( + ObjectName(vec![Ident::new("good"), Ident::new("my_table")]), + table_name + ); + assert_eq!( + Some(vec![ + Expr::BinaryOp { + left: Box::new(Expr::Identifier(Ident::new("year"))), + op: BinaryOperator::Eq, + right: Box::new(Expr::Value(Value::Number("2024".parse().unwrap(), false))), + }, + Expr::BinaryOp { + left: Box::new(Expr::Identifier(Ident::new("month"))), + op: BinaryOperator::Eq, + right: Box::new(Expr::Value(Value::Number("11".parse().unwrap(), false))), + } + ]), + partitioned + ); + assert_eq!( + Some(HiveLoadDataFormat { + serde: Expr::Value(Value::SingleQuotedString( + "org.apache.hadoop.hive.serde2.OpenCSVSerde".to_string() + )), + input_format: Expr::Value(Value::SingleQuotedString( + "org.apache.hadoop.mapred.TextInputFormat".to_string() + )) + }), + table_format + ); + } + _ => unreachable!(), + }; + + // negative test case + let sql = "LOAD DATA2 LOCAL INPATH '/local/path/to/data.txt' INTO TABLE test.my_table"; + assert_eq!( + dialects.parse_sql_statements(sql).unwrap_err(), + ParserError::ParserError( + "Expected: `DATA` or an extension name after `LOAD`, found: DATA2".to_string() + ) + ); +} + #[test] fn test_select_top() { let dialects = all_dialects_where(|d| d.supports_top_before_distinct());