Skip to content

Commit

Permalink
Adds versions to block schemas (#6491)
Browse files Browse the repository at this point in the history
* WIP: Adds version to block schema

* WIP: Adds tests for block schema version

* Adds tests for filters

* Fixes read after block schema creation

* Fixes erroneously removed index
  • Loading branch information
desertaxle committed Aug 29, 2022
1 parent 94b3951 commit ba66ced
Show file tree
Hide file tree
Showing 13 changed files with 347 additions and 30 deletions.
28 changes: 27 additions & 1 deletion src/prefect/blocks/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import hashlib
import inspect
import logging
import sys
import warnings
from abc import ABC
from textwrap import dedent
Expand All @@ -25,7 +26,12 @@
from typing_extensions import ParamSpec, Self, get_args, get_origin

import prefect
from prefect.orion.schemas.core import BlockDocument, BlockSchema, BlockType
from prefect.orion.schemas.core import (
DEFAULT_BLOCK_SCHEMA_VERSION,
BlockDocument,
BlockSchema,
BlockType,
)
from prefect.utilities.asyncutils import asyncnullcontext, sync_compatible
from prefect.utilities.collections import remove_nested_keys
from prefect.utilities.dispatch import lookup_type, register_base_type
Expand Down Expand Up @@ -188,6 +194,7 @@ def block_initialization(self) -> None:
_block_type_id: Optional[UUID] = None
_block_schema_id: Optional[UUID] = None
_block_schema_capabilities: Optional[List[str]] = None
_block_schema_version: Optional[str] = None
_block_document_id: Optional[UUID] = None
_block_document_name: Optional[str] = None
_is_anonymous: Optional[bool] = None
Expand Down Expand Up @@ -220,6 +227,24 @@ def get_block_capabilities(cls) -> FrozenSet[str]:
}
)

@classmethod
def _get_current_package_version(cls):
current_module = inspect.getmodule(cls)
if current_module:
top_level_module = sys.modules[
current_module.__name__.split(".")[0] or "__main__"
]
try:
return str(top_level_module.__version__)
except AttributeError:
# Module does not have a __version__ attribute
pass
return DEFAULT_BLOCK_SCHEMA_VERSION

@classmethod
def get_block_schema_version(cls) -> str:
return cls._block_schema_version or cls._get_current_package_version()

@classmethod
def _to_block_schema_reference_dict(cls):
return dict(
Expand Down Expand Up @@ -351,6 +376,7 @@ def _to_block_schema(cls, block_type_id: Optional[UUID] = None) -> BlockSchema:
block_type_id=block_type_id or cls._block_type_id,
block_type=cls._to_block_type(),
capabilities=list(cls.get_block_capabilities()),
version=cls.get_block_schema_version(),
)

@classmethod
Expand Down
8 changes: 6 additions & 2 deletions src/prefect/orion/api/block_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from uuid import UUID

import sqlalchemy as sa
from fastapi import Body, Depends, HTTPException, Path, Response, status
from fastapi import Body, Depends, HTTPException, Path, Query, Response, status
from fastapi.responses import Response

from prefect.blocks.core import Block
Expand Down Expand Up @@ -127,10 +127,14 @@ async def read_block_schema_by_checksum(
block_schema_checksum: str = Path(
..., description="The block schema checksum", alias="checksum"
),
version: Optional[str] = Query(
None,
description="Version of block schema. If not provided the most recently created block schema with the matching checksum will be returned.",
),
session: sa.orm.Session = Depends(dependencies.get_session),
) -> schemas.core.BlockSchema:
block_schema = await models.block_schemas.read_block_schema_by_checksum(
session=session, checksum=block_schema_checksum
session=session, checksum=block_schema_checksum, version=version
)
if not block_schema:
raise HTTPException(status.HTTP_404_NOT_FOUND, detail="Block schema not found")
Expand Down
4 changes: 4 additions & 0 deletions src/prefect/orion/database/migrations/MIGRATION-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Each time a database migration is written, an entry is included here with:

This gives us a history of changes and will create merge conflicts if two migrations are made at once, flagging situations where a branch needs to be updated before merging.

# Add version to block schema
SQLite: `e757138e954a`
Postgres: `2d5e000696f1`

# Add work queue name to runs
SQLite: `575634b7acd4`
Postgres: `77eb737fc759`
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Adds block schema version
Revision ID: 2d5e000696f1
Revises: 77eb737fc759
Create Date: 2022-08-18 10:28:04.449256
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "2d5e000696f1"
down_revision = "77eb737fc759"
branch_labels = None
depends_on = None


