Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Fixing missing multi label checkbox and basic input validation #145357

Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import { getAnalysisType } from '../../../../common/util/analytics_utils';
import { DataFrameAnalysisConfigType } from '../../../../common/types/data_frame_analytics';
import { ML_PAGES } from '../../../../common/constants/locator';
import { DEPLOYMENT_STATE, TRAINED_MODEL_TYPE } from '../../../../common/constants/trained_models';
import { isTestable, isTestEnabled } from './test_models';
import { isTestable } from './test_models';
import { ModelItem } from './models_list';

export function useModelActions({
Expand Down Expand Up @@ -361,7 +361,7 @@ export function useModelActions({
isPrimary: true,
available: isTestable,
onClick: (item) => onTestAction(item.model_id),
enabled: (item) => canTestTrainedModels && isTestEnabled(item),
enabled: (item) => canTestTrainedModels && isTestable(item, true),
},
],
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
*/

export { TestTrainedModelFlyout } from './test_flyout';
export { isTestable, isTestEnabled } from './utils';
export { isTestable } from './utils';
Original file line number Diff line number Diff line change
Expand Up @@ -33,32 +33,36 @@ export const InferenceInputFormIndexControls: FC<Props> = ({ inferrer, data }) =
setSelectedField,
} = data;

const runningState = useObservable(inferrer.getRunningState$());
const runningState = useObservable(inferrer.getRunningState$(), inferrer.getRunningState());
const pipeline = useObservable(inferrer.getPipeline$(), inferrer.getPipeline());
const inputComponent = useMemo(() => inferrer.getInputComponent(), [inferrer]);

return (
<>
<EuiFormRow label="Index">
<EuiFormRow label="Index" fullWidth>
<EuiSelect
options={dataViewListItems}
value={selectedDataViewId}
onChange={(e) => setSelectedDataViewId(e.target.value)}
hasNoInitialSelection={true}
disabled={runningState === RUNNING_STATE.RUNNING}
fullWidth
/>
</EuiFormRow>
<EuiSpacer size="m" />
<EuiFormRow
label={i18n.translate('xpack.ml.trainedModels.testModelsFlyout.indexInput.fieldInput', {
defaultMessage: 'Field',
})}
fullWidth
>
<EuiSelect
options={fieldNames}
value={selectedField}
onChange={(e) => setSelectedField(e.target.value)}
hasNoInitialSelection={true}
disabled={runningState === RUNNING_STATE.RUNNING}
fullWidth
/>
</EuiFormRow>

Expand All @@ -76,7 +80,7 @@ export const InferenceInputFormIndexControls: FC<Props> = ({ inferrer, data }) =
)}
>
<EuiCodeBlock language="json" fontSize="s" paddingSize="s" lineNumbers isCopyable={true}>
{JSON.stringify(inferrer.getPipeline(), null, 2)}
{JSON.stringify(pipeline, null, 2)}
</EuiCodeBlock>
</EuiAccordion>
</>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
* 2.0.
*/

import { BehaviorSubject } from 'rxjs';
import { BehaviorSubject, Observable, combineLatest, Subscription } from 'rxjs';
import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { i18n } from '@kbn/i18n';

