Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Neural] Add new cog #648

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions cogs/neural/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import json
from pathlib import Path

from redbot.core.bot import Red

from .neural import Neural

with open(Path(__file__).parent / "info.json") as fp:
__red_end_user_data_statement__: str = json.load(fp)["end_user_data_statement"]


async def setup(bot: Red) -> None:
"""Add the cog to the bot."""
await bot.add_cog(Neural(bot))
22 changes: 22 additions & 0 deletions cogs/neural/commandHandlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from redbot.core import commands
from redbot.core.commands import Context

from .commandsCore import CommandsCore


class CommandHandlers(CommandsCore):
@commands.group(name="neural")
@commands.guild_only()
@commands.admin_or_permissions(manage_guild=True)
async def _groupNeural(self, ctx: Context) -> None:
"""Configure the neural cog."""

@_groupNeural.command(name="activate")
async def _commandNeuralActivate(self, ctx: Context) -> None:
"""Activate the neural cog for the current channel."""
await self.commandNeuralActivate(ctx=ctx)

@_groupNeural.command(name="deactivate")
async def _commandNeuralDeactivate(self, ctx: Context) -> None:
"""Deactivate the neural cog for the current channel."""
await self.commandNeuralDeactivate(ctx=ctx)
17 changes: 17 additions & 0 deletions cogs/neural/commandsCore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from redbot.core.commands import Context

from .constants import KEY_NEURAL_ACTIVE
from .core import Core


class CommandsCore(Core):
async def commandNeuralActivate(self, ctx: Context) -> None:
"""Activate the neural cog for the current channel."""
await self.initChannelChatState(ctx.channel)
await self.config.channel(ctx.channel).get_attr(KEY_NEURAL_ACTIVE).set(True)
await ctx.send("SYSTEM: The neural cog is now **activated** for this channel.")

async def commandNeuralDeactivate(self, ctx: Context) -> None:
"""Deactivate the neural cog for the current channel."""
await self.config.channel(ctx.channel).get_attr(KEY_NEURAL_ACTIVE).set(False)
await ctx.send("SYSTEM: The neural cog is now **deactivated** for this channel.")
65 changes: 65 additions & 0 deletions cogs/neural/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Dict, Final, TypedDict

SYSTEM_MESSAGE_STR: Final[str] = "SYSTEM"
COMMAND_FILTER_CUTOFF: Final[int] = 24 # Chars
BASE_TYPING_WAIT: Final[int] = 6 # Seconds
INIT_CHAT_HISTORY_LIMIT: Final[int] = 30 # Messages
INBOUND_BUFFER_CHAR_LIMIT: Final[int] = 1024
OLD_CHAT_HISTORY_CHAR_LIMIT: Final[int] = 512


# LLM API
LLM_API_BASE: Final[str] = "http://localhost:8000/v1"
LLM_API_KEY: Final[str] = "renrenren" # No auth implemented, any random string works
LLM_API_PRESENCE_PENALTY: Final[float] = 1.17647


# Chances to start chat
KEY_AT_MENTION_START: Final[str] = "atMentionStart"
KEY_CASUAL_MENTION_START: Final[str] = "casualMentionStart"
KEY_SELF_START: Final[str] = "selfStart"
CHAT_START_CHANCE: Final[Dict[str, float]] = {
KEY_AT_MENTION_START: 1.0,
KEY_CASUAL_MENTION_START: 0.5,
KEY_SELF_START: 0.01,
}


# Chances to follow up after her own message during a chat
LONG_MESSAGE_CUTOFF: Final[int] = 32 # Chars

KEY_QUESTION_FOLLOW_UP: Final[str] = "question"
KEY_LONG_MESSAGE_FOLLOW_UP: Final[str] = "longMessage"
KEY_BASE_FOLLOW_UP: Final[str] = "baseChance"
FOLLOW_UP_CHANCE: Final[Dict[str, float]] = {
KEY_QUESTION_FOLLOW_UP: 0.01,
KEY_LONG_MESSAGE_FOLLOW_UP: 0.1,
KEY_BASE_FOLLOW_UP: 0.7,
}


# Chances to end chat
CHAT_HISTORY_CHAR_LIMIT: Final[int] = 3072 # Lazy token estimate, do it properly later
STALE_CHAT_PERIOD: Final[int] = 900 # Seconds

KEY_TOKEN_LIMIT_END: Final[str] = "tokenLimitEnd"
KEY_STALE_CHAT_END: Final[str] = "staleChatEnd"
KEY_SELF_END: Final[str] = "selfEnd"
CHAT_END_CHANCE: Final[Dict[str, float]] = {
KEY_TOKEN_LIMIT_END: 1.0,
KEY_STALE_CHAT_END: 0.5,
KEY_SELF_END: 0.01,
}


# Channel config
KEY_NEURAL_ACTIVE: Final[str] = "neuralActive"


class BaseChannel(TypedDict):
neuralActive: bool


BASE_CHANNEL: Final[BaseChannel] = {
KEY_NEURAL_ACTIVE: False,
}
224 changes: 224 additions & 0 deletions cogs/neural/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import os
from asyncio import Lock, sleep
from collections import deque
from datetime import datetime, timedelta
from logging import FileHandler, Formatter, Logger, getLogger
from pathlib import Path
from random import random
from typing import Deque, Dict, List, Optional

from discord import Message, TextChannel
from langchain.llms import OpenAI

from redbot.core import Config, data_manager
from redbot.core.bot import Red

from .constants import (
BASE_CHANNEL,
BASE_TYPING_WAIT,
CHAT_HISTORY_CHAR_LIMIT,
FOLLOW_UP_CHANCE,
INBOUND_BUFFER_CHAR_LIMIT,
INIT_CHAT_HISTORY_LIMIT,
KEY_BASE_FOLLOW_UP,
KEY_LONG_MESSAGE_FOLLOW_UP,
KEY_QUESTION_FOLLOW_UP,
LLM_API_BASE,
LLM_API_KEY,
LLM_API_PRESENCE_PENALTY,
LONG_MESSAGE_CUTOFF,
OLD_CHAT_HISTORY_CHAR_LIMIT,
STALE_CHAT_PERIOD,
SYSTEM_MESSAGE_STR,
)
from .prompts import chatPrompt, chatPromptBotName, chatReversePrompt
from .utils import isProbablyCommand


class Core:
def __init__(self, bot: Red) -> None:
self.bot: Red = bot
self.inboundBufferLock: Lock = Lock()
self.chatLock: Lock = Lock()
self.config: Config = Config.get_conf(
self,
identifier=5842647,
force_registration=True,
)
self.config.register_channel(**BASE_CHANNEL)
# Initialize logger and save to cog folder
saveFolder: Path = data_manager.cog_data_path(cog_instance=self)
self.logger: Logger = getLogger("red.luicogs.Neural")
if not self.logger.handlers:
logPath: str = os.path.join(saveFolder, "info.log")
handler: FileHandler = FileHandler(
filename=logPath,
encoding="utf-8",
mode="a",
)
handler.setFormatter(
Formatter("%(asctime)s %(message)s", datefmt="[%Y/%m/%d %H:%M:%S]")
)
self.logger.addHandler(handler)
# Initialize vars that track chat state
self.chatChannelId: Optional[int] = None
# Channel ID, data
self.inboundBuffer: Dict[int, Deque[Message]] = {} # TODO: class
self.chatHistory: Dict[int, Deque[str]] = {}
self.lastTypingTime: Dict[int, datetime] = {}
self.lastReplyTime: Dict[int, datetime] = {}

async def initChannelChatState(self, channel: TextChannel) -> None:
"""Initialize the chat state on `channel`."""
self.chatChannelId = channel.id
self.inboundBuffer[channel.id] = deque()
messageHistory: List[Message] = [
message async for message in channel.history(limit=INIT_CHAT_HISTORY_LIMIT)
]
for message in reversed(messageHistory):
if await self.filterInbound(message):
await self.appendInboundBuffer(message)
self.chatHistory[channel.id] = deque()
self.updateLastTypingTime(channel.id)
self.lastReplyTime[channel.id] = datetime.now()

async def filterInbound(self, message: Message) -> bool:
"""Return whether `message` should be seen by the chat model."""
if not isinstance(message.channel, TextChannel):
return False
if message.content.startswith(SYSTEM_MESSAGE_STR):
return False
if message.attachments:
return False
if await isProbablyCommand(self.bot, message):
return False
return True

async def appendInboundBuffer(self, message: Message) -> None:
"""Append the message to its channel's inbound buffer."""
async with self.inboundBufferLock:
channelId: int = message.channel.id
if channelId in self.inboundBuffer:
self.inboundBuffer[channelId].append(message)
else:
self.inboundBuffer[channelId] = deque([message])
# If necessary, delete older messages to keep buffer size under control
bufferLengthChars: int = 0
for msg in self.inboundBuffer[channelId]:
bufferLengthChars += len(msg.content)
while bufferLengthChars > INBOUND_BUFFER_CHAR_LIMIT:
bufferLengthChars -= len(self.inboundBuffer[channelId][0].content)
self.inboundBuffer[channelId].popleft()

async def clearInboundBuffer(self, channel: TextChannel) -> None:
"""Clear the inbound buffer on `channel`."""
async with self.inboundBufferLock:
self.inboundBuffer[channel.id].clear()

def updateLastTypingTime(self, channelId: int) -> None:
"""Update last typing timestamp for the channel."""
self.lastTypingTime[channelId] = datetime.now()