def upgrade():
op.add_column(
"block_schema",
sa.Column(
"version", sa.String(), server_default="non-versioned", nullable=False
),
)
op.drop_index("uq_block_schema__checksum", table_name="block_schema")
op.create_index(
"uq_block_schema__checksum_version",
"block_schema",
["checksum", "version"],
unique=True,
)


def downgrade():
op.drop_index("uq_block_schema__checksum_version", table_name="block_schema")
op.create_index(
"uq_block_schema__checksum", "block_schema", ["checksum"], unique=False
)
op.drop_column("block_schema", "version")
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Adds block schema version
Revision ID: e757138e954a
Revises: 575634b7acd4
Create Date: 2022-08-18 10:25:27.189680
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "e757138e954a"
down_revision = "575634b7acd4"
branch_labels = None
depends_on = None


def upgrade():
op.execute("PRAGMA foreign_keys=OFF")

with op.batch_alter_table("block_schema", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"version", sa.String(), server_default="non-versioned", nullable=False
)
)
batch_op.drop_index("uq_block_schema__checksum")
batch_op.create_index(
"uq_block_schema__checksum_version", ["checksum", "version"], unique=True
)

op.execute("PRAGMA foreign_keys=ON")


def downgrade():
with op.batch_alter_table("block_schema", schema=None) as batch_op:
batch_op.drop_index("uq_block_schema__checksum_version")
batch_op.create_index("uq_block_schema__checksum", ["checksum"], unique=False)
batch_op.drop_column("version")
10 changes: 8 additions & 2 deletions src/prefect/orion/database/orm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,11 @@ class ORMBlockSchema:
checksum = sa.Column(sa.String, nullable=False)
fields = sa.Column(JSON, server_default="{}", default=dict, nullable=False)
capabilities = sa.Column(JSON, server_default="[]", default=list, nullable=False)
version = sa.Column(
sa.String,
server_default=schemas.core.DEFAULT_BLOCK_SCHEMA_VERSION,
nullable=False,
)

