Skip to content

Commit

Permalink
rewriting observables
Browse files Browse the repository at this point in the history
  • Loading branch information
jgowdyelastic committed Nov 17, 2022
1 parent 69d9471 commit d818da9
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
* 2.0.
*/

import { BehaviorSubject } from 'rxjs';
import { BehaviorSubject, Observable, combineLatest, Subscription } from 'rxjs';
import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { i18n } from '@kbn/i18n';

import { map } from 'rxjs/operators';
import { MLHttpFetchError } from '../../../../../../common/util/errors';
import { SupportedPytorchTasksType } from '../../../../../../common/constants/trained_models';
import { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';
Expand Down Expand Up @@ -68,6 +69,9 @@ export abstract class InferenceBase<TInferResponse> {
protected isValid$ = new BehaviorSubject<boolean>(false);
protected readonly info: string[] = [];

protected validators$: Array<Observable<boolean>> = [];
private validatorsSubscriptions$: Subscription = new Subscription();

constructor(
protected readonly trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
protected readonly model: estypes.MlTrainedModelConfig,
Expand All @@ -76,10 +80,27 @@ export abstract class InferenceBase<TInferResponse> {
this.modelInputField = model.input?.field_names[0] ?? DEFAULT_INPUT_FIELD;
this.inputField = this.modelInputField;

this.inputText$.subscribe((inputText) => {
const inputTextPopulated = inputText.some((t) => t !== '');
this.inputTextValid$.next(inputTextPopulated);
});
this.validators$.push(
this.inputText$.pipe(map((inputText) => inputText.some((t) => t !== '')))
);
}

public destroy() {
this.validatorsSubscriptions$.unsubscribe();
}

protected initializeValidators() {
this.validatorsSubscriptions$.add(
combineLatest(this.validators$)
.pipe(
map((validationResults) => {
return validationResults.every((v) => !!v);
})
)
.subscribe((v) => {
this.isValid$.next(v);
})
);
}

public setStopped() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { i18n } from '@kbn/i18n';
import { combineLatest } from 'rxjs';
// import { combineLatest } from 'rxjs';
import { trainedModelsApiProvider } from '../../../../../services/ml_api_service/trained_models';
import { InferenceBase, INPUT_TYPE } from '../inference_base';
import type { InferResponse } from '../inference_base';
Expand Down Expand Up @@ -41,9 +41,7 @@ export class NerInference extends InferenceBase<NerResponse> {
) {
super(trainedModelsApi, model, inputType);

combineLatest([this.inputTextValid$]).subscribe(([inputTextValid]) => {
this.isValid$.next(inputTextValid);
});
this.initializeValidators();
}

protected async inferText() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
* 2.0.
*/

import { BehaviorSubject, combineLatest } from 'rxjs';
import { BehaviorSubject } from 'rxjs';
import { map } from 'rxjs/operators';
import { i18n } from '@kbn/i18n';
import { estypes } from '@elastic/elasticsearch';
import { InferenceBase, INPUT_TYPE } from '../inference_base';
Expand Down Expand Up @@ -65,12 +66,8 @@ export class QuestionAnsweringInference extends InferenceBase<QuestionAnsweringR
) {
super(trainedModelsApi, model, inputType);

combineLatest([this.inputTextValid$, this.questionText$]).subscribe(
([inputTextValid, questionText]) => {
const valid = inputTextValid && questionText !== '';
this.isValid$.next(valid);
}
);
this.validators$.push(this.questionText$.pipe(map((questionText) => questionText !== '')));
this.initializeValidators();
}

public async inferText() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import { i18n } from '@kbn/i18n';
import { estypes } from '@elastic/elasticsearch';
import { combineLatest } from 'rxjs';
import { map } from 'rxjs/operators';
import { InferenceBase, INPUT_TYPE } from '../inference_base';
import type { TextClassificationResponse, RawTextClassificationResponse } from './common';
import { processResponse, processInferenceResult } from './common';
Expand Down Expand Up @@ -38,10 +38,11 @@ export class FillMaskInference extends InferenceBase<TextClassificationResponse>
) {
super(trainedModelsApi, model, inputType);

combineLatest([this.inputTextValid$]).subscribe(([inputTextValid]) => {
const valid = inputTextValid && this.inputText$.getValue().every((t) => t.includes(MASK));
this.isValid$.next(valid);
});
this.validators$.push(
this.inputText$.pipe(map((inputText) => inputText.every((t) => t.includes(MASK))))
);

this.initializeValidators();
}

protected async inferText() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import { i18n } from '@kbn/i18n';
import { estypes } from '@elastic/elasticsearch';
import { combineLatest } from 'rxjs';
import { InferenceBase, INPUT_TYPE } from '../inference_base';
import type { InferenceType } from '../inference_base';
import { processInferenceResult, processResponse } from './common';
Expand Down Expand Up @@ -35,9 +34,7 @@ export class LangIdentInference extends InferenceBase<TextClassificationResponse
) {
super(trainedModelsApi, model, inputType);

combineLatest([this.inputTextValid$]).subscribe(([inputTextValid]) => {
this.isValid$.next(inputTextValid);
});
this.initializeValidators();
}

