diff --git a/src/ast/data_type.rs b/src/ast/data_type.rs index 2a6a004f4..9dd0e2265 100644 --- a/src/ast/data_type.rs +++ b/src/ast/data_type.rs @@ -20,7 +20,8 @@ use serde::{Deserialize, Serialize}; #[cfg(feature = "visitor")] use sqlparser_derive::{Visit, VisitMut}; -use crate::ast::ObjectName; +use crate::ast::{display_comma_separated, ObjectName}; +use crate::ast::ddl::StructField; use super::value::escape_single_quote_string; @@ -198,10 +199,15 @@ pub enum DataType { Custom(ObjectName, Vec), /// Arrays Array(Option>), + BracketArray(Option>), /// Enums Enum(Vec), /// Set Set(Vec), + /// Map + Map(Box, Box), + /// Struct + Struct(Vec) } impl fmt::Display for DataType { @@ -320,11 +326,17 @@ impl fmt::Display for DataType { DataType::Bytea => write!(f, "BYTEA"), DataType::Array(ty) => { if let Some(t) = &ty { - write!(f, "{t}[]") + write!(f, "ARRAY<{t}>") } else { write!(f, "ARRAY") } } + DataType::BracketArray(ty) => { + if let Some(t) = &ty { + write!(f, "{t}[]")? + } + Ok(()) + } DataType::Custom(ty, modifiers) => { if modifiers.is_empty() { write!(f, "{ty}") @@ -352,6 +364,14 @@ impl fmt::Display for DataType { } write!(f, ")") } + DataType::Map(key, value) => { + write!(f, "MAP<{}>", display_comma_separated(&[key, value])) + } + DataType::Struct(vals) => { + write!(f, "STRUCT<")?; + write!(f, "{}", display_comma_separated(vals))?; + write!(f, ">") + } } } } diff --git a/src/ast/ddl.rs b/src/ast/ddl.rs index a4640d557..1c5b536f4 100644 --- a/src/ast/ddl.rs +++ b/src/ast/ddl.rs @@ -678,6 +678,25 @@ impl fmt::Display for ColumnOption { } } +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct StructField { + pub(crate) name: Ident, + pub(crate) data_type: DataType, + pub(crate) options: Option, +} + +impl fmt::Display for StructField { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}: {}", self.name, self.data_type)?; + if let Some(option) = self.options.as_ref() { + write!(f, "{option}")?; + } + Ok(()) + } +} + /// `GeneratedAs`s are modifiers that follow a column option in a `generated`. /// 'ExpStored' is PostgreSQL specific #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] diff --git a/src/ast/mod.rs b/src/ast/mod.rs index a241f9509..057a22955 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -31,7 +31,7 @@ pub use self::data_type::{ pub use self::dcl::{AlterRoleOperation, ResetConfig, RoleOption, SetConfigValue}; pub use self::ddl::{ AlterColumnOperation, AlterIndexOperation, AlterTableOperation, ColumnDef, ColumnOption, - ColumnOptionDef, GeneratedAs, IndexType, KeyOrIndexDisplay, ProcedureParam, ReferentialAction, + ColumnOptionDef, GeneratedAs, IndexType, KeyOrIndexDisplay, ProcedureParam, ReferentialAction, StructField, TableConstraint, UserDefinedTypeCompositeAttributeDef, UserDefinedTypeRepresentation, }; pub use self::operator::{BinaryOperator, UnaryOperator}; diff --git a/src/keywords.rs b/src/keywords.rs index c73535fca..4b815fab2 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -365,6 +365,7 @@ define_keywords!( LOWER, MACRO, MANAGEDLOCATION, + MAP, MATCH, MATCHED, MATERIALIZED, @@ -575,6 +576,7 @@ define_keywords!( STORED, STRICT, STRING, + STRUCT, SUBMULTISET, SUBSTRING, SUBSTRING_REGEX, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 846215249..069a00c6c 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -21,6 +21,7 @@ use alloc::{ vec::Vec, }; use core::fmt; +use std::ops::Rem; use log::debug; @@ -114,6 +115,7 @@ mod recursion { Self { remaining_depth } } } + impl Drop for DepthGuard { fn drop(&mut self) { self.remaining_depth.fetch_add(1, Ordering::SeqCst); @@ -257,6 +259,7 @@ pub struct Parser<'a> { options: ParserOptions, /// ensure the stack does not overflow by limiting recursion depth recursion_counter: RecursionCounter, + max_depth: usize, } impl<'a> Parser<'a> { @@ -282,6 +285,7 @@ impl<'a> Parser<'a> { dialect, recursion_counter: RecursionCounter::new(DEFAULT_REMAINING_DEPTH), options: ParserOptions::default(), + max_depth: 1, } } @@ -2181,7 +2185,7 @@ impl<'a> Parser<'a> { token => { return token .cloned() - .unwrap_or_else(|| TokenWithLocation::wrap(Token::EOF)) + .unwrap_or_else(|| TokenWithLocation::wrap(Token::EOF)); } } } @@ -4648,6 +4652,13 @@ impl<'a> Parser<'a> { /// Parse a SQL datatype (in the context of a CREATE TABLE statement for example) pub fn parse_data_type(&mut self) -> Result { + self.parse_data_type_with_depth(1) + } + + pub fn parse_data_type_with_depth(&mut self, depth: usize) -> Result { + if depth > self.max_depth { + self.max_depth = depth - 1; + } let next_token = self.next_token(); let mut data = match next_token.token { Token::Word(w) => match w.keyword { @@ -4837,10 +4848,57 @@ impl<'a> Parser<'a> { // that ends with > will fail due to "C++" problem - >> is parsed as // Token::ShiftRight self.expect_token(&Token::Lt)?; - let inside_type = self.parse_data_type()?; + + let inside_type = self.parse_data_type_with_depth(depth + 1)?; + dbg!(depth, self.max_depth); + + if depth <= 1 { + dbg!("First Level"); + if (depth == 1 && self.max_depth == depth) + || (self.peek_previous_token()? == &Token::ShiftRight + && self.max_depth.rem(2) != 0) + { + self.expect_token(&Token::Gt)?; + } + } else if depth.rem(2) == 0 && depth != self.max_depth { + } else { + dbg!("Else Level"); + self.expect_token(&Token::ShiftRight)?; + } + + if dialect_of!(self is PostgreSqlDialect) { + Ok(DataType::BracketArray(Some(Box::new(inside_type)))) + } else { + Ok(DataType::Array(Some(Box::new(inside_type)))) + } + } + } + Keyword::MAP => { + self.expect_token(&Token::Lt)?; + let key = self.parse_data_type_with_depth(depth + 1)?; + let tok = self.consume_token(&Token::Comma); + debug!("Tok: {tok}"); + let value = self.parse_data_type_with_depth(depth + 1)?; + let tok = self.peek_token().token; + debug!("Next Tok: {tok}"); + if tok == Token::ShiftRight { + self.expect_token(&Token::ShiftRight)?; + } else if tok == Token::Gt { self.expect_token(&Token::Gt)?; - Ok(DataType::Array(Some(Box::new(inside_type)))) } + Ok(DataType::Map(Box::new(key), Box::new(value))) + } + Keyword::STRUCT => { + self.expect_token(&Token::Lt)?; + let fields = self.parse_comma_separated(Parser::parse_struct_fields)?; + let tok = self.peek_token().token; + debug!("Next Tok: {tok}"); + if tok == Token::ShiftRight { + self.expect_token(&Token::ShiftRight)?; + } else if tok == Token::Gt { + self.expect_token(&Token::Gt)?; + } + Ok(DataType::Struct(fields)) } _ => { self.prev_token(); @@ -4859,11 +4917,27 @@ impl<'a> Parser<'a> { // Keyword::ARRAY syntax from above while self.consume_token(&Token::LBracket) { self.expect_token(&Token::RBracket)?; - data = DataType::Array(Some(Box::new(data))) + data = DataType::BracketArray(Some(Box::new(data))) } Ok(data) } + pub fn peek_previous_token(&mut self) -> Result<&TokenWithLocation, ParserError> { + Ok(&self.tokens[self.index - 1]) + } + + pub fn parse_struct_fields(&mut self) -> Result { + let name = self.parse_identifier()?; + self.expect_token(&Token::Colon)?; + let data_type = self.parse_data_type()?; + let options = self.parse_optional_column_option()?; + Ok(StructField { + name, + data_type, + options, + }) + } + pub fn parse_string_values(&mut self) -> Result, ParserError> { self.expect_token(&Token::LParen)?; let mut values = Vec::new(); @@ -5028,12 +5102,12 @@ impl<'a> Parser<'a> { Token::EOF => { return Err(ParserError::ParserError( "Empty input when parsing identifier".to_string(), - ))? + ))?; } token => { return Err(ParserError::ParserError(format!( "Unexpected token in identifier: {token}" - )))? + )))?; } }; @@ -5046,19 +5120,19 @@ impl<'a> Parser<'a> { Token::EOF => { return Err(ParserError::ParserError( "Trailing period in identifier".to_string(), - ))? + ))?; } token => { return Err(ParserError::ParserError(format!( "Unexpected token following period in identifier: {token}" - )))? + )))?; } }, Token::EOF => break, token => { return Err(ParserError::ParserError(format!( "Unexpected token in identifier: {token}" - )))? + )))?; } } } @@ -6031,7 +6105,7 @@ impl<'a> Parser<'a> { _ => { return Err(ParserError::ParserError(format!( "expected OUTER, SEMI, ANTI or JOIN after {kw:?}" - ))) + ))); } } } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 3fdf3d211..d9e132e27 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -2416,17 +2416,41 @@ fn parse_create_table() { .contains("Expected constraint details after CONSTRAINT ")); } +#[test] +fn parse_double_greater_than_array() { + let supported_dialects = TestedDialects { + dialects: vec![Box::new(HiveDialect {})], + options: None, + }; + let levels = &[ + "CREATE TABLE t (a ARRAY, n INT)", + "CREATE TABLE t (a ARRAY>, n INT)", + "CREATE TABLE t (a ARRAY>>, n INT)", + "CREATE TABLE t (a ARRAY>>>, n INT)", + "CREATE TABLE t (a ARRAY>>>>, n INT)", + "CREATE TABLE t (a ARRAY>>>>>, n INT)", + "CREATE TABLE t (a ARRAY>>>>>>, n INT)", + "CREATE TABLE t (a ARRAY>>>>>>>, n INT)", + "CREATE TABLE t (a ARRAY>>>>>>>>, n INT)", + "CREATE TABLE t (a ARRAY>>>>>>>>>, n INT)", + ]; + for q in levels { + let statements = supported_dialects.parse_sql_statements(q).unwrap(); + println!("{}", statements[0]); + } +} + #[test] fn parse_create_table_hive_array() { // Parsing [] type arrays does not work in MsSql since [ is used in is_delimited_identifier_start let dialects = TestedDialects { - dialects: vec![Box::new(PostgreSqlDialect {}), Box::new(HiveDialect {})], + dialects: vec![Box::new(HiveDialect {})], options: None, }; let sql = "CREATE TABLE IF NOT EXISTS something (name int, val array)"; match dialects.one_statement_parses_to( sql, - "CREATE TABLE IF NOT EXISTS something (name INT, val INT[])", + "CREATE TABLE IF NOT EXISTS something (name INT, val ARRAY)", ) { Statement::CreateTable { if_not_exists, @@ -2457,13 +2481,9 @@ fn parse_create_table_hive_array() { _ => unreachable!(), } - // SnowflakeDialect using array diffrent + // SnowflakeDialect using array different let dialects = TestedDialects { - dialects: vec![ - Box::new(PostgreSqlDialect {}), - Box::new(HiveDialect {}), - Box::new(MySqlDialect {}), - ], + dialects: vec![Box::new(HiveDialect {}), Box::new(MySqlDialect {})], options: None, }; let sql = "CREATE TABLE IF NOT EXISTS something (name int, val array