Skip to content

Commit

Permalink
refactor: extract message parsing logic
Browse files Browse the repository at this point in the history
Extracted the logic for parsing commit messages from generated text into a separate private method for better code organization and readability.
  • Loading branch information
tak-bro committed Jul 15, 2024
1 parent 14b1ef5 commit e85ae2f
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 128 deletions.
59 changes: 43 additions & 16 deletions src/services/ai/ai.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { Observable, of } from 'rxjs';

import { CommitType, ValidConfig } from '../../utils/config.js';
import { StagedDiff } from '../../utils/git.js';
import { extraPrompt, generateDefaultPrompt, isValidConventionalMessage, isValidGitmojiMessage } from '../../utils/prompt.js';
import { extraPrompt, generateDefaultPrompt } from '../../utils/prompt.js';

// NOTE: get AI Type from key names
export const AIType = {
Expand All @@ -26,7 +26,7 @@ export interface CommitMessage {
value: string;
}

export interface ParsedMessage {
export interface RawCommitMessage {
message: string;
body: string;
}
Expand Down Expand Up @@ -63,7 +63,7 @@ export abstract class AIService {

protected buildPrompt(locale: string, diff: string, completions: number, maxLength: number, type: CommitType, prompt: string) {
const defaultPrompt = generateDefaultPrompt(locale, maxLength, type, prompt);
return `${defaultPrompt}\n${extraPrompt(completions)}\nHere are diff: \n${diff}`;
return `${defaultPrompt}\n${extraPrompt(completions, type)}\nHere are diff: \n${diff}`;
}

protected handleError$ = (error: AIServiceError): Observable<ReactiveListChoice> => {
Expand All @@ -79,7 +79,7 @@ export abstract class AIService {
});
};

protected sanitizeMessage(generatedText: string, type: CommitType, maxCount: number) {
protected sanitizeMessage(generatedText: string, type: CommitType, maxCount: number): CommitMessage[] {
const jsonPattern = /\[[\s\S]*?\]/;

try {
Expand All @@ -89,19 +89,10 @@ export abstract class AIService {
return [];
}
const jsonStr = jsonMatch[0];
const commitMessages: ParsedMessage[] = JSON.parse(jsonStr);
const commitMessages: RawCommitMessage[] = JSON.parse(jsonStr);
const filtedMessages = commitMessages
.filter(data => {
switch (type) {
case 'conventional':
return isValidConventionalMessage(data.message);
case 'gitmoji':
return isValidGitmojiMessage(data.message);
default:
return true;
}
})
.map((data: ParsedMessage) => {
.map(data => this.extractMessageAsType(data, type))
.map((data: RawCommitMessage) => {
return {
title: `${data.message}`,
value: data.body ? `${data.message}\n\n${data.body}` : `${data.message}`,
Expand All @@ -117,4 +108,40 @@ export abstract class AIService {
return [];
}
}

private extractMessageAsType(data: RawCommitMessage, type: CommitType): RawCommitMessage {
switch (type) {
case 'conventional':
const conventionalPattern = /(\w+)(?:\(.*?\))?:\s*(.*)/;
const conventionalMatch = data.message.match(conventionalPattern);
const message = conventionalMatch ? conventionalMatch[0] : data.message;
return {
...data,
message: this.normalizeCommitMessage(message),
};
case 'gitmoji':
const gitmojiTypePattern = /:\w*:\s*(.*)/;
const gitmojiMatch = data.message.match(gitmojiTypePattern);
return {
...data,
message: gitmojiMatch ? gitmojiMatch[0].toLowerCase() : data.message,
};
default:
return data;
}
}

private normalizeCommitMessage(message: string): string {
const messagePattern = /^(\w+)(\(.*?\))?:\s(.*)$/;
const match = message.match(messagePattern);

if (match) {
const [, type, scope, description] = match;
const normalizedType = type.toLowerCase();
const normalizedDescription = description.charAt(0).toLowerCase() + description.slice(1);
message = `${normalizedType}${scope || ''}: ${normalizedDescription}`;
}

return message;
}
}
2 changes: 1 addition & 1 deletion src/services/ai/anthropic.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ export class AnthropicService extends AIService {
const maxLength = this.params.config['max-length'];

const defaultPrompt = generateDefaultPrompt(locale, maxLength, type, userPrompt);
const systemPrompt = `${defaultPrompt}\n${extraPrompt(generate)}`;
const systemPrompt = `${defaultPrompt}\n${extraPrompt(generate, type)}`;

const params: Anthropic.MessageCreateParams = {
max_tokens: this.params.config['max-tokens'],
Expand Down
2 changes: 1 addition & 1 deletion src/services/ai/groq.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export class GroqService extends AIService {
const { locale, generate, type, prompt: userPrompt, logging } = this.params.config;
const maxLength = this.params.config['max-length'];
const defaultPrompt = generateDefaultPrompt(locale, maxLength, type, userPrompt);
const systemPrompt = `${defaultPrompt}\n${extraPrompt(generate)}`;
const systemPrompt = `${defaultPrompt}\n${extraPrompt(generate, type)}`;

const chatCompletion = await this.groq.chat.completions.create(
{
Expand Down
3 changes: 1 addition & 2 deletions src/services/ai/ollama.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import { extraPrompt, generateDefaultPrompt } from '../../utils/prompt.js';
import { capitalizeFirstLetter } from '../../utils/utils.js';
import { HttpRequestBuilder } from '../http/http-request.builder.js';


export interface OllamaServiceError extends AIServiceError {}

export class OllamaService extends AIService {
Expand Down Expand Up @@ -130,6 +129,6 @@ export class OllamaService extends AIService {
this.params.config.type,
this.params.config.prompt
);
return `${defaultPrompt}\n${extraPrompt(this.params.config.generate)}`;
return `${defaultPrompt}\n${extraPrompt(this.params.config.generate, this.params.config.type)}`;
}
}
50 changes: 29 additions & 21 deletions src/services/ai/openai.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { ReactiveListChoice } from 'inquirer-reactive-list-prompt';
import { Observable, catchError, concatMap, from, map, of } from 'rxjs';
import { fromPromise } from 'rxjs/internal/observable/innerFrom';

import { AIService, AIServiceError, AIServiceParams } from './ai.service.js';
import { AIService, AIServiceError, AIServiceParams, CommitMessage } from './ai.service.js';
import { generateCommitMessage } from '../../utils/openai.js';

export class OpenAIService extends AIService {
Expand All @@ -18,26 +18,8 @@ export class OpenAIService extends AIService {
}

generateCommitMessage$(): Observable<ReactiveListChoice> {
return fromPromise(
generateCommitMessage(
this.params.config.OPENAI_URL,
this.params.config.OPENAI_PATH,
this.params.config.OPENAI_KEY,
this.params.config.OPENAI_MODEL,
this.params.config.locale,
this.params.stagedDiff.diff,
this.params.config.generate,
this.params.config['max-length'],
this.params.config.type,
this.params.config.timeout,
this.params.config['max-tokens'],
this.params.config.temperature,
this.params.config.prompt,
this.params.config.logging,
this.params.config.proxy
)
).pipe(
concatMap(messages => from(messages)), // flat messages
return fromPromise(this.generateMessage()).pipe(
concatMap(messages => from(messages)),
map(data => ({
name: `${this.serviceName} ${data.title}`,
value: data.value,
Expand Down Expand Up @@ -75,4 +57,30 @@ export class OpenAIService extends AIService {
},
};
}

private async generateMessage(): Promise<CommitMessage[]> {
const diff = this.params.stagedDiff.diff;
const { locale, generate, type } = this.params.config;
const maxLength = this.params.config['max-length'];

const fullText = await generateCommitMessage(
this.params.config.OPENAI_URL,
this.params.config.OPENAI_PATH,
this.params.config.OPENAI_KEY,
this.params.config.OPENAI_MODEL,
locale,
diff,
generate,
maxLength,
type,
this.params.config.timeout,
this.params.config['max-tokens'],
this.params.config.temperature,
this.params.config.prompt,
this.params.config.logging,
this.params.config.proxy
);

return this.sanitizeMessage(fullText, this.params.config.type, generate);
}
}
69 changes: 3 additions & 66 deletions src/utils/openai.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
import http from 'http';
import https from 'https';

import {
type TiktokenModel,
encoding_for_model,
// encoding_for_model,
} from '@dqbd/tiktoken';
import { type TiktokenModel } from '@dqbd/tiktoken';
import createHttpsProxyAgent from 'https-proxy-agent';

import { KnownError } from './error.js';
import { createLogResponse } from './log.js';
import { generateDefaultPrompt, isValidConventionalMessage, isValidGitmojiMessage } from './prompt.js';
import { CommitMessage, ParsedMessage } from '../services/ai/ai.service.js';
import { generateDefaultPrompt } from './prompt.js';

import type { CommitType } from './config.js';
import type { ClientRequest, IncomingMessage } from 'http';
Expand Down Expand Up @@ -151,25 +146,6 @@ const sanitizeMessage = (message: string) =>
.replace(/[\n\r]/g, '')
.replace(/(\w)\.$/, '$1');

export const deduplicateMessages = (array: CommitMessage[]) => Array.from(new Set(array));

const generateStringFromLength = (length: number) => {
let result = '';
const highestTokenChar = 'z';
for (let i = 0; i < length; i += 1) {
result += highestTokenChar;
}
return result;
};

const getTokens = (prompt: string, model: TiktokenModel) => {
const encoder = encoding_for_model(model);
const tokens = encoder.encode(prompt).length;
// Free the encoder to avoid possible memory leaks.
encoder.free();
return tokens;
};

export const generateCommitMessage = async (
url: string,
path: string,
Expand Down Expand Up @@ -223,7 +199,7 @@ export const generateCommitMessage = async (
.map(choice => sanitizeMessage(choice.message!.content as string))
.join();
logging && createLogResponse('OPEN AI', diff, systemPrompt, fullText);
return parseCommitMessage(fullText, type, completions);
return fullText;
} catch (error) {
const errorAsAny = error as any;
if (errorAsAny.code === 'ENOTFOUND') {
Expand All @@ -232,42 +208,3 @@ export const generateCommitMessage = async (
throw errorAsAny;
}
};

const parseCommitMessage = (generatedText: string, type: CommitType, maxCount: number): CommitMessage[] => {
const jsonPattern = /\[[\s\S]*?\]/;

try {
const jsonMatch = generatedText.match(jsonPattern);
if (!jsonMatch) {
// No valid JSON array found in the response
return [];
}
const jsonStr = jsonMatch[0];
const commitMessages: ParsedMessage[] = JSON.parse(jsonStr);
const filtedMessages = commitMessages
.filter(data => {
switch (type) {
case 'conventional':
return isValidConventionalMessage(data.message);
case 'gitmoji':
return isValidGitmojiMessage(data.message);
default:
return true;
}
})
.map((data: ParsedMessage) => {
return {
title: `${data.message}`,
value: data.body ? `${data.message}\n\n${data.body}` : `${data.message}`,
};
});

if (filtedMessages.length > maxCount) {
return filtedMessages.slice(0, maxCount);
}
return filtedMessages;
} catch (e) {
// Error parsing JSON
return [];
}
};
Loading

0 comments on commit e85ae2f

Please sign in to comment.