Skip to content
This repository has been archived by the owner on Apr 24, 2020. It is now read-only.

Commit

Permalink
[#16] Can create EXISTS statement from SelectQuery.
Browse files Browse the repository at this point in the history
  • Loading branch information
ducdetronquito committed Feb 21, 2017
1 parent 0cf70fe commit 2fd5da0
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 35 deletions.
55 changes: 30 additions & 25 deletions plume/plume.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ def __str__(self):

class Transaction:
__slots__ = ('connection')

def __init__(self, connection):
self.connection = connection

def __enter__(self):
self.connection.execute('BEGIN')

def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type:
self.connection.execute('ROLLBACK')
Expand Down Expand Up @@ -110,10 +110,10 @@ def atomic(self):

def build(self, query, values=None, read_only=False):
raw_query = query.build()

if read_only:
return self.execute(raw_query, values)

if self._connection.in_transaction:
return self.execute(raw_query, values)

Expand All @@ -126,15 +126,15 @@ def build_create(self, query):
query._table.lower(), str(query._fields),
)
return ' '.join(query)

def build_delete(self, query):
output = [self.DELETE, self.FROM, query._table.lower()]

if query._filters is not None:
output.extend((self.WHERE, str(query._filters)))

return ' '.join(output)

def build_drop(self, query):
query = (
self.DROP, self.TABLE, self.IF, self.EXISTS, query._table.lower()
Expand All @@ -157,38 +157,40 @@ def build_select(self, query):
output.append(self.DISTINCT)

fields = str(CSV(query._fields)) if query._fields else self.ALL
tables = (table if isinstance(table, str) else table.__name__ for table in query._tables)
output.extend((fields, self.FROM, str(CSV(tables)).lower()))
output.append(fields)
if query._tables:
tables = (table if isinstance(table, str) else table.__name__ for table in query._tables)
output.extend((self.FROM, str(CSV(tables)).lower()))

if query._filters is not None:
output.extend((self.WHERE, str(query._filters)))

if query._limit is not None:
output.extend((self.LIMIT, str(query._limit)))

if query._offset is not None:
output.extend((self.OFFSET, str(query._offset)))

return ' '.join(output)

def build_update(self, query):
output = [self.UPDATE, query._table.lower(), self.SET, str(CSV(query._fields))]

if query._filters is not None:
output.extend((self.WHERE, str(query._filters)))

return ' '.join(output)

def create(self):
return CreateQuery(db=self)

def delete(self):
return DeleteQuery(db=self)

def drop(self, table=None):
query = DropQuery(db=self)
return query if table is None else query.table(table)

def execute(self, raw_query, values=None):
cursor = self._connection.cursor()
if values is None:
Expand All @@ -211,7 +213,7 @@ def register(self, *args):

def select(self, *args):
return SelectQuery(db=self).select(*args)

def update(self, *args):
return UpdateQuery(db=self).fields(*args)

Expand Down Expand Up @@ -497,7 +499,7 @@ def select(cls, *args):

@classmethod
def update(cls, *args):
return UpdateQuery(db=cls._db).table(model).set(*args)
return UpdateQuery(db=cls._db).table(model).fields(*args)

@classmethod
def where(cls, *args):
Expand Down Expand Up @@ -568,7 +570,7 @@ def __init__(self, db):

def __str__(self):
return ''.join(('(', self.build(), ')'))

def build(self):
return self._db.build_delete(self)

Expand All @@ -586,13 +588,13 @@ class DropQuery:
def __init__(self, db):
self._db = db
self._table = None

def __str__(self):
return self.build()

def build(self):
return self._db.build_drop(self)

def execute(self):
return self._db.build(self)

Expand Down Expand Up @@ -746,6 +748,9 @@ def execute(self):
cursor = self._db.build(self, read_only=True)
return cursor.fetchall()

def exists(self):
return Expression(self._db.EXISTS, self)

def get(self):
"""Returns a list of Model instances."""
model = self._tables[0]
Expand All @@ -755,7 +760,7 @@ def limit(self, limit:int):
""" Slice a SelectQuery without hiting the database."""
self._limit = limit
return self

def offset(self, offset:int):
self._offset = offset
return self
Expand All @@ -781,7 +786,7 @@ def __init__(self, db):

def __str__(self):
return ''.join(('(', self.build(), ')'))

def build(self):
return self._db.build_update(self)

Expand All @@ -796,7 +801,7 @@ def fields(self, *args):

self._fields.extend(args)
return self

def table(self, table):
self._table = table if isinstance(table, str) else table.__name__
return self
38 changes: 28 additions & 10 deletions tests/test_selectquery.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from plume import *
from plume.plume import SelectQuery
from plume.plume import InsertQuery, SelectQuery
from utils import BaseTestCase, Pokemon, Trainer

import pytest
Expand Down Expand Up @@ -31,7 +31,7 @@ def test_can_select_fields_as_model_fields(self):

def test_has_limit_method(self):
assert hasattr(SelectQuery(self.db), 'limit') is True

def test_has_offset_method(self):
assert hasattr(SelectQuery(self.db), 'offset') is True

Expand Down Expand Up @@ -93,9 +93,14 @@ def test_can_select_fields_when_returning_tuples(self):
assert 'Giovanni' in expected_tuple[0]

def test_can_output_selectquery_as_string(self):
result = str(SelectQuery(self.db).tables(Trainer).where(Trainer.age > 18))
query = SelectQuery(self.db).tables(Trainer).where(Trainer.age > 18)
expected = "(SELECT * FROM trainer WHERE trainer.age > 18)"
assert result == expected
assert str(query) == expected

def test_can_ouput_exists_expression(self):
expression = SelectQuery(self.db).tables(Trainer).exists()
expected = 'EXISTS (SELECT * FROM trainer)'
assert str(expression) == expected


class TestSelectQueryLimitMethod(BaseTestCase):
Expand Down Expand Up @@ -315,9 +320,22 @@ def test_result_select_distinct(self):
result = SelectQuery(db=self.db).tables(Trainer).distinct(Trainer.name).execute()
assert len(result) == 1

"""
def test_can_check_if_exists(self):
query = Select().fields().tables()
expected = '(SELECT DISTINCT trainer.name FROM trainer)'
assert str(query) == expected
"""
def test_can_select_with_exists_return_false(self):
result = SelectQuery(self.db).select(
SelectQuery(self.db).tables(Trainer).exists()
).execute()
expected = 0
assert result[0][0] == expected

def test_can_select_with_exists_return_true(self):
InsertQuery(self.db).table(Trainer).from_dicts({
'name': 'Giovanni',
'age': 42
}).execute()

result = SelectQuery(self.db).select(
SelectQuery(self.db).tables(Trainer).exists()
).execute()
expected = 1
assert result[0][0] == expected

0 comments on commit 2fd5da0

Please sign in to comment.