Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement user activation/deactivation #1575

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/dstack/_internal/core/models/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class User(CoreModel):
username: str
global_role: GlobalRole
email: Optional[str]
active: bool


class UserTokenCreds(CoreModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Add UserModel.active
Revision ID: d6b11105f659
Revises: 54a77e19c64c
Create Date: 2024-08-19 15:10:25.751199
"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "d6b11105f659"
down_revision = "54a77e19c64c"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("users", schema=None) as batch_op:
batch_op.add_column(sa.Column("active", sa.Boolean(), nullable=True))

op.execute(sa.sql.text("UPDATE users SET active = TRUE"))

with op.batch_alter_table("users", schema=None) as batch_op:
batch_op.alter_column("active", nullable=False)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("users", schema=None) as batch_op:
batch_op.drop_column("active")

# ### end Alembic commands ###
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ class UserModel(BaseModel):
# token_hash is needed for fast search by token when stored token is encrypted
token_hash: Mapped[str] = mapped_column(String(2000), unique=True)
global_role: Mapped[GlobalRole] = mapped_column(Enum(GlobalRole))
# deactivated users cannot access API
active: Mapped[bool] = mapped_column(Boolean, default=True)

email: Mapped[Optional[str]] = mapped_column(String(200), nullable=True)

Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/routers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ async def create_user(
username=body.username,
global_role=body.global_role,
email=body.email,
active=body.active,
)
return users.user_model_to_user(res)

Expand All @@ -75,6 +76,7 @@ async def update_user(
username=body.username,
global_role=body.global_role,
email=body.email,
active=body.active,
)
if res is None:
raise ResourceNotExistsError()
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/server/schemas/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class CreateUserRequest(CoreModel):
username: str
global_role: GlobalRole
email: Optional[str]
active: bool = True


UpdateUserRequest = CreateUserRequest
Expand Down
12 changes: 6 additions & 6 deletions src/dstack/_internal/server/security/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
get_project_model_by_name,
get_user_project_role,
)
from dstack._internal.server.services.users import get_user_model_by_token
from dstack._internal.server.services.users import log_in_with_token
from dstack._internal.server.utils.routers import (
error_forbidden,
error_invalid_token,
Expand All @@ -26,7 +26,7 @@ async def __call__(
session: AsyncSession = Depends(get_session),
token: HTTPAuthorizationCredentials = Security(HTTPBearer()),
) -> UserModel:
user = await get_user_model_by_token(session=session, token=token.credentials)
user = await log_in_with_token(session=session, token=token.credentials)
if user is None:
raise error_invalid_token()
return user
Expand All @@ -38,7 +38,7 @@ async def __call__(
session: AsyncSession = Depends(get_session),
token: HTTPAuthorizationCredentials = Security(HTTPBearer()),
) -> UserModel:
user = await get_user_model_by_token(session=session, token=token.credentials)
user = await log_in_with_token(session=session, token=token.credentials)
if user is None:
raise error_invalid_token()
if user.global_role == GlobalRole.ADMIN:
Expand All @@ -53,7 +53,7 @@ async def __call__(
session: AsyncSession = Depends(get_session),
token: HTTPAuthorizationCredentials = Security(HTTPBearer()),
) -> Tuple[UserModel, ProjectModel]:
user = await get_user_model_by_token(session=session, token=token.credentials)
user = await log_in_with_token(session=session, token=token.credentials)
if user is None:
raise error_invalid_token()
project = await get_project_model_by_name(session=session, project_name=project_name)
Expand All @@ -74,7 +74,7 @@ async def __call__(
session: AsyncSession = Depends(get_session),
token: HTTPAuthorizationCredentials = Security(HTTPBearer()),
) -> Tuple[UserModel, ProjectModel]:
user = await get_user_model_by_token(session=session, token=token.credentials)
user = await log_in_with_token(session=session, token=token.credentials)
if user is None:
raise error_invalid_token()
project = await get_project_model_by_name(session=session, project_name=project_name)
Expand All @@ -96,7 +96,7 @@ async def __call__(
project_name: str,
token: HTTPAuthorizationCredentials = Security(HTTPBearer()),
) -> Tuple[UserModel, ProjectModel]:
user = await get_user_model_by_token(session=session, token=token.credentials)
user = await log_in_with_token(session=session, token=token.credentials)
if user is None:
raise error_invalid_token()
project = await get_project_model_by_name(session=session, project_name=project_name)
Expand Down
20 changes: 17 additions & 3 deletions src/dstack/_internal/server/services/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ async def create_user(
username: str,
global_role: GlobalRole,
email: Optional[str] = None,
active: bool = True,
) -> UserModel:
user_model = await get_user_model_by_name(session=session, username=username, ignore_case=True)
if user_model is not None:
Expand All @@ -74,6 +75,7 @@ async def create_user(
token=DecryptedString(plaintext=token),
token_hash=get_token_hash(token),
email=email,
active=active,
)
session.add(user)
await session.commit()
Expand All @@ -87,11 +89,16 @@ async def update_user(
username: str,
global_role: GlobalRole,
email: Optional[str] = None,
active: bool = True,
) -> UserModel:
await session.execute(
update(UserModel)
.where(UserModel.name == username)
.values(global_role=global_role, email=email)
.values(
global_role=global_role,
email=email,
active=active,
)
)
await session.commit()
return await get_user_model_by_name_or_error(session=session, username=username)
Expand Down Expand Up @@ -145,9 +152,14 @@ async def get_user_model_by_name_or_error(session: AsyncSession, username: str)
return res.scalar_one()


async def get_user_model_by_token(session: AsyncSession, token: str) -> Optional[UserModel]:
async def log_in_with_token(session: AsyncSession, token: str) -> Optional[UserModel]:
token_hash = get_token_hash(token)
res = await session.execute(select(UserModel).where(UserModel.token_hash == token_hash))
res = await session.execute(
select(UserModel).where(
UserModel.token_hash == token_hash,
UserModel.active == True,
)
)
user = res.scalar()
if user is None:
return None
Expand All @@ -167,6 +179,7 @@ def user_model_to_user(user_model: UserModel) -> User:
username=user_model.name,
global_role=user_model.global_role,
email=user_model.email,
active=user_model.active,
)


