Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Allow updates for number of allocations and priority for trained model deployments #144704

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,11 @@ export function trainedModelsApiProvider(httpService: HttpService) {

startModelAllocation(
modelId: string,
queryParams?: { number_of_allocations: number; threads_per_allocation: number }
queryParams?: {
number_of_allocations: number;
threads_per_allocation: number;
priority: 'low' | 'normal';
}
) {
return httpService.http<{ acknowledge: boolean }>({
path: `${apiBasePath}/trained_models/${modelId}/deployment/_start`,
Expand All @@ -145,6 +149,14 @@ export function trainedModelsApiProvider(httpService: HttpService) {
});
},

updateModelDeployment(modelId: string, params: { number_of_allocations: number }) {
return httpService.http<{ acknowledge: boolean }>({
path: `${apiBasePath}/trained_models/${modelId}/deployment/_update`,
method: 'POST',
body: JSON.stringify(params),
});
},

inferTrainedModel(modelId: string, payload: any, timeout?: string) {
const body = JSON.stringify(payload);
return httpService.http<estypes.MlInferTrainedModelResponse>({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,26 @@ import type { Observable } from 'rxjs';
import type { CoreTheme, OverlayStart } from '@kbn/core/public';
import { css } from '@emotion/react';
import { numberValidator } from '@kbn/ml-agg-utils';
import { isCloudTrial } from '../../services/ml_server_info';
import { composeValidators, requiredValidator } from '../../../../common/util/validators';

interface StartDeploymentSetup {
interface DeploymentSetupProps {
config: ThreadingParams;
onConfigChange: (config: ThreadingParams) => void;
}

export interface ThreadingParams {
numOfAllocations: number;
threadsPerAllocations: number;
threadsPerAllocations?: number;
priority?: 'low' | 'normal';
}

const THREADS_MAX_EXPONENT = 4;

/**
* Form for setting threading params.
*/
export const StartDeploymentSetup: FC<StartDeploymentSetup> = ({ config, onConfigChange }) => {
export const DeploymentSetup: FC<DeploymentSetupProps> = ({ config, onConfigChange }) => {
const numOfAllocation = config.numOfAllocations;
const threadsPerAllocations = config.threadsPerAllocations;

Expand All @@ -69,12 +71,78 @@ export const StartDeploymentSetup: FC<StartDeploymentSetup> = ({ config, onConfi
[]
);

const toggleIdSelected = threadsPerAllocationsOptions.find(
(v) => v.value === threadsPerAllocations
)!.id;
const disableThreadingControls = config.priority === 'low';

return (
<EuiForm component={'form'} id={'startDeploymentForm'}>
{config.priority !== undefined ? (
<EuiDescribedFormGroup
titleSize={'xxs'}
title={
<h3>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.startDeployment.priorityLabel"
defaultMessage="Priority"
/>
</h3>
}
description={
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.startDeployment.priorityHelp"
defaultMessage="Select low priority for demonstrations where each model will be very lightly used."
/>
}
>
<EuiFormRow
label={
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.startDeployment.priorityLabel"
defaultMessage="Priority"
/>
}
hasChildLabel={false}
>
<EuiButtonGroup
legend={i18n.translate(
'xpack.ml.trainedModels.modelsList.startDeployment.priorityLegend',
{
defaultMessage: 'Priority selector',
}
)}
name={'priority'}
isFullWidth
idSelected={config.priority}
onChange={(optionId: string) => {
onConfigChange({ ...config, priority: optionId as ThreadingParams['priority'] });
}}
options={[
{
id: 'low',
value: 'low',
label: i18n.translate(
'xpack.ml.trainedModels.modelsList.startDeployment.lowPriorityLabel',
{
defaultMessage: 'low',
}
),
},
{
id: 'normal',
value: 'normal',
label: i18n.translate(
'xpack.ml.trainedModels.modelsList.startDeployment.normalPriorityLabel',
{
defaultMessage: 'normal',
}
),
},
]}
data-test-subj={'mlModelsStartDeploymentModalPriority'}
/>
</EuiFormRow>
</EuiDescribedFormGroup>
) : null}

<EuiDescribedFormGroup
titleSize={'xxs'}
title={
Expand All @@ -100,13 +168,15 @@ export const StartDeploymentSetup: FC<StartDeploymentSetup> = ({ config, onConfi
/>
}
hasChildLabel={false}
isDisabled={disableThreadingControls}
>
<EuiFieldNumber
disabled={disableThreadingControls}
fullWidth
min={1}
step={1}
name={'numOfAllocations'}
value={numOfAllocation}
value={disableThreadingControls ? 1 : numOfAllocation}
onChange={(event) => {
onConfigChange({ ...config, numOfAllocations: Number(event.target.value) });
}}
Expand All @@ -115,51 +185,59 @@ export const StartDeploymentSetup: FC<StartDeploymentSetup> = ({ config, onConfi
</EuiFormRow>
</EuiDescribedFormGroup>

<EuiDescribedFormGroup
titleSize={'xxs'}
title={
<h3>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.startDeployment.threadsPerAllocationLabel"
defaultMessage="Threads per allocation"
/>
</h3>
}
description={
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.startDeployment.threadsPerAllocationHelp"
defaultMessage="Increase to improve latency for each request."
/>
}
>
<EuiFormRow
label={
{threadsPerAllocations !== undefined ? (
<EuiDescribedFormGroup
titleSize={'xxs'}
title={
<h3>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.startDeployment.threadsPerAllocationLabel"
defaultMessage="Threads per allocation"
/>
</h3>
}
description={
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.startDeployment.threadsPerAllocationLabel"
defaultMessage="Threads per allocation"
id="xpack.ml.trainedModels.modelsList.startDeployment.threadsPerAllocationHelp"
defaultMessage="Increase to improve latency for each request."
/>
}
hasChildLabel={false}
>
<EuiButtonGroup
legend={i18n.translate(
'xpack.ml.trainedModels.modelsList.startDeployment.threadsPerAllocationLegend',
{
defaultMessage: 'Threads per allocation selector',
<EuiFormRow
label={
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.startDeployment.threadsPerAllocationLabel"
defaultMessage="Threads per allocation"
/>
}
hasChildLabel={false}
isDisabled={disableThreadingControls}
>
<EuiButtonGroup
isDisabled={disableThreadingControls}
legend={i18n.translate(
'xpack.ml.trainedModels.modelsList.startDeployment.threadsPerAllocationLegend',
{
defaultMessage: 'Threads per allocation selector',
}
)}
name={'threadsPerAllocation'}
isFullWidth
idSelected={
disableThreadingControls
? '1'
: threadsPerAllocationsOptions.find((v) => v.value === threadsPerAllocations)!.id
}
)}
name={'threadsPerAllocation'}
isFullWidth
idSelected={toggleIdSelected}
onChange={(optionId) => {
const value = threadsPerAllocationsOptions.find((v) => v.id === optionId)!.value;
onConfigChange({ ...config, threadsPerAllocations: value });
}}
options={threadsPerAllocationsOptions}
data-test-subj={'mlModelsStartDeploymentModalThreadsPerAllocation'}
/>
</EuiFormRow>
</EuiDescribedFormGroup>
onChange={(optionId) => {
const value = threadsPerAllocationsOptions.find((v) => v.id === optionId)!.value;
onConfigChange({ ...config, threadsPerAllocations: value });
}}
options={threadsPerAllocationsOptions}
data-test-subj={'mlModelsStartDeploymentModalThreadsPerAllocation'}
/>
</EuiFormRow>
</EuiDescribedFormGroup>
) : null}
</EuiForm>
);
};
Expand All @@ -169,24 +247,28 @@ interface StartDeploymentModalProps {
startModelDeploymentDocUrl: string;
onConfigChange: (config: ThreadingParams) => void;
onClose: () => void;
initialParams?: ThreadingParams;
}

/**
* Modal window wrapper for {@link StartDeploymentSetup}
*
* @param onConfigChange
* @param onClose
* Modal window wrapper for {@link DeploymentSetup}
*/
export const StartDeploymentModal: FC<StartDeploymentModalProps> = ({
export const StartUpdateDeploymentModal: FC<StartDeploymentModalProps> = ({
modelId,
onConfigChange,
onClose,
startModelDeploymentDocUrl,
initialParams,
}) => {
const [config, setConfig] = useState<ThreadingParams>({
numOfAllocations: 1,
threadsPerAllocations: 1,
});
const [config, setConfig] = useState<ThreadingParams>(
initialParams ?? {
numOfAllocations: 1,
threadsPerAllocations: 1,
priority: isCloudTrial() ? 'low' : 'normal',
}
);

const isUpdate = initialParams !== undefined;

const numOfAllocationsValidator = composeValidators(
requiredValidator(),
Expand All @@ -208,11 +290,19 @@ export const StartDeploymentModal: FC<StartDeploymentModalProps> = ({
<EuiFlexItem grow={false}>
<EuiTitle size={'s'}>
<h2>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.startDeployment.modalTitle"
defaultMessage="Start {modelId} deployment"
values={{ modelId }}
/>
{isUpdate ? (
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.updateDeployment.modalTitle"
defaultMessage="Update {modelId} deployment"
values={{ modelId }}
/>
) : (
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.startDeployment.modalTitle"
defaultMessage="Start {modelId} deployment"
values={{ modelId }}
/>
)}
</h2>
</EuiTitle>
</EuiFlexItem>
Expand All @@ -236,7 +326,7 @@ export const StartDeploymentModal: FC<StartDeploymentModalProps> = ({
/>
<EuiSpacer size={'m'} />

<StartDeploymentSetup config={config} onConfigChange={setConfig} />
<DeploymentSetup config={config} onConfigChange={setConfig} />

<EuiSpacer size={'m'} />
</EuiModalBody>
Expand Down Expand Up @@ -272,10 +362,17 @@ export const StartDeploymentModal: FC<StartDeploymentModalProps> = ({
disabled={!!errors}
data-test-subj={'mlModelsStartDeploymentModalStartButton'}
>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.startDeployment.startButton"
defaultMessage="Start"
/>
{isUpdate ? (
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.startDeployment.updateButton"
defaultMessage="Update"
/>
) : (
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.startDeployment.startButton"
defaultMessage="Start"
/>
)}
</EuiButton>
</EuiModalFooter>
</EuiModal>
Expand All @@ -291,30 +388,38 @@ export const StartDeploymentModal: FC<StartDeploymentModalProps> = ({
*/
export const getUserInputThreadingParamsProvider =
(overlays: OverlayStart, theme$: Observable<CoreTheme>, startModelDeploymentDocUrl: string) =>
(modelId: string): Promise<ThreadingParams | void> => {
return new Promise(async (resolve, reject) => {
(modelId: string, initialParams?: ThreadingParams): Promise<ThreadingParams | void> => {
return new Promise(async (resolve) => {
try {
const modalSession = overlays.openModal(
toMountPoint(
wrapWithTheme(
<StartDeploymentModal
<StartUpdateDeploymentModal
startModelDeploymentDocUrl={startModelDeploymentDocUrl}
initialParams={initialParams}
modelId={modelId}
onConfigChange={(config) => {
modalSession.close();
resolve(config);

const resultConfig = { ...config };
if (resultConfig.priority === 'low') {
resultConfig.numOfAllocations = 1;
resultConfig.threadsPerAllocations = 1;
}

resolve(resultConfig);
}}
onClose={() => {
modalSession.close();
reject();
resolve();
}}
/>,
theme$
)
)
);
} catch (e) {
reject();
resolve();
}
});
};
Loading