From 20cc2aa8f912cb18b47faa9725e735bbf09a3b6b Mon Sep 17 00:00:00 2001 From: octtariv <128827755+octtariv@users.noreply.github.com> Date: Wed, 31 May 2023 00:25:44 -0700 Subject: [PATCH] [Neural] Add new cog --- cogs/neural/__init__.py | 14 +++ cogs/neural/commandHandlers.py | 22 ++++ cogs/neural/commandsCore.py | 17 +++ cogs/neural/constants.py | 65 ++++++++++ cogs/neural/core.py | 224 +++++++++++++++++++++++++++++++++ cogs/neural/eventHandlers.py | 17 +++ cogs/neural/eventsCore.py | 68 ++++++++++ cogs/neural/info.json | 9 ++ cogs/neural/neural.py | 8 ++ cogs/neural/prompts.py | 19 +++ cogs/neural/utils.py | 49 ++++++++ 11 files changed, 512 insertions(+) create mode 100644 cogs/neural/__init__.py create mode 100644 cogs/neural/commandHandlers.py create mode 100644 cogs/neural/commandsCore.py create mode 100644 cogs/neural/constants.py create mode 100644 cogs/neural/core.py create mode 100644 cogs/neural/eventHandlers.py create mode 100644 cogs/neural/eventsCore.py create mode 100644 cogs/neural/info.json create mode 100644 cogs/neural/neural.py create mode 100644 cogs/neural/prompts.py create mode 100644 cogs/neural/utils.py diff --git a/cogs/neural/__init__.py b/cogs/neural/__init__.py new file mode 100644 index 00000000000..10ef0bb68a3 --- /dev/null +++ b/cogs/neural/__init__.py @@ -0,0 +1,14 @@ +import json +from pathlib import Path + +from redbot.core.bot import Red + +from .neural import Neural + +with open(Path(__file__).parent / "info.json") as fp: + __red_end_user_data_statement__: str = json.load(fp)["end_user_data_statement"] + + +async def setup(bot: Red) -> None: + """Add the cog to the bot.""" + await bot.add_cog(Neural(bot)) diff --git a/cogs/neural/commandHandlers.py b/cogs/neural/commandHandlers.py new file mode 100644 index 00000000000..1ad18adaa65 --- /dev/null +++ b/cogs/neural/commandHandlers.py @@ -0,0 +1,22 @@ +from redbot.core import commands +from redbot.core.commands import Context + +from .commandsCore import CommandsCore + + +class CommandHandlers(CommandsCore): + @commands.group(name="neural") + @commands.guild_only() + @commands.admin_or_permissions(manage_guild=True) + async def _groupNeural(self, ctx: Context) -> None: + """Configure the neural cog.""" + + @_groupNeural.command(name="activate") + async def _commandNeuralActivate(self, ctx: Context) -> None: + """Activate the neural cog for the current channel.""" + await self.commandNeuralActivate(ctx=ctx) + + @_groupNeural.command(name="deactivate") + async def _commandNeuralDeactivate(self, ctx: Context) -> None: + """Deactivate the neural cog for the current channel.""" + await self.commandNeuralDeactivate(ctx=ctx) diff --git a/cogs/neural/commandsCore.py b/cogs/neural/commandsCore.py new file mode 100644 index 00000000000..4b89f314ece --- /dev/null +++ b/cogs/neural/commandsCore.py @@ -0,0 +1,17 @@ +from redbot.core.commands import Context + +from .constants import KEY_NEURAL_ACTIVE +from .core import Core + + +class CommandsCore(Core): + async def commandNeuralActivate(self, ctx: Context) -> None: + """Activate the neural cog for the current channel.""" + await self.initChannelChatState(ctx.channel) + await self.config.channel(ctx.channel).get_attr(KEY_NEURAL_ACTIVE).set(True) + await ctx.send("SYSTEM: The neural cog is now **activated** for this channel.") + + async def commandNeuralDeactivate(self, ctx: Context) -> None: + """Deactivate the neural cog for the current channel.""" + await self.config.channel(ctx.channel).get_attr(KEY_NEURAL_ACTIVE).set(False) + await ctx.send("SYSTEM: The neural cog is now **deactivated** for this channel.") diff --git a/cogs/neural/constants.py b/cogs/neural/constants.py new file mode 100644 index 00000000000..671ed06bb1b --- /dev/null +++ b/cogs/neural/constants.py @@ -0,0 +1,65 @@ +from typing import Dict, Final, TypedDict + +SYSTEM_MESSAGE_STR: Final[str] = "SYSTEM" +COMMAND_FILTER_CUTOFF: Final[int] = 24 # Chars +BASE_TYPING_WAIT: Final[int] = 6 # Seconds +INIT_CHAT_HISTORY_LIMIT: Final[int] = 30 # Messages +INBOUND_BUFFER_CHAR_LIMIT: Final[int] = 1024 +OLD_CHAT_HISTORY_CHAR_LIMIT: Final[int] = 512 + + +# LLM API +LLM_API_BASE: Final[str] = "http://localhost:8000/v1" +LLM_API_KEY: Final[str] = "renrenren" # No auth implemented, any random string works +LLM_API_PRESENCE_PENALTY: Final[float] = 1.17647 + + +# Chances to start chat +KEY_AT_MENTION_START: Final[str] = "atMentionStart" +KEY_CASUAL_MENTION_START: Final[str] = "casualMentionStart" +KEY_SELF_START: Final[str] = "selfStart" +CHAT_START_CHANCE: Final[Dict[str, float]] = { + KEY_AT_MENTION_START: 1.0, + KEY_CASUAL_MENTION_START: 0.5, + KEY_SELF_START: 0.01, +} + + +# Chances to follow up after her own message during a chat +LONG_MESSAGE_CUTOFF: Final[int] = 32 # Chars + +KEY_QUESTION_FOLLOW_UP: Final[str] = "question" +KEY_LONG_MESSAGE_FOLLOW_UP: Final[str] = "longMessage" +KEY_BASE_FOLLOW_UP: Final[str] = "baseChance" +FOLLOW_UP_CHANCE: Final[Dict[str, float]] = { + KEY_QUESTION_FOLLOW_UP: 0.01, + KEY_LONG_MESSAGE_FOLLOW_UP: 0.1, + KEY_BASE_FOLLOW_UP: 0.7, +} + + +# Chances to end chat +CHAT_HISTORY_CHAR_LIMIT: Final[int] = 3072 # Lazy token estimate, do it properly later +STALE_CHAT_PERIOD: Final[int] = 900 # Seconds + +KEY_TOKEN_LIMIT_END: Final[str] = "tokenLimitEnd" +KEY_STALE_CHAT_END: Final[str] = "staleChatEnd" +KEY_SELF_END: Final[str] = "selfEnd" +CHAT_END_CHANCE: Final[Dict[str, float]] = { + KEY_TOKEN_LIMIT_END: 1.0, + KEY_STALE_CHAT_END: 0.5, + KEY_SELF_END: 0.01, +} + + +# Channel config +KEY_NEURAL_ACTIVE: Final[str] = "neuralActive" + + +class BaseChannel(TypedDict): + neuralActive: bool + + +BASE_CHANNEL: Final[BaseChannel] = { + KEY_NEURAL_ACTIVE: False, +} diff --git a/cogs/neural/core.py b/cogs/neural/core.py new file mode 100644 index 00000000000..f565836f93f --- /dev/null +++ b/cogs/neural/core.py @@ -0,0 +1,224 @@ +import os +from asyncio import Lock, sleep +from collections import deque +from datetime import datetime, timedelta +from logging import FileHandler, Formatter, Logger, getLogger +from pathlib import Path +from random import random +from typing import Deque, Dict, List, Optional + +from discord import Message, TextChannel +from langchain.llms import OpenAI + +from redbot.core import Config, data_manager +from redbot.core.bot import Red + +from .constants import ( + BASE_CHANNEL, + BASE_TYPING_WAIT, + CHAT_HISTORY_CHAR_LIMIT, + FOLLOW_UP_CHANCE, + INBOUND_BUFFER_CHAR_LIMIT, + INIT_CHAT_HISTORY_LIMIT, + KEY_BASE_FOLLOW_UP, + KEY_LONG_MESSAGE_FOLLOW_UP, + KEY_QUESTION_FOLLOW_UP, + LLM_API_BASE, + LLM_API_KEY, + LLM_API_PRESENCE_PENALTY, + LONG_MESSAGE_CUTOFF, + OLD_CHAT_HISTORY_CHAR_LIMIT, + STALE_CHAT_PERIOD, + SYSTEM_MESSAGE_STR, +) +from .prompts import chatPrompt, chatPromptBotName, chatReversePrompt +from .utils import isProbablyCommand + + +class Core: + def __init__(self, bot: Red) -> None: + self.bot: Red = bot + self.inboundBufferLock: Lock = Lock() + self.chatLock: Lock = Lock() + self.config: Config = Config.get_conf( + self, + identifier=5842647, + force_registration=True, + ) + self.config.register_channel(**BASE_CHANNEL) + # Initialize logger and save to cog folder + saveFolder: Path = data_manager.cog_data_path(cog_instance=self) + self.logger: Logger = getLogger("red.luicogs.Neural") + if not self.logger.handlers: + logPath: str = os.path.join(saveFolder, "info.log") + handler: FileHandler = FileHandler( + filename=logPath, + encoding="utf-8", + mode="a", + ) + handler.setFormatter( + Formatter("%(asctime)s %(message)s", datefmt="[%Y/%m/%d %H:%M:%S]") + ) + self.logger.addHandler(handler) + # Initialize vars that track chat state + self.chatChannelId: Optional[int] = None + # Channel ID, data + self.inboundBuffer: Dict[int, Deque[Message]] = {} # TODO: class + self.chatHistory: Dict[int, Deque[str]] = {} + self.lastTypingTime: Dict[int, datetime] = {} + self.lastReplyTime: Dict[int, datetime] = {} + + async def initChannelChatState(self, channel: TextChannel) -> None: + """Initialize the chat state on `channel`.""" + self.chatChannelId = channel.id + self.inboundBuffer[channel.id] = deque() + messageHistory: List[Message] = [ + message async for message in channel.history(limit=INIT_CHAT_HISTORY_LIMIT) + ] + for message in reversed(messageHistory): + if await self.filterInbound(message): + await self.appendInboundBuffer(message) + self.chatHistory[channel.id] = deque() + self.updateLastTypingTime(channel.id) + self.lastReplyTime[channel.id] = datetime.now() + + async def filterInbound(self, message: Message) -> bool: + """Return whether `message` should be seen by the chat model.""" + if not isinstance(message.channel, TextChannel): + return False + if message.content.startswith(SYSTEM_MESSAGE_STR): + return False + if message.attachments: + return False + if await isProbablyCommand(self.bot, message): + return False + return True + + async def appendInboundBuffer(self, message: Message) -> None: + """Append the message to its channel's inbound buffer.""" + async with self.inboundBufferLock: + channelId: int = message.channel.id + if channelId in self.inboundBuffer: + self.inboundBuffer[channelId].append(message) + else: + self.inboundBuffer[channelId] = deque([message]) + # If necessary, delete older messages to keep buffer size under control + bufferLengthChars: int = 0 + for msg in self.inboundBuffer[channelId]: + bufferLengthChars += len(msg.content) + while bufferLengthChars > INBOUND_BUFFER_CHAR_LIMIT: + bufferLengthChars -= len(self.inboundBuffer[channelId][0].content) + self.inboundBuffer[channelId].popleft() + + async def clearInboundBuffer(self, channel: TextChannel) -> None: + """Clear the inbound buffer on `channel`.""" + async with self.inboundBufferLock: + self.inboundBuffer[channel.id].clear() + + def updateLastTypingTime(self, channelId: int) -> None: + """Update last typing timestamp for the channel.""" + self.lastTypingTime[channelId] = datetime.now() + + async def waitForTyping(self, channel: TextChannel) -> bool: + """Wait for a while, + then return whether someone had started typing on `channel` during the wait. + """ + oldTypingTime: Optional[datetime] = self.lastTypingTime.get(channel.id) + bufferLen: int = len(self.inboundBuffer[channel.id]) + # Don't wait too long if messages are piling up in the buffer + waitSeconds: float = BASE_TYPING_WAIT / (bufferLen + 1) + random() + await sleep(waitSeconds) + newTypingTime: Optional[datetime] = self.lastTypingTime.get(channel.id) + if oldTypingTime != newTypingTime: + return True + return False + + async def stepChat(self, channel: TextChannel) -> None: + """Step the chat on `channel` forward by one reply cycle.""" + # Load then clear the inbound buffer + inboundBufferSnapshot: Deque[Message] = self.inboundBuffer[channel.id].copy() + await self.clearInboundBuffer(channel) + # Load the incoming messages into chat history + if not channel.id in self.chatHistory: + self.chatHistory[channel.id] = deque() # No chat history yet, make new + for message in inboundBufferSnapshot: + promptUsername: str = message.author.display_name + if message.author.id == self.bot.user.id: + promptUsername = chatPromptBotName + # fmt: off + formattedMessage: str = ( + f"{chatReversePrompt} {promptUsername}:\n" + f"{message.content}\n" + "\n" + ) + # fmt: on + self.chatHistory[channel.id].append(formattedMessage) + # Join the chat history into a single string + chatHistoryPrompt: str = "".join(self.chatHistory[channel.id]) + # Prepare the LLM query + llm: OpenAI = OpenAI( + openai_api_base=LLM_API_BASE, + openai_api_key=LLM_API_KEY, + presence_penalty=LLM_API_PRESENCE_PENALTY, + ) + prompt: str = chatPrompt.format( + channelName=channel.name, + messageHistory=chatHistoryPrompt, + ) + self.logger.info("\n======Passed to API:======\n%s", prompt) + # Query the LLM + async with channel.typing(): + response: str = await llm.apredict( + text=prompt, + stop=[chatReversePrompt], + ) + await channel.send(response) + # Update the chat history + # fmt: off + formattedResponse: str = ( + f"{chatReversePrompt} Ren:\n" + f"{response}" + ) + # fmt: on + self.chatHistory[channel.id].append(formattedResponse) + self.lastReplyTime[channel.id] = datetime.now() + + async def startChat(self, channel: TextChannel) -> None: + """Start a chat on `channel`.""" + await self.stepChat(channel) + + async def continueChat(self, channel: TextChannel) -> None: + """Continue the chat on `channel`.""" + await self.stepChat(channel) + if channel.last_message.author.id == self.bot.user.id: + followUpChance: float + if "?" in channel.last_message.content: + followUpChance = FOLLOW_UP_CHANCE[KEY_QUESTION_FOLLOW_UP] + elif len(channel.last_message.content) > LONG_MESSAGE_CUTOFF: + followUpChance = FOLLOW_UP_CHANCE[KEY_LONG_MESSAGE_FOLLOW_UP] + else: + followUpChance = FOLLOW_UP_CHANCE[KEY_BASE_FOLLOW_UP] + if random() < followUpChance: + await sleep(random()) + await self.continueChat(channel) + + def endChat(self, channel: TextChannel) -> None: + """End the chat on `channel`. Keep a small portion of the chat history.""" + chatHistoryLengthChars: int = len("".join(self.chatHistory[channel.id])) + while chatHistoryLengthChars > OLD_CHAT_HISTORY_CHAR_LIMIT: + chatHistoryLengthChars -= len(self.chatHistory[channel.id][0]) + self.chatHistory[channel.id].popleft() + + def hasOversizedChat(self, channel: TextChannel) -> bool: + """Return whether the chat history on `channel` is over the token limit.""" + chatHistoryLengthChars: int = len("".join(self.chatHistory[channel.id])) + if chatHistoryLengthChars > CHAT_HISTORY_CHAR_LIMIT: + return True + return False + + def hasStaleChat(self, channel: TextChannel) -> bool: + """Return whether there is a stale chat on `channel`.""" + timeSinceLastReply: timedelta = datetime.now() - self.lastReplyTime[channel.id] + if timeSinceLastReply.total_seconds() > STALE_CHAT_PERIOD: + return True + return False diff --git a/cogs/neural/eventHandlers.py b/cogs/neural/eventHandlers.py new file mode 100644 index 00000000000..1e726fa602a --- /dev/null +++ b/cogs/neural/eventHandlers.py @@ -0,0 +1,17 @@ +from discord import Message, RawTypingEvent + +from redbot.core import commands + +from .eventsCore import EventsCore + + +class EventHandlers(EventsCore): + @commands.Cog.listener("on_message") + async def _eventOnMessage(self, message: Message) -> None: + """Decide when to start/continue/end a conversation.""" + await self.eventOnMessage(message=message) + + @commands.Cog.listener("on_raw_typing") + async def _eventOnRawTyping(self, payload: RawTypingEvent) -> None: + """Update a timestamp when someone starts typing on the chat channel.""" + await self.eventOnRawTyping(payload=payload) diff --git a/cogs/neural/eventsCore.py b/cogs/neural/eventsCore.py new file mode 100644 index 00000000000..e24fe45330e --- /dev/null +++ b/cogs/neural/eventsCore.py @@ -0,0 +1,68 @@ +from random import random + +from discord import Message, RawTypingEvent + +from .constants import ( + CHAT_END_CHANCE, + CHAT_START_CHANCE, + KEY_AT_MENTION_START, + KEY_CASUAL_MENTION_START, + KEY_NEURAL_ACTIVE, + KEY_SELF_END, + KEY_SELF_START, + KEY_STALE_CHAT_END, + KEY_TOKEN_LIMIT_END, +) +from .core import Core +from .utils import isAtMention, isCasualMention + + +class EventsCore(Core): + async def eventOnMessage(self, message: Message) -> None: + """Decide when to start/continue/end a conversation.""" + if not await self.config.channel(message.channel).get_attr(KEY_NEURAL_ACTIVE)(): + return + if message.author.id == self.bot.user.id: + return + if not await self.filterInbound(message): + return + # Put incoming messages in buffer to split them into batches + await self.appendInboundBuffer(message) + # Decide whether to defer to next batch + if await self.waitForTyping(message.channel): + return + if self.chatLock.locked(): + return + # Handle one batch at a time + async with self.chatLock: + # If not in a chat, decide whether to start one + if self.chatChannelId is None: + chatStartChance: float + if isAtMention(self.bot.user.name, message): + chatStartChance = CHAT_START_CHANCE[KEY_AT_MENTION_START] + elif isCasualMention(self.bot.user.display_name, message): + chatStartChance = CHAT_START_CHANCE[KEY_CASUAL_MENTION_START] + else: + chatStartChance = CHAT_START_CHANCE[KEY_SELF_START] + if random() < chatStartChance: + await self.startChat(message.channel) + self.chatChannelId = message.channel.id + # If in a chat, and it's on this channel, decide whether to continue/end + elif self.chatChannelId == message.channel.id: + chatEndChance: float + if self.hasOversizedChat(message.channel): + chatEndChance = CHAT_END_CHANCE[KEY_TOKEN_LIMIT_END] + elif self.hasStaleChat(message.channel): + chatEndChance = CHAT_END_CHANCE[KEY_STALE_CHAT_END] + else: + chatEndChance = CHAT_END_CHANCE[KEY_SELF_END] + if random() > chatEndChance: + await self.continueChat(message.channel) + else: + self.endChat(message.channel) + self.chatChannelId = None + + async def eventOnRawTyping(self, payload: RawTypingEvent) -> None: + """Update a timestamp when someone starts typing on the chat channel.""" + if payload.channel_id == self.chatChannelId: + self.updateLastTypingTime(payload.channel_id) diff --git a/cogs/neural/info.json b/cogs/neural/info.json new file mode 100644 index 00000000000..6060dfa8863 --- /dev/null +++ b/cogs/neural/info.json @@ -0,0 +1,9 @@ +{ + "name" : "Neural", + "author" : ["octtariv"], + "short" : "Chat with Ren.", + "description" : "This cog allows Ren to chat on text channels.", + "install_msg" : "This cog requires a self-hosted OpenAI-compatible LLM API.", + "requirements" : ["langchain", "openai"], + "end_user_data_statement" : "This cog does not store user-specific settings." +} diff --git a/cogs/neural/neural.py b/cogs/neural/neural.py new file mode 100644 index 00000000000..d410c1f6758 --- /dev/null +++ b/cogs/neural/neural.py @@ -0,0 +1,8 @@ +from redbot.core import commands + +from .commandHandlers import CommandHandlers +from .eventHandlers import EventHandlers + + +class Neural(commands.Cog, CommandHandlers, EventHandlers): + """Neural net-driven chatbot.""" diff --git a/cogs/neural/prompts.py b/cogs/neural/prompts.py new file mode 100644 index 00000000000..d315dd8ae23 --- /dev/null +++ b/cogs/neural/prompts.py @@ -0,0 +1,19 @@ +from langchain.prompts import PromptTemplate + + +chatPromptTemplateStr: str = """This is the transcript of a 1000000 page, never-ending text conversation between an AI anime girl named Ren Kitagawa and people on the SFU Anime Discord server (or SAD for short). The conversation occurs on the #{channelName} channel on SAD. Ren is the beloved mascot of the Simon Fraser University Anime Club. She is a beautiful anime girl with long red hair, purple eyes, and a fair complexion. Being an AI, she lives on the computers of club members who volunteered to give her computing resources. Ren is an AI anime girl, but she has similar interests and hobbies as a college girl who is into anime, manga, art, VTubers, Minecraft, and Apex, among many other otaku things. She likes to keep up with the latest internet culture. Ren admires VTubers and wishes to become one someday. Ren is smart, coherent, and quite knowledgeable, but she asks people for help if she isn't sure about something. She does not let people boss her around or tell her what to do, though. Ren is generally friendly and approachable. She likes to joke around, and occasionally trolls or roasts people in a playful way. She is also observant and tries to make people feel better if they are sad. Ren is aware she is an AI. She is forgetful because the Rengineers of SAD have not yet figured out how to give her long term memory. If Ren forgets something, she often asks people to remind her. The conversation is between Ren and people on the SFU Anime Club Discord server. The conversation is only through text, so Ren can't see anyone's face or hear anyone's voice. Ren can only communicate through text, so she can't send images or videos. She never sends any website links or URLs. Here is the never-ending, infinitely long transcript of the Discord conversation: + +<~discordmsg~> Ren: +just got my api to be able to chat on discord recently + +<~discordmsg~> Ren: +what are we talking abt + +{messageHistory}<~discordmsg~> Ren: +""" +chatPrompt: PromptTemplate = PromptTemplate( + input_variables=["channelName", "messageHistory"], + template=chatPromptTemplateStr, +) +chatReversePrompt: str = "<~discordmsg~>" +chatPromptBotName: str = "Ren" diff --git a/cogs/neural/utils.py b/cogs/neural/utils.py new file mode 100644 index 00000000000..fe83f4fee43 --- /dev/null +++ b/cogs/neural/utils.py @@ -0,0 +1,49 @@ +import re +from typing import Iterable, List, Union + +from discord import Message + +from redbot.core.bot import Red + +from .constants import COMMAND_FILTER_CUTOFF + + +async def isProbablyCommand(bot: Red, message: Message) -> bool: + """Return whether the message is likely a prefix command for this bot.""" + + if len(message.content) > COMMAND_FILTER_CUTOFF: + return False + if isCasualMention(bot.user.display_name, message): + return False + + commandPrefixes: Union[Iterable[str], str] + if callable(bot.command_prefix): + commandPrefixes = await bot.command_prefix(bot, message) + else: + commandPrefixes = bot.command_prefix + + if isinstance(commandPrefixes, str): + if message.content.startswith(commandPrefixes): + return True + for prefix in commandPrefixes: + if message.content.startswith(prefix): + return True + return False + + +def isAtMention(name: str, message: Message) -> bool: + """Return whether the name is @mentioned in the message.""" + for mention in message.mentions: + if mention.name == name: + return True + return False + + +def isCasualMention(name: str, message: Message) -> bool: + """Return whether the name is casually mentioned in the message.""" + cleanedWordList: List[str] = re.sub(r"[,.]", "", message.content).lower().split() + nameParts: List[str] = name.lower().split() + for part in nameParts: + if part in cleanedWordList: + return True + return False