Skip to content

Commit

Permalink
Adds limit_count to QuerySet constructor
Browse files Browse the repository at this point in the history
- Passes limit_count through .filter and .select_related
- Adds a test to verify that .filter persists the limit
  • Loading branch information
Kirt Gittens committed Apr 17, 2019
1 parent fb26f8e commit 39fbe3d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
8 changes: 5 additions & 3 deletions orm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ def __new__(


class QuerySet:
def __init__(self, model_cls=None, filter_clauses=None, select_related=None):
def __init__(self, model_cls=None, filter_clauses=None, select_related=None, limit_count=None):
self.model_cls = model_cls
self.filter_clauses = [] if filter_clauses is None else filter_clauses
self._select_related = [] if select_related is None else select_related
self.limit_count = None
self.limit_count = None if limit_count is None else limit_count

def __get__(self, instance, owner):
return self.__class__(model_cls=owner)
Expand Down Expand Up @@ -147,6 +147,7 @@ def filter(self, **kwargs):
model_cls=self.model_cls,
filter_clauses=filter_clauses,
select_related=select_related,
limit_count=self.limit_count
)

def select_related(self, related):
Expand All @@ -158,6 +159,7 @@ def select_related(self, related):
model_cls=self.model_cls,
filter_clauses=self.filter_clauses,
select_related=related,
limit_count=self.limit_count
)

async def exists(self) -> bool:
Expand All @@ -170,8 +172,8 @@ def limit(self, rows_to_return: int):
model_cls=self.model_cls,
filter_clauses=self.filter_clauses,
select_related=self._select_related,
limit_count=rows_to_return
)
new_query_set.limit_count = rows_to_return
return new_query_set

async def count(self) -> int:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,13 @@ async def test_model_limit():
await User.objects.create(name="Lucy")

assert len(await User.objects.limit(2).all()) == 2


@async_adapter
async def test_model_limit_with_filter():
async with database:
await User.objects.create(name="Tom")
await User.objects.create(name="Tom")
await User.objects.create(name="Tom")

assert len(await User.objects.limit(2).filter(name__iexact='Tom').all()) == 2

0 comments on commit 39fbe3d

Please sign in to comment.