Skip to content

Commit

Permalink
Implement user activation/deactivation (#1575)
Browse files Browse the repository at this point in the history
  • Loading branch information
r4victor authored Aug 19, 2024
1 parent 784f91c commit c00ccfd
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 12 deletions.
1 change: 1 addition & 0 deletions src/dstack/_internal/core/models/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class User(CoreModel):
username: str
global_role: GlobalRole
email: Optional[str]
active: bool


class UserTokenCreds(CoreModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Add UserModel.active
Revision ID: d6b11105f659
Revises: 54a77e19c64c
Create Date: 2024-08-19 15:10:25.751199
"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "d6b11105f659"
down_revision = "54a77e19c64c"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("users", schema=None) as batch_op:
batch_op.add_column(sa.Column("active", sa.Boolean(), nullable=True))

op.execute(sa.sql.text("UPDATE users SET active = TRUE"))

with op.batch_alter_table("users", schema=None) as batch_op:
batch_op.alter_column("active", nullable=False)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("users", schema=None) as batch_op:
batch_op.drop_column("active")

# ### end Alembic commands ###
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ class UserModel(BaseModel):
# token_hash is needed for fast search by token when stored token is encrypted
token_hash: Mapped[str] = mapped_column(String(2000), unique=True)
global_role: Mapped[GlobalRole] = mapped_column(Enum(GlobalRole))
# deactivated users cannot access API
active: Mapped[bool] = mapped_column(Boolean, default=True)

email: Mapped[Optional[str]] = mapped_column(String(200), nullable=True)

Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/routers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ async def create_user(
username=body.username,
global_role=body.global_role,
email=body.email,
active=body.active,
)
return users.user_model_to_user(res)

Expand All @@ -75,6 +76,7 @@ async def update_user(
username=body.username,
global_role=body.global_role,
email=body.email,
active=body.active,
)
if res is None:
raise ResourceNotExistsError()
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/server/schemas/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class CreateUserRequest(CoreModel):
username: str
global_role: GlobalRole
email: Optional[str]
active: bool = True


UpdateUserRequest = CreateUserRequest
Expand Down
12 changes: 6 additions & 6 deletions src/dstack/_internal/server/security/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
get_project_model_by_name,
get_user_project_role,
)
from dstack._internal.server.services.users import get_user_model_by_token
from dstack._internal.server.services.users import log_in_with_token
from dstack._internal.server.utils.routers import (
error_forbidden,
error_invalid_token,
Expand All @@ -26,7 +26,7 @@ async def __call__(
session: AsyncSession = Depends(get_session),
token: HTTPAuthorizationCredentials = Security(HTTPBearer()),
) -> UserModel:
user = await get_user_model_by_token(session=session, token=token.credentials)
user = await log_in_with_token(session=session, token=token.credentials)
if user is None:
raise error_invalid_token()
return user
Expand All @@ -38,7 +38,7 @@ async def __call__(
session: AsyncSession = Depends(get_session),
token: HTTPAuthorizationCredentials = Security(HTTPBearer()),
) -> UserModel:
user = await get_user_model_by_token(session=session, token=token.credentials)
user = await log_in_with_token(session=session, token=token.credentials)
if user is None:
raise error_invalid_token()
if user.global_role == GlobalRole.ADMIN:
Expand All @@ -53,7 +53,7 @@ async def __call__(
session: AsyncSession = Depends(get_session),
token: HTTPAuthorizationCredentials = Security(HTTPBearer()),
) -> Tuple[UserModel, ProjectModel]:
user = await get_user_model_by_token(session=session, token=token.credentials)
user = await log_in_with_token(session=session, token=token.credentials)
if user is None:
raise error_invalid_token()
project = await get_project_model_by_name(session=session, project_name=project_name)
Expand All @@ -74,7 +74,7 @@ async def __call__(
session: AsyncSession = Depends(get_session),
token: HTTPAuthorizationCredentials = Security(HTTPBearer()),
) -> Tuple[UserModel, ProjectModel]:
user = await get_user_model_by_token(session=session, token=token.credentials)
user = await log_in_with_token(session=session, token=token.credentials)
if user is None:
raise error_invalid_token()
project = await get_project_model_by_name(session=session, project_name=project_name)
Expand All @@ -96,7 +96,7 @@ async def __call__(
project_name: str,
token: HTTPAuthorizationCredentials = Security(HTTPBearer()),
) -> Tuple[UserModel, ProjectModel]:
user = await get_user_model_by_token(session=session, token=token.credentials)
user = await log_in_with_token(session=session, token=token.credentials)
if user is None:
raise error_invalid_token()
project = await get_project_model_by_name(session=session, project_name=project_name)
Expand Down
20 changes: 17 additions & 3 deletions src/dstack/_internal/server/services/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ async def create_user(
username: str,
global_role: GlobalRole,
email: Optional[str] = None,
active: bool = True,
) -> UserModel:
user_model = await get_user_model_by_name(session=session, username=username, ignore_case=True)
if user_model is not None:
Expand All @@ -74,6 +75,7 @@ async def create_user(
token=DecryptedString(plaintext=token),
token_hash=get_token_hash(token),
email=email,
active=active,
)
session.add(user)
await session.commit()
Expand All @@ -87,11 +89,16 @@ async def update_user(
username: str,
global_role: GlobalRole,
email: Optional[str] = None,
active: bool = True,
) -> UserModel:
await session.execute(
update(UserModel)
.where(UserModel.name == username)
.values(global_role=global_role, email=email)
.values(
global_role=global_role,
email=email,
active=active,
)
)
await session.commit()
return await get_user_model_by_name_or_error(session=session, username=username)
Expand Down Expand Up @@ -145,9 +152,14 @@ async def get_user_model_by_name_or_error(session: AsyncSession, username: str)
return res.scalar_one()


async def get_user_model_by_token(session: AsyncSession, token: str) -> Optional[UserModel]:
async def log_in_with_token(session: AsyncSession, token: str) -> Optional[UserModel]:
token_hash = get_token_hash(token)
res = await session.execute(select(UserModel).where(UserModel.token_hash == token_hash))
res = await session.execute(
select(UserModel).where(
UserModel.token_hash == token_hash,
UserModel.active == True,
)
)
user = res.scalar()
if user is None:
return None
Expand All @@ -167,6 +179,7 @@ def user_model_to_user(user_model: UserModel) -> User:
username=user_model.name,
global_role=user_model.global_role,
email=user_model.email,
active=user_model.active,
)


Expand All @@ -176,6 +189,7 @@ def user_model_to_user_with_creds(user_model: UserModel) -> UserWithCreds:
username=user_model.name,
global_role=user_model.global_role,
email=user_model.email,
active=user_model.active,
creds=UserTokenCreds(token=user_model.token.get_plaintext_or_error()),
)

Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ async def create_user(
global_role: GlobalRole = GlobalRole.ADMIN,
token: Optional[str] = None,
email: Optional[str] = None,
active: bool = True,
) -> UserModel:
if token is None:
token = str(uuid.uuid4())
Expand All @@ -81,6 +82,7 @@ async def create_user(
token=DecryptedString(plaintext=token),
token_hash=get_token_hash(token),
email=email,
active=active,
)
session.add(user)
await session.commit()
Expand Down
8 changes: 8 additions & 0 deletions src/tests/_internal/server/routers/test_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ async def test_returns_projects(self, test_db, session: AsyncSession):
"username": user.name,
"global_role": user.global_role,
"email": None,
"active": True,
},
"backends": [],
"members": [],
Expand Down Expand Up @@ -88,6 +89,7 @@ async def test_creates_project(self, test_db, session: AsyncSession):
"username": user.name,
"global_role": user.global_role,
"email": None,
"active": True,
},
"backends": [],
"members": [
Expand All @@ -97,6 +99,7 @@ async def test_creates_project(self, test_db, session: AsyncSession):
"username": user.name,
"global_role": user.global_role,
"email": None,
"active": True,
},
"project_role": ProjectRole.ADMIN,
}
Expand Down Expand Up @@ -277,6 +280,7 @@ async def test_returns_project(self, test_db, session: AsyncSession):
"username": user.name,
"global_role": user.global_role,
"email": None,
"active": True,
},
"backends": [],
"members": [
Expand All @@ -286,6 +290,7 @@ async def test_returns_project(self, test_db, session: AsyncSession):
"username": user.name,
"global_role": user.global_role,
"email": None,
"active": True,
},
"project_role": ProjectRole.ADMIN,
}
Expand Down Expand Up @@ -338,6 +343,7 @@ async def test_sets_project_members(self, test_db, session: AsyncSession):
"username": admin.name,
"global_role": admin.global_role,
"email": None,
"active": True,
},
"project_role": ProjectRole.ADMIN,
},
Expand All @@ -347,6 +353,7 @@ async def test_sets_project_members(self, test_db, session: AsyncSession):
"username": user1.name,
"global_role": user1.global_role,
"email": None,
"active": True,
},
"project_role": ProjectRole.ADMIN,
},
Expand All @@ -356,6 +363,7 @@ async def test_sets_project_members(self, test_db, session: AsyncSession):
"username": user2.name,
"global_role": user2.global_role,
"email": None,
"active": True,
},
"project_role": ProjectRole.USER,
},
Expand Down
22 changes: 19 additions & 3 deletions src/tests/_internal/server/routers/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ async def test_returns_users(self, test_db, session):
"username": user.name,
"global_role": user.global_role,
"email": None,
"active": True,
}
]

