Skip to content
This repository has been archived by the owner on Apr 24, 2020. It is now read-only.

Commit

Permalink
[#15] Can filter SelectQuery with sub-queries.
Browse files Browse the repository at this point in the history
  • Loading branch information
ducdetronquito committed Feb 4, 2017
1 parent 6eef67e commit 4fda01c
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 81 deletions.
96 changes: 45 additions & 51 deletions plume/plume.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@
]


class CSV(tuple):
"""Output a iterable as coma-separated value."""
__slots__ = ()

def __str__(self):
return ', '.join(str(value) for value in self)


class BracketCSV(CSV):
"""Output a iterable as coma-separated value between brackets."""
__slots__ = ()

def __str__(self):
return '(' + super().__str__() + ')'


class SQLiteAPI:
# Create table query
AUTOINCREMENT = 'AUTOINCREMENT'
Expand All @@ -19,6 +35,7 @@ class SQLiteAPI:
NOT_NULL = 'NOT NULL'
PK = 'PRIMARY KEY'
REAL = 'REAL'
REFERENCES = 'REFERENCES'
TABLE = 'TABLE'
TEXT = 'TEXT'
UNIQUE = 'UNIQUE'
Expand Down Expand Up @@ -68,7 +85,7 @@ class SQLiteAPI:
def create_table(cls, name, fields):
query = (
cls.CREATE, cls.TABLE, cls.IF, cls.invert[cls.EXISTS],
name.lower(), cls.to_csv(fields, bracket=True),
name.lower(), str(fields),
)
return ' '.join(query)

Expand All @@ -92,49 +109,33 @@ def drop_table(cls, name):
def insert_into(cls, table_name, field_names):
query = (
cls.INSERT, table_name.lower(),
cls.to_csv(field_names, bracket=True),
str(field_names),
cls.VALUES,
cls.to_csv([cls.PLACEHOLDER] * len(field_names), bracket=True)
str(BracketCSV([cls.PLACEHOLDER] * len(field_names))),
)
return ' '.join(query)

@classmethod
def select(cls, tables, fields=None, where=None, count=None, offset=None):
query = [
cls.SELECT, cls.to_csv(fields or cls.ALL).lower(),
cls.FROM, cls.to_csv(tables).lower(),
]
fields = str(fields) if fields else cls.ALL
query = [cls.SELECT, fields, cls.FROM, str(tables)]

if where is not None:
query.extend((cls.WHERE, str(where)))

if count is not None and offset is not None:
query.extend((cls.LIMIT, str(count), cls.OFFSET, str(offset)))
return ' '.join(query)

@classmethod
def update(cls, table_name, fields, where=None):
query = [
cls.UPDATE, table_name.lower(), cls.SET, cls.to_csv(str(fields), bracket=False)
]
query = [cls.UPDATE, table_name.lower(), cls.SET, CSV(fields)]

if where is not None:
query.extend((cls.WHERE, str(where)))

return ' '.join(query)

@staticmethod
def to_csv(values, bracket=False):
"""Convert a string value or a sequence of string values into a coma-separated string."""
try:
values = values.split()
except AttributeError:
pass

csv = ', '.join(values)

return '(' + csv + ')' if bracket else csv


class FilterableQuery:
__slots__ = ('_where',)
Expand Down Expand Up @@ -187,7 +188,7 @@ class SelectQuery(FilterableQuery):
def __init__(self, model, fields=None, where=None):
super().__init__(where)
self._model = model
self._tables = [model.__name__.lower()]
self._tables = CSV([model.__name__.lower()])
self._fields = fields
self._count = None
self._offset = None
Expand All @@ -200,7 +201,7 @@ def __str__(self):

def select(self, *args):
# Allow to filter Select-Query on columns.
self._fields = [str(field) for field in args]
self._fields = CSV(args)
return self

def limit(self, count, offset):
Expand Down Expand Up @@ -274,7 +275,7 @@ def _execute(self):

def dicts(self, allow_fields_subset=True):
"""Query the database and returns the result as a list of dict"""
fields = self._fields or self._model._fieldnames
fields = [field.name for field in self._fields] if self._fields else self._model._fieldnames

