Skip to content

Commit

Permalink
adding feature flag
Browse files Browse the repository at this point in the history
  • Loading branch information
shikha372 committed Sep 4, 2024
1 parent 6c6ac5f commit 9a4b791
Show file tree
Hide file tree
Showing 7 changed files with 317 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ import { BedrockInvokeModel } from 'aws-cdk-lib/aws-stepfunctions-tasks';
* * aws stepfunctions describe-execution --execution-arn <exection-arn generated before> : should return status as SUCCEEDED
* This integ test does not actually verify a Step Functions execution, as not all AWS accounts have Bedrock model access.
*/
const app = new cdk.App();
const app = new cdk.App({
postCliContext: {
'@aws-cdk/aws-cdk.aws-stepfunctions-tasks:useNewS3UriParametersForBedrockInvokeModelTask': true,
},
});
const stack = new cdk.Stack(app, 'aws-stepfunctions-tasks-bedrock-invoke-model-integ');

const model = bedrock.FoundationModel.fromFoundationModelId(stack, 'Model', bedrock.FoundationModelIdentifier.AMAZON_TITAN_TEXT_G1_EXPRESS_V1);
Expand Down Expand Up @@ -70,14 +74,32 @@ const prompt3 = new BedrockInvokeModel(stack, 'Prompt3', {
outputPath: sfn.JsonPath.stringAt('$.names'),
});

/** Test for Bedrock s3 URI Path */
/**Test for Bedrock Input Path */
const prompt4 = new BedrockInvokeModel(stack, 'Prompt4', {
model,
body: sfn.TaskInput.fromObject(
{
inputText: sfn.JsonPath.format(
'Alphabetize this list of first names:\n{}',
sfn.JsonPath.stringAt('$.names'),
),
textGenerationConfig: {
maxTokenCount: 100,
temperature: 1,
},
},
),
inputPath: sfn.JsonPath.stringAt('$.names'),
});

/** Test for Bedrock s3 URI Path */
const prompt5 = new BedrockInvokeModel(stack, 'Prompt5', {
model,
input: { s3InputUri: sfn.JsonPath.stringAt('$.names') },
output: { s3OutputUri: sfn.JsonPath.stringAt('$.names') },
});

const chain = sfn.Chain.start(prompt1).next(prompt2).next(prompt3).next(prompt4);
const chain = sfn.Chain.start(prompt1).next(prompt2).next(prompt3).next(prompt4).next(prompt5);

new sfn.StateMachine(stack, 'StateMachine', {
definitionBody: sfn.DefinitionBody.fromChainable(chain),
Expand Down
38 changes: 36 additions & 2 deletions packages/aws-cdk-lib/aws-stepfunctions-tasks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -398,11 +398,45 @@ const task = new tasks.BedrockInvokeModel(this, 'Prompt Model', {
names: sfn.JsonPath.stringAt('$.Body.results[0].outputText'),
},
});

```
### Using Input Path for S3 URI

Provide S3 URI as an input or output path to invoke a model

To specify the S3 URI as JSON path to your input or output fields, use props `s3InputUri` and `s3OutputUri` under BedrockInvokeModelProps and set
feature flag `@aws-cdk/aws-stepfunctions-tasks:useNewS3UriParametersForBedrockInvokeModelTask` to true.

If this flag is not set, then the existing behaviour of populating the S3Uri from `InputPath` and `OutputPath` will take effect.

```ts

import * as bedrock from 'aws-cdk-lib/aws-bedrock';

const model = bedrock.FoundationModel.fromFoundationModelId(
this,
'Model',
bedrock.FoundationModelIdentifier.AMAZON_TITAN_TEXT_G1_EXPRESS_V1,
);

const task = new tasks.BedrockInvokeModel(this, 'Prompt Model', {
model,
input : { s3InputUri: sfn.JsonPath.stringAt('$.prompt') },
output: { s3OutputUri: sfn.JsonPath.stringAt('$.prompt') },
});

```

### Using Input Path

Provide S3 URI as an input or output path to invoke a model

Currently, input and output Path provided in the BedrockInvokeModelProps input is defined as S3URI field under task definition of state machine.
To modify the existing behaviour, set `@aws-cdk/aws-stepfunctions-tasks:useNewS3UriParametersForBedrockInvokeModelTask` to true.

If this feature flag is enabled, S3URI fields will be generated from other Props(`s3InputUri` and `s3OutputUri`), and the given inputPath, OutputPath will be rendered as
it is in the JSON task definition.

```ts

import * as bedrock from 'aws-cdk-lib/aws-bedrock';
Expand All @@ -415,8 +449,8 @@ const model = bedrock.FoundationModel.fromFoundationModelId(

const task = new tasks.BedrockInvokeModel(this, 'Prompt Model', {
model,
s3InputUri: sfn.JsonPath.stringAt('$.prompt'),
s3OutputUri: sfn.JsonPath.stringAt('$.prompt'),
inputPath: sfn.JsonPath.stringAt('$.prompt'),
outputPath: sfn.JsonPath.stringAt('$.prompt'),
});

```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ import * as bedrock from '../../../aws-bedrock';
import * as iam from '../../../aws-iam';
import * as s3 from '../../../aws-s3';
import * as sfn from '../../../aws-stepfunctions';
import { Stack } from '../../../core';
import { Annotations, Stack, FeatureFlags } from '../../../core';
import * as cxapi from '../../../cx-api';
import { integrationResourceArn, validatePatternSupported } from '../private/task-utils';

/**
Expand Down Expand Up @@ -163,10 +164,17 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {

validatePatternSupported(this.integrationPattern, BedrockInvokeModel.SUPPORTED_INTEGRATION_PATTERNS);

const isBodySpecified = props.body !== undefined;
const isFeatureFlagEnabled = FeatureFlags.of(this).isEnabled(cxapi.USE_NEW_S3URI_PARAMETERS_FOR_BEDROCK_INVOKE_MODEL_TASK);

//Either specific props.input with bucket name and object key or input s3 path
const isInputSpecified = props.input!==undefined ? props.input?.s3Location !== undefined || props.input?.s3InputUri !== undefined : false;
const isBodySpecified = props.body !== undefined;

let isInputSpecified: boolean;
if (!isFeatureFlagEnabled) {
//Either specific props.input with bucket name and object key or input s3 path
isInputSpecified = (props.input !== undefined && props.input.s3Location !== undefined) || (props.inputPath !== undefined);
} else {
isInputSpecified = props.input!==undefined ? props.input?.s3Location !== undefined || props.input?.s3InputUri !== undefined : false;
}

if (isBodySpecified && isInputSpecified) {
throw new Error('Either `body` or `input` must be specified, but not both.');
Expand All @@ -181,18 +189,25 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {
throw new Error('Output S3 object version is not supported.');
}

this.taskPolicies = this.renderPolicyStatements();
//Warning to let users know about the newly introduced props
if (props.inputPath || props.outputPath && !isFeatureFlagEnabled) {
Annotations.of(scope).addWarningV2('aws-cdk-lib/aws-stepfunctions-taks',
'These props will set the value of inputPath/outputPath as s3 URI under input/output field in state machine JSON definition. To modify the behaviour set feature flag `@aws-cdk/aws-stepfunctions-tasks:useNewS3UriParametersForBedrockInvokeModelTask": true` and use props s3InputUri/s3OutputUri');
}

this.taskPolicies = this.renderPolicyStatements(isFeatureFlagEnabled);
}

private renderPolicyStatements(): iam.PolicyStatement[] {
private renderPolicyStatements(isFeatureFlagEnabled?: boolean): iam.PolicyStatement[] {
const policyStatements = [
new iam.PolicyStatement({
actions: ['bedrock:InvokeModel'],
resources: [this.props.model.modelArn],
}),
];

if (this.props.input?.s3InputUri !== undefined) {
//For Compatibility with existing behaviour of input path
if (this.props.input?.s3InputUri !== undefined || (!isFeatureFlagEnabled && this.props.inputPath !== undefined)) {
policyStatements.push(
new iam.PolicyStatement({
actions: ['s3:GetObject'],
Expand Down Expand Up @@ -223,7 +238,8 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {
);
}

if (this.props.output?.s3OutputUri !== undefined) {
//For Compatibility with existing behaviour of output path
if (this.props.output?.s3OutputUri !== undefined || (!isFeatureFlagEnabled && this.props.outputPath !== undefined)) {
policyStatements.push(
new iam.PolicyStatement({
actions: ['s3:PutObject'],
Expand Down Expand Up @@ -281,19 +297,19 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {
* @internal
*/
protected _renderTask(): any {

const isFeatureFlagEnabled = FeatureFlags.of(this).isEnabled(cxapi.USE_NEW_S3URI_PARAMETERS_FOR_BEDROCK_INVOKE_MODEL_TASK);
const inputSource = this.getInputSource(this.props.input, this.props.inputPath, isFeatureFlagEnabled);
const outputSource = this.getOutputSource(this.props.output, this.props.outputPath, isFeatureFlagEnabled);
return {
Resource: integrationResourceArn('bedrock', 'invokeModel'),
Parameters: sfn.FieldUtils.renderObject({
ModelId: this.props.model.modelArn,
Accept: this.props.accept,
ContentType: this.props.contentType,
Body: this.props.body?.value,
Input: this.props.input?.s3Location ? {
S3Uri: `s3://${this.props.input.s3Location.bucketName}/${this.props.input.s3Location.objectKey}`,
} : this.props.input?.s3InputUri ? { S3Uri: this.props.input?.s3InputUri } : undefined,
Output: this.props.output?.s3Location ? {
S3Uri: `s3://${this.props.output.s3Location.bucketName}/${this.props.output.s3Location.objectKey}`,
} : this.props.output?.s3OutputUri? { S3Uri: this.props.output.s3OutputUri }: undefined,
Input: this.props.input || (this.props.inputPath && !isFeatureFlagEnabled) ? { S3Uri: inputSource } : undefined,
Output: this.props.output || ( this.props.outputPath && !isFeatureFlagEnabled) ? { S3Uri: outputSource } : undefined,
GuardrailIdentifier: this.props.guardrail?.guardrailIdentifier,
GuardrailVersion: this.props.guardrail?.guardrailVersion,
Trace: this.props.traceEnabled === undefined
Expand All @@ -304,5 +320,27 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {
}),
};
};

private getInputSource(props?: BedrockInvokeModelInputProps, inputPath?: string, isFeatureFlagEnabled?: boolean): string | undefined {
if (props?.s3Location) {
return `s3://${props.s3Location.bucketName}/${props.s3Location.objectKey}`;
} else if (isFeatureFlagEnabled && props?.s3InputUri) {
return props.s3InputUri;
} else if (!isFeatureFlagEnabled && inputPath) {
return inputPath;
}
return undefined;
}

private getOutputSource(props?: BedrockInvokeModelOutputProps, outputPath?: string, isFeatureFlagEnabled?: boolean): string | undefined {
if (props?.s3Location) {
return `s3://${props.s3Location.bucketName}/${props.s3Location.objectKey}`;
} else if (isFeatureFlagEnabled && props?.s3OutputUri) {
return props.s3OutputUri;
} else if (!isFeatureFlagEnabled && outputPath) {
return outputPath;
}
return undefined;
}
}

Loading

0 comments on commit 9a4b791

Please sign in to comment.