diff --git a/src/dstack/_internal/core/models/users.py b/src/dstack/_internal/core/models/users.py index fd9cefc28..c1a3d039e 100644 --- a/src/dstack/_internal/core/models/users.py +++ b/src/dstack/_internal/core/models/users.py @@ -22,6 +22,7 @@ class User(CoreModel): username: str global_role: GlobalRole email: Optional[str] + active: bool class UserTokenCreds(CoreModel): diff --git a/src/dstack/_internal/server/migrations/versions/d6b11105f659_add_usermodel_active.py b/src/dstack/_internal/server/migrations/versions/d6b11105f659_add_usermodel_active.py new file mode 100644 index 000000000..80fd4de55 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/d6b11105f659_add_usermodel_active.py @@ -0,0 +1,36 @@ +"""Add UserModel.active + +Revision ID: d6b11105f659 +Revises: 54a77e19c64c +Create Date: 2024-08-19 15:10:25.751199 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "d6b11105f659" +down_revision = "54a77e19c64c" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("users", schema=None) as batch_op: + batch_op.add_column(sa.Column("active", sa.Boolean(), nullable=True)) + + op.execute(sa.sql.text("UPDATE users SET active = TRUE")) + + with op.batch_alter_table("users", schema=None) as batch_op: + batch_op.alter_column("active", nullable=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("users", schema=None) as batch_op: + batch_op.drop_column("active") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index bed7fc976..a4886a9ef 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -156,6 +156,8 @@ class UserModel(BaseModel): # token_hash is needed for fast search by token when stored token is encrypted token_hash: Mapped[str] = mapped_column(String(2000), unique=True) global_role: Mapped[GlobalRole] = mapped_column(Enum(GlobalRole)) + # deactivated users cannot access API + active: Mapped[bool] = mapped_column(Boolean, default=True) email: Mapped[Optional[str]] = mapped_column(String(200), nullable=True) diff --git a/src/dstack/_internal/server/routers/users.py b/src/dstack/_internal/server/routers/users.py index a47d5d8b4..c89b0b10e 100644 --- a/src/dstack/_internal/server/routers/users.py +++ b/src/dstack/_internal/server/routers/users.py @@ -60,6 +60,7 @@ async def create_user( username=body.username, global_role=body.global_role, email=body.email, + active=body.active, ) return users.user_model_to_user(res) @@ -75,6 +76,7 @@ async def update_user( username=body.username, global_role=body.global_role, email=body.email, + active=body.active, ) if res is None: raise ResourceNotExistsError() diff --git a/src/dstack/_internal/server/schemas/users.py b/src/dstack/_internal/server/schemas/users.py index 1298bd0ed..6579d9657 100644 --- a/src/dstack/_internal/server/schemas/users.py +++ b/src/dstack/_internal/server/schemas/users.py @@ -12,6 +12,7 @@ class CreateUserRequest(CoreModel): username: str global_role: GlobalRole email: Optional[str] + active: bool = True UpdateUserRequest = CreateUserRequest diff --git a/src/dstack/_internal/server/security/permissions.py b/src/dstack/_internal/server/security/permissions.py index b9727df70..9954e0c34 100644 --- a/src/dstack/_internal/server/security/permissions.py +++ b/src/dstack/_internal/server/security/permissions.py @@ -12,7 +12,7 @@ get_project_model_by_name, get_user_project_role, ) -from dstack._internal.server.services.users import get_user_model_by_token +from dstack._internal.server.services.users import log_in_with_token from dstack._internal.server.utils.routers import ( error_forbidden, error_invalid_token, @@ -26,7 +26,7 @@ async def __call__( session: AsyncSession = Depends(get_session), token: HTTPAuthorizationCredentials = Security(HTTPBearer()), ) -> UserModel: - user = await get_user_model_by_token(session=session, token=token.credentials) + user = await log_in_with_token(session=session, token=token.credentials) if user is None: raise error_invalid_token() return user @@ -38,7 +38,7 @@ async def __call__( session: AsyncSession = Depends(get_session), token: HTTPAuthorizationCredentials = Security(HTTPBearer()), ) -> UserModel: - user = await get_user_model_by_token(session=session, token=token.credentials) + user = await log_in_with_token(session=session, token=token.credentials) if user is None: raise error_invalid_token() if user.global_role == GlobalRole.ADMIN: @@ -53,7 +53,7 @@ async def __call__( session: AsyncSession = Depends(get_session), token: HTTPAuthorizationCredentials = Security(HTTPBearer()), ) -> Tuple[UserModel, ProjectModel]: - user = await get_user_model_by_token(session=session, token=token.credentials) + user = await log_in_with_token(session=session, token=token.credentials) if user is None: raise error_invalid_token() project = await get_project_model_by_name(session=session, project_name=project_name) @@ -74,7 +74,7 @@ async def __call__( session: AsyncSession = Depends(get_session), token: HTTPAuthorizationCredentials = Security(HTTPBearer()), ) -> Tuple[UserModel, ProjectModel]: - user = await get_user_model_by_token(session=session, token=token.credentials) + user = await log_in_with_token(session=session, token=token.credentials) if user is None: raise error_invalid_token() project = await get_project_model_by_name(session=session, project_name=project_name) @@ -96,7 +96,7 @@ async def __call__( project_name: str, token: HTTPAuthorizationCredentials = Security(HTTPBearer()), ) -> Tuple[UserModel, ProjectModel]: - user = await get_user_model_by_token(session=session, token=token.credentials) + user = await log_in_with_token(session=session, token=token.credentials) if user is None: raise error_invalid_token() project = await get_project_model_by_name(session=session, project_name=project_name) diff --git a/src/dstack/_internal/server/services/users.py b/src/dstack/_internal/server/services/users.py index 378b25057..8a247668f 100644 --- a/src/dstack/_internal/server/services/users.py +++ b/src/dstack/_internal/server/services/users.py @@ -62,6 +62,7 @@ async def create_user( username: str, global_role: GlobalRole, email: Optional[str] = None, + active: bool = True, ) -> UserModel: user_model = await get_user_model_by_name(session=session, username=username, ignore_case=True) if user_model is not None: @@ -74,6 +75,7 @@ async def create_user( token=DecryptedString(plaintext=token), token_hash=get_token_hash(token), email=email, + active=active, ) session.add(user) await session.commit() @@ -87,11 +89,16 @@ async def update_user( username: str, global_role: GlobalRole, email: Optional[str] = None, + active: bool = True, ) -> UserModel: await session.execute( update(UserModel) .where(UserModel.name == username) - .values(global_role=global_role, email=email) + .values( + global_role=global_role, + email=email, + active=active, + ) ) await session.commit() return await get_user_model_by_name_or_error(session=session, username=username) @@ -145,9 +152,14 @@ async def get_user_model_by_name_or_error(session: AsyncSession, username: str) return res.scalar_one() -async def get_user_model_by_token(session: AsyncSession, token: str) -> Optional[UserModel]: +async def log_in_with_token(session: AsyncSession, token: str) -> Optional[UserModel]: token_hash = get_token_hash(token) - res = await session.execute(select(UserModel).where(UserModel.token_hash == token_hash)) + res = await session.execute( + select(UserModel).where( + UserModel.token_hash == token_hash, + UserModel.active == True, + ) + ) user = res.scalar() if user is None: return None @@ -167,6 +179,7 @@ def user_model_to_user(user_model: UserModel) -> User: username=user_model.name, global_role=user_model.global_role, email=user_model.email, + active=user_model.active, ) @@ -176,6 +189,7 @@ def user_model_to_user_with_creds(user_model: UserModel) -> UserWithCreds: username=user_model.name, global_role=user_model.global_role, email=user_model.email, + active=user_model.active, creds=UserTokenCreds(token=user_model.token.get_plaintext_or_error()), ) diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 1f38b874b..832e980f6 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -72,6 +72,7 @@ async def create_user( global_role: GlobalRole = GlobalRole.ADMIN, token: Optional[str] = None, email: Optional[str] = None, + active: bool = True, ) -> UserModel: if token is None: token = str(uuid.uuid4()) @@ -81,6 +82,7 @@ async def create_user( token=DecryptedString(plaintext=token), token_hash=get_token_hash(token), email=email, + active=active, ) session.add(user) await session.commit() diff --git a/src/tests/_internal/server/routers/test_projects.py b/src/tests/_internal/server/routers/test_projects.py index 11de70fc1..1545f8908 100644 --- a/src/tests/_internal/server/routers/test_projects.py +++ b/src/tests/_internal/server/routers/test_projects.py @@ -54,6 +54,7 @@ async def test_returns_projects(self, test_db, session: AsyncSession): "username": user.name, "global_role": user.global_role, "email": None, + "active": True, }, "backends": [], "members": [], @@ -88,6 +89,7 @@ async def test_creates_project(self, test_db, session: AsyncSession): "username": user.name, "global_role": user.global_role, "email": None, + "active": True, }, "backends": [], "members": [ @@ -97,6 +99,7 @@ async def test_creates_project(self, test_db, session: AsyncSession): "username": user.name, "global_role": user.global_role, "email": None, + "active": True, }, "project_role": ProjectRole.ADMIN, } @@ -277,6 +280,7 @@ async def test_returns_project(self, test_db, session: AsyncSession): "username": user.name, "global_role": user.global_role, "email": None, + "active": True, }, "backends": [], "members": [ @@ -286,6 +290,7 @@ async def test_returns_project(self, test_db, session: AsyncSession): "username": user.name, "global_role": user.global_role, "email": None, + "active": True, }, "project_role": ProjectRole.ADMIN, } @@ -338,6 +343,7 @@ async def test_sets_project_members(self, test_db, session: AsyncSession): "username": admin.name, "global_role": admin.global_role, "email": None, + "active": True, }, "project_role": ProjectRole.ADMIN, }, @@ -347,6 +353,7 @@ async def test_sets_project_members(self, test_db, session: AsyncSession): "username": user1.name, "global_role": user1.global_role, "email": None, + "active": True, }, "project_role": ProjectRole.ADMIN, }, @@ -356,6 +363,7 @@ async def test_sets_project_members(self, test_db, session: AsyncSession): "username": user2.name, "global_role": user2.global_role, "email": None, + "active": True, }, "project_role": ProjectRole.USER, }, diff --git a/src/tests/_internal/server/routers/test_users.py b/src/tests/_internal/server/routers/test_users.py index e03e802a4..8e59ffa06 100644 --- a/src/tests/_internal/server/routers/test_users.py +++ b/src/tests/_internal/server/routers/test_users.py @@ -30,6 +30,7 @@ async def test_returns_users(self, test_db, session): "username": user.name, "global_role": user.global_role, "email": None, + "active": True, } ] @@ -40,7 +41,17 @@ def test_returns_40x_if_not_authenticated(self): assert response.status_code in [401, 403] @pytest.mark.asyncio - async def test_returns_logged_in_user(self, test_db, session): + async def test_returns_40x_if_deactivated(self, test_db, session: AsyncSession): + user = await create_user(session=session, active=False) + response = client.post("/api/users/get_my_user", headers=get_auth_headers(user.token)) + assert response.status_code in [401, 403] + user.active = True + await session.commit() + response = client.post("/api/users/get_my_user", headers=get_auth_headers(user.token)) + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_returns_logged_in_user(self, test_db, session: AsyncSession): user = await create_user(session=session) response = client.post("/api/users/get_my_user", headers=get_auth_headers(user.token)) assert response.status_code == 200 @@ -49,6 +60,7 @@ async def test_returns_logged_in_user(self, test_db, session): "username": user.name, "global_role": user.global_role, "email": None, + "active": True, } @@ -58,7 +70,7 @@ def test_returns_40x_if_not_authenticated(self): assert response.status_code in [401, 403] @pytest.mark.asyncio - async def test_returns_400_if_not_global_admin(self, test_db, session): + async def test_returns_400_if_not_global_admin(self, test_db, session: AsyncSession): user = await create_user(session=session, global_role=GlobalRole.USER) other_user = await create_user(session=session, name="other_user", token="1234") response = client.post( @@ -69,7 +81,7 @@ async def test_returns_400_if_not_global_admin(self, test_db, session): assert response.status_code == 400 @pytest.mark.asyncio - async def test_returns_logged_in_user(self, test_db, session): + async def test_returns_logged_in_user(self, test_db, session: AsyncSession): user = await create_user(session=session, global_role=GlobalRole.ADMIN) other_user = await create_user(session=session, name="other_user", token="1234") response = client.post( @@ -84,6 +96,7 @@ async def test_returns_logged_in_user(self, test_db, session): "global_role": other_user.global_role, "email": None, "creds": {"token": "1234"}, + "active": True, } @@ -104,6 +117,7 @@ async def test_creates_user(self, test_db, session: AsyncSession): "username": "test", "global_role": GlobalRole.USER, "email": "test@example.com", + "active": True, }, ) assert response.status_code == 200 @@ -112,6 +126,7 @@ async def test_creates_user(self, test_db, session: AsyncSession): "username": "test", "global_role": "user", "email": "test@example.com", + "active": True, } res = await session.execute(select(UserModel).where(UserModel.name == "test")) assert len(res.scalars().all()) == 1 @@ -135,6 +150,7 @@ async def test_return_400_if_username_taken(self, test_db, session: AsyncSession "username": "Test", "global_role": "user", "email": None, + "active": True, } # Username uniqueness check should be case insensitive for username in ["test", "Test", "TesT"]: