From f2bd1d399ec7ac0e454ff7dbc024d159d043111e Mon Sep 17 00:00:00 2001 From: niklub Date: Wed, 22 May 2024 10:52:19 +0100 Subject: [PATCH] feat: RND-72: Add vendor extensions support in openapi autoschema (#5908) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit drf-yasg default `SwaggerAutoSchema` doesn’t provide a way to include custom extensions in the generated openapi.json. For example, we need to produce the following path to be able to use[ idiomatic method naming with fern](https://buildwithfern.com/learn/api-definition/openapi/extensions#sdk-method-names) ```yaml paths: "/api/annotations/{id}/": get: x-fern-sdk-group-name: annotations x-fern-sdk-method-name: get operationId: api_annotations_read summary: Get annotation by its ID description: Retrieve a specific annotation for a task using the annotation result ID. responses: "200": description: "" content: application/json: schema: $ref: "#/components/schemas/Annotation" tags: - Annotations ``` **Proposed solution**: - Extend SwaggerAutoSchema to include extra operation without modifying the names - Allow base settings config to specify vendor extensions prefixes **Additional changes** - openapi 3.0 incompatible namings in some paths - example of using x-fern in annotations.get --------- Co-authored-by: nik --- label_studio/core/settings/base.py | 8 ++++++++ label_studio/core/utils/openapi_extensions.py | 13 +++++++++++++ label_studio/data_import/api.py | 12 ++++++++++++ label_studio/ml/api.py | 8 +++----- label_studio/tasks/api.py | 2 ++ label_studio/users/api.py | 7 +++++-- 6 files changed, 43 insertions(+), 7 deletions(-) create mode 100644 label_studio/core/utils/openapi_extensions.py diff --git a/label_studio/core/settings/base.py b/label_studio/core/settings/base.py index 272a0c28de1..8cad9779a21 100644 --- a/label_studio/core/settings/base.py +++ b/label_studio/core/settings/base.py @@ -345,6 +345,13 @@ }, } +# specify the list of the extensions that are allowed to be presented in auto generated OpenAPI schema +# for example, by specifying in swagger_auto_schema(..., x_fern_sdk_group_name='projects') we can group endpoints +# /api/projects/: +# get: +# x-fern-sdk-group-name: projects +X_VENDOR_OPENAPI_EXTENSIONS = ['x-fern'] + # Swagger: automatic API documentation SWAGGER_SETTINGS = { 'SECURITY_DEFINITIONS': { @@ -362,6 +369,7 @@ 'APIS_SORTER': 'alpha', 'SUPPORTED_SUBMIT_METHODS': ['get', 'post', 'put', 'delete', 'patch'], 'OPERATIONS_SORTER': 'alpha', + 'DEFAULT_AUTO_SCHEMA_CLASS': 'core.utils.openapi_extensions.XVendorExtensionsAutoSchema', } SENTRY_DSN = get_env('SENTRY_DSN', None) diff --git a/label_studio/core/utils/openapi_extensions.py b/label_studio/core/utils/openapi_extensions.py new file mode 100644 index 00000000000..794f76a1420 --- /dev/null +++ b/label_studio/core/utils/openapi_extensions.py @@ -0,0 +1,13 @@ +from django.conf import settings +from drf_yasg.inspectors import SwaggerAutoSchema + + +class XVendorExtensionsAutoSchema(SwaggerAutoSchema): + allowed_extensions = tuple([e.replace('-', '_') for e in settings.X_VENDOR_OPENAPI_EXTENSIONS]) + + def get_operation(self, operation_keys=None): + operation = super(XVendorExtensionsAutoSchema, self).get_operation(operation_keys) + for key, value in self.overrides.items(): + if key.startswith(self.allowed_extensions): + operation[key.replace('_', '-')] = value + return operation diff --git a/label_studio/data_import/api.py b/label_studio/data_import/api.py index 421c2fb47e8..331b948d212 100644 --- a/label_studio/data_import/api.py +++ b/label_studio/data_import/api.py @@ -111,6 +111,8 @@ name='post', decorator=swagger_auto_schema( tags=['Import'], + x_fern_sdk_group_name='projects', + x_fern_sdk_method_name='import_tasks', responses=task_create_response_scheme, manual_parameters=[ openapi.Parameter( @@ -490,6 +492,8 @@ def post(self, *args, **kwargs): name='get', decorator=swagger_auto_schema( tags=['Import'], + x_fern_sdk_group_name='files', + x_fern_sdk_method_name='list', operation_summary='Get files list', manual_parameters=[ openapi.Parameter( @@ -515,6 +519,8 @@ def post(self, *args, **kwargs): name='delete', decorator=swagger_auto_schema( tags=['Import'], + x_fern_sdk_group_name='files', + x_fern_sdk_method_name='delete_many', operation_summary='Delete files', operation_description=""" Delete uploaded files for a specific project. @@ -561,6 +567,8 @@ def delete(self, request, *args, **kwargs): name='get', decorator=swagger_auto_schema( tags=['Import'], + x_fern_sdk_group_name='files', + x_fern_sdk_method_name='get', operation_summary='Get file upload', operation_description='Retrieve details about a specific uploaded file.', ), @@ -569,6 +577,8 @@ def delete(self, request, *args, **kwargs): name='patch', decorator=swagger_auto_schema( tags=['Import'], + x_fern_sdk_group_name='files', + x_fern_sdk_method_name='update', operation_summary='Update file upload', operation_description='Update a specific uploaded file.', request_body=FileUploadSerializer, @@ -578,6 +588,8 @@ def delete(self, request, *args, **kwargs): name='delete', decorator=swagger_auto_schema( tags=['Import'], + x_fern_sdk_group_name='files', + x_fern_sdk_method_name='delete', operation_summary='Delete file upload', operation_description='Delete a specific uploaded file.', ), diff --git a/label_studio/ml/api.py b/label_studio/ml/api.py index f82f78e0edd..0b38e16ae37 100644 --- a/label_studio/ml/api.py +++ b/label_studio/ml/api.py @@ -188,11 +188,10 @@ def perform_update(self, serializer): }, ), responses={ - 200: openapi.Response(title='Training OK', description='Training has successfully started.'), + 200: openapi.Response(description='Training has successfully started.'), 500: openapi.Response( description='Training error', schema=openapi.Schema( - title='Error message', description='Error message', type=openapi.TYPE_STRING, example='Server responded with an error.', @@ -230,11 +229,10 @@ def post(self, request, *args, **kwargs): ), ], responses={ - 200: openapi.Response(title='Predicting OK', description='Predicting has successfully started.'), + 200: openapi.Response(description='Predicting has successfully started.'), 500: openapi.Response( description='Predicting error', schema=openapi.Schema( - title='Error message', description='Error message', type=openapi.TYPE_STRING, example='Server responded with an error.', @@ -287,7 +285,7 @@ def post(self, request, *args, **kwargs): ], request_body=MLInteractiveAnnotatingRequest, responses={ - 200: openapi.Response(title='Annotating OK', description='Interactive annotation has succeeded.'), + 200: openapi.Response(description='Interactive annotation has succeeded.'), }, ), ) diff --git a/label_studio/tasks/api.py b/label_studio/tasks/api.py index 00512273527..6809efd87f2 100644 --- a/label_studio/tasks/api.py +++ b/label_studio/tasks/api.py @@ -241,6 +241,8 @@ def put(self, request, *args, **kwargs): tags=['Annotations'], operation_summary='Get annotation by its ID', operation_description='Retrieve a specific annotation for a task using the annotation result ID.', + x_fern_sdk_group_name='annotations', + x_fern_sdk_method_name='get', ), ) @method_decorator( diff --git a/label_studio/users/api.py b/label_studio/users/api.py index 9edd66992e7..79451e71e05 100644 --- a/label_studio/users/api.py +++ b/label_studio/users/api.py @@ -205,8 +205,11 @@ def post(self, request, *args, **kwargs): responses={ 200: openapi.Response( description='User token response', - type=openapi.TYPE_OBJECT, - properties={'detail': openapi.Schema(description='Token', type=openapi.TYPE_STRING)}, + schema=openapi.Schema( + description='User token', + type=openapi.TYPE_OBJECT, + properties={'detail': openapi.Schema(description='Token', type=openapi.TYPE_STRING)}, + ), ) }, ),