Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement project Manager role #1572

Merged
merged 3 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/dstack/_internal/core/models/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

class ProjectRole(str, enum.Enum):
ADMIN = "admin"
MANAGER = "manager"
USER = "user"


Expand Down
5 changes: 4 additions & 1 deletion src/dstack/_internal/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,12 @@ def register_routes(app: FastAPI, ui: bool = True):

@app.exception_handler(ForbiddenError)
async def forbidden_error_handler(request: Request, exc: ForbiddenError):
msg = "Access denied"
if len(exc.args) > 0:
msg = exc.args[0]
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content=error_detail("Access denied"),
content=error_detail(msg),
)

@app.exception_handler(ServerClientError)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Add Manager project role

Revision ID: 54a77e19c64c
Revises: 710e5b3fac8f
Create Date: 2024-08-16 14:25:52.125915

"""

import sqlalchemy as sa
from alembic import op
from alembic_postgresql_enum import TableReference

# revision identifiers, used by Alembic.
revision = "54a77e19c64c"
down_revision = "710e5b3fac8f"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
if op.get_context().dialect.name == "postgresql":
op.sync_enum_values(
"public",
"projectrole",
["ADMIN", "MANAGER", "USER"],
[
TableReference(
table_schema="public", table_name="members", column_name="project_role"
)
],
enum_values_to_rename=[],
)
else:
with op.batch_alter_table("members", schema=None) as batch_op:
batch_op.alter_column(
"project_role",
existing_type=sa.VARCHAR(length=5),
type_=sa.Enum("ADMIN", "MANAGER", "USER", name="projectrole"),
existing_nullable=False,
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
if op.get_context().dialect.name == "postgresql":
op.sync_enum_values(
"public",
"projectrole",
["ADMIN", "USER"],
[
TableReference(
table_schema="public", table_name="members", column_name="project_role"
)
],
enum_values_to_rename=[],
)
else:
with op.batch_alter_table("members", schema=None) as batch_op:
batch_op.alter_column(
"project_role",
existing_type=sa.Enum("ADMIN", "MANAGER", "USER", name="projectrole"),
type_=sa.VARCHAR(length=5),
existing_nullable=False,
)
# ### end Alembic commands ###
11 changes: 8 additions & 3 deletions src/dstack/_internal/server/routers/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
DeleteProjectsRequest,
SetProjectMembersRequest,
)
from dstack._internal.server.security.permissions import Authenticated, ProjectAdmin, ProjectMember
from dstack._internal.server.security.permissions import (
Authenticated,
ProjectManager,
ProjectMember,
)
from dstack._internal.server.services import projects

router = APIRouter(prefix="/api/projects", tags=["projects"])
Expand Down Expand Up @@ -70,11 +74,12 @@ async def get_project(
async def set_project_members(
body: SetProjectMembersRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectManager()),
) -> Project:
_, project = user_project
user, project = user_project
await projects.set_project_members(
session=session,
user=user,
project=project,
members=body.members,
)
Expand Down
41 changes: 31 additions & 10 deletions src/dstack/_internal/server/security/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from dstack._internal.core.models.users import GlobalRole, ProjectRole
from dstack._internal.server.db import get_session
from dstack._internal.server.models import ProjectModel, UserModel
from dstack._internal.server.services.projects import get_project_model_by_name
from dstack._internal.server.services.projects import (
get_project_model_by_name,
get_user_project_role,
)
from dstack._internal.server.services.users import get_user_model_by_token
from dstack._internal.server.utils.routers import (
error_forbidden,
Expand Down Expand Up @@ -58,12 +61,30 @@ async def __call__(
raise error_forbidden()
if user.global_role == GlobalRole.ADMIN:
return user, project
for member in project.members:
if member.user_id == user.id:
if member.project_role == ProjectRole.ADMIN:
return user, project
else:
raise error_forbidden()
project_role = get_user_project_role(user=user, project=project)
if project_role == ProjectRole.ADMIN:
return user, project
raise error_forbidden()


class ProjectManager:
async def __call__(
self,
project_name: str,
session: AsyncSession = Depends(get_session),
token: HTTPAuthorizationCredentials = Security(HTTPBearer()),
) -> Tuple[UserModel, ProjectModel]:
user = await get_user_model_by_token(session=session, token=token.credentials)
if user is None:
raise error_invalid_token()
project = await get_project_model_by_name(session=session, project_name=project_name)
if project is None:
raise error_forbidden()
if user.global_role in GlobalRole.ADMIN:
return user, project
project_role = get_user_project_role(user=user, project=project)
if project_role in [ProjectRole.ADMIN, ProjectRole.MANAGER]:
return user, project
raise error_forbidden()


Expand All @@ -83,7 +104,7 @@ async def __call__(
raise error_not_found()
if user.global_role == GlobalRole.ADMIN:
return user, project
for member in project.members:
if member.user_id == user.id:
return user, project
project_role = get_user_project_role(user=user, project=project)
if project_role is not None:
return user, project
raise error_forbidden()
67 changes: 46 additions & 21 deletions src/dstack/_internal/server/services/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,48 +144,66 @@ async def delete_projects(
await session.commit()


async def add_project_member(
session: AsyncSession,
project: ProjectModel,
user: UserModel,
project_role: ProjectRole,
commit: bool = True,
) -> MemberModel:
member = MemberModel(
user_id=user.id,
project_id=project.id,
project_role=project_role,
)
session.add(member)
if commit:
await session.commit()
return member


async def set_project_members(
session: AsyncSession,
user: UserModel,
project: ProjectModel,
members: List[MemberSetting],
):
# reload with members
project = await get_project_model_by_name_or_error(
session=session,
project_name=project.name,
)
project_role = get_user_project_role(user=user, project=project)
if project_role == ProjectRole.MANAGER:
new_admins_members = {
(m.username, m.project_role) for m in members if m.project_role == ProjectRole.ADMIN
}
current_admins_members = {
(m.user.name, m.project_role)
for m in project.members
if m.project_role == ProjectRole.ADMIN
}
if new_admins_members != current_admins_members:
raise ForbiddenError("Access denied: changing project admins")
await clear_project_members(session=session, project=project)
usernames = [m.username for m in members]
res = await session.execute(select(UserModel).where(UserModel.name.in_(usernames)))
users = res.scalars().all()
username_to_user = {user.name: user for user in users}
for member in members:
user = username_to_user.get(member.username)
if user is None:
user_to_add = username_to_user.get(member.username)
if user_to_add is None:
continue
await add_project_member(
session=session,
project=project,
user=user,
user=user_to_add,
project_role=member.project_role,
commit=False,
)
await session.commit()


async def add_project_member(
session: AsyncSession,
project: ProjectModel,
user: UserModel,
project_role: ProjectRole,
commit: bool = True,
) -> MemberModel:
member = MemberModel(
user_id=user.id,
project_id=project.id,
project_role=project_role,
)
session.add(member)
if commit:
await session.commit()
return member


async def clear_project_members(
session: AsyncSession,
project: ProjectModel,
Expand Down Expand Up @@ -305,6 +323,13 @@ async def create_project_model(
return project


def get_user_project_role(user: UserModel, project: ProjectModel) -> Optional[ProjectRole]:
for member in project.members:
if member.user_id == user.id:
return member.project_role
return None


def project_model_to_project(
project_model: ProjectModel,
include_backends: bool = True,
Expand Down
63 changes: 62 additions & 1 deletion src/tests/_internal/server/routers/test_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,10 @@ async def test_sets_project_members(self, test_db, session: AsyncSession):
project = await create_project(session=session)
admin = await create_user(session=session)
await add_project_member(
session=session, project=project, user=admin, project_role=ProjectRole.ADMIN
session=session,
project=project,
user=admin,
project_role=ProjectRole.ADMIN,
)
user1 = await create_user(session=session, name="user1")
user2 = await create_user(session=session, name="user2")
Expand Down Expand Up @@ -360,3 +363,61 @@ async def test_sets_project_members(self, test_db, session: AsyncSession):
res = await session.execute(select(MemberModel))
members = res.scalars().all()
assert len(members) == 3

@pytest.mark.asyncio
async def test_manager_cannot_set_project_admins(self, test_db, session: AsyncSession):
project = await create_project(session=session)
user = await create_user(session=session, global_role=GlobalRole.USER)
await add_project_member(
session=session,
project=project,
user=user,
project_role=ProjectRole.MANAGER,
)
user1 = await create_user(session=session, name="user1")
members = [
{
"username": user.name,
"project_role": ProjectRole.ADMIN,
},
{
"username": user1.name,
"project_role": ProjectRole.ADMIN,
},
]
body = {"members": members}
response = client.post(
f"/api/projects/{project.name}/set_members",
headers=get_auth_headers(user.token),
json=body,
)
assert response.status_code == 403

@pytest.mark.asyncio
async def test_non_manager_cannot_set_project_members(self, test_db, session: AsyncSession):
project = await create_project(session=session)
user = await create_user(session=session, global_role=GlobalRole.USER)
await add_project_member(
session=session,
project=project,
user=user,
project_role=ProjectRole.USER,
)
user1 = await create_user(session=session, name="user1")
members = [
{
"username": user.name,
"project_role": ProjectRole.ADMIN,
},
{
"username": user1.name,
"project_role": ProjectRole.ADMIN,
},
]
body = {"members": members}
response = client.post(
f"/api/projects/{project.name}/set_members",
headers=get_auth_headers(user.token),
json=body,
)
assert response.status_code == 403