From 0e8a105cd46192de2f5e4e345b416999bc66e799 Mon Sep 17 00:00:00 2001 From: wugeer <1284057728@qq.com> Date: Tue, 12 Nov 2024 18:02:24 +0800 Subject: [PATCH 1/6] Add support for Hive's `LOAD DATA` expr --- src/ast/mod.rs | 54 +++++++++++++ src/dialect/duckdb.rs | 5 ++ src/dialect/hive.rs | 5 ++ src/dialect/mod.rs | 10 +++ src/keywords.rs | 1 + src/parser/mod.rs | 52 +++++++++++-- tests/sqlparser_common.rs | 154 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 274 insertions(+), 7 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 505386fbf..2932f11a1 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -3329,6 +3329,22 @@ pub enum Statement { channel: Ident, payload: Option, }, + /// ```sql + /// LOAD DATA [LOCAL] INPATH 'filepath' [OVERWRITE] INTO TABLE tablename + /// [PARTITION (partcol1=val1, partcol2=val2 ...)] + /// [INPUTFORMAT 'inputformat' SERDE 'serde'] + /// ``` + /// Loading files into tables + /// + /// See Hive + LoadData { + local: bool, + inpath: String, + overwrite: bool, + table_name: ObjectName, + partitioned: Option>, + table_format: Option, + }, } impl fmt::Display for Statement { @@ -3931,6 +3947,36 @@ impl fmt::Display for Statement { Ok(()) } Statement::CreateTable(create_table) => create_table.fmt(f), + Statement::LoadData { + local, + inpath, + overwrite, + table_name, + partitioned, + table_format, + } => { + write!( + f, + "LOAD DATA {local}INPATH '{inpath}' {overwrite}INTO TABLE {table_name}", + local = if *local { "LOCAL " } else { "" }, + inpath = inpath, + overwrite = if *overwrite { "OVERWRITE " } else { "" }, + table_name = table_name, + )?; + if let Some(ref parts) = &partitioned { + if !parts.is_empty() { + write!(f, " PARTITION ({})", display_comma_separated(parts))?; + } + } + if let Some(HiveLoadDataOption { + serde, + input_format, + }) = &table_format + { + write!(f, " INPUTFORMAT {input_format} SERDE {serde}")?; + } + Ok(()) + } Statement::CreateVirtualTable { name, if_not_exists, @@ -5816,6 +5862,14 @@ pub enum HiveRowFormat { DELIMITED { delimiters: Vec }, } +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct HiveLoadDataOption { + pub serde: Expr, + pub input_format: Expr, +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] diff --git a/src/dialect/duckdb.rs b/src/dialect/duckdb.rs index e1b8db118..905b04e36 100644 --- a/src/dialect/duckdb.rs +++ b/src/dialect/duckdb.rs @@ -66,4 +66,9 @@ impl Dialect for DuckDbDialect { fn supports_explain_with_utility_options(&self) -> bool { true } + + /// See DuckDB + fn supports_load_extension(&self) -> bool { + true + } } diff --git a/src/dialect/hive.rs b/src/dialect/hive.rs index b97bf69be..571f9b9ba 100644 --- a/src/dialect/hive.rs +++ b/src/dialect/hive.rs @@ -56,4 +56,9 @@ impl Dialect for HiveDialect { fn supports_bang_not_operator(&self) -> bool { true } + + /// See Hive + fn supports_load_data(&self) -> bool { + true + } } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index d95d7c70a..c47079903 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -611,6 +611,16 @@ pub trait Dialect: Debug + Any { false } + /// Returns true if the dialect supports the `LOAD DATA` statement + fn supports_load_data(&self) -> bool { + false + } + + /// Returns true if the dialect supports the `LOAD extension` statement + fn supports_load_extension(&self) -> bool { + false + } + /// Returns true if this dialect expects the `TOP` option /// before the `ALL`/`DISTINCT` options in a `SELECT` statement. fn supports_top_before_distinct(&self) -> bool { diff --git a/src/keywords.rs b/src/keywords.rs index 9cdc90ce2..790268219 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -389,6 +389,7 @@ define_keywords!( INITIALLY, INNER, INOUT, + INPATH, INPUT, INPUTFORMAT, INSENSITIVE, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 756f4d68b..173626042 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -543,10 +543,7 @@ impl<'a> Parser<'a> { Keyword::INSTALL if dialect_of!(self is DuckDbDialect | GenericDialect) => { self.parse_install() } - // `LOAD` is duckdb specific https://duckdb.org/docs/extensions/overview - Keyword::LOAD if dialect_of!(self is DuckDbDialect | GenericDialect) => { - self.parse_load() - } + Keyword::LOAD => self.parse_load(), // `OPTIMIZE` is clickhouse specific https://clickhouse.tech/docs/en/sql-reference/statements/optimize/ Keyword::OPTIMIZE if dialect_of!(self is ClickHouseDialect | GenericDialect) => { self.parse_optimize_table() @@ -11178,6 +11175,22 @@ impl<'a> Parser<'a> { } } + pub fn parse_load_data_table_format( + &mut self, + ) -> Result, ParserError> { + if self.parse_keyword(Keyword::INPUTFORMAT) { + let input_format = self.parse_expr()?; + self.expect_keyword(Keyword::SERDE)?; + let serde = self.parse_expr()?; + Ok(Some(HiveLoadDataOption { + input_format, + serde, + })) + } else { + Ok(None) + } + } + /// Parse an UPDATE statement, returning a `Box`ed SetExpr /// /// This is used to reduce the size of the stack frames in debug builds @@ -12180,10 +12193,35 @@ impl<'a> Parser<'a> { Ok(Statement::Install { extension_name }) } - /// `LOAD [extension_name]` + /// Parse a SQL LOAD statement pub fn parse_load(&mut self) -> Result { - let extension_name = self.parse_identifier(false)?; - Ok(Statement::Load { extension_name }) + if self.dialect.supports_load_extension() { + let extension_name = self.parse_identifier(false)?; + Ok(Statement::Load { extension_name }) + } else if self.parse_keyword(Keyword::DATA) && self.dialect.supports_load_data() { + let local = self.parse_one_of_keywords(&[Keyword::LOCAL]).is_some(); + self.expect_keyword(Keyword::INPATH)?; + let inpath = self.parse_literal_string()?; + let overwrite = self.parse_one_of_keywords(&[Keyword::OVERWRITE]).is_some(); + self.expect_keyword(Keyword::INTO)?; + self.expect_keyword(Keyword::TABLE)?; + let table_name = self.parse_object_name(false)?; + let partitioned = self.parse_insert_partition()?; + let table_format = self.parse_load_data_table_format()?; + Ok(Statement::LoadData { + local, + inpath, + overwrite, + table_name, + partitioned, + table_format, + }) + } else { + self.expected( + "Expected: dialect supports `LOAD DATA` or `LOAD extension` to parse `LOAD` statements", + self.peek_token(), + ) + } } /// ```sql diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 25bf306ad..69c32d517 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -11510,6 +11510,160 @@ fn parse_notify_channel() { } } +#[test] +fn parse_load_data() { + let dialects = all_dialects_where(|d| d.supports_load_data()); + + match dialects + .verified_stmt("LOAD DATA INPATH '/local/path/to/data.txt' INTO TABLE test.my_table") + { + Statement::LoadData { + local, + inpath, + overwrite, + table_name, + partitioned, + table_format, + } => { + assert_eq!(false, local); + assert_eq!("/local/path/to/data.txt", inpath); + assert_eq!(false, overwrite); + assert_eq!( + ObjectName(vec![Ident::new("test"), Ident::new("my_table")]), + table_name + ); + assert_eq!(None, partitioned); + assert_eq!(None, table_format); + } + _ => unreachable!(), + }; + + // with OVERWRITE keyword + match dialects + .verified_stmt("LOAD DATA INPATH '/local/path/to/data.txt' OVERWRITE INTO TABLE my_table") + { + Statement::LoadData { + local, + inpath, + overwrite, + table_name, + partitioned, + table_format, + } => { + assert_eq!(false, local); + assert_eq!("/local/path/to/data.txt", inpath); + assert_eq!(true, overwrite); + assert_eq!(ObjectName(vec![Ident::new("my_table")]), table_name); + assert_eq!(None, partitioned); + assert_eq!(None, table_format); + } + _ => unreachable!(), + }; + + // with LOCAL keyword + match dialects + .verified_stmt("LOAD DATA LOCAL INPATH '/local/path/to/data.txt' INTO TABLE test.my_table") + { + Statement::LoadData { + local, + inpath, + overwrite, + table_name, + partitioned, + table_format, + } => { + assert_eq!(true, local); + assert_eq!("/local/path/to/data.txt", inpath); + assert_eq!(false, overwrite); + assert_eq!( + ObjectName(vec![Ident::new("test"), Ident::new("my_table")]), + table_name + ); + assert_eq!(None, partitioned); + assert_eq!(None, table_format); + } + _ => unreachable!(), + }; + + // with PARTITION clause + match dialects.verified_stmt("LOAD DATA LOCAL INPATH '/local/path/to/data.txt' INTO TABLE my_table PARTITION (year = 2024, month = 11)") { + Statement::LoadData {local, inpath, overwrite, table_name, partitioned, table_format} => { + assert_eq!(true, local); + assert_eq!("/local/path/to/data.txt", inpath); + assert_eq!(false, overwrite); + assert_eq!(ObjectName(vec![Ident::new("my_table")]), table_name); + assert_eq!(Some(vec![ + Expr::BinaryOp{ + left: Box::new(Expr::Identifier(Ident::new("year"))), + op: BinaryOperator::Eq, + right: Box::new(Expr::Value(Value::Number("2024".parse().unwrap(), false))), + }, + Expr::BinaryOp{ + left: Box::new(Expr::Identifier(Ident::new("month"))), + op: BinaryOperator::Eq, + right: Box::new(Expr::Value(Value::Number("11".parse().unwrap(), false))), + }]), partitioned); + assert_eq!(None, table_format); + } + _ => unreachable!(), + }; + + // with PARTITION clause + match dialects.verified_stmt("LOAD DATA LOCAL INPATH '/local/path/to/data.txt' OVERWRITE INTO TABLE good.my_table PARTITION (year = 2024, month = 11) INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat' SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde'") { + Statement::LoadData {local, inpath, overwrite, table_name, partitioned, table_format} => { + assert_eq!(true, local); + assert_eq!("/local/path/to/data.txt", inpath); + assert_eq!(true, overwrite); + assert_eq!(ObjectName(vec![Ident::new("good"), Ident::new("my_table")]), table_name); + assert_eq!(Some(vec![ + Expr::BinaryOp{ + left: Box::new(Expr::Identifier(Ident::new("year"))), + op: BinaryOperator::Eq, + right: Box::new(Expr::Value(Value::Number("2024".parse().unwrap(), false))), + }, + Expr::BinaryOp{ + left: Box::new(Expr::Identifier(Ident::new("month"))), + op: BinaryOperator::Eq, + right: Box::new(Expr::Value(Value::Number("11".parse().unwrap(), false))), + }]), partitioned); + assert_eq!(Some(HiveLoadDataOption {serde: Expr::Value(Value::SingleQuotedString("org.apache.hadoop.hive.serde2.OpenCSVSerde".to_string())), input_format: Expr::Value(Value::SingleQuotedString("org.apache.hadoop.mapred.TextInputFormat".to_string()))}), table_format); + } + _ => unreachable!(), + }; + + let dialects = all_dialects_where(|d| !d.supports_load_data() && d.supports_load_extension()); + + assert_eq!( + dialects + .parse_sql_statements( + "LOAD DATA LOCAL INPATH '/local/path/to/data.txt' INTO TABLE test.my_table" + ) + .unwrap_err(), + ParserError::ParserError("Expected: end of statement, found: LOCAL".to_string()) + ); + + assert_eq!( + dialects + .parse_sql_statements( + "LOAD DATA INPATH '/local/path/to/data.txt' INTO TABLE test.my_table" + ) + .unwrap_err(), + ParserError::ParserError("Expected: end of statement, found: INPATH".to_string()) + ); + + let dialects = all_dialects_where(|d| !d.supports_load_data() && !d.supports_load_extension()); + + assert_eq!( + dialects.parse_sql_statements("LOAD DATA LOCAL INPATH '/local/path/to/data.txt' INTO TABLE test.my_table").unwrap_err(), + ParserError::ParserError("Expected: Expected: dialect supports `LOAD DATA` or `LOAD extension` to parse `LOAD` statements, found: LOCAL".to_string()) + ); + + assert_eq!( + dialects.parse_sql_statements("LOAD DATA INPATH '/local/path/to/data.txt' INTO TABLE test.my_table").unwrap_err(), + ParserError::ParserError("Expected: Expected: dialect supports `LOAD DATA` or `LOAD extension` to parse `LOAD` statements, found: INPATH".to_string()) + ); +} + #[test] fn test_select_top() { let dialects = all_dialects_where(|d| d.supports_top_before_distinct()); From a74195a34e854650f07f2c3e58bebd98319878d8 Mon Sep 17 00:00:00 2001 From: wugeer <1284057728@qq.com> Date: Tue, 12 Nov 2024 22:02:03 +0800 Subject: [PATCH 2/6] add more tests --- src/ast/mod.rs | 6 +++--- src/parser/mod.rs | 6 +++--- tests/sqlparser_common.rs | 20 +++++++++++++------- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 2932f11a1..e480b1026 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -3343,7 +3343,7 @@ pub enum Statement { overwrite: bool, table_name: ObjectName, partitioned: Option>, - table_format: Option, + table_format: Option, }, } @@ -3968,7 +3968,7 @@ impl fmt::Display for Statement { write!(f, " PARTITION ({})", display_comma_separated(parts))?; } } - if let Some(HiveLoadDataOption { + if let Some(HiveLoadDataFormat { serde, input_format, }) = &table_format @@ -5865,7 +5865,7 @@ pub enum HiveRowFormat { #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] -pub struct HiveLoadDataOption { +pub struct HiveLoadDataFormat { pub serde: Expr, pub input_format: Expr, } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 173626042..f2d7005aa 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -11177,12 +11177,12 @@ impl<'a> Parser<'a> { pub fn parse_load_data_table_format( &mut self, - ) -> Result, ParserError> { + ) -> Result, ParserError> { if self.parse_keyword(Keyword::INPUTFORMAT) { let input_format = self.parse_expr()?; self.expect_keyword(Keyword::SERDE)?; let serde = self.parse_expr()?; - Ok(Some(HiveLoadDataOption { + Ok(Some(HiveLoadDataFormat { input_format, serde, })) @@ -12218,7 +12218,7 @@ impl<'a> Parser<'a> { }) } else { self.expected( - "Expected: dialect supports `LOAD DATA` or `LOAD extension` to parse `LOAD` statements", + "dialect supports `LOAD DATA` or `LOAD extension` to parse `LOAD` statements", self.peek_token(), ) } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 69c32d517..93a9d9d39 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -11503,10 +11503,6 @@ fn parse_notify_channel() { dialects.parse_sql_statements(sql).unwrap_err(), ParserError::ParserError("Expected: an SQL statement, found: NOTIFY".to_string()) ); - assert_eq!( - dialects.parse_sql_statements(sql).unwrap_err(), - ParserError::ParserError("Expected: an SQL statement, found: NOTIFY".to_string()) - ); } } @@ -11626,11 +11622,21 @@ fn parse_load_data() { op: BinaryOperator::Eq, right: Box::new(Expr::Value(Value::Number("11".parse().unwrap(), false))), }]), partitioned); - assert_eq!(Some(HiveLoadDataOption {serde: Expr::Value(Value::SingleQuotedString("org.apache.hadoop.hive.serde2.OpenCSVSerde".to_string())), input_format: Expr::Value(Value::SingleQuotedString("org.apache.hadoop.mapred.TextInputFormat".to_string()))}), table_format); + assert_eq!(Some(HiveLoadDataFormat {serde: Expr::Value(Value::SingleQuotedString("org.apache.hadoop.hive.serde2.OpenCSVSerde".to_string())), input_format: Expr::Value(Value::SingleQuotedString("org.apache.hadoop.mapred.TextInputFormat".to_string()))}), table_format); } _ => unreachable!(), }; + // negative test case + assert_eq!( + dialects + .parse_sql_statements( + "LOAD DATA2 LOCAL INPATH '/local/path/to/data.txt' INTO TABLE test.my_table" + ) + .unwrap_err(), + ParserError::ParserError("Expected: dialect supports `LOAD DATA` or `LOAD extension` to parse `LOAD` statements, found: DATA2".to_string()) + ); + let dialects = all_dialects_where(|d| !d.supports_load_data() && d.supports_load_extension()); assert_eq!( @@ -11655,12 +11661,12 @@ fn parse_load_data() { assert_eq!( dialects.parse_sql_statements("LOAD DATA LOCAL INPATH '/local/path/to/data.txt' INTO TABLE test.my_table").unwrap_err(), - ParserError::ParserError("Expected: Expected: dialect supports `LOAD DATA` or `LOAD extension` to parse `LOAD` statements, found: LOCAL".to_string()) + ParserError::ParserError("Expected: dialect supports `LOAD DATA` or `LOAD extension` to parse `LOAD` statements, found: LOCAL".to_string()) ); assert_eq!( dialects.parse_sql_statements("LOAD DATA INPATH '/local/path/to/data.txt' INTO TABLE test.my_table").unwrap_err(), - ParserError::ParserError("Expected: Expected: dialect supports `LOAD DATA` or `LOAD extension` to parse `LOAD` statements, found: INPATH".to_string()) + ParserError::ParserError("Expected: dialect supports `LOAD DATA` or `LOAD extension` to parse `LOAD` statements, found: INPATH".to_string()) ); } From 16cc753d9c9c409c516ef70cbe4c256d7818ea6c Mon Sep 17 00:00:00 2001 From: wugeer <1284057728@qq.com> Date: Tue, 12 Nov 2024 22:21:22 +0800 Subject: [PATCH 3/6] Adjust the structure of the test code --- tests/sqlparser_common.rs | 172 ++++++++++++++++++++++---------------- 1 file changed, 99 insertions(+), 73 deletions(-) diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 93a9d9d39..62e06e826 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -11509,10 +11509,13 @@ fn parse_notify_channel() { #[test] fn parse_load_data() { let dialects = all_dialects_where(|d| d.supports_load_data()); + let only_supports_load_extension_dialects = + all_dialects_where(|d| !d.supports_load_data() && d.supports_load_extension()); + let not_supports_load_dialects = + all_dialects_where(|d| !d.supports_load_data() && !d.supports_load_extension()); - match dialects - .verified_stmt("LOAD DATA INPATH '/local/path/to/data.txt' INTO TABLE test.my_table") - { + let sql = "LOAD DATA INPATH '/local/path/to/data.txt' INTO TABLE test.my_table"; + match dialects.verified_stmt(sql) { Statement::LoadData { local, inpath, @@ -11535,9 +11538,8 @@ fn parse_load_data() { }; // with OVERWRITE keyword - match dialects - .verified_stmt("LOAD DATA INPATH '/local/path/to/data.txt' OVERWRITE INTO TABLE my_table") - { + let sql = "LOAD DATA INPATH '/local/path/to/data.txt' OVERWRITE INTO TABLE my_table"; + match dialects.verified_stmt(sql) { Statement::LoadData { local, inpath, @@ -11556,10 +11558,19 @@ fn parse_load_data() { _ => unreachable!(), }; + assert_eq!( + only_supports_load_extension_dialects + .parse_sql_statements(sql) + .unwrap_err(), + ParserError::ParserError("Expected: end of statement, found: INPATH".to_string()) + ); + assert_eq!( not_supports_load_dialects.parse_sql_statements(sql).unwrap_err(), + ParserError::ParserError("Expected: dialect supports `LOAD DATA` or `LOAD extension` to parse `LOAD` statements, found: INPATH".to_string()) + ); + // with LOCAL keyword - match dialects - .verified_stmt("LOAD DATA LOCAL INPATH '/local/path/to/data.txt' INTO TABLE test.my_table") - { + let sql = "LOAD DATA LOCAL INPATH '/local/path/to/data.txt' INTO TABLE test.my_table"; + match dialects.verified_stmt(sql) { Statement::LoadData { local, inpath, @@ -11581,93 +11592,108 @@ fn parse_load_data() { _ => unreachable!(), }; + assert_eq!( + only_supports_load_extension_dialects + .parse_sql_statements(sql) + .unwrap_err(), + ParserError::ParserError("Expected: end of statement, found: LOCAL".to_string()) + ); + assert_eq!( + not_supports_load_dialects.parse_sql_statements(sql).unwrap_err(), + ParserError::ParserError("Expected: dialect supports `LOAD DATA` or `LOAD extension` to parse `LOAD` statements, found: LOCAL".to_string()) + ); + // with PARTITION clause - match dialects.verified_stmt("LOAD DATA LOCAL INPATH '/local/path/to/data.txt' INTO TABLE my_table PARTITION (year = 2024, month = 11)") { - Statement::LoadData {local, inpath, overwrite, table_name, partitioned, table_format} => { + let sql = "LOAD DATA LOCAL INPATH '/local/path/to/data.txt' INTO TABLE my_table PARTITION (year = 2024, month = 11)"; + match dialects.verified_stmt(sql) { + Statement::LoadData { + local, + inpath, + overwrite, + table_name, + partitioned, + table_format, + } => { assert_eq!(true, local); assert_eq!("/local/path/to/data.txt", inpath); assert_eq!(false, overwrite); assert_eq!(ObjectName(vec![Ident::new("my_table")]), table_name); - assert_eq!(Some(vec![ - Expr::BinaryOp{ - left: Box::new(Expr::Identifier(Ident::new("year"))), - op: BinaryOperator::Eq, - right: Box::new(Expr::Value(Value::Number("2024".parse().unwrap(), false))), - }, - Expr::BinaryOp{ - left: Box::new(Expr::Identifier(Ident::new("month"))), - op: BinaryOperator::Eq, - right: Box::new(Expr::Value(Value::Number("11".parse().unwrap(), false))), - }]), partitioned); + assert_eq!( + Some(vec![ + Expr::BinaryOp { + left: Box::new(Expr::Identifier(Ident::new("year"))), + op: BinaryOperator::Eq, + right: Box::new(Expr::Value(Value::Number("2024".parse().unwrap(), false))), + }, + Expr::BinaryOp { + left: Box::new(Expr::Identifier(Ident::new("month"))), + op: BinaryOperator::Eq, + right: Box::new(Expr::Value(Value::Number("11".parse().unwrap(), false))), + } + ]), + partitioned + ); assert_eq!(None, table_format); } _ => unreachable!(), }; // with PARTITION clause - match dialects.verified_stmt("LOAD DATA LOCAL INPATH '/local/path/to/data.txt' OVERWRITE INTO TABLE good.my_table PARTITION (year = 2024, month = 11) INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat' SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde'") { - Statement::LoadData {local, inpath, overwrite, table_name, partitioned, table_format} => { - assert_eq!(true, local); - assert_eq!("/local/path/to/data.txt", inpath); - assert_eq!(true, overwrite); - assert_eq!(ObjectName(vec![Ident::new("good"), Ident::new("my_table")]), table_name); - assert_eq!(Some(vec![ - Expr::BinaryOp{ + let sql = "LOAD DATA LOCAL INPATH '/local/path/to/data.txt' OVERWRITE INTO TABLE good.my_table PARTITION (year = 2024, month = 11) INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat' SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde'"; + match dialects.verified_stmt(sql) { + Statement::LoadData { + local, + inpath, + overwrite, + table_name, + partitioned, + table_format, + } => { + assert_eq!(true, local); + assert_eq!("/local/path/to/data.txt", inpath); + assert_eq!(true, overwrite); + assert_eq!( + ObjectName(vec![Ident::new("good"), Ident::new("my_table")]), + table_name + ); + assert_eq!( + Some(vec![ + Expr::BinaryOp { left: Box::new(Expr::Identifier(Ident::new("year"))), op: BinaryOperator::Eq, - right: Box::new(Expr::Value(Value::Number("2024".parse().unwrap(), false))), + right: Box::new(Expr::Value(Value::Number("2024".parse().unwrap(), false))), }, - Expr::BinaryOp{ + Expr::BinaryOp { left: Box::new(Expr::Identifier(Ident::new("month"))), op: BinaryOperator::Eq, - right: Box::new(Expr::Value(Value::Number("11".parse().unwrap(), false))), - }]), partitioned); - assert_eq!(Some(HiveLoadDataFormat {serde: Expr::Value(Value::SingleQuotedString("org.apache.hadoop.hive.serde2.OpenCSVSerde".to_string())), input_format: Expr::Value(Value::SingleQuotedString("org.apache.hadoop.mapred.TextInputFormat".to_string()))}), table_format); - } - _ => unreachable!(), - }; + right: Box::new(Expr::Value(Value::Number("11".parse().unwrap(), false))), + } + ]), + partitioned + ); + assert_eq!( + Some(HiveLoadDataFormat { + serde: Expr::Value(Value::SingleQuotedString( + "org.apache.hadoop.hive.serde2.OpenCSVSerde".to_string() + )), + input_format: Expr::Value(Value::SingleQuotedString( + "org.apache.hadoop.mapred.TextInputFormat".to_string() + )) + }), + table_format + ); + } + _ => unreachable!(), + }; // negative test case + let sql = "LOAD DATA2 LOCAL INPATH '/local/path/to/data.txt' INTO TABLE test.my_table"; assert_eq!( dialects - .parse_sql_statements( - "LOAD DATA2 LOCAL INPATH '/local/path/to/data.txt' INTO TABLE test.my_table" - ) + .parse_sql_statements(sql) .unwrap_err(), ParserError::ParserError("Expected: dialect supports `LOAD DATA` or `LOAD extension` to parse `LOAD` statements, found: DATA2".to_string()) ); - - let dialects = all_dialects_where(|d| !d.supports_load_data() && d.supports_load_extension()); - - assert_eq!( - dialects - .parse_sql_statements( - "LOAD DATA LOCAL INPATH '/local/path/to/data.txt' INTO TABLE test.my_table" - ) - .unwrap_err(), - ParserError::ParserError("Expected: end of statement, found: LOCAL".to_string()) - ); - - assert_eq!( - dialects - .parse_sql_statements( - "LOAD DATA INPATH '/local/path/to/data.txt' INTO TABLE test.my_table" - ) - .unwrap_err(), - ParserError::ParserError("Expected: end of statement, found: INPATH".to_string()) - ); - - let dialects = all_dialects_where(|d| !d.supports_load_data() && !d.supports_load_extension()); - - assert_eq!( - dialects.parse_sql_statements("LOAD DATA LOCAL INPATH '/local/path/to/data.txt' INTO TABLE test.my_table").unwrap_err(), - ParserError::ParserError("Expected: dialect supports `LOAD DATA` or `LOAD extension` to parse `LOAD` statements, found: LOCAL".to_string()) - ); - - assert_eq!( - dialects.parse_sql_statements("LOAD DATA INPATH '/local/path/to/data.txt' INTO TABLE test.my_table").unwrap_err(), - ParserError::ParserError("Expected: dialect supports `LOAD DATA` or `LOAD extension` to parse `LOAD` statements, found: INPATH".to_string()) - ); } #[test] From fad038eab9d0b09af5c44afc83613dda1580510f Mon Sep 17 00:00:00 2001 From: wugeer <1284057728@qq.com> Date: Wed, 13 Nov 2024 22:08:09 +0800 Subject: [PATCH 4/6] Update src/parser/mod.rs more insightful suggestions. Co-authored-by: Ifeanyi Ubah --- src/parser/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index f2d7005aa..a65623abe 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -12218,7 +12218,7 @@ impl<'a> Parser<'a> { }) } else { self.expected( - "dialect supports `LOAD DATA` or `LOAD extension` to parse `LOAD` statements", + " DATA` or an extension name after `LOAD`", self.peek_token(), ) } From 42c8743d95c45f8328876e0b5155d977e38a9f0d Mon Sep 17 00:00:00 2001 From: wugeer <1284057728@qq.com> Date: Wed, 13 Nov 2024 22:23:35 +0800 Subject: [PATCH 5/6] add more insightful suggestions --- src/parser/mod.rs | 2 +- tests/sqlparser_common.rs | 27 ++++++++++++++++++--------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/parser/mod.rs b/src/parser/mod.rs index a65623abe..d246d0eb5 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -12218,7 +12218,7 @@ impl<'a> Parser<'a> { }) } else { self.expected( - " DATA` or an extension name after `LOAD`", + "`DATA` or an extension name after `LOAD`", self.peek_token(), ) } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 62e06e826..68da8e601 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -11564,8 +11564,13 @@ fn parse_load_data() { .unwrap_err(), ParserError::ParserError("Expected: end of statement, found: INPATH".to_string()) ); - assert_eq!( not_supports_load_dialects.parse_sql_statements(sql).unwrap_err(), - ParserError::ParserError("Expected: dialect supports `LOAD DATA` or `LOAD extension` to parse `LOAD` statements, found: INPATH".to_string()) + assert_eq!( + not_supports_load_dialects + .parse_sql_statements(sql) + .unwrap_err(), + ParserError::ParserError( + "Expected: `DATA` or an extension name after `LOAD`, found: INPATH".to_string() + ) ); // with LOCAL keyword @@ -11599,8 +11604,12 @@ fn parse_load_data() { ParserError::ParserError("Expected: end of statement, found: LOCAL".to_string()) ); assert_eq!( - not_supports_load_dialects.parse_sql_statements(sql).unwrap_err(), - ParserError::ParserError("Expected: dialect supports `LOAD DATA` or `LOAD extension` to parse `LOAD` statements, found: LOCAL".to_string()) + not_supports_load_dialects + .parse_sql_statements(sql) + .unwrap_err(), + ParserError::ParserError( + "Expected: `DATA` or an extension name after `LOAD`, found: LOCAL".to_string() + ) ); // with PARTITION clause @@ -11689,11 +11698,11 @@ fn parse_load_data() { // negative test case let sql = "LOAD DATA2 LOCAL INPATH '/local/path/to/data.txt' INTO TABLE test.my_table"; assert_eq!( - dialects - .parse_sql_statements(sql) - .unwrap_err(), - ParserError::ParserError("Expected: dialect supports `LOAD DATA` or `LOAD extension` to parse `LOAD` statements, found: DATA2".to_string()) - ); + dialects.parse_sql_statements(sql).unwrap_err(), + ParserError::ParserError( + "Expected: `DATA` or an extension name after `LOAD`, found: DATA2".to_string() + ) + ); } #[test] From 61227cd0adf9bba07db094b738790f218a4ab58e Mon Sep 17 00:00:00 2001 From: wugeer <1284057728@qq.com> Date: Wed, 13 Nov 2024 22:36:16 +0800 Subject: [PATCH 6/6] add `LOAD extension` for Generic dialect --- src/dialect/generic.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/dialect/generic.rs b/src/dialect/generic.rs index 8cfac217b..4998e0f4b 100644 --- a/src/dialect/generic.rs +++ b/src/dialect/generic.rs @@ -115,4 +115,8 @@ impl Dialect for GenericDialect { fn supports_comment_on(&self) -> bool { true } + + fn supports_load_extension(&self) -> bool { + true + } }