Skip to content

Commit

Permalink
Implement ValidationContext(schema_map + alias) to enhance validation…
Browse files Browse the repository at this point in the history
… of ambiguous columns (gluesql#883)

# Symptom
gluesql> create table Foo(id int);
gluesql> select * from Foo a join Foo b on a.id = b.id;
| id | id |

gluesql> select id from Foo a join Foo b on a.id = b.id;
| id |
Even if There are two ids, it does not return Error since It gathers table_name with Hashmap

# Fix to
gluesql> select id from Foo a join Foo b on a.id = b.id;
ERROR:  column reference "id" is ambiguous
  • Loading branch information
devgony committed Jan 27, 2023
1 parent 5038730 commit 9a26de4
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 79 deletions.
4 changes: 1 addition & 3 deletions core/src/plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ pub use {

pub async fn plan(storage: &dyn Store, statement: Statement) -> Result<Statement> {
let schema_map = fetch_schema_map(storage, &statement).await?;

let statement = validate(&schema_map, statement)?;

validate(&schema_map, &statement)?;
let statement = plan_primary_key(&schema_map, statement);
let statement = plan_index(&schema_map, statement)?;
let statement = plan_join(&schema_map, statement);
Expand Down
63 changes: 28 additions & 35 deletions core/src/plan/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@ use {
store::Store,
},
async_recursion::async_recursion,
futures::{
future,
stream::{self, StreamExt, TryStreamExt},
},
futures::stream::{self, StreamExt, TryStreamExt},
std::collections::HashMap,
};

Expand All @@ -22,26 +19,17 @@ pub async fn fetch_schema_map(
statement: &Statement,
) -> Result<HashMap<String, Schema>> {
match statement {
Statement::Query(query) => scan_query(storage, query).await.map(|schema_list| {
schema_list
.into_iter()
.map(|schema| (schema.table_name.clone(), schema))
.collect::<HashMap<_, _>>()
}),
Statement::Query(query) => scan_query(storage, query).await,
Statement::Insert {
table_name, source, ..
} => {
let table_schema = storage
.fetch_schema(table_name)
.await?
.map(|schema| vec![schema])
.unwrap_or_else(Vec::new);
.map(|schema| HashMap::from([(table_name.to_owned(), schema)]))
.unwrap_or_else(HashMap::new);
let source_schema_list = scan_query(storage, source).await?;
let schema_list = [table_schema, source_schema_list]
.into_iter()
.flatten()
.map(|schema| (schema.table_name.clone(), schema))
.collect();
let schema_list = table_schema.into_iter().chain(source_schema_list).collect();

Ok(schema_list)
}
Expand All @@ -61,7 +49,7 @@ pub async fn fetch_schema_map(
}
}

async fn scan_query(storage: &dyn Store, query: &Query) -> Result<Vec<Schema>> {
async fn scan_query(storage: &dyn Store, query: &Query) -> Result<HashMap<String, Schema>> {
let Query {
body,
limit,
Expand All @@ -71,7 +59,7 @@ async fn scan_query(storage: &dyn Store, query: &Query) -> Result<Vec<Schema>> {

let schema_list = match body {
SetExpr::Select(select) => scan_select(storage, select).await?,
SetExpr::Values(_) => Vec::new(),
SetExpr::Values(_) => HashMap::new(),
};

let schema_list = match (limit, offset) {
Expand All @@ -90,7 +78,7 @@ async fn scan_query(storage: &dyn Store, query: &Query) -> Result<Vec<Schema>> {
Ok(schema_list)
}

async fn scan_select(storage: &dyn Store, select: &Select) -> Result<Vec<Schema>> {
async fn scan_select(storage: &dyn Store, select: &Select) -> Result<HashMap<String, Schema>> {
let Select {
projection,
from,
Expand All @@ -100,13 +88,13 @@ async fn scan_select(storage: &dyn Store, select: &Select) -> Result<Vec<Schema>
} = select;

let projection = stream::iter(projection)
.then(|select_item| match select_item {
SelectItem::Expr { expr, .. } => scan_expr(storage, expr),
SelectItem::QualifiedWildcard(_) | SelectItem::Wildcard => {
Box::pin(future::ok(Vec::new()))
.then(|select_item| async move {
match select_item {
SelectItem::Expr { expr, .. } => scan_expr(storage, expr).await,
SelectItem::QualifiedWildcard(_) | SelectItem::Wildcard => Ok(HashMap::new()),
}
})
.try_collect::<Vec<Vec<Schema>>>()
.try_collect::<Vec<HashMap<String, Schema>>>()
.await?
.into_iter()
.flatten();
Expand All @@ -117,7 +105,7 @@ async fn scan_select(storage: &dyn Store, select: &Select) -> Result<Vec<Schema>

Ok(stream::iter(exprs)
.then(|expr| scan_expr(storage, expr))
.try_collect::<Vec<Vec<Schema>>>()
.try_collect::<Vec<HashMap<String, Schema>>>()
.await?
.into_iter()
.flatten()
Expand All @@ -129,21 +117,21 @@ async fn scan_select(storage: &dyn Store, select: &Select) -> Result<Vec<Schema>
async fn scan_table_with_joins(
storage: &dyn Store,
table_with_joins: &TableWithJoins,
) -> Result<Vec<Schema>> {
) -> Result<HashMap<String, Schema>> {
let TableWithJoins { relation, joins } = table_with_joins;
let schema_list = scan_table_factor(storage, relation).await?;

Ok(stream::iter(joins)
.then(|join| scan_join(storage, join))
.try_collect::<Vec<Vec<_>>>()
.try_collect::<Vec<HashMap<String, Schema>>>()
.await?
.into_iter()
.flatten()
.chain(schema_list)
.collect())
}

async fn scan_join(storage: &dyn Store, join: &Join) -> Result<Vec<Schema>> {
async fn scan_join(storage: &dyn Store, join: &Join) -> Result<HashMap<String, Schema>> {
let Join {
relation,
join_operator,
Expand All @@ -166,24 +154,29 @@ async fn scan_join(storage: &dyn Store, join: &Join) -> Result<Vec<Schema>> {
}

#[async_recursion(?Send)]
async fn scan_table_factor(storage: &dyn Store, table_factor: &TableFactor) -> Result<Vec<Schema>> {
async fn scan_table_factor(
storage: &dyn Store,
table_factor: &TableFactor,
) -> Result<HashMap<String, Schema>> {
match table_factor {
TableFactor::Table { name, .. } => {
let schema = storage.fetch_schema(name).await?;
let schema_list = schema.map(|schema| vec![schema]).unwrap_or_else(Vec::new);
let schema_list: HashMap<String, Schema> = schema.map_or_else(HashMap::new, |schema| {
HashMap::from([(name.to_owned(), schema)])
});

Ok(schema_list)
}
TableFactor::Derived { subquery, .. } => scan_query(storage, subquery).await,
TableFactor::Series { .. } | TableFactor::Dictionary { .. } => Ok(vec![]),
TableFactor::Series { .. } | TableFactor::Dictionary { .. } => Ok(HashMap::new()),
}
}

#[async_recursion(?Send)]
async fn scan_expr(storage: &dyn Store, expr: &Expr) -> Result<Vec<Schema>> {
async fn scan_expr(storage: &dyn Store, expr: &Expr) -> Result<HashMap<String, Schema>> {
let schema_list = match expr.into() {
PlanExpr::None | PlanExpr::Identifier(_) | PlanExpr::CompoundIdentifier { .. } => {
Vec::new()
HashMap::new()
}
PlanExpr::Expr(expr) => scan_expr(storage, expr).await?,
PlanExpr::TwoExprs(expr, expr2) => scan_expr(storage, expr)
Expand All @@ -199,7 +192,7 @@ async fn scan_expr(storage: &dyn Store, expr: &Expr) -> Result<Vec<Schema>> {
.collect(),
PlanExpr::MultiExprs(exprs) => stream::iter(exprs)
.then(|expr| scan_expr(storage, expr))
.try_collect::<Vec<Vec<_>>>()
.try_collect::<Vec<HashMap<String, Schema>>>()
.await?
.into_iter()
.flatten()
Expand Down
193 changes: 157 additions & 36 deletions core/src/plan/validate.rs
Original file line number Diff line number Diff line change
@@ -1,50 +1,171 @@
use {
super::PlanError,
crate::{
ast::{Expr, SelectItem, SetExpr, Statement},
ast::{Expr, Join, Query, SelectItem, SetExpr, Statement, TableFactor, TableWithJoins},
data::Schema,
result::Result,
},
std::collections::HashMap,
std::{collections::HashMap, rc::Rc},
};

type SchemaMap = HashMap<String, Schema>;
/// Validate user select column should not be ambiguous
pub fn validate(schema_map: &HashMap<String, Schema>, statement: Statement) -> Result<Statement> {
if let Statement::Query(query) = &statement {
if let SetExpr::Select(select) = &query.body {
if !select.from.joins.is_empty() {
select
.projection
.iter()
.map(|select_item| {
if let SelectItem::Expr {
expr: Expr::Identifier(ident),
..
} = select_item
{
let tables_with_given_col =
schema_map.iter().filter_map(|(_, schema)| {
match schema.column_defs.as_ref() {
Some(column_defs) => {
column_defs.iter().find(|col| &col.name == ident)
}
None => None,
}
});

if tables_with_given_col.count() > 1 {
return Err(
PlanError::ColumnReferenceAmbiguous(ident.to_owned()).into()
);
}
}

Ok(())
})
.collect::<Result<Vec<()>>>()?;
pub fn validate(schema_map: &SchemaMap, statement: &Statement) -> Result<()> {
if let Statement::Query(Query {
body: SetExpr::Select(select),
..
}) = &statement
{
for select_item in &select.projection {
if let SelectItem::Expr {
expr: Expr::Identifier(ident),
..
} = select_item
{
if let Some(context) = contextualize_stmt(schema_map, statement) {
context.validate_duplicated(ident)?;
}
}
}
}

Ok(statement)
Ok(())
}

enum Context<'a> {
Data {
labels: Option<Vec<&'a String>>,
next: Option<Rc<Context<'a>>>,
},
Bridge {
left: Rc<Context<'a>>,
right: Rc<Context<'a>>,
},
}

impl<'a> Context<'a> {
fn new(labels: Option<Vec<&'a String>>, next: Option<Rc<Context<'a>>>) -> Self {
Self::Data { labels, next }
}

fn concat(left: Option<Rc<Context<'a>>>, right: Option<Rc<Context<'a>>>) -> Option<Rc<Self>> {
match (left, right) {
(Some(left), Some(right)) => Some(Rc::new(Self::Bridge { left, right })),
(context @ Some(_), None) | (None, context @ Some(_)) => context,
(None, None) => None,
}
}

fn validate_duplicated(&self, column_name: &str) -> Result<()> {
fn validate(context: &Context, column_name: &str) -> Result<bool> {
let (left, right) = match context {
Context::Data { labels, next, .. } => {
let current = labels
.as_ref()
.map(|labels| labels.iter().any(|label| *label == column_name))
.unwrap_or(false);

let next = next
.as_ref()
.map(|next| validate(next, column_name))
.unwrap_or(Ok(false))?;

(current, next)
}
Context::Bridge { left, right } => {
let left = validate(left, column_name)?;
let right = validate(right, column_name)?;

(left, right)
}
};

if left && right {
Err(PlanError::ColumnReferenceAmbiguous(column_name.to_owned()).into())
} else {
Ok(left || right)
}
}

validate(self, column_name).map(|_| ())
}
}

fn get_lables(schema: &Schema) -> Option<Vec<&String>> {
schema.column_defs.as_ref().map(|column_defs| {
column_defs
.iter()
.map(|column_def| &column_def.name)
.collect::<Vec<_>>()
})
}

fn contextualize_stmt<'a>(
schema_map: &'a SchemaMap,
statement: &'a Statement,
) -> Option<Rc<Context<'a>>> {
match statement {
Statement::Query(query) => contextualize_query(schema_map, query),
Statement::Insert {
table_name, source, ..
} => {
let table_context = schema_map
.get(table_name)
.map(|schema| Rc::from(Context::new(get_lables(schema), None)));

let source_context = contextualize_query(schema_map, source);

Context::concat(table_context, source_context)
}
Statement::DropTable { names, .. } => names
.iter()
.map(|name| {
let schema = schema_map.get(name);
schema.map(|schema| Rc::from(Context::new(get_lables(schema), None)))
})
.fold(None, Context::concat),
_ => None,
}
}

fn contextualize_query<'a>(schema_map: &'a SchemaMap, query: &'a Query) -> Option<Rc<Context<'a>>> {
let Query { body, .. } = query;
match body {
SetExpr::Select(select) => {
let TableWithJoins { relation, joins } = &select.from;

let by_table = match relation {
TableFactor::Table { name, .. } => {
let schema = schema_map.get(name);
schema.map(|schema| Rc::from(Context::new(get_lables(schema), None)))
}
TableFactor::Derived { subquery, .. } => contextualize_query(schema_map, subquery),
TableFactor::Series { .. } | TableFactor::Dictionary { .. } => None,
}
.map(Rc::from);

let by_joins = joins
.iter()
.map(|Join { relation, .. }| contextualize_table_factor(schema_map, relation))
.fold(None, Context::concat);

Context::concat(by_table, by_joins)
}
SetExpr::Values(_) => None,
}
}

fn contextualize_table_factor<'a>(
schema_map: &'a SchemaMap,
table_factor: &'a TableFactor,
) -> Option<Rc<Context<'a>>> {
match table_factor {
TableFactor::Table { name, .. } => {
let schema = schema_map.get(name);
schema.map(|schema| Rc::from(Context::new(get_lables(schema), None)))
}
TableFactor::Derived { subquery, .. } => contextualize_query(schema_map, subquery),
TableFactor::Series { .. } | TableFactor::Dictionary { .. } => None,
}
.map(Rc::from)
}
Loading

0 comments on commit 9a26de4

Please sign in to comment.