Skip to content

Commit

Permalink
Merge pull request fastapi-users#1249 from fastapi-users/pydantic-v2
Browse files Browse the repository at this point in the history
Pydantic V2 support
  • Loading branch information
frankie567 committed Jul 12, 2023
2 parents 3bf0f88 + 5b6d5d4 commit 49ea718
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 43 deletions.
24 changes: 20 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ on: [push, pull_request]

jobs:

test:
lint:
runs-on: ubuntu-latest
strategy:
matrix:
Expand All @@ -20,13 +20,29 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install hatch
hatch env create
- name: Lint and typecheck
run: |
hatch run lint-check
test:
runs-on: ubuntu-latest
strategy:
matrix:
python_version: [3.8, 3.9, '3.10', '3.11']

steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python_version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install hatch
- name: Test
run: |
hatch run test-cov-xml
hatch run test:test-cov-xml
- uses: codecov/codecov-action@v3
with:
token: ${{ secrets.CODECOV_TOKEN }}
Expand All @@ -40,7 +56,7 @@ jobs:
release:
runs-on: ubuntu-latest
needs: test
needs: [lint, test]
if: startsWith(github.ref, 'refs/tags/')

steps:
Expand Down
3 changes: 2 additions & 1 deletion fastapi_users/authentication/transport/bearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
TransportLogoutNotSupportedError,
)
from fastapi_users.openapi import OpenAPIResponseType
from fastapi_users.schemas import model_dump


class BearerResponse(BaseModel):
Expand All @@ -23,7 +24,7 @@ def __init__(self, tokenUrl: str):

async def get_login_response(self, token: str) -> Response:
bearer_response = BearerResponse(access_token=token, token_type="bearer")
return JSONResponse(bearer_response.dict())
return JSONResponse(model_dump(bearer_response))

async def get_logout_response(self) -> Response:
raise TransportLogoutNotSupportedError()
Expand Down
2 changes: 1 addition & 1 deletion fastapi_users/router/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,6 @@ async def callback(
request,
)

return user_schema.from_orm(user)
return schemas.model_validate(user_schema, user)

return router
2 changes: 1 addition & 1 deletion fastapi_users/router/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,6 @@ async def register(
},
)

return user_schema.from_orm(created_user)
return schemas.model_validate(user_schema, created_user)

