diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 4f9aac885..b4b32a3a3 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -20,6 +20,7 @@ use alloc::{ }; use core::fmt::{self, Display}; +use core::ops::Deref; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; @@ -993,7 +994,26 @@ impl fmt::Display for LambdaFunction { /// Encapsulates the common pattern in SQL where either one unparenthesized item /// such as an identifier or expression is permitted, or multiple of the same -/// item in a parenthesized list. +/// item in a parenthesized list. For accessing items regardless of the form, +/// `OneOrManyWithParens` implements `Deref` and `IntoIterator`, +/// so you can call slice methods on it and iterate over items +/// # Examples +/// Acessing as a slice: +/// ``` +/// # use sqlparser::ast::OneOrManyWithParens; +/// let one = OneOrManyWithParens::One("a"); +/// +/// assert_eq!(one[0], "a"); +/// assert_eq!(one.len(), 1); +/// ``` +/// Iterating: +/// ``` +/// # use sqlparser::ast::OneOrManyWithParens; +/// let one = OneOrManyWithParens::One("a"); +/// let many = OneOrManyWithParens::Many(vec!["a", "b"]); +/// +/// assert_eq!(one.into_iter().chain(many).collect::>(), vec!["a", "a", "b"] ); +/// ``` #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] @@ -1004,6 +1024,125 @@ pub enum OneOrManyWithParens { Many(Vec), } +impl Deref for OneOrManyWithParens { + type Target = [T]; + + fn deref(&self) -> &[T] { + match self { + OneOrManyWithParens::One(one) => core::slice::from_ref(one), + OneOrManyWithParens::Many(many) => many, + } + } +} + +impl AsRef<[T]> for OneOrManyWithParens { + fn as_ref(&self) -> &[T] { + self + } +} + +impl<'a, T> IntoIterator for &'a OneOrManyWithParens { + type Item = &'a T; + type IntoIter = core::slice::Iter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +/// Owned iterator implementation of `OneOrManyWithParens` +#[derive(Debug, Clone)] +pub struct OneOrManyWithParensIntoIter { + inner: OneOrManyWithParensIntoIterInner, +} + +#[derive(Debug, Clone)] +enum OneOrManyWithParensIntoIterInner { + One(core::iter::Once), + Many( as IntoIterator>::IntoIter), +} + +impl core::iter::FusedIterator for OneOrManyWithParensIntoIter +where + core::iter::Once: core::iter::FusedIterator, + as IntoIterator>::IntoIter: core::iter::FusedIterator, +{ +} + +impl core::iter::ExactSizeIterator for OneOrManyWithParensIntoIter +where + core::iter::Once: core::iter::ExactSizeIterator, + as IntoIterator>::IntoIter: core::iter::ExactSizeIterator, +{ +} + +impl core::iter::Iterator for OneOrManyWithParensIntoIter { + type Item = T; + + fn next(&mut self) -> Option { + match &mut self.inner { + OneOrManyWithParensIntoIterInner::One(one) => one.next(), + OneOrManyWithParensIntoIterInner::Many(many) => many.next(), + } + } + + fn size_hint(&self) -> (usize, Option) { + match &self.inner { + OneOrManyWithParensIntoIterInner::One(one) => one.size_hint(), + OneOrManyWithParensIntoIterInner::Many(many) => many.size_hint(), + } + } + + fn count(self) -> usize + where + Self: Sized, + { + match self.inner { + OneOrManyWithParensIntoIterInner::One(one) => one.count(), + OneOrManyWithParensIntoIterInner::Many(many) => many.count(), + } + } + + fn fold(mut self, init: B, f: F) -> B + where + Self: Sized, + F: FnMut(B, Self::Item) -> B, + { + match &mut self.inner { + OneOrManyWithParensIntoIterInner::One(one) => one.fold(init, f), + OneOrManyWithParensIntoIterInner::Many(many) => many.fold(init, f), + } + } +} + +impl core::iter::DoubleEndedIterator for OneOrManyWithParensIntoIter { + fn next_back(&mut self) -> Option { + match &mut self.inner { + OneOrManyWithParensIntoIterInner::One(one) => one.next_back(), + OneOrManyWithParensIntoIterInner::Many(many) => many.next_back(), + } + } +} + +impl IntoIterator for OneOrManyWithParens { + type Item = T; + + type IntoIter = OneOrManyWithParensIntoIter; + + fn into_iter(self) -> Self::IntoIter { + let inner = match self { + OneOrManyWithParens::One(one) => { + OneOrManyWithParensIntoIterInner::One(core::iter::once(one)) + } + OneOrManyWithParens::Many(many) => { + OneOrManyWithParensIntoIterInner::Many(many.into_iter()) + } + }; + + OneOrManyWithParensIntoIter { inner } + } +} + impl fmt::Display for OneOrManyWithParens where T: fmt::Display, @@ -6919,4 +7058,178 @@ mod tests { }); assert_eq!("INTERVAL '5' SECOND (1, 3)", format!("{interval}")); } + + #[test] + fn test_one_or_many_with_parens_deref() { + use core::ops::Index; + + let one = OneOrManyWithParens::One("a"); + + assert_eq!(one.deref(), &["a"]); + assert_eq!( as Deref>::deref(&one), &["a"]); + + assert_eq!(one[0], "a"); + assert_eq!(one.index(0), &"a"); + assert_eq!( + < as Deref>::Target as Index>::index(&one, 0), + &"a" + ); + + assert_eq!(one.len(), 1); + assert_eq!( as Deref>::Target::len(&one), 1); + + let many1 = OneOrManyWithParens::Many(vec!["b"]); + + assert_eq!(many1.deref(), &["b"]); + assert_eq!( as Deref>::deref(&many1), &["b"]); + + assert_eq!(many1[0], "b"); + assert_eq!(many1.index(0), &"b"); + assert_eq!( + < as Deref>::Target as Index>::index(&many1, 0), + &"b" + ); + + assert_eq!(many1.len(), 1); + assert_eq!( as Deref>::Target::len(&many1), 1); + + let many2 = OneOrManyWithParens::Many(vec!["c", "d"]); + + assert_eq!(many2.deref(), &["c", "d"]); + assert_eq!( + as Deref>::deref(&many2), + &["c", "d"] + ); + + assert_eq!(many2[0], "c"); + assert_eq!(many2.index(0), &"c"); + assert_eq!( + < as Deref>::Target as Index>::index(&many2, 0), + &"c" + ); + + assert_eq!(many2[1], "d"); + assert_eq!(many2.index(1), &"d"); + assert_eq!( + < as Deref>::Target as Index>::index(&many2, 1), + &"d" + ); + + assert_eq!(many2.len(), 2); + assert_eq!( as Deref>::Target::len(&many2), 2); + } + + #[test] + fn test_one_or_many_with_parens_as_ref() { + let one = OneOrManyWithParens::One("a"); + + assert_eq!(one.as_ref(), &["a"]); + assert_eq!( as AsRef<_>>::as_ref(&one), &["a"]); + + let many1 = OneOrManyWithParens::Many(vec!["b"]); + + assert_eq!(many1.as_ref(), &["b"]); + assert_eq!( as AsRef<_>>::as_ref(&many1), &["b"]); + + let many2 = OneOrManyWithParens::Many(vec!["c", "d"]); + + assert_eq!(many2.as_ref(), &["c", "d"]); + assert_eq!( + as AsRef<_>>::as_ref(&many2), + &["c", "d"] + ); + } + + #[test] + fn test_one_or_many_with_parens_ref_into_iter() { + let one = OneOrManyWithParens::One("a"); + + assert_eq!(Vec::from_iter(&one), vec![&"a"]); + + let many1 = OneOrManyWithParens::Many(vec!["b"]); + + assert_eq!(Vec::from_iter(&many1), vec![&"b"]); + + let many2 = OneOrManyWithParens::Many(vec!["c", "d"]); + + assert_eq!(Vec::from_iter(&many2), vec![&"c", &"d"]); + } + + #[test] + fn test_one_or_many_with_parens_value_into_iter() { + use core::iter::once; + + //tests that our iterator implemented methods behaves exactly as it's inner iterator, at every step up to n calls to next/next_back + fn test_steps(ours: OneOrManyWithParens, inner: I, n: usize) + where + I: IntoIterator + Clone, + { + fn checks(ours: OneOrManyWithParensIntoIter, inner: I) + where + I: Iterator + Clone + DoubleEndedIterator, + { + assert_eq!(ours.size_hint(), inner.size_hint()); + assert_eq!(ours.clone().count(), inner.clone().count()); + + assert_eq!( + ours.clone().fold(1, |a, v| a + v), + inner.clone().fold(1, |a, v| a + v) + ); + + assert_eq!(Vec::from_iter(ours.clone()), Vec::from_iter(inner.clone())); + assert_eq!( + Vec::from_iter(ours.clone().rev()), + Vec::from_iter(inner.clone().rev()) + ); + } + + let mut ours_next = ours.clone().into_iter(); + let mut inner_next = inner.clone().into_iter(); + + for _ in 0..n { + checks(ours_next.clone(), inner_next.clone()); + + assert_eq!(ours_next.next(), inner_next.next()); + } + + let mut ours_next_back = ours.clone().into_iter(); + let mut inner_next_back = inner.clone().into_iter(); + + for _ in 0..n { + checks(ours_next_back.clone(), inner_next_back.clone()); + + assert_eq!(ours_next_back.next_back(), inner_next_back.next_back()); + } + + let mut ours_mixed = ours.clone().into_iter(); + let mut inner_mixed = inner.clone().into_iter(); + + for i in 0..n { + checks(ours_mixed.clone(), inner_mixed.clone()); + + if i % 2 == 0 { + assert_eq!(ours_mixed.next_back(), inner_mixed.next_back()); + } else { + assert_eq!(ours_mixed.next(), inner_mixed.next()); + } + } + + let mut ours_mixed2 = ours.into_iter(); + let mut inner_mixed2 = inner.into_iter(); + + for i in 0..n { + checks(ours_mixed2.clone(), inner_mixed2.clone()); + + if i % 2 == 0 { + assert_eq!(ours_mixed2.next(), inner_mixed2.next()); + } else { + assert_eq!(ours_mixed2.next_back(), inner_mixed2.next_back()); + } + } + } + + test_steps(OneOrManyWithParens::One(1), once(1), 3); + test_steps(OneOrManyWithParens::Many(vec![2]), vec![2], 3); + test_steps(OneOrManyWithParens::Many(vec![3, 4]), vec![3, 4], 4); + } }