async def waitForTyping(self, channel: TextChannel) -> bool:
"""Wait for a while,
then return whether someone had started typing on `channel` during the wait.
"""
oldTypingTime: Optional[datetime] = self.lastTypingTime.get(channel.id)
bufferLen: int = len(self.inboundBuffer[channel.id])
# Don't wait too long if messages are piling up in the buffer
waitSeconds: float = BASE_TYPING_WAIT / (bufferLen + 1) + random()
await sleep(waitSeconds)
newTypingTime: Optional[datetime] = self.lastTypingTime.get(channel.id)
if oldTypingTime != newTypingTime:
return True
return False

async def stepChat(self, channel: TextChannel) -> None:
"""Step the chat on `channel` forward by one reply cycle."""
# Load then clear the inbound buffer
inboundBufferSnapshot: Deque[Message] = self.inboundBuffer[channel.id].copy()
await self.clearInboundBuffer(channel)
# Load the incoming messages into chat history
if not channel.id in self.chatHistory:
self.chatHistory[channel.id] = deque() # No chat history yet, make new
for message in inboundBufferSnapshot:
promptUsername: str = message.author.display_name
if message.author.id == self.bot.user.id:
promptUsername = chatPromptBotName
# fmt: off
formattedMessage: str = (
f"{chatReversePrompt} {promptUsername}:\n"
f"{message.content}\n"
"\n"
)
# fmt: on
self.chatHistory[channel.id].append(formattedMessage)
# Join the chat history into a single string
chatHistoryPrompt: str = "".join(self.chatHistory[channel.id])
# Prepare the LLM query
llm: OpenAI = OpenAI(
openai_api_base=LLM_API_BASE,
openai_api_key=LLM_API_KEY,
presence_penalty=LLM_API_PRESENCE_PENALTY,
)
prompt: str = chatPrompt.format(
channelName=channel.name,
messageHistory=chatHistoryPrompt,
)
self.logger.info("\n======Passed to API:======\n%s", prompt)
# Query the LLM
async with channel.typing():
response: str = await llm.apredict(
text=prompt,
stop=[chatReversePrompt],
)
await channel.send(response)
# Update the chat history
# fmt: off
formattedResponse: str = (
f"{chatReversePrompt} Ren:\n"
f"{response}"
)
# fmt: on
self.chatHistory[channel.id].append(formattedResponse)
self.lastReplyTime[channel.id] = datetime.now()

async def startChat(self, channel: TextChannel) -> None:
"""Start a chat on `channel`."""
await self.stepChat(channel)

async def continueChat(self, channel: TextChannel) -> None:
"""Continue the chat on `channel`."""
await self.stepChat(channel)
if channel.last_message.author.id == self.bot.user.id:
followUpChance: float
if "?" in channel.last_message.content:
followUpChance = FOLLOW_UP_CHANCE[KEY_QUESTION_FOLLOW_UP]
elif len(channel.last_message.content) > LONG_MESSAGE_CUTOFF:
followUpChance = FOLLOW_UP_CHANCE[KEY_LONG_MESSAGE_FOLLOW_UP]
else:
followUpChance = FOLLOW_UP_CHANCE[KEY_BASE_FOLLOW_UP]
if random() < followUpChance:
await sleep(random())
await self.continueChat(channel)

def endChat(self, channel: TextChannel) -> None:
"""End the chat on `channel`. Keep a small portion of the chat history."""
chatHistoryLengthChars: int = len("".join(self.chatHistory[channel.id]))
while chatHistoryLengthChars > OLD_CHAT_HISTORY_CHAR_LIMIT:
chatHistoryLengthChars -= len(self.chatHistory[channel.id][0])
self.chatHistory[channel.id].popleft()

def hasOversizedChat(self, channel: TextChannel) -> bool:
"""Return whether the chat history on `channel` is over the token limit."""
chatHistoryLengthChars: int = len("".join(self.chatHistory[channel.id]))
if chatHistoryLengthChars > CHAT_HISTORY_CHAR_LIMIT:
return True
return False

def hasStaleChat(self, channel: TextChannel) -> bool:
"""Return whether there is a stale chat on `channel`."""
timeSinceLastReply: timedelta = datetime.now() - self.lastReplyTime[channel.id]
if timeSinceLastReply.total_seconds() > STALE_CHAT_PERIOD:
return True
return False
17 changes: 17 additions & 0 deletions cogs/neural/eventHandlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from discord import Message, RawTypingEvent

from redbot.core import commands

from .eventsCore import EventsCore


class EventHandlers(EventsCore):
@commands.Cog.listener("on_message")
async def _eventOnMessage(self, message: Message) -> None:
"""Decide when to start/continue/end a conversation."""
await self.eventOnMessage(message=message)

@commands.Cog.listener("on_raw_typing")
async def _eventOnRawTyping(self, payload: RawTypingEvent) -> None:
"""Update a timestamp when someone starts typing on the chat channel."""
await self.eventOnRawTyping(payload=payload)
Loading