Expand All @@ -176,6 +189,7 @@ def user_model_to_user_with_creds(user_model: UserModel) -> UserWithCreds:
username=user_model.name,
global_role=user_model.global_role,
email=user_model.email,
active=user_model.active,
creds=UserTokenCreds(token=user_model.token.get_plaintext_or_error()),
)

Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ async def create_user(
global_role: GlobalRole = GlobalRole.ADMIN,
token: Optional[str] = None,
email: Optional[str] = None,
active: bool = True,
) -> UserModel:
if token is None:
token = str(uuid.uuid4())
Expand All @@ -81,6 +82,7 @@ async def create_user(
token=DecryptedString(plaintext=token),
token_hash=get_token_hash(token),
email=email,
active=active,
)
session.add(user)
await session.commit()
Expand Down
8 changes: 8 additions & 0 deletions src/tests/_internal/server/routers/test_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ async def test_returns_projects(self, test_db, session: AsyncSession):
"username": user.name,
"global_role": user.global_role,
"email": None,
"active": True,
},
"backends": [],
"members": [],
Expand Down Expand Up @@ -88,6 +89,7 @@ async def test_creates_project(self, test_db, session: AsyncSession):
"username": user.name,
"global_role": user.global_role,
"email": None,
"active": True,
},
"backends": [],
"members": [
Expand All @@ -97,6 +99,7 @@ async def test_creates_project(self, test_db, session: AsyncSession):
"username": user.name,
"global_role": user.global_role,
"email": None,
"active": True,
},
"project_role": ProjectRole.ADMIN,
}
Expand Down Expand Up @@ -277,6 +280,7 @@ async def test_returns_project(self, test_db, session: AsyncSession):
"username": user.name,
"global_role": user.global_role,
"email": None,
"active": True,
},
"backends": [],
"members": [
Expand All @@ -286,6 +290,7 @@ async def test_returns_project(self, test_db, session: AsyncSession):
"username": user.name,
"global_role": user.global_role,
"email": None,
"active": True,
},
"project_role": ProjectRole.ADMIN,
}
Expand Down Expand Up @@ -338,6 +343,7 @@ async def test_sets_project_members(self, test_db, session: AsyncSession):
"username": admin.name,
"global_role": admin.global_role,
"email": None,
"active": True,
},
"project_role": ProjectRole.ADMIN,
},
Expand All @@ -347,6 +353,7 @@ async def test_sets_project_members(self, test_db, session: AsyncSession):
"username": user1.name,
"global_role": user1.global_role,
"email": None,
"active": True,
},
"project_role": ProjectRole.ADMIN,
},
Expand All @@ -356,6 +363,7 @@ async def test_sets_project_members(self, test_db, session: AsyncSession):
"username": user2.name,
"global_role": user2.global_role,
"email": None,
"active": True,
},
"project_role": ProjectRole.USER,
},
Expand Down
22 changes: 19 additions & 3 deletions src/tests/_internal/server/routers/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ async def test_returns_users(self, test_db, session):
"username": user.name,
"global_role": user.global_role,
"email": None,
"active": True,
}
]

