Skip to content

Commit

Permalink
Add "or_" filtering (PrefectHQ/orion#2139)
Browse files Browse the repository at this point in the history
Add or_ filtering
  • Loading branch information
zangell44 committed Jul 26, 2022
1 parent ae56b38 commit 970badb
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 14 deletions.
53 changes: 39 additions & 14 deletions src/prefect/orion/schemas/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import prefect.orion.schemas as schemas
from prefect.orion.utilities.schemas import PrefectBaseModel
from prefect.utilities.collections import AutoEnum

if TYPE_CHECKING:
from prefect.orion.database.interface import OrionDBInterface
Expand All @@ -23,6 +24,13 @@
# present in the schemas module


class Operator(AutoEnum):
"""Operators for combining filter criteria."""

and_ = AutoEnum.auto()
or_ = AutoEnum.auto()


class PrefectFilterBaseModel(PrefectBaseModel):
"""Base model for Prefect filters"""

Expand All @@ -32,13 +40,30 @@ class Config:
def as_sql_filter(self, db: "OrionDBInterface") -> BooleanClauseList:
"""Generate SQL filter from provided filter parameters. If no filters parameters are available, return a TRUE filter."""
filters = self._get_filter_list(db)
return sa.and_(*filters) if filters else sa.and_(True)
if not filters:
return True
return sa.and_(*filters)

def _get_filter_list(self, db: "OrionDBInterface") -> List:
"""Return a list of boolean filter statements based on filter parameters"""
raise NotImplementedError("_get_filter_list must be implemented")


class PrefectOperatorFilterBaseModel(PrefectFilterBaseModel):
"""Base model for Prefect filters that combines criteria with a user-provided operator"""

operator: Operator = Field(
default=Operator.and_,
description="Operator for combining filter criteria. Defaults to 'and_'.",
)

def as_sql_filter(self, db: "OrionDBInterface") -> BooleanClauseList:
filters = self._get_filter_list(db)
if not filters:
return True
return sa.and_(*filters) if self.operator == Operator.and_ else sa.or_(*filters)


class FlowFilterId(PrefectFilterBaseModel):
"""Filter by `Flow.id`."""

Expand Down Expand Up @@ -79,7 +104,7 @@ def _get_filter_list(self, db: "OrionDBInterface") -> List:
return filters


class FlowFilterTags(PrefectFilterBaseModel):
class FlowFilterTags(PrefectOperatorFilterBaseModel):
"""Filter by `Flow.tags`."""

all_: List[str] = Field(
Expand All @@ -100,7 +125,7 @@ def _get_filter_list(self, db: "OrionDBInterface") -> List:
return filters


class FlowFilter(PrefectFilterBaseModel):
class FlowFilter(PrefectOperatorFilterBaseModel):
"""Filter for flows. Only flows matching all criteria will be returned."""

id: Optional[FlowFilterId] = Field(
Expand Down Expand Up @@ -169,7 +194,7 @@ def _get_filter_list(self, db: "OrionDBInterface") -> List:
return filters


class FlowRunFilterTags(PrefectFilterBaseModel):
class FlowRunFilterTags(PrefectOperatorFilterBaseModel):
"""Filter by `FlowRun.tags`."""

all_: List[str] = Field(
Expand All @@ -194,7 +219,7 @@ def _get_filter_list(self, db: "OrionDBInterface") -> List:
return filters


class FlowRunFilterDeploymentId(PrefectFilterBaseModel):
class FlowRunFilterDeploymentId(PrefectOperatorFilterBaseModel):
"""Filter by `FlowRun.deployment_id`."""

any_: List[UUID] = Field(
Expand Down Expand Up @@ -340,7 +365,7 @@ def _get_filter_list(self, db: "OrionDBInterface") -> List:
return filters


class FlowRunFilterParentTaskRunId(PrefectFilterBaseModel):
class FlowRunFilterParentTaskRunId(PrefectOperatorFilterBaseModel):
"""Filter by `FlowRun.parent_task_run_id`."""

any_: List[UUID] = Field(
Expand All @@ -363,7 +388,7 @@ def _get_filter_list(self, db: "OrionDBInterface") -> List:
return filters


class FlowRunFilter(PrefectFilterBaseModel):
class FlowRunFilter(PrefectOperatorFilterBaseModel):
"""Filter flow runs. Only flow runs matching all criteria will be returned"""

id: Optional[FlowRunFilterId] = Field(
Expand Down Expand Up @@ -464,7 +489,7 @@ def _get_filter_list(self, db: "OrionDBInterface") -> List:
return filters


class TaskRunFilterTags(PrefectFilterBaseModel):
class TaskRunFilterTags(PrefectOperatorFilterBaseModel):
"""Filter by `TaskRun.tags`."""

all_: List[str] = Field(
Expand Down Expand Up @@ -573,7 +598,7 @@ def _get_filter_list(self, db: "OrionDBInterface") -> List:
return filters


class TaskRunFilter(PrefectFilterBaseModel):
class TaskRunFilter(PrefectOperatorFilterBaseModel):
"""Filter task runs. Only task runs matching all criteria will be returned"""

id: Optional[TaskRunFilterId] = Field(
Expand Down Expand Up @@ -669,7 +694,7 @@ def _get_filter_list(self, db: "OrionDBInterface") -> List:
return filters


class DeploymentFilterTags(PrefectFilterBaseModel):
class DeploymentFilterTags(PrefectOperatorFilterBaseModel):
"""Filter by `Deployment.tags`."""

all_: List[str] = Field(
Expand All @@ -694,7 +719,7 @@ def _get_filter_list(self, db: "OrionDBInterface") -> List:
return filters


class DeploymentFilter(PrefectFilterBaseModel):
class DeploymentFilter(PrefectOperatorFilterBaseModel):
"""Filter for deployments. Only deployments matching all criteria will be returned."""

id: Optional[DeploymentFilterId] = Field(
Expand Down Expand Up @@ -808,7 +833,7 @@ def _get_filter_list(self, db: "OrionDBInterface") -> List:
return filters


class LogFilter(PrefectFilterBaseModel):
class LogFilter(PrefectOperatorFilterBaseModel):
"""Filter logs. Only logs matching all criteria will be returned"""

level: Optional[LogFilterLevel] = Field(
Expand Down Expand Up @@ -936,7 +961,7 @@ def _get_filter_list(self, db: "OrionDBInterface") -> List:
return filters


class BlockSchemaFilter(PrefectFilterBaseModel):
class BlockSchemaFilter(PrefectOperatorFilterBaseModel):
"""Filter BlockSchemas"""

block_type_id: Optional[BlockSchemaFilterBlockTypeId] = Field(
Expand Down Expand Up @@ -989,7 +1014,7 @@ def _get_filter_list(self, db: "OrionDBInterface") -> List:
return filters


class BlockDocumentFilter(PrefectFilterBaseModel):
class BlockDocumentFilter(PrefectOperatorFilterBaseModel):
"""Filter BlockDocuments. Only BlockDocuments matching all criteria will be returned"""

is_anonymous: Optional[BlockDocumentFilterIsAnonymous] = Field(
Expand Down
40 changes: 40 additions & 0 deletions tests/orion/models/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,16 @@ class TestCountFlowsModels:
),
1,
],
# empty filter
[dict(flow_filter=filters.FlowFilter()), 4],
# multiple empty filters
[
dict(
flow_filter=filters.FlowFilter(),
flow_run_filter=filters.FlowRunFilter(),
),
4,
],
]

@pytest.mark.parametrize("kwargs,expected", params)
Expand Down Expand Up @@ -549,6 +559,16 @@ class TestCountFlowRunModels:
),
1,
],
# empty filter
[dict(flow_filter=filters.FlowFilter()), 12],
# multiple empty filters
[
dict(
flow_filter=filters.FlowFilter(),
flow_run_filter=filters.FlowRunFilter(),
),
12,
],
]

@pytest.mark.parametrize("kwargs,expected", params)
Expand Down Expand Up @@ -724,6 +744,16 @@ class TestCountTaskRunsModels:
),
0,
],
# empty filter
[dict(flow_filter=filters.FlowFilter()), 10],
# multiple empty filters
[
dict(
flow_filter=filters.FlowFilter(),
flow_run_filter=filters.FlowRunFilter(),
),
10,
],
]

@pytest.mark.parametrize("kwargs,expected", params)
Expand Down Expand Up @@ -875,6 +905,16 @@ class TestCountDeploymentModels:
),
0,
],
# empty filter
[dict(flow_filter=filters.FlowFilter()), 3],
# multiple empty filters
[
dict(
flow_filter=filters.FlowFilter(),
flow_run_filter=filters.FlowRunFilter(),
),
3,
],
]

@pytest.mark.parametrize("kwargs,expected", params)
Expand Down
11 changes: 11 additions & 0 deletions tests/orion/models/test_flow_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,17 @@ async def test_read_flow_runs_filters_by_multiple_criteria(self, flow, session):
)
assert len(result) == 0

# filter using OR
result = await models.flow_runs.read_flow_runs(
session=session,
flow_run_filter=schemas.filters.FlowRunFilter(
operator="or_",
id=schemas.filters.FlowRunFilterId(any_=[flow_run_2.id]),
tags=schemas.filters.FlowRunFilterTags(all_=["blue"]),
),
)
assert {res.id for res in result} == {flow_run_1.id, flow_run_2.id}

async def test_read_flow_runs_filters_by_flow_criteria(self, flow, session):
flow_run_1 = await models.flow_runs.create_flow_run(
session=session,
Expand Down

0 comments on commit 970badb

Please sign in to comment.