Skip to content

Commit

Permalink
Fixed #35064 -- Fixed Window(order_by) crash with DecimalFields on SQ…
Browse files Browse the repository at this point in the history
…Lite.

This avoids cast of Window(order_by) for DecimalFields on SQLite.

This was achieved by piggy-backing ExpressionList which already
implements a specialized as_sqlite() method to override the inherited
behaviour of Func through SQLiteNumericMixin.

Refs #31723.

Thanks Quoates for the report.
  • Loading branch information
charettes authored and felixxm committed Dec 29, 2023
1 parent 90d365d commit e16d0c1
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 16 deletions.
17 changes: 3 additions & 14 deletions django/db/models/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,12 +1267,12 @@ def as_sqlite(self, compiler, connection, **extra_context):

def get_group_by_cols(self):
group_by_cols = []
for partition in self.get_source_expressions():
group_by_cols.extend(partition.get_group_by_cols())
for expr in self.get_source_expressions():
group_by_cols.extend(expr.get_group_by_cols())
return group_by_cols


class OrderByList(Func):
class OrderByList(ExpressionList):
allowed_default = False
template = "ORDER BY %(expressions)s"

Expand All @@ -1287,17 +1287,6 @@ def __init__(self, *expressions, **extra):
)
super().__init__(*expressions, **extra)

def as_sql(self, *args, **kwargs):
if not self.source_expressions:
return "", ()
return super().as_sql(*args, **kwargs)

def get_group_by_cols(self):
group_by_cols = []
for order_by in self.get_source_expressions():
group_by_cols.extend(order_by.get_group_by_cols())
return group_by_cols


@deconstructible(path="django.db.models.ExpressionWrapper")
class ExpressionWrapper(SQLiteNumericMixin, Expression):
Expand Down
26 changes: 24 additions & 2 deletions tests/expressions_window/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,29 @@ def test_lag_decimalfield(self):
transform=lambda row: (row.name, row.bonus, row.department, row.lag),
)

def test_order_by_decimalfield(self):
qs = Employee.objects.annotate(
rank=Window(expression=Rank(), order_by="bonus")
).order_by("-bonus", "id")
self.assertQuerySetEqual(
qs,
[
("Miller", 250.0, 12),
("Johnson", 200.0, 11),
("Wilkinson", 150.0, 10),
("Smith", 137.5, 9),
("Brown", 132.5, 8),
("Adams", 125.0, 7),
("Jones", 112.5, 5),
("Jenson", 112.5, 5),
("Johnson", 100.0, 4),
("Smith", 95.0, 3),
("Williams", 92.5, 2),
("Moore", 85.0, 1),
],
transform=lambda row: (row.name, float(row.bonus), row.rank),
)

def test_first_value(self):
qs = Employee.objects.annotate(
first_value=Window(
Expand Down Expand Up @@ -1934,8 +1957,7 @@ def test_window_repr(self):
)
self.assertEqual(
repr(Window(expression=Avg("salary"), order_by=F("department").asc())),
"<Window: Avg(F(salary)) OVER (OrderByList(OrderBy(F(department), "
"descending=False)))>",
"<Window: Avg(F(salary)) OVER (OrderBy(F(department), descending=False))>",
)

def test_window_frame_repr(self):
Expand Down

0 comments on commit e16d0c1

Please sign in to comment.