Skip to content

Commit

Permalink
[Security solution] Attack Discovery "View in AI Assistant" button fix (
Browse files Browse the repository at this point in the history
  • Loading branch information
stephmilovic authored Sep 12, 2024
1 parent b02e1f3 commit ea6bb9e
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { useAssistantOverlay } from '.';
import { waitFor } from '@testing-library/react';
import { useFetchCurrentUserConversations } from '../api';
import { Conversation } from '../../assistant_context/types';
import { mockConnectors } from '../../mock/connectors';

const mockUseAssistantContext = {
registerPromptContext: jest.fn(),
Expand All @@ -27,19 +28,21 @@ jest.mock('../../assistant_context', () => {
};
});
jest.mock('../api/conversations/use_fetch_current_user_conversations');
const mockCreateConversation = jest.fn().mockResolvedValue({ id: 'conversation-id' });
jest.mock('../use_conversation', () => {
return {
useConversation: jest.fn(() => ({
createConversation: mockCreateConversation,
currentConversation: { id: 'conversation-id' },
})),
};
});
jest.mock('../helpers');

jest.mock('../../connectorland/helpers');
jest.mock('../../connectorland/use_load_connectors', () => {
return {
useLoadConnectors: jest.fn(() => ({
data: [],
data: mockConnectors,
error: null,
isSuccess: true,
})),
Expand Down Expand Up @@ -158,10 +161,78 @@ describe('useAssistantOverlay', () => {
result.current.showAssistantOverlay(true);
});

expect(mockCreateConversation).not.toHaveBeenCalled();
expect(mockUseAssistantContext.showAssistantOverlay).toHaveBeenCalledWith({
showOverlay: true,
promptContextId: 'id',
conversationTitle: 'conversation-id',
});
});

it('calls `showAssistantOverlay` and creates a new conversation when shouldCreateConversation: true and the conversation does not exist', async () => {
const isAssistantAvailable = true;
const { result } = renderHook(() =>
useAssistantOverlay(
'event',
'conversation-id',
'description',
() => Promise.resolve('data'),
'id',
null,
'tooltip',
isAssistantAvailable
)
);

act(() => {
result.current.showAssistantOverlay(true, true);
});

expect(mockCreateConversation).toHaveBeenCalledWith({
title: 'conversation-id',
apiConfig: {
actionTypeId: '.gen-ai',
connectorId: 'connectorId',
},
category: 'assistant',
});

await waitFor(() => {
expect(mockUseAssistantContext.showAssistantOverlay).toHaveBeenCalledWith({
showOverlay: true,
promptContextId: 'id',
conversationTitle: 'conversation-id',
});
});
});

it('calls `showAssistantOverlay` and does not create a new conversation when shouldCreateConversation: true and the conversation exists', async () => {
const isAssistantAvailable = true;
const { result } = renderHook(() =>
useAssistantOverlay(
'event',
'electric sheep',
'description',
() => Promise.resolve('data'),
'id',
null,
'tooltip',
isAssistantAvailable
)
);

act(() => {
result.current.showAssistantOverlay(true, true);
});

expect(mockCreateConversation).not.toHaveBeenCalled();

await waitFor(() => {
expect(mockUseAssistantContext.showAssistantOverlay).toHaveBeenCalledWith({
showOverlay: true,
promptContextId: 'id',
conversationTitle: 'electric sheep',
});
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ import { useCallback, useEffect, useMemo } from 'react';
import { useAssistantContext } from '../../assistant_context';
import { getUniquePromptContextId } from '../../assistant_context/helpers';
import type { PromptContext } from '../prompt_context/types';
import { useConversation } from '../use_conversation';
import { getDefaultConnector, mergeBaseWithPersistedConversations } from '../helpers';
import { getGenAiConfig } from '../../connectorland/helpers';
import { useLoadConnectors } from '../../connectorland/use_load_connectors';
import { FetchConversationsResponse, useFetchCurrentUserConversations } from '../api';
import { Conversation } from '../../assistant_context/types';

interface UseAssistantOverlay {
showAssistantOverlay: (show: boolean, silent?: boolean) => void;
Expand Down Expand Up @@ -76,6 +82,26 @@ export const useAssistantOverlay = (
*/
replacements?: Replacements | null
): UseAssistantOverlay => {
const { http } = useAssistantContext();
const { data: connectors } = useLoadConnectors({
http,
});

const defaultConnector = useMemo(() => getDefaultConnector(connectors), [connectors]);
const apiConfig = useMemo(() => getGenAiConfig(defaultConnector), [defaultConnector]);

const { createConversation } = useConversation();

const onFetchedConversations = useCallback(
(conversationsData: FetchConversationsResponse): Record<string, Conversation> =>
mergeBaseWithPersistedConversations({}, conversationsData),
[]
);
const { data: conversations, isLoading } = useFetchCurrentUserConversations({
http,
onFetch: onFetchedConversations,
isAssistantEnabled,
});
// memoize the props so that we can use them in the effect below:
const _category: PromptContext['category'] = useMemo(() => category, [category]);
const _description: PromptContext['description'] = useMemo(() => description, [description]);
Expand Down Expand Up @@ -104,16 +130,52 @@ export const useAssistantOverlay = (
// proxy show / hide calls to assistant context, using our internal prompt context id:
// silent:boolean doesn't show the toast notification if the conversation is not found
const showAssistantOverlay = useCallback(
async (showOverlay: boolean) => {
// shouldCreateConversation should only be passed for
// non-default conversations that may need to be initialized
async (showOverlay: boolean, shouldCreateConversation: boolean = false) => {
if (promptContextId != null) {
if (shouldCreateConversation) {
let conversation;
if (!isLoading) {
conversation = conversationTitle
? Object.values(conversations).find((conv) => conv.title === conversationTitle)
: undefined;
}

if (isAssistantEnabled && !conversation && defaultConnector && !isLoading) {
try {
await createConversation({
apiConfig: {
...apiConfig,
actionTypeId: defaultConnector?.actionTypeId,
connectorId: defaultConnector?.id,
},
category: 'assistant',
title: conversationTitle ?? '',
});
} catch (e) {
/* empty */
}
}
}
assistantContextShowOverlay({
showOverlay,
promptContextId,
conversationTitle: conversationTitle ?? undefined,
});
}
},
[assistantContextShowOverlay, conversationTitle, promptContextId]
[
apiConfig,
assistantContextShowOverlay,
conversationTitle,
conversations,
createConversation,
defaultConnector,
isAssistantEnabled,
isLoading,
promptContextId,
]
);

useEffect(() => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ import { useViewInAiAssistant } from './use_view_in_ai_assistant';
jest.mock('@kbn/elastic-assistant');
jest.mock('../../../assistant/use_assistant_availability');
jest.mock('../../get_attack_discovery_markdown/get_attack_discovery_markdown');

const mockUseAssistantOverlay = useAssistantOverlay as jest.Mock;
describe('useViewInAiAssistant', () => {
beforeEach(() => {
jest.clearAllMocks();

(useAssistantOverlay as jest.Mock).mockReturnValue({
mockUseAssistantOverlay.mockReturnValue({
promptContextId: 'prompt-context-id',
showAssistantOverlay: jest.fn(),
});
Expand Down Expand Up @@ -83,4 +83,16 @@ describe('useViewInAiAssistant', () => {

expect(result.current.disabled).toBe(true);
});

it('uses the title + last 5 of the attack discovery id as the conversation title', () => {
renderHook(() =>
useViewInAiAssistant({
attackDiscovery: mockAttackDiscovery,
})
);

expect(mockUseAssistantOverlay.mock.calls[0][1]).toEqual(
'Malware Attack With Credential Theft Attempt - b72b1'
);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ export const useViewInAiAssistant = ({
}),
[attackDiscovery]
);

const lastFive = attackDiscovery.id ? ` - ${attackDiscovery.id.slice(-5)}` : '';
const { promptContextId, showAssistantOverlay: showOverlay } = useAssistantOverlay(
category,
attackDiscovery.title, // conversation title
attackDiscovery.title + lastFive, // conversation title
attackDiscovery.title, // description used in context pill
getPromptContext,
attackDiscovery.id ?? null, // accept the UUID default for this prompt context
Expand Down

0 comments on commit ea6bb9e

Please sign in to comment.