public async inferText() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import { i18n } from '@kbn/i18n';
import { estypes } from '@elastic/elasticsearch';
import { combineLatest } from 'rxjs';
import { InferenceBase, INPUT_TYPE } from '../inference_base';
import { processInferenceResult, processResponse } from './common';
import type { TextClassificationResponse, RawTextClassificationResponse } from './common';
Expand Down Expand Up @@ -35,9 +34,7 @@ export class TextClassificationInference extends InferenceBase<TextClassificatio
) {
super(trainedModelsApi, model, inputType);

combineLatest([this.inputTextValid$]).subscribe(([inputTextValid]) => {
this.isValid$.next(inputTextValid);
});
this.initializeValidators();
}

public async inferText() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import { i18n } from '@kbn/i18n';
import { BehaviorSubject } from 'rxjs';
import { map } from 'rxjs/operators';
import { estypes } from '@elastic/elasticsearch';
import { combineLatest } from 'rxjs';
import { trainedModelsApiProvider } from '../../../../../services/ml_api_service/trained_models';
import { InferenceBase, INPUT_TYPE } from '../inference_base';
import { processInferenceResult, processResponse } from './common';
Expand Down Expand Up @@ -41,12 +41,9 @@ export class ZeroShotClassificationInference extends InferenceBase<TextClassific
) {
super(trainedModelsApi, model, inputType);

combineLatest([this.inputTextValid$, this.labelsText$]).subscribe(
([inputTextValid, labelsText]) => {
const isValid = inputTextValid && labelsText !== '';
this.isValid$.next(isValid);
}
);
this.validators$.push(this.labelsText$.pipe(map((labelsText) => labelsText !== '')));

this.initializeValidators();
}

public async inferText() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import { i18n } from '@kbn/i18n';
import { estypes } from '@elastic/elasticsearch';
import { combineLatest } from 'rxjs';
import { InferenceBase, INPUT_TYPE } from '../inference_base';
import type { InferResponse } from '../inference_base';
import { getGeneralInputComponent } from '../text_input';
Expand Down Expand Up @@ -47,9 +46,7 @@ export class TextEmbeddingInference extends InferenceBase<TextEmbeddingResponse>
) {
super(trainedModelsApi, model, inputType);

combineLatest([this.inputTextValid$]).subscribe(([inputTextValid]) => {
this.isValid$.next(inputTextValid);
});
this.initializeValidators();
}

public async inferText() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
*/

import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import React, { FC, useMemo } from 'react';
import React, { FC, useMemo, useEffect } from 'react';

import { NerInference } from './models/ner';
import { QuestionAnsweringInference } from './models/question_answering';
Expand Down Expand Up @@ -69,6 +69,12 @@ export const SelectedModel: FC<Props> = ({ model, inputType }) => {
}
}, [inputType, model, trainedModels]);

useEffect(() => {
return () => {
inferrer?.destroy();
};
}, [inferrer]);

if (inferrer !== undefined) {
return <InferenceInputForm inferrer={inferrer} inputType={inputType} />;
}
Expand Down

0 comments on commit d818da9

Please sign in to comment.