Skip to content

Commit

Permalink
✨ Pydantic 2.0 compatibility (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed Sep 19, 2023
1 parent 7c9f121 commit c0c8ae6
Show file tree
Hide file tree
Showing 3 changed files with 428 additions and 334 deletions.
43 changes: 24 additions & 19 deletions flama/schemas/_libs/pydantic/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import typing as t

import pydantic
from pydantic.fields import ModelField
from pydantic.schema import field_schema, model_schema
from pydantic.fields import FieldInfo
from pydantic.json_schema import model_json_schema

from flama.injection import Parameter
from flama.schemas.adapter import Adapter
Expand All @@ -19,7 +19,7 @@
__all__ = ["PydanticAdapter"]

Schema = pydantic.BaseModel
Field = ModelField
Field = FieldInfo


class PydanticAdapter(Adapter[Schema, Field]):
Expand All @@ -44,13 +44,12 @@ def build_field(
if nullable:
annotation = t.Optional[annotation]

return ModelField.infer(
name=name,
annotation=annotation,
value=pydantic.Field(**kwargs),
class_validators=None,
config=pydantic.BaseConfig,
)
if default is Parameter.empty:
field = FieldInfo.from_annotation(annotation)
else:
field = FieldInfo.from_annotated_attribute(annotation, default)

return field

def build_schema(
self,
Expand All @@ -64,21 +63,21 @@ def build_schema(
**{
**(
{
name: (field.annotation, field.field_info)
for name, field in self.unique_schema(schema).__fields__.items()
name: (field_info.annotation, field_info)
for name, field_info in self.unique_schema(schema).model_fields.items()
}
if schema
else {}
),
**({name: (field.annotation, field.field_info) for name, field in fields.items()} if fields else {}),
**({name: (field.annotation, field) for name, field in fields.items()} if fields else {}),
},
)

def validate(self, schema: t.Union[Schema, t.Type[Schema]], values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
schema_cls = self.unique_schema(schema)

try:
return schema_cls(**values).dict()
return schema_cls(**values).model_dump()
except pydantic.ValidationError as errors:
raise SchemaValidationError(errors={str(error["loc"][0]): error for error in errors.errors()})

Expand All @@ -95,12 +94,18 @@ def dump(self, schema: t.Union[Schema, t.Type[Schema]], value: t.Dict[str, t.Any
def to_json_schema(self, schema: t.Union[Schema, t.Type[Schema], Field]) -> JSONSchema:
try:
if self.is_schema(schema):
json_schema = model_schema(schema, ref_prefix="#/components/schemas/")
json_schema = model_json_schema(schema, ref_template="#/components/schemas/{model}")
if "$defs" in json_schema:
del json_schema["$defs"]
elif self.is_field(schema):
json_schema = field_schema(schema, ref_prefix="#/components/schemas/", model_name_map={})[0]
if schema.allow_none:
types = [json_schema["type"]] if isinstance(json_schema["type"], str) else json_schema["type"]
json_schema["type"] = list(dict.fromkeys(types + ["null"]))
json_schema = model_json_schema(
self.build_schema(fields={"x": schema}), ref_template="#/components/schemas/{model}"
)["properties"]["x"]
if not schema.title: # Pydantic is introducing a default title, so we drop it
del json_schema["title"]
if "anyOf" in json_schema: # Just simplifying type definition from anyOf to a list of types
json_schema["type"] = [x["type"] for x in json_schema["anyOf"]]
del json_schema["anyOf"]
else:
raise TypeError("Not a valid schema class or field")

Expand Down
Loading

0 comments on commit c0c8ae6

Please sign in to comment.