@declared_attr
def block_type_id(cls):
Expand All @@ -788,8 +793,9 @@ def block_type(cls):
def __table_args__(cls):
return (
sa.Index(
"uq_block_schema__checksum",
"uq_block_schema__checksum_version",
"checksum",
"version",
unique=True,
),
sa.Index("ix_block_schema__created", "created"),
Expand Down Expand Up @@ -1264,7 +1270,7 @@ def block_type_unique_upsert_columns(self):
@property
def block_schema_unique_upsert_columns(self):
"""Unique columns for upserting a BlockSchema"""
return [self.BlockSchema.checksum]
return [self.BlockSchema.checksum, self.BlockSchema.version]

@property
def flow_unique_upsert_columns(self):
Expand Down
26 changes: 22 additions & 4 deletions src/prefect/orion/models/block_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def create_block_schema(

# Check for existing block schema based on calculated checksum
existing_block_schema = await read_block_schema_by_checksum(
session=session, checksum=checksum
session=session, checksum=checksum, version=block_schema.version
)
# Return existing block schema if it exists. Allows block schema creation to be called multiple
# times for the same schema without errors.
Expand Down Expand Up @@ -106,9 +106,14 @@ async def create_block_schema(
.where(
db.BlockSchema.checksum == insert_values["checksum"],
)
.order_by(db.BlockSchema.created.desc())
.limit(1)
.execution_options(populate_existing=True)
)

if block_schema.version is not None:
query = query.where(db.BlockSchema.version == block_schema.version)

result = await session.execute(query)
created_block_schema = copy(result.scalar())

Expand Down Expand Up @@ -671,6 +676,7 @@ async def read_block_schema_by_checksum(
session: sa.orm.Session,
checksum: str,
db: OrionDBInterface,
version: Optional[str] = None,
) -> Optional[BlockSchema]:
"""
Reads a block_schema by checksum. Will reconstruct the block schema's fields
Expand All @@ -679,21 +685,33 @@ async def read_block_schema_by_checksum(
Args:
session: A database session
checksum: a block_schema checksum
version: A block_schema version
Returns:
db.BlockSchema: the block_schema
"""
# Construction of a recursive query which returns the specified block schema
# along with and nested block schemas coupled with the ID of their parent schema
# the key that they reside under.

# The same checksum with different versions can occur in the DB. Return only the
# most recently created one.
root_block_schema_query = (
sa.select(db.BlockSchema).filter_by(checksum=checksum).cte("root_block_schema")
sa.select(db.BlockSchema)
.filter_by(checksum=checksum)
.order_by(db.BlockSchema.created.desc())
.limit(1)
)

if version is not None:
root_block_schema_query = root_block_schema_query.filter_by(version=version)

root_block_schema_cte = root_block_schema_query.cte("root_block_schema")

block_schema_references_query = (
sa.select(db.BlockSchemaReference)
.select_from(db.BlockSchemaReference)
.filter_by(parent_block_schema_id=root_block_schema_query.c.id)
.filter_by(parent_block_schema_id=root_block_schema_cte.c.id)
.cte("block_schema_references", recursive=True)
)
block_schema_references_join = (
Expand Down Expand Up @@ -725,7 +743,7 @@ async def read_block_schema_by_checksum(
)
.filter(
sa.or_(
db.BlockSchema.id == root_block_schema_query.c.id,
db.BlockSchema.id == root_block_schema_cte.c.id,
recursive_block_schema_references_cte.c.parent_block_schema_id.is_not(
None
),
Expand Down
2 changes: 1 addition & 1 deletion src/prefect/orion/schemas/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ class Config:
class BlockSchemaCreate(
schemas.core.BlockSchema.subclass(
name="BlockSchemaCreate",
include_fields=["fields", "capabilities", "block_type_id"],
include_fields=["fields", "capabilities", "block_type_id", "version"],
)
):
"""Data used by the Orion API to create a block schema."""
Expand Down
22 changes: 5 additions & 17 deletions src/prefect/orion/schemas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
"flow_run_state_message",
]

DEFAULT_BLOCK_SCHEMA_VERSION = "non-versioned"


def raise_on_invalid_name(name: str) -> None:
"""
Expand Down Expand Up @@ -499,24 +501,10 @@ class BlockSchema(ORMBaseModel):
default_factory=list,
description="A list of Block capabilities",
)


class BlockSchemaReference(ORMBaseModel):
"""An ORM representation of a block schema reference."""

parent_block_schema_id: UUID = Field(
..., description="ID of block schema the reference is nested within"
)
parent_block_schema: Optional[BlockSchema] = Field(
None, description="The block schema the reference is nested within"
)
reference_block_schema_id: UUID = Field(
..., description="ID of the nested block schema"
version: str = Field(
DEFAULT_BLOCK_SCHEMA_VERSION,
description="Human readable identifier for the block schema",
)
reference_block_schema: Optional[BlockSchema] = Field(
None, description="The nested block schema"
)
name: str = Field(..., description="The name that the reference is nested under")


class BlockSchemaReference(ORMBaseModel):
Expand Down
23 changes: 23 additions & 0 deletions src/prefect/orion/schemas/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,24 @@ def _get_filter_list(self, db: "OrionDBInterface") -> List:
return filters


class BlockSchemaFilterVersion(PrefectFilterBaseModel):
"""Filter by `BlockSchema.capabilities`"""

any_: List[str] = Field(
None,
example=["2.0.0", "2.1.0"],
description="A list of block schema versions.",
)

def _get_filter_list(self, db: "OrionDBInterface") -> List:
pass

filters = []
if self.any_ is not None:
filters.append(db.BlockSchema.version.in_(self.any_))
return filters


class BlockSchemaFilter(PrefectOperatorFilterBaseModel):
"""Filter BlockSchemas"""

Expand All @@ -1023,6 +1041,9 @@ class BlockSchemaFilter(PrefectOperatorFilterBaseModel):
id: Optional[BlockSchemaFilterId] = Field(
None, description="Filter criteria for `BlockSchema.id`"
)
version: Optional[BlockSchemaFilterVersion] = Field(
None, description="Filter criteria for `BlockSchema.version`"
)

def _get_filter_list(self, db: "OrionDBInterface") -> List:
filters = []
Expand All @@ -1033,6 +1054,8 @@ def _get_filter_list(self, db: "OrionDBInterface") -> List:
filters.append(self.block_capabilities.as_sql_filter(db))
if self.id is not None:
filters.append(self.id.as_sql_filter(db))
if self.version is not None:
filters.append(self.version.as_sql_filter(db))

return filters

Expand Down
Loading

0 comments on commit ba66ced

Please sign in to comment.