Expand All @@ -40,7 +41,17 @@ def test_returns_40x_if_not_authenticated(self):
assert response.status_code in [401, 403]

@pytest.mark.asyncio
async def test_returns_logged_in_user(self, test_db, session):
async def test_returns_40x_if_deactivated(self, test_db, session: AsyncSession):
user = await create_user(session=session, active=False)
response = client.post("/api/users/get_my_user", headers=get_auth_headers(user.token))
assert response.status_code in [401, 403]
user.active = True
await session.commit()
response = client.post("/api/users/get_my_user", headers=get_auth_headers(user.token))
assert response.status_code == 200

@pytest.mark.asyncio
async def test_returns_logged_in_user(self, test_db, session: AsyncSession):
user = await create_user(session=session)
response = client.post("/api/users/get_my_user", headers=get_auth_headers(user.token))
assert response.status_code == 200
Expand All @@ -49,6 +60,7 @@ async def test_returns_logged_in_user(self, test_db, session):
"username": user.name,
"global_role": user.global_role,
"email": None,
"active": True,
}


Expand All @@ -58,7 +70,7 @@ def test_returns_40x_if_not_authenticated(self):
assert response.status_code in [401, 403]

@pytest.mark.asyncio
async def test_returns_400_if_not_global_admin(self, test_db, session):
async def test_returns_400_if_not_global_admin(self, test_db, session: AsyncSession):
user = await create_user(session=session, global_role=GlobalRole.USER)
other_user = await create_user(session=session, name="other_user", token="1234")
response = client.post(
Expand All @@ -69,7 +81,7 @@ async def test_returns_400_if_not_global_admin(self, test_db, session):
assert response.status_code == 400

@pytest.mark.asyncio
async def test_returns_logged_in_user(self, test_db, session):
async def test_returns_logged_in_user(self, test_db, session: AsyncSession):
user = await create_user(session=session, global_role=GlobalRole.ADMIN)
other_user = await create_user(session=session, name="other_user", token="1234")
response = client.post(
Expand All @@ -84,6 +96,7 @@ async def test_returns_logged_in_user(self, test_db, session):
"global_role": other_user.global_role,
"email": None,
"creds": {"token": "1234"},
"active": True,
}


Expand All @@ -104,6 +117,7 @@ async def test_creates_user(self, test_db, session: AsyncSession):
"username": "test",
"global_role": GlobalRole.USER,
"email": "test@example.com",
"active": True,
},
)
assert response.status_code == 200
Expand All @@ -112,6 +126,7 @@ async def test_creates_user(self, test_db, session: AsyncSession):
"username": "test",
"global_role": "user",
"email": "test@example.com",
"active": True,
}
res = await session.execute(select(UserModel).where(UserModel.name == "test"))
assert len(res.scalars().all()) == 1
Expand All @@ -135,6 +150,7 @@ async def test_return_400_if_username_taken(self, test_db, session: AsyncSession
"username": "Test",
"global_role": "user",
"email": None,
"active": True,
}
# Username uniqueness check should be case insensitive
for username in ["test", "Test", "TesT"]:
Expand Down

0 comments on commit c00ccfd

Please sign in to comment.