return router
8 changes: 4 additions & 4 deletions fastapi_users/router/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def get_user_or_404(
async def me(
user: models.UP = Depends(get_current_active_user),
):
return user_schema.from_orm(user)
return schemas.model_validate(user_schema, user)

@router.patch(
"/me",
Expand Down Expand Up @@ -96,7 +96,7 @@ async def update_me(
user = await user_manager.update(
user_update, user, safe=True, request=request
)
return user_schema.from_orm(user)
return schemas.model_validate(user_schema, user)
except exceptions.InvalidPasswordException as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand Down Expand Up @@ -129,7 +129,7 @@ async def update_me(
},
)
async def get_user(user=Depends(get_user_or_404)):
return user_schema.from_orm(user)
return schemas.model_validate(user_schema, user)

@router.patch(
"/{id}",
Expand Down Expand Up @@ -183,7 +183,7 @@ async def update_user(
user = await user_manager.update(
user_update, user, safe=False, request=request
)
return user_schema.from_orm(user)
return schemas.model_validate(user_schema, user)
except exceptions.InvalidPasswordException as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand Down
2 changes: 1 addition & 1 deletion fastapi_users/router/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def verify(
):
try:
user = await user_manager.verify(token, request)
return user_schema.from_orm(user)
return schemas.model_validate(user_schema, user)
except (exceptions.InvalidVerifyToken, exceptions.UserNotExists):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand Down
60 changes: 45 additions & 15 deletions fastapi_users/schemas.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,35 @@
from typing import Generic, List, Optional, TypeVar
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar

from pydantic import BaseModel, EmailStr
from pydantic import BaseModel, ConfigDict, EmailStr
from pydantic.version import VERSION as PYDANTIC_VERSION

from fastapi_users import models

PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")

SCHEMA = TypeVar("SCHEMA", bound=BaseModel)

if PYDANTIC_V2: # pragma: no cover

def model_dump(model: BaseModel, *args, **kwargs) -> Dict[str, Any]:
return model.model_dump(*args, **kwargs) # type: ignore

def model_validate(schema: Type[SCHEMA], obj: Any, *args, **kwargs) -> SCHEMA:
return schema.model_validate(obj, *args, **kwargs) # type: ignore

else: # pragma: no cover # type: ignore

def model_dump(model: BaseModel, *args, **kwargs) -> Dict[str, Any]:
return model.dict(*args, **kwargs) # type: ignore

def model_validate(schema: Type[SCHEMA], obj: Any, *args, **kwargs) -> SCHEMA:
return schema.from_orm(obj) # type: ignore


class CreateUpdateDictModel(BaseModel):
def create_update_dict(self):
return self.dict(
return model_dump(
self,
exclude_unset=True,
exclude={
"id",
Expand All @@ -19,10 +41,10 @@ def create_update_dict(self):
)

def create_update_dict_superuser(self):
return self.dict(exclude_unset=True, exclude={"id"})
return model_dump(self, exclude_unset=True, exclude={"id"})


class BaseUser(Generic[models.ID], CreateUpdateDictModel):
class BaseUser(CreateUpdateDictModel, Generic[models.ID]):
"""Base User model."""

id: models.ID
Expand All @@ -31,8 +53,12 @@ class BaseUser(Generic[models.ID], CreateUpdateDictModel):
is_superuser: bool = False
is_verified: bool = False

class Config:
orm_mode = True
if PYDANTIC_V2: # pragma: no cover
model_config = ConfigDict(from_attributes=True) # type: ignore
else: # pragma: no cover

class Config:
orm_mode = True


class BaseUserCreate(CreateUpdateDictModel):
Expand All @@ -44,19 +70,19 @@ class BaseUserCreate(CreateUpdateDictModel):


class BaseUserUpdate(CreateUpdateDictModel):
password: Optional[str]
email: Optional[EmailStr]
is_active: Optional[bool]
is_superuser: Optional[bool]
is_verified: Optional[bool]
password: Optional[str] = None
email: Optional[EmailStr] = None
is_active: Optional[bool] = None
is_superuser: Optional[bool] = None
is_verified: Optional[bool] = None


U = TypeVar("U", bound=BaseUser)
UC = TypeVar("UC", bound=BaseUserCreate)
UU = TypeVar("UU", bound=BaseUserUpdate)


class BaseOAuthAccount(Generic[models.ID], BaseModel):
class BaseOAuthAccount(BaseModel, Generic[models.ID]):
"""Base OAuth account model."""

id: models.ID
Expand All @@ -67,8 +93,12 @@ class BaseOAuthAccount(Generic[models.ID], BaseModel):
account_id: str
account_email: str

class Config:
orm_mode = True
if PYDANTIC_V2: # pragma: no cover
model_config = ConfigDict(from_attributes=True) # type: ignore
else: # pragma: no cover

class Config:
orm_mode = True


class BaseOAuthAccountMixin(BaseModel):
Expand Down
17 changes: 15 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ dependencies = [
]

[tool.hatch.envs.default.scripts]
test = "pytest --cov=fastapi_users/ --cov-report=term-missing --cov-fail-under=100"
test-cov-xml = "pytest --cov=fastapi_users/ --cov-report=xml --cov-fail-under=100"
lint = [
"isort ./fastapi_users ./tests",
"isort ./docs/src -o fastapi_users",
Expand All @@ -94,6 +92,21 @@ lint-check = [
]
docs = "mkdocs serve"

[tool.hatch.envs.test]

[tool.hatch.envs.test.scripts]
test = "pytest --cov=fastapi_users/ --cov-report=term-missing --cov-fail-under=100"
test-cov-xml = "pytest --cov=fastapi_users/ --cov-report=xml --cov-fail-under=100"

[[tool.hatch.envs.test.matrix]]
pydantic = ["v1", "v2"]

[tool.hatch.envs.test.overrides]
matrix.pydantic.extra-dependencies = [
{value = "pydantic<2.0", if = ["v1"]},
{value = "pydantic>=2.0", if = ["v2"]},
]

[tool.hatch.build.targets.sdist]
support-legacy = true # Create setup.py

Expand Down
12 changes: 6 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@
excalibur_password_hash = password_helper.hash("excalibur")


IDType = uuid.UUID
IDType = UUID4


@dataclasses.dataclass
class UserModel(models.UserProtocol[IDType]):
email: str
hashed_password: str
id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4)
id: IDType = dataclasses.field(default_factory=uuid.uuid4)
is_active: bool = True
is_superuser: bool = False
is_verified: bool = False
Expand All @@ -59,7 +59,7 @@ class OAuthAccountModel(models.OAuthAccountProtocol[IDType]):
access_token: str
account_id: str
account_email: str
id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4)
id: IDType = dataclasses.field(default_factory=uuid.uuid4)
expires_at: Optional[int] = None
refresh_token: Optional[str] = None

Expand All @@ -70,15 +70,15 @@ class UserOAuthModel(UserModel):


class User(schemas.BaseUser[IDType]):
first_name: Optional[str]
first_name: Optional[str] = None


class UserCreate(schemas.BaseUserCreate):
first_name: Optional[str]
first_name: Optional[str] = None


class UserUpdate(schemas.BaseUserUpdate):
first_name: Optional[str]
first_name: Optional[str] = None


class UserOAuth(User, schemas.BaseOAuthAccountMixin):
Expand Down
12 changes: 6 additions & 6 deletions tests/test_fastapi_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from fastapi import Depends, FastAPI, status

from fastapi_users import FastAPIUsers
from fastapi_users import FastAPIUsers, schemas
from tests.conftest import IDType, User, UserCreate, UserModel, UserUpdate


Expand Down Expand Up @@ -77,31 +77,31 @@ def current_verified_superuser(
def optional_current_user(
user: Optional[UserModel] = Depends(fastapi_users.current_user(optional=True)),
):
return User.from_orm(user) if user else None
return schemas.model_validate(User, user) if user else None

@app.get("/optional-current-active-user")
def optional_current_active_user(
user: Optional[UserModel] = Depends(
fastapi_users.current_user(optional=True, active=True)
),
):
return User.from_orm(user) if user else None
return schemas.model_validate(User, user) if user else None

@app.get("/optional-current-verified-user")
def optional_current_verified_user(
user: Optional[UserModel] = Depends(
fastapi_users.current_user(optional=True, verified=True)
),
):
return User.from_orm(user) if user else None
return schemas.model_validate(User, user) if user else None

@app.get("/optional-current-superuser")
def optional_current_superuser(
user: Optional[UserModel] = Depends(
fastapi_users.current_user(optional=True, active=True, superuser=True)
),
):
return User.from_orm(user) if user else None
return schemas.model_validate(User, user) if user else None

@app.get("/optional-current-verified-superuser")
def optional_current_verified_superuser(
Expand All @@ -111,7 +111,7 @@ def optional_current_verified_superuser(
)
),
):
return User.from_orm(user) if user else None
return schemas.model_validate(User, user) if user else None

async for client in get_test_client(app):
yield client
Expand Down
4 changes: 2 additions & 2 deletions tests/test_manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import uuid
from typing import Callable

import pytest
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import UUID4
from pytest_mock import MockerFixture

from fastapi_users.exceptions import (
Expand Down Expand Up @@ -77,7 +77,7 @@ def _create_oauth2_password_request_form(username, password):
class TestGet:
async def test_not_existing_user(self, user_manager: UserManagerMock[UserModel]):
with pytest.raises(UserNotExists):
await user_manager.get(UUID4("d35d213e-f3d8-4f08-954a-7e0d1bea286f"))
await user_manager.get(uuid.UUID("d35d213e-f3d8-4f08-954a-7e0d1bea286f"))

async def test_existing_user(
self, user_manager: UserManagerMock[UserModel], user: UserModel
Expand Down

0 comments on commit 49ea718

Please sign in to comment.