diff --git a/x-pack/plugins/ml/public/application/services/ml_api_service/trained_models.ts b/x-pack/plugins/ml/public/application/services/ml_api_service/trained_models.ts index 42844d0557ae93..4bad39b54cf3eb 100644 --- a/x-pack/plugins/ml/public/application/services/ml_api_service/trained_models.ts +++ b/x-pack/plugins/ml/public/application/services/ml_api_service/trained_models.ts @@ -126,7 +126,11 @@ export function trainedModelsApiProvider(httpService: HttpService) { startModelAllocation( modelId: string, - queryParams?: { number_of_allocations: number; threads_per_allocation: number } + queryParams?: { + number_of_allocations: number; + threads_per_allocation: number; + priority: 'low' | 'normal'; + } ) { return httpService.http<{ acknowledge: boolean }>({ path: `${apiBasePath}/trained_models/${modelId}/deployment/_start`, @@ -145,6 +149,14 @@ export function trainedModelsApiProvider(httpService: HttpService) { }); }, + updateModelDeployment(modelId: string, params: { number_of_allocations: number }) { + return httpService.http<{ acknowledge: boolean }>({ + path: `${apiBasePath}/trained_models/${modelId}/deployment/_update`, + method: 'POST', + body: JSON.stringify(params), + }); + }, + inferTrainedModel(modelId: string, payload: any, timeout?: string) { const body = JSON.stringify(payload); return httpService.http({ diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/start_deployment_setup.tsx b/x-pack/plugins/ml/public/application/trained_models/models_management/deployment_setup.tsx similarity index 50% rename from x-pack/plugins/ml/public/application/trained_models/models_management/start_deployment_setup.tsx rename to x-pack/plugins/ml/public/application/trained_models/models_management/deployment_setup.tsx index ea3f1c7a5705a8..f8082997a1c6ea 100644 --- a/x-pack/plugins/ml/public/application/trained_models/models_management/start_deployment_setup.tsx +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/deployment_setup.tsx @@ -33,16 +33,18 @@ import type { Observable } from 'rxjs'; import type { CoreTheme, OverlayStart } from '@kbn/core/public'; import { css } from '@emotion/react'; import { numberValidator } from '@kbn/ml-agg-utils'; +import { isCloudTrial } from '../../services/ml_server_info'; import { composeValidators, requiredValidator } from '../../../../common/util/validators'; -interface StartDeploymentSetup { +interface DeploymentSetupProps { config: ThreadingParams; onConfigChange: (config: ThreadingParams) => void; } export interface ThreadingParams { numOfAllocations: number; - threadsPerAllocations: number; + threadsPerAllocations?: number; + priority?: 'low' | 'normal'; } const THREADS_MAX_EXPONENT = 4; @@ -50,7 +52,7 @@ const THREADS_MAX_EXPONENT = 4; /** * Form for setting threading params. */ -export const StartDeploymentSetup: FC = ({ config, onConfigChange }) => { +export const DeploymentSetup: FC = ({ config, onConfigChange }) => { const numOfAllocation = config.numOfAllocations; const threadsPerAllocations = config.threadsPerAllocations; @@ -69,12 +71,78 @@ export const StartDeploymentSetup: FC = ({ config, onConfi [] ); - const toggleIdSelected = threadsPerAllocationsOptions.find( - (v) => v.value === threadsPerAllocations - )!.id; + const disableThreadingControls = config.priority === 'low'; return ( + {config.priority !== undefined ? ( + + + + } + description={ + + } + > + + } + hasChildLabel={false} + > + { + onConfigChange({ ...config, priority: optionId as ThreadingParams['priority'] }); + }} + options={[ + { + id: 'low', + value: 'low', + label: i18n.translate( + 'xpack.ml.trainedModels.modelsList.startDeployment.lowPriorityLabel', + { + defaultMessage: 'low', + } + ), + }, + { + id: 'normal', + value: 'normal', + label: i18n.translate( + 'xpack.ml.trainedModels.modelsList.startDeployment.normalPriorityLabel', + { + defaultMessage: 'normal', + } + ), + }, + ]} + data-test-subj={'mlModelsStartDeploymentModalPriority'} + /> + + + ) : null} + = ({ config, onConfi /> } hasChildLabel={false} + isDisabled={disableThreadingControls} > { onConfigChange({ ...config, numOfAllocations: Number(event.target.value) }); }} @@ -115,51 +185,59 @@ export const StartDeploymentSetup: FC = ({ config, onConfi - - - - } - description={ - - } - > - + + + } + description={ } - hasChildLabel={false} > - + } + hasChildLabel={false} + isDisabled={disableThreadingControls} + > + v.value === threadsPerAllocations)!.id } - )} - name={'threadsPerAllocation'} - isFullWidth - idSelected={toggleIdSelected} - onChange={(optionId) => { - const value = threadsPerAllocationsOptions.find((v) => v.id === optionId)!.value; - onConfigChange({ ...config, threadsPerAllocations: value }); - }} - options={threadsPerAllocationsOptions} - data-test-subj={'mlModelsStartDeploymentModalThreadsPerAllocation'} - /> - - + onChange={(optionId) => { + const value = threadsPerAllocationsOptions.find((v) => v.id === optionId)!.value; + onConfigChange({ ...config, threadsPerAllocations: value }); + }} + options={threadsPerAllocationsOptions} + data-test-subj={'mlModelsStartDeploymentModalThreadsPerAllocation'} + /> + + + ) : null} ); }; @@ -169,24 +247,28 @@ interface StartDeploymentModalProps { startModelDeploymentDocUrl: string; onConfigChange: (config: ThreadingParams) => void; onClose: () => void; + initialParams?: ThreadingParams; } /** - * Modal window wrapper for {@link StartDeploymentSetup} - * - * @param onConfigChange - * @param onClose + * Modal window wrapper for {@link DeploymentSetup} */ -export const StartDeploymentModal: FC = ({ +export const StartUpdateDeploymentModal: FC = ({ modelId, onConfigChange, onClose, startModelDeploymentDocUrl, + initialParams, }) => { - const [config, setConfig] = useState({ - numOfAllocations: 1, - threadsPerAllocations: 1, - }); + const [config, setConfig] = useState( + initialParams ?? { + numOfAllocations: 1, + threadsPerAllocations: 1, + priority: isCloudTrial() ? 'low' : 'normal', + } + ); + + const isUpdate = initialParams !== undefined; const numOfAllocationsValidator = composeValidators( requiredValidator(), @@ -208,11 +290,19 @@ export const StartDeploymentModal: FC = ({

- + {isUpdate ? ( + + ) : ( + + )}

@@ -236,7 +326,7 @@ export const StartDeploymentModal: FC = ({ /> - + @@ -272,10 +362,17 @@ export const StartDeploymentModal: FC = ({ disabled={!!errors} data-test-subj={'mlModelsStartDeploymentModalStartButton'} > - + {isUpdate ? ( + + ) : ( + + )} @@ -291,22 +388,30 @@ export const StartDeploymentModal: FC = ({ */ export const getUserInputThreadingParamsProvider = (overlays: OverlayStart, theme$: Observable, startModelDeploymentDocUrl: string) => - (modelId: string): Promise => { - return new Promise(async (resolve, reject) => { + (modelId: string, initialParams?: ThreadingParams): Promise => { + return new Promise(async (resolve) => { try { const modalSession = overlays.openModal( toMountPoint( wrapWithTheme( - { modalSession.close(); - resolve(config); + + const resultConfig = { ...config }; + if (resultConfig.priority === 'low') { + resultConfig.numOfAllocations = 1; + resultConfig.threadsPerAllocations = 1; + } + + resolve(resultConfig); }} onClose={() => { modalSession.close(); - reject(); + resolve(); }} />, theme$ @@ -314,7 +419,7 @@ export const getUserInputThreadingParamsProvider = ) ); } catch (e) { - reject(); + resolve(); } }); }; diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/model_actions.tsx b/x-pack/plugins/ml/public/application/trained_models/models_management/model_actions.tsx new file mode 100644 index 00000000000000..2bd18a782f0715 --- /dev/null +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/model_actions.tsx @@ -0,0 +1,387 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { Action } from '@elastic/eui/src/components/basic_table/action_types'; +import { i18n } from '@kbn/i18n'; +import { isPopulatedObject } from '@kbn/ml-is-populated-object'; +import { EuiToolTip } from '@elastic/eui'; +import React, { useCallback, useMemo } from 'react'; +import { BUILT_IN_MODEL_TAG } from '../../../../common/constants/data_frame_analytics'; +import { useTrainedModelsApiService } from '../../services/ml_api_service/trained_models'; +import { getUserConfirmationProvider } from './force_stop_dialog'; +import { useToastNotificationService } from '../../services/toast_notification_service'; +import { getUserInputThreadingParamsProvider } from './deployment_setup'; +import { useMlKibana, useMlLocator, useNavigateToPath } from '../../contexts/kibana'; +import { getAnalysisType } from '../../../../common/util/analytics_utils'; +import { DataFrameAnalysisConfigType } from '../../../../common/types/data_frame_analytics'; +import { ML_PAGES } from '../../../../common/constants/locator'; +import { DEPLOYMENT_STATE, TRAINED_MODEL_TYPE } from '../../../../common/constants/trained_models'; +import { isTestable, isTestEnabled } from './test_models'; +import { ModelItem } from './models_list'; + +export function useModelActions({ + onTestAction, + onModelsDeleteRequest, + onLoading, + isLoading, + fetchModels, +}: { + isLoading: boolean; + onTestAction: (model: ModelItem) => void; + onModelsDeleteRequest: (modelsIds: string[]) => void; + onLoading: (isLoading: boolean) => void; + fetchModels: () => void; +}): Array> { + const { + services: { + application: { navigateToUrl, capabilities }, + overlays, + theme, + docLinks, + }, + } = useMlKibana(); + + const startModelDeploymentDocUrl = docLinks.links.ml.startTrainedModelsDeployment; + + const navigateToPath = useNavigateToPath(); + + const { displayErrorToast, displaySuccessToast } = useToastNotificationService(); + + const urlLocator = useMlLocator()!; + + const trainedModelsApiService = useTrainedModelsApiService(); + + const canStartStopTrainedModels = capabilities.ml.canStartStopTrainedModels as boolean; + const canTestTrainedModels = capabilities.ml.canTestTrainedModels as boolean; + const canDeleteTrainedModels = capabilities.ml.canDeleteTrainedModels as boolean; + + const getUserConfirmation = useMemo( + () => getUserConfirmationProvider(overlays, theme), + [overlays, theme] + ); + + const getUserInputThreadingParams = useMemo( + () => getUserInputThreadingParamsProvider(overlays, theme.theme$, startModelDeploymentDocUrl), + [overlays, theme.theme$, startModelDeploymentDocUrl] + ); + + const isBuiltInModel = useCallback( + (item: ModelItem) => item.tags.includes(BUILT_IN_MODEL_TAG), + [] + ); + + return useMemo( + () => [ + { + name: i18n.translate('xpack.ml.trainedModels.modelsList.viewTrainingDataActionLabel', { + defaultMessage: 'View training data', + }), + description: i18n.translate( + 'xpack.ml.trainedModels.modelsList.viewTrainingDataActionLabel', + { + defaultMessage: 'View training data', + } + ), + icon: 'visTable', + type: 'icon', + available: (item) => !!item.metadata?.analytics_config?.id, + onClick: async (item) => { + if (item.metadata?.analytics_config === undefined) return; + + const analysisType = getAnalysisType( + item.metadata?.analytics_config.analysis + ) as DataFrameAnalysisConfigType; + + const url = await urlLocator.getUrl({ + page: ML_PAGES.DATA_FRAME_ANALYTICS_EXPLORATION, + pageState: { + jobId: item.metadata?.analytics_config.id as string, + analysisType, + ...(analysisType === 'classification' || analysisType === 'regression' + ? { + queryText: `${item.metadata?.analytics_config.dest.results_field}.is_training : true`, + } + : {}), + }, + }); + + await navigateToUrl(url); + }, + isPrimary: true, + }, + { + name: i18n.translate('xpack.ml.inference.modelsList.analyticsMapActionLabel', { + defaultMessage: 'Analytics map', + }), + description: i18n.translate('xpack.ml.inference.modelsList.analyticsMapActionLabel', { + defaultMessage: 'Analytics map', + }), + icon: 'graphApp', + type: 'icon', + isPrimary: true, + available: (item) => !!item.metadata?.analytics_config?.id, + onClick: async (item) => { + const path = await urlLocator.getUrl({ + page: ML_PAGES.DATA_FRAME_ANALYTICS_MAP, + pageState: { modelId: item.model_id }, + }); + + await navigateToPath(path, false); + }, + }, + { + name: i18n.translate('xpack.ml.inference.modelsList.startModelDeploymentActionLabel', { + defaultMessage: 'Start deployment', + }), + description: i18n.translate( + 'xpack.ml.inference.modelsList.startModelDeploymentActionLabel', + { + defaultMessage: 'Start deployment', + } + ), + 'data-test-subj': 'mlModelsTableRowStartDeploymentAction', + icon: 'play', + type: 'icon', + isPrimary: true, + enabled: (item) => { + const { state } = item.stats?.deployment_stats ?? {}; + return ( + canStartStopTrainedModels && + !isLoading && + state !== DEPLOYMENT_STATE.STARTED && + state !== DEPLOYMENT_STATE.STARTING + ); + }, + available: (item) => item.model_type === TRAINED_MODEL_TYPE.PYTORCH, + onClick: async (item) => { + const threadingParams = await getUserInputThreadingParams(item.model_id); + + if (!threadingParams) return; + + try { + onLoading(true); + await trainedModelsApiService.startModelAllocation(item.model_id, { + number_of_allocations: threadingParams.numOfAllocations, + threads_per_allocation: threadingParams.threadsPerAllocations!, + priority: threadingParams.priority!, + }); + displaySuccessToast( + i18n.translate('xpack.ml.trainedModels.modelsList.startSuccess', { + defaultMessage: 'Deployment for "{modelId}" has been started successfully.', + values: { + modelId: item.model_id, + }, + }) + ); + await fetchModels(); + } catch (e) { + displayErrorToast( + e, + i18n.translate('xpack.ml.trainedModels.modelsList.startFailed', { + defaultMessage: 'Failed to start "{modelId}"', + values: { + modelId: item.model_id, + }, + }) + ); + onLoading(false); + } + }, + }, + { + name: i18n.translate('xpack.ml.inference.modelsList.updateModelDeploymentActionLabel', { + defaultMessage: 'Update deployment', + }), + description: i18n.translate( + 'xpack.ml.inference.modelsList.updateModelDeploymentActionLabel', + { + defaultMessage: 'Update deployment', + } + ), + 'data-test-subj': 'mlModelsTableRowUpdateDeploymentAction', + icon: 'documentEdit', + type: 'icon', + isPrimary: false, + available: (item) => + item.model_type === TRAINED_MODEL_TYPE.PYTORCH && + canStartStopTrainedModels && + !isLoading && + item.stats?.deployment_stats?.state === DEPLOYMENT_STATE.STARTED, + onClick: async (item) => { + const threadingParams = await getUserInputThreadingParams(item.model_id, { + numOfAllocations: item.stats?.deployment_stats?.number_of_allocations!, + }); + + if (!threadingParams) return; + + try { + onLoading(true); + await trainedModelsApiService.updateModelDeployment(item.model_id, { + number_of_allocations: threadingParams.numOfAllocations, + }); + displaySuccessToast( + i18n.translate('xpack.ml.trainedModels.modelsList.updateSuccess', { + defaultMessage: 'Deployment for "{modelId}" has been updated successfully.', + values: { + modelId: item.model_id, + }, + }) + ); + await fetchModels(); + } catch (e) { + displayErrorToast( + e, + i18n.translate('xpack.ml.trainedModels.modelsList.updateFailed', { + defaultMessage: 'Failed to update "{modelId}"', + values: { + modelId: item.model_id, + }, + }) + ); + onLoading(false); + } + }, + }, + { + name: i18n.translate('xpack.ml.inference.modelsList.stopModelDeploymentActionLabel', { + defaultMessage: 'Stop deployment', + }), + description: i18n.translate( + 'xpack.ml.inference.modelsList.stopModelDeploymentActionLabel', + { + defaultMessage: 'Stop deployment', + } + ), + 'data-test-subj': 'mlModelsTableRowStopDeploymentAction', + icon: 'stop', + type: 'icon', + isPrimary: true, + available: (item) => item.model_type === TRAINED_MODEL_TYPE.PYTORCH, + enabled: (item) => + canStartStopTrainedModels && + !isLoading && + isPopulatedObject(item.stats?.deployment_stats) && + item.stats?.deployment_stats?.state !== DEPLOYMENT_STATE.STOPPING, + onClick: async (item) => { + const requireForceStop = isPopulatedObject(item.pipelines); + + if (requireForceStop) { + const hasUserApproved = await getUserConfirmation(item); + if (!hasUserApproved) return; + } + + if (requireForceStop) { + const hasUserApproved = await getUserConfirmation(item); + if (!hasUserApproved) return; + } + + try { + onLoading(true); + await trainedModelsApiService.stopModelAllocation(item.model_id, { + force: requireForceStop, + }); + displaySuccessToast( + i18n.translate('xpack.ml.trainedModels.modelsList.stopSuccess', { + defaultMessage: 'Deployment for "{modelId}" has been stopped successfully.', + values: { + modelId: item.model_id, + }, + }) + ); + // Need to fetch model state updates + await fetchModels(); + } catch (e) { + displayErrorToast( + e, + i18n.translate('xpack.ml.trainedModels.modelsList.stopFailed', { + defaultMessage: 'Failed to stop "{modelId}"', + values: { + modelId: item.model_id, + }, + }) + ); + onLoading(false); + } + }, + }, + { + name: (model) => { + const enabled = !isPopulatedObject(model.pipelines); + return ( + + <> + {i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', { + defaultMessage: 'Delete model', + })} + + + ); + }, + description: i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', { + defaultMessage: 'Delete model', + }), + 'data-test-subj': 'mlModelsTableRowDeleteAction', + icon: 'trash', + type: 'icon', + color: 'danger', + isPrimary: false, + onClick: (model) => { + onModelsDeleteRequest([model.model_id]); + }, + available: (item) => canDeleteTrainedModels && !isBuiltInModel(item), + enabled: (item) => { + // TODO check for permissions to delete ingest pipelines. + // ATM undefined means pipelines fetch failed server-side. + return !isPopulatedObject(item.pipelines); + }, + }, + { + name: i18n.translate('xpack.ml.inference.modelsList.testModelActionLabel', { + defaultMessage: 'Test model', + }), + description: i18n.translate('xpack.ml.inference.modelsList.testModelActionLabel', { + defaultMessage: 'Test model', + }), + 'data-test-subj': 'mlModelsTableRowTestAction', + icon: 'inputOutput', + type: 'icon', + isPrimary: true, + available: isTestable, + onClick: onTestAction, + enabled: (item) => canTestTrainedModels && isTestEnabled(item), + }, + ], + [ + canDeleteTrainedModels, + canStartStopTrainedModels, + canTestTrainedModels, + displayErrorToast, + displaySuccessToast, + getUserConfirmation, + getUserInputThreadingParams, + isBuiltInModel, + navigateToPath, + navigateToUrl, + onTestAction, + trainedModelsApiService, + urlLocator, + onModelsDeleteRequest, + onLoading, + fetchModels, + isLoading, + ] + ); +} diff --git a/x-pack/plugins/ml/public/application/trained_models/models_management/models_list.tsx b/x-pack/plugins/ml/public/application/trained_models/models_management/models_list.tsx index a216436e6c90a4..b14a12b1b904ff 100644 --- a/x-pack/plugins/ml/public/application/trained_models/models_management/models_list.tsx +++ b/x-pack/plugins/ml/public/application/trained_models/models_management/models_list.tsx @@ -17,21 +17,18 @@ import { EuiSpacer, EuiTitle, SearchFilterConfig, - EuiToolTip, } from '@elastic/eui'; import { i18n } from '@kbn/i18n'; import { FormattedMessage } from '@kbn/i18n-react'; import { EuiBasicTableColumn } from '@elastic/eui/src/components/basic_table/basic_table'; import { EuiTableSelectionType } from '@elastic/eui/src/components/basic_table/table_types'; -import { Action } from '@elastic/eui/src/components/basic_table/action_types'; import { FIELD_FORMAT_IDS } from '@kbn/field-formats-plugin/common'; import { isPopulatedObject } from '@kbn/ml-is-populated-object'; -import { getUserInputThreadingParamsProvider } from './start_deployment_setup'; -import { getAnalysisType } from '../../data_frame_analytics/common'; +import { useModelActions } from './model_actions'; import { ModelsTableToConfigMapping } from '.'; import { ModelsBarStats, StatsBar } from '../../components/stats_bar'; -import { useMlKibana, useMlLocator, useNavigateToPath, useTimefilter } from '../../contexts/kibana'; +import { useMlKibana, useTimefilter } from '../../contexts/kibana'; import { useTrainedModelsApiService } from '../../services/ml_api_service/trained_models'; import { ModelPipelines, @@ -39,7 +36,6 @@ import { TrainedModelStat, } from '../../../../common/types/trained_models'; import { BUILT_IN_MODEL_TAG } from '../../../../common/constants/data_frame_analytics'; -import { DataFrameAnalysisConfigType } from '../../../../common/types/data_frame_analytics'; import { DeleteModelsModal } from './delete_models_modal'; import { ML_PAGES } from '../../../../common/constants/locator'; import { ListingPageUrlState } from '../../../../common/types/common'; @@ -49,14 +45,9 @@ import { useTableSettings } from '../../data_frame_analytics/pages/analytics_man import { useToastNotificationService } from '../../services/toast_notification_service'; import { useFieldFormatter } from '../../contexts/kibana/use_field_formatter'; import { useRefresh } from '../../routing/use_refresh'; -import { - DEPLOYMENT_STATE, - TRAINED_MODEL_TYPE, - BUILT_IN_MODEL_TYPE, -} from '../../../../common/constants/trained_models'; -import { getUserConfirmationProvider } from './force_stop_dialog'; +import { BUILT_IN_MODEL_TYPE } from '../../../../common/constants/trained_models'; import { SavedObjectsWarning } from '../../components/saved_objects_warning'; -import { TestTrainedModelFlyout, isTestable, isTestEnabled } from './test_models'; +import { TestTrainedModelFlyout } from './test_models'; type Stats = Omit; @@ -86,15 +77,9 @@ export const ModelsList: FC = ({ }) => { const { services: { - application: { navigateToUrl, capabilities }, - overlays, - theme, - docLinks, + application: { capabilities }, }, } = useMlKibana(); - const urlLocator = useMlLocator()!; - - const startModelDeploymentDocUrl = docLinks.links.ml.startTrainedModelsDeployment; useTimefilter({ timeRangeSelector: false, autoRefreshSelector: true }); @@ -118,12 +103,10 @@ export const ModelsList: FC = ({ const searchQueryText = pageState.queryText ?? ''; const canDeleteTrainedModels = capabilities.ml.canDeleteTrainedModels as boolean; - const canStartStopTrainedModels = capabilities.ml.canStartStopTrainedModels as boolean; - const canTestTrainedModels = capabilities.ml.canTestTrainedModels as boolean; const trainedModelsApiService = useTrainedModelsApiService(); - const { displayErrorToast, displaySuccessToast } = useToastNotificationService(); + const { displayErrorToast } = useToastNotificationService(); const [isLoading, setIsLoading] = useState(false); const [items, setItems] = useState([]); @@ -133,15 +116,6 @@ export const ModelsList: FC = ({ {} ); const [showTestFlyout, setShowTestFlyout] = useState(null); - // eslint-disable-next-line react-hooks/exhaustive-deps - const getUserConfirmation = useMemo(() => getUserConfirmationProvider(overlays, theme), []); - - const getUserInputThreadingParams = useMemo( - () => getUserInputThreadingParamsProvider(overlays, theme.theme$, startModelDeploymentDocUrl), - [overlays, theme.theme$, startModelDeploymentDocUrl] - ); - - const navigateToPath = useNavigateToPath(); const isBuiltInModel = useCallback( (item: ModelItem) => item.tags.includes(BUILT_IN_MODEL_TAG), @@ -287,231 +261,13 @@ export const ModelsList: FC = ({ /** * Table actions */ - const actions: Array> = [ - { - name: i18n.translate('xpack.ml.trainedModels.modelsList.viewTrainingDataActionLabel', { - defaultMessage: 'View training data', - }), - description: i18n.translate('xpack.ml.trainedModels.modelsList.viewTrainingDataActionLabel', { - defaultMessage: 'View training data', - }), - icon: 'visTable', - type: 'icon', - available: (item) => !!item.metadata?.analytics_config?.id, - onClick: async (item) => { - if (item.metadata?.analytics_config === undefined) return; - - const analysisType = getAnalysisType( - item.metadata?.analytics_config.analysis - ) as DataFrameAnalysisConfigType; - - const url = await urlLocator.getUrl({ - page: ML_PAGES.DATA_FRAME_ANALYTICS_EXPLORATION, - pageState: { - jobId: item.metadata?.analytics_config.id as string, - analysisType, - ...(analysisType === 'classification' || analysisType === 'regression' - ? { - queryText: `${item.metadata?.analytics_config.dest.results_field}.is_training : true`, - } - : {}), - }, - }); - - await navigateToUrl(url); - }, - isPrimary: true, - }, - { - name: i18n.translate('xpack.ml.inference.modelsList.analyticsMapActionLabel', { - defaultMessage: 'Analytics map', - }), - description: i18n.translate('xpack.ml.inference.modelsList.analyticsMapActionLabel', { - defaultMessage: 'Analytics map', - }), - icon: 'graphApp', - type: 'icon', - isPrimary: true, - available: (item) => !!item.metadata?.analytics_config?.id, - onClick: async (item) => { - const path = await urlLocator.getUrl({ - page: ML_PAGES.DATA_FRAME_ANALYTICS_MAP, - pageState: { modelId: item.model_id }, - }); - - await navigateToPath(path, false); - }, - }, - { - name: i18n.translate('xpack.ml.inference.modelsList.startModelDeploymentActionLabel', { - defaultMessage: 'Start deployment', - }), - description: i18n.translate('xpack.ml.inference.modelsList.startModelDeploymentActionLabel', { - defaultMessage: 'Start deployment', - }), - 'data-test-subj': 'mlModelsTableRowStartDeploymentAction', - icon: 'play', - type: 'icon', - isPrimary: true, - enabled: (item) => { - const { state } = item.stats?.deployment_stats ?? {}; - return ( - canStartStopTrainedModels && - !isLoading && - state !== DEPLOYMENT_STATE.STARTED && - state !== DEPLOYMENT_STATE.STARTING - ); - }, - available: (item) => item.model_type === TRAINED_MODEL_TYPE.PYTORCH, - onClick: async (item) => { - const threadingParams = await getUserInputThreadingParams(item.model_id); - - if (!threadingParams) return; - - try { - setIsLoading(true); - await trainedModelsApiService.startModelAllocation(item.model_id, { - number_of_allocations: threadingParams.numOfAllocations, - threads_per_allocation: threadingParams.threadsPerAllocations, - }); - displaySuccessToast( - i18n.translate('xpack.ml.trainedModels.modelsList.startSuccess', { - defaultMessage: 'Deployment for "{modelId}" has been started successfully.', - values: { - modelId: item.model_id, - }, - }) - ); - await fetchModelsData(); - } catch (e) { - displayErrorToast( - e, - i18n.translate('xpack.ml.trainedModels.modelsList.startFailed', { - defaultMessage: 'Failed to start "{modelId}"', - values: { - modelId: item.model_id, - }, - }) - ); - setIsLoading(false); - } - }, - }, - { - name: i18n.translate('xpack.ml.inference.modelsList.stopModelDeploymentActionLabel', { - defaultMessage: 'Stop deployment', - }), - description: i18n.translate('xpack.ml.inference.modelsList.stopModelDeploymentActionLabel', { - defaultMessage: 'Stop deployment', - }), - 'data-test-subj': 'mlModelsTableRowStopDeploymentAction', - icon: 'stop', - type: 'icon', - isPrimary: true, - available: (item) => item.model_type === TRAINED_MODEL_TYPE.PYTORCH, - enabled: (item) => - canStartStopTrainedModels && - !isLoading && - isPopulatedObject(item.stats?.deployment_stats) && - item.stats?.deployment_stats?.state !== DEPLOYMENT_STATE.STOPPING, - onClick: async (item) => { - const requireForceStop = isPopulatedObject(item.pipelines); - - if (requireForceStop) { - const hasUserApproved = await getUserConfirmation(item); - if (!hasUserApproved) return; - } - - if (requireForceStop) { - const hasUserApproved = await getUserConfirmation(item); - if (!hasUserApproved) return; - } - - try { - setIsLoading(true); - await trainedModelsApiService.stopModelAllocation(item.model_id, { - force: requireForceStop, - }); - displaySuccessToast( - i18n.translate('xpack.ml.trainedModels.modelsList.stopSuccess', { - defaultMessage: 'Deployment for "{modelId}" has been stopped successfully.', - values: { - modelId: item.model_id, - }, - }) - ); - // Need to fetch model state updates - await fetchModelsData(); - } catch (e) { - displayErrorToast( - e, - i18n.translate('xpack.ml.trainedModels.modelsList.stopFailed', { - defaultMessage: 'Failed to stop "{modelId}"', - values: { - modelId: item.model_id, - }, - }) - ); - setIsLoading(false); - } - }, - }, - { - name: (model) => { - const enabled = !isPopulatedObject(model.pipelines); - return ( - - <> - {i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', { - defaultMessage: 'Delete model', - })} - - - ); - }, - description: i18n.translate('xpack.ml.trainedModels.modelsList.deleteModelActionLabel', { - defaultMessage: 'Delete model', - }), - 'data-test-subj': 'mlModelsTableRowDeleteAction', - icon: 'trash', - type: 'icon', - color: 'danger', - isPrimary: false, - onClick: (model) => { - setModelIdsToDelete([model.model_id]); - }, - available: (item) => canDeleteTrainedModels && !isBuiltInModel(item), - enabled: (item) => { - // TODO check for permissions to delete ingest pipelines. - // ATM undefined means pipelines fetch failed server-side. - return !isPopulatedObject(item.pipelines); - }, - }, - { - name: i18n.translate('xpack.ml.inference.modelsList.testModelActionLabel', { - defaultMessage: 'Test model', - }), - description: i18n.translate('xpack.ml.inference.modelsList.testModelActionLabel', { - defaultMessage: 'Test model', - }), - 'data-test-subj': 'mlModelsTableRowTestAction', - icon: 'inputOutput', - type: 'icon', - isPrimary: true, - available: isTestable, - onClick: setShowTestFlyout, - enabled: (item) => canTestTrainedModels && isTestEnabled(item), - }, - ]; + const actions = useModelActions({ + isLoading, + fetchModels: fetchModelsData, + onTestAction: setShowTestFlyout, + onModelsDeleteRequest: setModelIdsToDelete, + onLoading: setIsLoading, + }); const toggleDetails = async (item: ModelItem) => { const itemIdToExpandedRowMapValues = { ...itemIdToExpandedRowMap }; diff --git a/x-pack/plugins/ml/server/lib/ml_client/ml_client.ts b/x-pack/plugins/ml/server/lib/ml_client/ml_client.ts index a54bad9d886fee..b02c58ebb2b6c5 100644 --- a/x-pack/plugins/ml/server/lib/ml_client/ml_client.ts +++ b/x-pack/plugins/ml/server/lib/ml_client/ml_client.ts @@ -490,6 +490,15 @@ export function getMlClient( await modelIdsCheck(p); return mlClient.startTrainedModelDeployment(...p); }, + async updateTrainedModelDeployment(...p: Parameters) { + await modelIdsCheck(p); + const { model_id: modelId, number_of_allocations: numberOfAllocations } = p[0]; + return client.asInternalUser.transport.request({ + method: 'POST', + path: `/_ml/trained_models/${modelId}/deployment/_update`, + body: { number_of_allocations: numberOfAllocations }, + }); + }, async stopTrainedModelDeployment(...p: Parameters) { await modelIdsCheck(p); return mlClient.stopTrainedModelDeployment(...p); diff --git a/x-pack/plugins/ml/server/lib/ml_client/types.ts b/x-pack/plugins/ml/server/lib/ml_client/types.ts index 38e864f034e04c..ab68f63f8ce581 100644 --- a/x-pack/plugins/ml/server/lib/ml_client/types.ts +++ b/x-pack/plugins/ml/server/lib/ml_client/types.ts @@ -12,6 +12,10 @@ type OrigMlClient = ElasticsearchClient['ml']; export interface MlClient extends OrigMlClient { anomalySearch: ReturnType['anomalySearch']; + updateTrainedModelDeployment: (payload: { + model_id: string; + number_of_allocations: number; + }) => Promise<{ acknowledge: boolean }>; } export type MlClientParams = diff --git a/x-pack/plugins/ml/server/routes/apidoc.json b/x-pack/plugins/ml/server/routes/apidoc.json index 893f541e6c9f62..dbe45dfdb7a6e2 100644 --- a/x-pack/plugins/ml/server/routes/apidoc.json +++ b/x-pack/plugins/ml/server/routes/apidoc.json @@ -173,6 +173,7 @@ "GetTrainedModelsNodesOverview", "GetTrainedModelPipelines", "StartTrainedModelDeployment", + "UpdateTrainedModelDeployment", "StopTrainedModelDeployment", "PutTrainedModel", "DeleteTrainedModel", diff --git a/x-pack/plugins/ml/server/routes/schemas/inference_schema.ts b/x-pack/plugins/ml/server/routes/schemas/inference_schema.ts index b55dab613f7b77..d73d479e7aac33 100644 --- a/x-pack/plugins/ml/server/routes/schemas/inference_schema.ts +++ b/x-pack/plugins/ml/server/routes/schemas/inference_schema.ts @@ -18,9 +18,14 @@ export const threadingParamsSchema = schema.maybe( schema.object({ number_of_allocations: schema.number(), threads_per_allocation: schema.number(), + priority: schema.oneOf([schema.literal('low'), schema.literal('normal')]), }) ); +export const updateDeploymentParamsSchema = schema.object({ + number_of_allocations: schema.number(), +}); + export const optionalModelIdSchema = schema.object({ /** * Model ID diff --git a/x-pack/plugins/ml/server/routes/trained_models.ts b/x-pack/plugins/ml/server/routes/trained_models.ts index f5c792d8f8fdfc..29a39ecfbf7e42 100644 --- a/x-pack/plugins/ml/server/routes/trained_models.ts +++ b/x-pack/plugins/ml/server/routes/trained_models.ts @@ -16,6 +16,7 @@ import { inferTrainedModelQuery, inferTrainedModelBody, threadingParamsSchema, + updateDeploymentParamsSchema, } from './schemas/inference_schema'; import { modelsProvider } from '../models/data_frame_analytics'; import { TrainedModelConfigResponse } from '../../common/types/trained_models'; @@ -324,6 +325,40 @@ export function trainedModelsRoutes({ router, routeGuard }: RouteInitialization) }) ); + /** + * @apiGroup TrainedModels + * + * @api {post} /api/ml/trained_models/:modelId/deployment/_update Update trained model deployment + * @apiName UpdateTrainedModelDeployment + * @apiDescription Updates trained model deployment. + */ + router.post( + { + path: '/api/ml/trained_models/{modelId}/deployment/_update', + validate: { + params: modelIdSchema, + body: updateDeploymentParamsSchema, + }, + options: { + tags: ['access:ml:canStartStopTrainedModels'], + }, + }, + routeGuard.fullLicenseAPIGuard(async ({ mlClient, request, response }) => { + try { + const { modelId } = request.params; + const body = await mlClient.updateTrainedModelDeployment({ + model_id: modelId, + ...request.body, + }); + return response.ok({ + body, + }); + } catch (e) { + return response.customError(wrapError(e)); + } + }) + ); + /** * @apiGroup TrainedModels * diff --git a/x-pack/test/functional/apps/ml/short_tests/model_management/model_list.ts b/x-pack/test/functional/apps/ml/short_tests/model_management/model_list.ts index 3b4946e35bd703..2743df8a819498 100644 --- a/x-pack/test/functional/apps/ml/short_tests/model_management/model_list.ts +++ b/x-pack/test/functional/apps/ml/short_tests/model_management/model_list.ts @@ -191,6 +191,7 @@ export default function ({ getService }: FtrProviderContext) { it(`starts deployment of the imported model ${model.id}`, async () => { await ml.trainedModelsTable.startDeploymentWithParams(model.id, { + priority: 'normal', numOfAllocations: 1, threadsPerAllocation: 2, }); diff --git a/x-pack/test/functional/services/ml/trained_models_table.ts b/x-pack/test/functional/services/ml/trained_models_table.ts index c8d43207dd5ab9..6f2b76be1bdb19 100644 --- a/x-pack/test/functional/services/ml/trained_models_table.ts +++ b/x-pack/test/functional/services/ml/trained_models_table.ts @@ -249,6 +249,13 @@ export function TrainedModelsTableProvider( await this.assertNumOfAllocations(value); } + public async setPriority(value: 'low' | 'normal') { + await mlCommonUI.selectButtonGroupValue( + 'mlModelsStartDeploymentModalPriority', + value.toString() + ); + } + public async setThreadsPerAllocation(value: number) { await mlCommonUI.selectButtonGroupValue( 'mlModelsStartDeploymentModalThreadsPerAllocation', @@ -258,10 +265,11 @@ export function TrainedModelsTableProvider( public async startDeploymentWithParams( modelId: string, - params: { numOfAllocations: number; threadsPerAllocation: number } + params: { priority: 'low' | 'normal'; numOfAllocations: number; threadsPerAllocation: number } ) { await this.openStartDeploymentModal(modelId); + await this.setPriority(params.priority); await this.setNumOfAllocations(params.numOfAllocations); await this.setThreadsPerAllocation(params.threadsPerAllocation);