Expand All @@ -40,7 +41,17 @@ def test_returns_40x_if_not_authenticated(self):
assert response.status_code in [401, 403]

@pytest.mark.asyncio
async def test_returns_logged_in_user(self, test_db, session):
async def test_returns_40x_if_deactivated(self, test_db, session: AsyncSession):
user = await create_user(session=session, active=False)
response = client.post("/api/users/get_my_user", headers=get_auth_headers(user.token))
assert response.status_code in [401, 403]
user.active = True
await session.commit()
response = client.post("/api/users/get_my_user", headers=get_auth_headers(user.token))
assert response.status_code == 200

@pytest.mark.asyncio
async def test_returns_logged_in_user(self, test_db, session: AsyncSession):
user = await create_user(session=session)
response = client.post("/api/users/get_my_user", headers=get_auth_headers(user.token))
assert response.status_code == 200
Expand All @@ -49,6 +60,7 @@ async def test_returns_logged_in_user(self, test_db, session):
"username": user.name,
"global_role": user.global_role,
"email": None,
"active": True,
}


Expand All @@ -58,7 +70,7 @@ def test_returns_40x_if_not_authenticated(self):
assert response.status_code in [401, 403]

@pytest.mark.asyncio
async def test_returns_400_if_not_global_admin(self, test_db, session):
async def test_returns_400_if_not_global_admin(self, test_db, session: AsyncSession):
user = await create_user(session=session, global_role=GlobalRole.USER)
other_user = await create_user(session=session, name="other_user", token="1234")
response = client.post(
Expand All @@ -69,7 +81,7 @@ async def test_returns_400_if_not_global_admin(self, test_db, session):
assert response.status_code == 400

@pytest.mark.asyncio
async def test_returns_logged_in_user(self, test_db, session):
async def test_returns_logged_in_user(self, test_db, session: AsyncSession):
user = await create_user(session=session, global_role=GlobalRole.ADMIN)
other_user = await create_user(session=session, name="other_user", token="1234")
response = client.post(
Expand All @@ -84,6 +96,7 @@ async def test_returns_logged_in_user(self, test_db, session):
"global_role": other_user.global_role,
"email": None,
"creds": {"token": "1234"},
"active": True,
}


Expand All @@ -104,6 +117,7 @@ async def test_creates_user(self, test_db, session: AsyncSession):
"username": "test",
"global_role": GlobalRole.USER,
"email": "test@example.com",
"active": True,
},
)
assert response.status_code == 200
Expand All @@ -112,6 +126,7 @@ async def test_creates_user(self, test_db, session: AsyncSession):
"username": "test",
"global_role": "user",
"email": "test@example.com",
"active": True,
}
res = await session.execute(select(UserModel).where(UserModel.name == "test"))
assert len(res.scalars().all()) == 1
Expand All @@ -135,6 +150,7 @@ async def test_return_400_if_username_taken(self, test_db, session: AsyncSession
"username": "Test",
"global_role": "user",
"email": None,
"active": True,
}
# Username uniqueness check should be case insensitive
for username in ["test", "Test", "TesT"]:
Expand Down