return [
{fieldname: value for fieldname, value in zip(fields, row)}
Expand All @@ -289,8 +290,12 @@ def tuples(self, named=False):
"""Output a SelectQuery as a list of tuples or namedtuples"""
if not named:
return self._execute()

fields = self._fields or self._model._fieldnames

if self._fields:
fields = CSV(field.name for field in self._fields)
else:
fields = CSV(self._model._fieldnames)

factory = namedtuple('factory', fields)
return [factory._make(row) for row in self._execute()]

Expand Down Expand Up @@ -398,7 +403,7 @@ def __ne__(self, other):

def __rshift__(self, expressions):
if not isinstance(expressions, SelectQuery):
expressions = CSExpression([self.format(exp) for exp in expressions])
expressions = BracketCSV((self.format(exp) for exp in expressions))
return Expression(self, SQLiteAPI.IN, expressions)


Expand All @@ -414,13 +419,6 @@ def __str__(self):
return ' '.join(str(e) for e in (self.lo, self.op, self.ro) if e is not None)


class CSExpression(tuple):
"""Coma-separated expression"""

def __str__(self):
return '(' + ', '.join(str(element) for element in self) + ')'


class Field(Node):
__slots__ = ('default', 'model_name', 'name', 'required', 'unique', 'value')
internal_type = None
Expand Down Expand Up @@ -564,7 +562,7 @@ def is_valid(self, value):


def sql(self):
return super().sql() + ['REFERENCES', self.related_model.__name__.lower() + '(pk)']
return super().sql() + [SQLiteAPI.REFERENCES, self.related_model.__name__.lower() + '(pk)']


class Model(metaclass=BaseModel):
Expand All @@ -591,19 +589,14 @@ def __eq__(self, other):
@classmethod
def create(cls, **kwargs):
"""Return an instance of the related model."""
field_names = [fieldname for fieldname in cls._fieldnames if fieldname in kwargs]
values = [kwargs[fieldname] for fieldname in cls._fieldnames if fieldname in kwargs]

fields = BracketCSV(field for field in cls._fieldnames if field in kwargs)
values = [kwargs[field] for field in cls._fieldnames if field in kwargs]
values = [value.pk if isinstance(value, Model) else value for value in values]

query = SQLiteAPI.insert_into(cls.__name__, field_names)

last_row_id = cls._db.insert_into(query, values)
kwargs = {field: value for field, value in zip(field_names, values)}

last_row_id = cls._db.insert_into(cls.__name__, fields, values)
kwargs = {field: value for field, value in zip(fields, values)}
kwargs['pk'] = last_row_id
instance = cls(**kwargs)

return instance
return cls(**kwargs)

@classmethod
def delete(cls, *args):
Expand Down Expand Up @@ -638,10 +631,10 @@ def drop_table(self, model_class):
self._connection.commit()

def create_table(self, model_class):
fields = [
fields = BracketCSV((
' '.join(getattr(model_class, fieldname).sql())
for fieldname in model_class._fieldnames
]
))

query = SQLiteAPI.create_table(model_class.__name__, fields)

Expand All @@ -656,8 +649,9 @@ def delete(self, table, where=None):
cursor.execute(query)
self._connection.commit()

def insert_into(self, query, values):
def insert_into(self, table, fields, values):
last_row_id = None
query = SQLiteAPI.insert_into(table, fields)

with closing(self._connection.cursor()) as cursor:
cursor.execute(query, values)
Expand Down
71 changes: 48 additions & 23 deletions tests/test_selectquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_can_select_fields(self):

def test_can_select_fields_as_strings(self):
query = SelectQuery(Trainer)
query.select('name', 'age').execute()
query.select(Trainer.name, Trainer.age).execute()

def test_can_select_fields_as_model_fields(self):
query = SelectQuery(Trainer)
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_can_return_result_as_list_of_dicts(self):

def test_can_select_fields_when_returning_dicts(self):
self.add_trainer(['Giovanni'])
result = SelectQuery(Trainer).select('name').dicts()
result = SelectQuery(Trainer).select(Trainer.name).dicts()
assert len(result) == 1
expected_dict = result[0]
assert len(expected_dict.keys()) == 1
Expand All @@ -75,7 +75,7 @@ def test_can_return_result_as_a_list_of_model_instances(self):
def test_fail_to_select_fields_when_returning_model_instances(self):
self.add_trainer(['Giovanni'])
with pytest.raises(AttributeError):
result = SelectQuery(Trainer).select('name').execute()
result = SelectQuery(Trainer).select(Trainer.name).execute()

def test_can_return_result_as_a_list_of_tuples(self):
self.add_trainer(['Giovanni', 'James', 'Jessie'])
Expand All @@ -87,7 +87,7 @@ def test_can_return_result_as_a_list_of_tuples(self):

def test_can_select_fields_when_returning_tuples(self):
self.add_trainer(['Giovanni'])
result = SelectQuery(Trainer).select('name').tuples()
result = SelectQuery(Trainer).select(Trainer.name).tuples()
assert len(result) == 1
expected_tuple = result[0]
assert len(expected_tuple) == 1
Expand All @@ -106,7 +106,7 @@ def test_can_return_result_as_a_list_of_namedtuples(self):

def test_can_select_fields_when_returning_namedtuples(self):
self.add_trainer(['Giovanni'])
result = SelectQuery(Trainer).select('name').tuples(named=True)
result = SelectQuery(Trainer).select(Trainer.name).tuples(named=True)
assert len(result) == 1
expected_namedtuple = result[0]
assert len(expected_namedtuple) == 1
Expand Down Expand Up @@ -220,6 +220,17 @@ def test_select_with_one_filter(self):
assert giovanni.age == 42
assert james.name == 'James'
assert james.age == 21

def test_filter_with_query(self):
self.add_trainer(['Giovanni', 'James', 'Jessie'])
giovanni_age = SelectQuery(Trainer).select(Trainer.age).where(Trainer.name == 'Giovanni')
result = SelectQuery(Trainer).where(Trainer.age < giovanni_age).execute()
assert len(result) == 2
james, jessie = result
assert james.name == 'James'
assert james.age == 21
assert jessie.name == 'Jessie'
assert jessie.age == 17

def test_filter_with_AND_operator(self):
self.add_trainer(['Giovanni', 'James', 'Jessie'])
Expand Down Expand Up @@ -255,6 +266,35 @@ def test_filter_with_IN_operator(self):
assert james.age == 21
assert jessie.name == 'Jessie'
assert jessie.age == 17

def test_IN_operator_filter_with_query(self):
self.add_trainer(['Giovanni', 'James', 'Jessie'])
self.add_pokemon(['Kangaskhan', 'Koffing', 'Wobbuffet'])
trainer_pks = SelectQuery(Trainer).select(Trainer.pk).where(Trainer.name != 'Jessie')
result = SelectQuery(Pokemon).select(Pokemon.name).where(Pokemon.trainer >> trainer_pks).tuples()
assert len(result) == 2
assert result[0][0] == 'Kangaskhan'
assert result[1][0] == 'Koffing'

def test_IN_operator_filter_with_list_of_query(self):
self.add_trainer(['Giovanni', 'James', 'Jessie'])
self.add_pokemon(['Kangaskhan', 'Koffing', 'Wobbuffet'])
giovanni_pk = SelectQuery(Trainer).select(Trainer.pk).where(Trainer.name == 'Giovanni')
james_pk = SelectQuery(Trainer).select(Trainer.pk).where(Trainer.name == 'James')
result = SelectQuery(Pokemon).select(Pokemon.name).where(Pokemon.trainer >> [giovanni_pk, james_pk]).tuples()
assert len(result) == 2
assert result[0][0] == 'Kangaskhan'
assert result[1][0] == 'Koffing'

def test_IN_operator_filter_with_list_of_value_and_query(self):
self.add_trainer(['Giovanni', 'James', 'Jessie'])
self.add_pokemon(['Kangaskhan', 'Koffing', 'Wobbuffet'])
giovanni_pk = SelectQuery(Trainer).select(Trainer.pk).where(Trainer.name == 'Giovanni')
james_pk = 2
result = SelectQuery(Pokemon).select(Pokemon.name).where(Pokemon.trainer >> [giovanni_pk, james_pk]).tuples()
assert len(result) == 2
assert result[0][0] == 'Kangaskhan'
assert result[1][0] == 'Koffing'

def test_filter_with_field_restriction_and_tuples_output(self):
self.add_trainer(['Giovanni', 'James', 'Jessie'])
Expand All @@ -269,12 +309,6 @@ def test_filter_with_field_restriction_and_tuples_output(self):

def test_filter_with_field_restriction_and_namedtuples_output(self):
self.add_trainer(['Giovanni', 'James', 'Jessie'])
"""
assert (
"[NEED FIX]:Cannot output namedtuples with field restriction "
"because fields contains the model name !"
) is False
"""
result = (
Trainer.select(Trainer.name)
.where(Trainer.age >> [17, 21])
Expand All @@ -285,9 +319,9 @@ def test_filter_with_field_restriction_and_namedtuples_output(self):
print(james)
print(result)
assert len(james) == 1
assert james.trainer.name == 'James'
assert james.name == 'James'
assert len(jessie) == 1
assert jessie.trainer.name == 'Jessie'
assert jessie.name == 'Jessie'

def test_filter_on_table_with_related_field(self):
self.add_trainer('Giovanni')
Expand All @@ -301,13 +335,4 @@ def test_filter_on_table_with_related_field(self):
trainer = pokemon.trainer
assert trainer.name == 'Giovanni'
assert trainer.age == 42

def test_filter_with_nested_query(self):
self.add_trainer(['Giovanni', 'James', 'Jessie'])
self.add_pokemon(['Kangaskhan', 'Koffing', 'Wobbuffet'])
trainer_pks = SelectQuery(Trainer).select(Trainer.pk).where(Trainer.name != 'Jessie')
result = SelectQuery(Pokemon).select(Pokemon.name).where(Pokemon.trainer >> trainer_pks).tuples()
assert len(result) == 2
assert result[0][0] == 'Kangaskhan'
assert result[1][0] == 'Koffing'


Loading

0 comments on commit 4fda01c

Please sign in to comment.