import { map } from 'rxjs/operators';
import { MLHttpFetchError } from '../../../../../../common/util/errors';
import { SupportedPytorchTasksType } from '../../../../../../common/constants/trained_models';
import { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';
Expand Down Expand Up @@ -58,21 +59,69 @@ export enum INPUT_TYPE {
export abstract class InferenceBase<TInferResponse> {
protected abstract readonly inferenceType: InferenceType;
protected abstract readonly inferenceTypeLabel: string;
protected inputField: string;
protected readonly modelInputField: string;
private inputText$ = new BehaviorSubject<string[]>([]);

protected inputText$ = new BehaviorSubject<string[]>([]);
private inputField$ = new BehaviorSubject<string>('');
private inferenceResult$ = new BehaviorSubject<TInferResponse[] | null>(null);
private inferenceError$ = new BehaviorSubject<MLHttpFetchError | null>(null);
private runningState$ = new BehaviorSubject<RUNNING_STATE>(RUNNING_STATE.STOPPED);
private isValid$ = new BehaviorSubject<boolean>(false);
private pipeline$ = new BehaviorSubject<estypes.IngestPipeline>({});

protected readonly info: string[] = [];

private subscriptions$: Subscription = new Subscription();

constructor(
protected readonly trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
protected readonly model: estypes.MlTrainedModelConfig,
protected readonly inputType: INPUT_TYPE
) {
this.modelInputField = model.input?.field_names[0] ?? DEFAULT_INPUT_FIELD;
this.inputField = this.modelInputField;
this.inputField$.next(this.modelInputField);
}

public destroy() {
this.subscriptions$.unsubscribe();
}

protected initialize(
additionalValidators?: Array<Observable<boolean>>,
additionalChanges?: Array<Observable<unknown>>
) {
this.initializeValidators(additionalValidators);
this.initializePipeline(additionalChanges);
}

private initializeValidators(additionalValidators?: Array<Observable<boolean>>) {
const validators$: Array<Observable<boolean>> = [
this.inputText$.pipe(map((inputText) => inputText.some((t) => t !== ''))),
...(additionalValidators ? additionalValidators : []),
];

this.subscriptions$.add(
combineLatest(validators$)
.pipe(
map((validationResults) => {
return validationResults.every((v) => !!v);
})
)
.subscribe(this.isValid$)
);
}

private initializePipeline(additionalChanges?: Array<Observable<unknown>>) {
const formObservables$: Array<Observable<unknown>> = [
this.inputField$.asObservable(),
...(additionalChanges ? additionalChanges : []),
];

this.subscriptions$.add(
combineLatest(formObservables$).subscribe(() => {
this.pipeline$.next(this.generatePipeline());
})
);
}

public setStopped() {
Expand Down Expand Up @@ -110,7 +159,11 @@ export abstract class InferenceBase<TInferResponse> {
}

public setInputField(field: string | undefined) {
this.inputField = field === undefined ? this.modelInputField : field;
this.inputField$.next(field === undefined ? this.modelInputField : field);
}

public getInputField() {
return this.inputField$.getValue();
}

public setInputText(text: string[]) {
Expand All @@ -121,18 +174,42 @@ export abstract class InferenceBase<TInferResponse> {
return this.inputText$.asObservable();
}

public getInputText() {
return this.inputText$.getValue();
}

public getInferenceResult$() {
return this.inferenceResult$.asObservable();
}

public getInferenceResult() {
return this.inferenceResult$.getValue();
}

public getInferenceError$() {
return this.inferenceError$.asObservable();
}

public getInferenceError() {
return this.inferenceError$.getValue();
}

public getRunningState$() {
return this.runningState$.asObservable();
}

public getRunningState() {
return this.runningState$.getValue();
}

public getIsValid$() {
return this.isValid$.asObservable();
}

public getIsValid() {
return this.isValid$.getValue();
}

protected abstract getInputComponent(): JSX.Element | null;
protected abstract getOutputComponent(): JSX.Element;

Expand All @@ -143,12 +220,20 @@ export abstract class InferenceBase<TInferResponse> {
protected abstract inferText(): Promise<TInferResponse[]>;
protected abstract inferIndex(): Promise<TInferResponse[]>;

public getPipeline(): estypes.IngestPipeline {
public generatePipeline(): estypes.IngestPipeline {
return {
processors: this.getProcessors(),
};
}

public getPipeline$() {
return this.pipeline$.asObservable();
}

public getPipeline(): estypes.IngestPipeline {
return this.pipeline$.getValue();
}

protected getBasicProcessors(
inferenceConfigOverrides?: InferenceOptions
): estypes.IngestProcessorContainer[] {
Expand All @@ -157,7 +242,7 @@ export abstract class InferenceBase<TInferResponse> {
model_id: this.model.model_id,
target_field: this.inferenceType,
field_map: {
[this.inputField]: this.modelInputField,
[this.inputField$.getValue()]: this.modelInputField,
},
...(inferenceConfigOverrides && Object.keys(inferenceConfigOverrides).length
? { inference_config: this.getInferenceConfig(inferenceConfigOverrides) }
Expand All @@ -179,17 +264,20 @@ export abstract class InferenceBase<TInferResponse> {
}

protected async runInfer<TRawInferResponse>(
getInferBody: (inputText: string) => estypes.MlInferTrainedModelRequest['body'],
getInferenceConfig: () => estypes.MlInferenceConfigUpdateContainer | void,
processResponse: (resp: TRawInferResponse, inputText: string) => TInferResponse
): Promise<TInferResponse[]> {
try {
this.setRunning();
const inputText = this.inputText$.getValue()[0];
const body = getInferBody(inputText);
const inferenceConfig = getInferenceConfig();

const resp = (await this.trainedModelsApi.inferTrainedModel(
this.model.model_id,
body,
{
docs: this.getInferDocs(),
...(inferenceConfig ? { inference_config: inferenceConfig } : {}),
},
DEFAULT_INFERENCE_TIME_OUT
)) as unknown as TRawInferResponse;

Expand Down Expand Up @@ -226,10 +314,14 @@ export abstract class InferenceBase<TInferResponse> {

protected abstract getProcessors(): estypes.IngestProcessorContainer[];

protected getInferDocs() {
return [{ [this.inputField$.getValue()]: this.inputText$.getValue()[0] }];
}

protected getPipelineDocs() {
return this.inputText$.getValue().map((v) => ({
_source: {
[this.inputField]: v,
[this.inputField$.getValue()]: v,
},
}));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@ interface Props {
inferrer: InferrerType;
}

export const IndexInput: FC<Props> = ({ inferrer }) => {
export const IndexInputForm: FC<Props> = ({ inferrer }) => {
const data = useIndexInput({ inferrer });
const { reloadExamples, selectedField } = data;

const [errorText, setErrorText] = useState<string | null>(null);
const runningState = useObservable(inferrer.getRunningState$());
const examples = useObservable(inferrer.getInputText$()) ?? [];
const runningState = useObservable(inferrer.getRunningState$(), inferrer.getRunningState());
const examples = useObservable(inferrer.getInputText$(), inferrer.getInputText());
const isValid = useObservable(inferrer.getIsValid$(), inferrer.getIsValid());
const outputComponent = useMemo(() => inferrer.getOutputComponent(), [inferrer]);
const infoComponent = useMemo(() => inferrer.getInfoComponent(), [inferrer]);

Expand All @@ -60,7 +61,7 @@ export const IndexInput: FC<Props> = ({ inferrer }) => {
<EuiFlexItem grow={false}>
<EuiButton
onClick={run}
disabled={runningState === RUNNING_STATE.RUNNING || selectedField === undefined}
disabled={runningState === RUNNING_STATE.RUNNING || isValid === false}
fullWidth={false}
>
<FormattedMessage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import React, { FC } from 'react';

import { INPUT_TYPE } from '../inference_base';
import { TextInput } from './text_input';
import { IndexInput } from './index_input';
import { TextInputForm } from './text_input';
import { IndexInputForm } from './index_input';
import { InferrerType } from '..';

interface Props {
Expand All @@ -19,8 +19,8 @@ interface Props {

export const InferenceInputForm: FC<Props> = ({ inferrer, inputType }) => {
return inputType === INPUT_TYPE.TEXT ? (
<TextInput inferrer={inferrer} />
<TextInputForm inferrer={inferrer} />
) : (
<IndexInput inferrer={inferrer} />
<IndexInputForm inferrer={inferrer} />
);
};
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ enum TAB {
RAW,
}

export const TextInput: FC<Props> = ({ inferrer }) => {
export const TextInputForm: FC<Props> = ({ inferrer }) => {
const [selectedTab, setSelectedTab] = useState(TAB.TEXT);
const [errorText, setErrorText] = useState<string | null>(null);

const runningState = useObservable(inferrer.getRunningState$());
const inputText = useObservable(inferrer.getInputText$()) ?? [];
const isValid = useObservable(inferrer.getIsValid$(), inferrer.getIsValid());
const runningState = useObservable(inferrer.getRunningState$(), inferrer.getRunningState());
const inputComponent = useMemo(() => inferrer.getInputComponent(), [inferrer]);
const outputComponent = useMemo(() => inferrer.getOutputComponent(), [inferrer]);
const infoComponent = useMemo(() => inferrer.getInfoComponent(), [inferrer]);
Expand All @@ -54,7 +54,7 @@ export const TextInput: FC<Props> = ({ inferrer }) => {
<div>
<EuiButton
onClick={run}
disabled={runningState === RUNNING_STATE.RUNNING || inputText[0] === ''}
disabled={runningState === RUNNING_STATE.RUNNING || isValid === false}
fullWidth={false}
>
<FormattedMessage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { i18n } from '@kbn/i18n';
import { trainedModelsApiProvider } from '../../../../../services/ml_api_service/trained_models';
import { InferenceBase, INPUT_TYPE } from '../inference_base';
import type { InferResponse } from '../inference_base';
import { getGeneralInputComponent } from '../text_input';
Expand All @@ -32,11 +33,19 @@ export class NerInference extends InferenceBase<NerResponse> {
}),
];

constructor(
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE
) {
super(trainedModelsApi, model, inputType);

this.initialize();
}

protected async inferText() {
return this.runInfer<estypes.MlInferTrainedModelResponse>(
(inputText: string) => {
return { docs: [{ [this.inputField]: inputText }] };
},
() => {},
(resp, inputText) => {
return {
response: parseResponse(resp),
Expand All @@ -52,7 +61,7 @@ export class NerInference extends InferenceBase<NerResponse> {
return {
response: parseResponse({ inference_results: [doc._source[this.inferenceType]] }),
rawResponse: doc._source[this.inferenceType],
inputText: doc._source[this.inputField],
inputText: doc._source[this.getInputField()],
};
});
}
Expand Down
Loading