Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/encode/orm into implement…
Browse files Browse the repository at this point in the history
…-queryset-limit
  • Loading branch information
Kirt Gittens committed Apr 17, 2019
2 parents 39fbe3d + 1593ad1 commit 90ffd04
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 4 deletions.
12 changes: 11 additions & 1 deletion orm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def __new__(

class QuerySet:
def __init__(self, model_cls=None, filter_clauses=None, select_related=None, limit_count=None):
ESCAPE_CHARACTERS = ['%', '_']

self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses
self._select_related = [] if select_related is None else select_related
Expand Down Expand Up @@ -133,14 +135,22 @@ def filter(self, **kwargs):
# Map the operation code onto SQLAlchemy's ColumnElement
# https://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.ColumnElement
op_attr = FILTER_OPERATORS[op]
has_escaped_character = False

if op in ["contains", "icontains"]:
value = "%" + value + "%"
has_escaped_character = any(c for c in self.ESCAPE_CHARACTERS
if c in value)
if has_escaped_character:
# enable escape modifier
for char in self.ESCAPE_CHARACTERS:
value = value.replace(char, f'\\{char}')
value = f"%{value}%"

if isinstance(value, Model):
value = value.pk

clause = getattr(column, op_attr)(value)
clause.modifiers['escape'] = '\\' if has_escaped_character else None
filter_clauses.append(clause)

return self.__class__(
Expand Down
1 change: 1 addition & 0 deletions tests/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DATABASE_URL = "sqlite:///test.db"
4 changes: 3 additions & 1 deletion tests/test_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import databases
import orm

DATABASE_URL = "sqlite:///test.db"
from tests.settings import DATABASE_URL


database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()

Expand Down
3 changes: 2 additions & 1 deletion tests/test_foreignkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import databases
import orm

DATABASE_URL = "sqlite:///test.db"
from tests.settings import DATABASE_URL

database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()

Expand Down
16 changes: 15 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import databases
import orm

DATABASE_URL = "sqlite:///test.db"
from tests.settings import DATABASE_URL

database = databases.Database(DATABASE_URL, force_rollback=True)
metadata = sqlalchemy.MetaData()

Expand Down Expand Up @@ -138,6 +139,19 @@ async def test_model_filter():
products = await Product.objects.all(name__icontains="T")
assert len(products) == 2

# Test escaping % character from icontains, contains, and iexact
await Product.objects.create(name="100%-Cotton", rating=3)
await Product.objects.create(name="Cotton-100%-Egyptian", rating=3)
await Product.objects.create(name="Cotton-100%", rating=3)
products = Product.objects.filter(name__iexact="100%-cotton")
assert await products.count() == 1

products = Product.objects.filter(name__contains="%")
assert await products.count() == 3

products = Product.objects.filter(name__icontains="%")
assert await products.count() == 3


@async_adapter
async def test_model_exists():
Expand Down

0 comments on commit 90ffd04

Please sign in to comment.