Skip to content

Commit

Permalink
DESIGN CHANGE: SchedulerConfigProviderBase to provide structured and …
Browse files Browse the repository at this point in the history
…validated configuration to scheduler, avoid global config
  • Loading branch information
pedohorse committed Jun 27, 2024
1 parent 0ec26c0 commit ecb1b79
Show file tree
Hide file tree
Showing 24 changed files with 664 additions and 223 deletions.
24 changes: 12 additions & 12 deletions src/lifeblood/basenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@ def _ui_changed(self, definition_changed=False):
# # this may also apply to _ui_changed above, but nodes really SHOULD NOT change their own parameters during processing
# asyncio.get_event_loop().create_task(self.__parent.node_reports_changes_needs_saving(self.__parent_nid))

def _process_task_wrapper(self, task_dict) -> ProcessingResult:
def _process_task_wrapper(self, task_dict, node_config) -> ProcessingResult:
# with self.get_ui().lock_interface_readonly(): # TODO: this is bad, RETHINK!
# TODO: , in case threads do l1---r1 - release2 WILL leave lock in locked state forever, as it remembered it at l2
# TODO: l2---r2
return self.process_task(ProcessingContext(self, task_dict))
return self.process_task(ProcessingContext(self, task_dict, node_config))

def process_task(self, context: ProcessingContext) -> ProcessingResult:
"""
Expand All @@ -182,9 +182,9 @@ def process_task(self, context: ProcessingContext) -> ProcessingResult:
"""
raise NotImplementedError()

def _postprocess_task_wrapper(self, task_dict) -> ProcessingResult:
def _postprocess_task_wrapper(self, task_dict, node_config) -> ProcessingResult:
# with self.get_ui().lock_interface_readonly(): #TODO: read comment for _process_task_wrapper
return self.postprocess_task(ProcessingContext(self, task_dict))
return self.postprocess_task(ProcessingContext(self, task_dict, node_config))

def postprocess_task(self, context: ProcessingContext) -> ProcessingResult:
"""
Expand Down Expand Up @@ -279,9 +279,9 @@ def __init__(self, name: str):
ui.add_parameter('worker gpu mem cost', 'min <memory (GBs)> preferred', NodeParameterType.FLOAT, 0.0).set_value_limits(value_min=0)
ui.add_parameter('worker gpu mem cost preferred', None, NodeParameterType.FLOAT, 0.0).set_value_limits(value_min=0)

def __apply_requirements(self, task_dict: dict, result: ProcessingResult):
def __apply_requirements(self, task_dict: dict, node_config: dict, result: ProcessingResult):
if result.invocation_job is not None:
context = ProcessingContext(self, task_dict)
context = ProcessingContext(self, task_dict, node_config)
raw_groups = context.param_value('worker groups').strip()
reqs = result.invocation_job.requirements()
if raw_groups != '':
Expand All @@ -302,13 +302,13 @@ def __apply_requirements(self, task_dict: dict, result: ProcessingResult):
result.invocation_job.set_priority(context.param_value('priority adjustment'))
return result

def _process_task_wrapper(self, task_dict) -> ProcessingResult:
result = super(BaseNodeWithTaskRequirements, self)._process_task_wrapper(task_dict)
return self.__apply_requirements(task_dict, result)
def _process_task_wrapper(self, task_dict, node_config) -> ProcessingResult:
result = super(BaseNodeWithTaskRequirements, self)._process_task_wrapper(task_dict, node_config)
return self.__apply_requirements(task_dict, node_config, result)

def _postprocess_task_wrapper(self, task_dict) -> ProcessingResult:
result = super(BaseNodeWithTaskRequirements, self)._postprocess_task_wrapper(task_dict)
return self.__apply_requirements(task_dict, result)
def _postprocess_task_wrapper(self, task_dict, node_config) -> ProcessingResult:
result = super(BaseNodeWithTaskRequirements, self)._postprocess_task_wrapper(task_dict, node_config)
return self.__apply_requirements(task_dict, node_config, result)


# class BaseNodeWithEnvironmentRequirements(BaseNode):
Expand Down
22 changes: 16 additions & 6 deletions src/lifeblood/main_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .scheduler import Scheduler
from .basenode_serializer_v1 import NodeSerializerV1
from .basenode_serializer_v2 import NodeSerializerV2
from .scheduler_config_provider_file import SchedulerConfigProviderFileOverrides
from .defaults import scheduler_port as default_scheduler_port, ui_port as default_ui_port
from .config import get_config, create_default_user_config_file, get_local_scratch_path
from . import logging
Expand Down Expand Up @@ -78,15 +79,24 @@ def create_default_scheduler(db_file_path, *,
helpers_minimal_idle_to_ensure=1,
server_addr: Optional[Tuple[str, int, int]] = None,
server_ui_addr: Optional[Tuple[str, int]] = None) -> Scheduler:
legacy_addr = None
message_addr = None
if server_addr is not None:
legacy_addr = (server_addr[0], server_addr[1])
message_addr = (server_addr[0], server_addr[2])
config = SchedulerConfigProviderFileOverrides(
main_db_location=db_file_path,
do_broadcast=do_broadcasting,
broadcast_interval=broadcast_interval,
minimal_idle_helpers=helpers_minimal_idle_to_ensure,
legacy_server_address=legacy_addr,
message_processor_address=message_addr,
ui_address=server_ui_addr,
)
return Scheduler(
db_file_path,
scheduler_config_provider=config,
node_data_provider=PluginNodeDataProvider(),
node_serializers=[NodeSerializerV2(), NodeSerializerV1()],
do_broadcasting=do_broadcasting,
broadcast_interval=broadcast_interval,
helpers_minimal_idle_to_ensure=helpers_minimal_idle_to_ensure,
server_addr=server_addr,
server_ui_addr=server_ui_addr,
)


Expand Down
26 changes: 13 additions & 13 deletions src/lifeblood/processingcontext.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from types import MappingProxyType
import re

from .attribute_serialization import deserialize_attributes_core
from .config import get_config
from .environment_resolver import EnvironmentResolverArguments

from typing import TYPE_CHECKING, Dict, Optional
from typing import Dict, Optional, TYPE_CHECKING, Union

if TYPE_CHECKING:
from .basenode import BaseNode
Expand Down Expand Up @@ -44,28 +42,30 @@ def __getattr__(self, item):
raise AttributeError(f'node has no field {item}')

class ConfigWrapper:
def __init__(self, node_type_id):
self.__config = get_config('scheduler.nodes')
self.__scheduler_globals = dict(get_config('scheduler').get_option_noasync('scheduler.globals', {}))
self.__nodetypeid = node_type_id
def __init__(self, node_config):
self.__node_config = node_config

def get(self, key, default=None):
return self.__config.get_option_noasync(f'{self.__nodetypeid}.{key}',
self.__scheduler_globals.get(key,
default))
return self.__node_config.get(key, default)

def __getitem__(self, item):
return self.get(item)

def __init__(self, node: "BaseNode", task_dict: dict):
def __init__(self, node: "BaseNode", task_dict: dict, node_config: Dict[str, Union[str, int, float, list, dict]]):
"""
All information node can access during processing.
This is read-only.
All modifications are to be done through ProcessingResult
:param node_config: extra mapping that node can access through parameter expressions
"""
task_dict = dict(task_dict)
self.__task_attributes = deserialize_attributes_core(task_dict.get('attributes', '{}'))
self.__task_dict = task_dict
self.__task_wrapper = ProcessingContext.TaskWrapper(task_dict)
self.__node_wrapper = ProcessingContext.NodeWrapper(node, self)
sanitized_name = re.sub(r'\W', lambda m: f'x{ord(m.group(0))}', node.type_name())
self.__env_args = EnvironmentResolverArguments.deserialize(task_dict.get('environment_resolver_data')) if task_dict.get('environment_resolver_data') is not None else None
self.__conf_wrapper = ProcessingContext.ConfigWrapper(sanitized_name)
self.__conf_wrapper = ProcessingContext.ConfigWrapper(node_config)
self.__node = node

def param_value(self, param_name: str):
Expand Down
90 changes: 16 additions & 74 deletions src/lifeblood/scheduler/data_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
import sqlite3
import random
import struct
from dataclasses import dataclass, field
from dataclasses import dataclass
from ..attribute_serialization import serialize_attributes
from ..db_misc import sql_init_script
from ..expiring_collections import ExpiringValuesSetMap
from ..config import get_config
from ..enums import TaskState, InvocationState
from ..worker_metadata import WorkerMetadata
from ..logging import get_logger
from ..shared_lazy_sqlite_connection import SharedLazyAiosqliteConnection
from .. import aiosqlite_overlay
from ..environment_resolver import EnvironmentResolverArguments
from ..scheduler_config_provider_base import SchedulerConfigProviderBase

from typing import Any, Dict, Iterable, List, Optional, Tuple, Type

Expand All @@ -40,95 +40,43 @@ class TaskSpawnData:
environment_resolver_arguments: Optional[EnvironmentResolverArguments]


@dataclass
class WorkerResourceDefinition:
name: str
type: Type
description: str
label: str # nicer looking user facing name


class DataAccess:
def __init__(self, db_path, db_connection_timeout):
def __init__(
self,
*,
config_provider: SchedulerConfigProviderBase
):
self.__logger = get_logger('scheduler.data_access')
self.db_path: str = db_path
self.db_timeout: int = db_connection_timeout

config = get_config('scheduler')
self.__db_path: str = config_provider.main_database_location()
self.__db_timeout: float = config_provider.main_database_connection_timeout()

# "public" members
self.mem_cache_workers_resources: dict = {}
self.mem_cache_workers_state: dict = {}
self.__mem_cache_invocations: dict = {}
#

# resource definitions
# TODO: load resource definitions from config
config_resources = config.get_option_noasync('resource_definitions.per_machine', None)
if config_resources is None: # use default resource definitions
self.__worker_resource_definitions: Tuple[WorkerResourceDefinition, ...] = (
WorkerResourceDefinition('cpu_count',
float,
'CPU core count',
'CPU count'),
WorkerResourceDefinition('cpu_mem',
int,
'RAM amount in bytes',
'RAM'),
WorkerResourceDefinition('gpu_count',
float,
'number of GPUs',
'GPU count'), # TODO: get rid of these in defaults when devices are implemented
WorkerResourceDefinition('gpu_mem',
int,
'combined GPU memory in bytes',
'GPU mem'),
)
else:
if not isinstance(config_resources, dict):
raise RuntimeError('bad config schema: resource_definitions.per_machine must be a mapping') # TODO: turn into config schema error or smth
conf_2_type_mapping = {
'int': int,
'float': float,
'number': float,
}
res_defs = []
for res_name, res_data in config_resources.items():
if res_name.startswith('total_'):
raise RuntimeError('resource name cannot start with "total_"') # TODO: turn into config schema error or smth
res_type = conf_2_type_mapping.get(res_data.get('type').lower(), None)
if res_type is None:
raise RuntimeError('resource type may be one of "int", "float", "number"') # TODO: turn into config schema error or smth
res_defs.append(WorkerResourceDefinition(
res_name,
res_type,
res_data.get('description', ''),
res_data.get('label', res_name),
))
self.__worker_resource_definitions: Tuple[WorkerResourceDefinition, ...] = tuple(res_defs)
#

self.__task_blocking_values: Dict[int, int] = {}
# on certain submission errors we might want to ban hwid for some time, as it can be assumed
# that consecutive submission attempts will result in the same error (like package resolution error)
self.__banned_hwids_per_task: ExpiringValuesSetMap = ExpiringValuesSetMap()
self.__ban_time = config.get_option_noasync('data_access.hwid_ban_timeout', 10)
self.__ban_time = config_provider.hardware_ban_timeout()

self.__workers_metadata: Dict[int, WorkerMetadata] = {}
#
# ensure database is initialized
with sqlite3.connect(db_path) as con:
with sqlite3.connect(self.__db_path) as con:
con.executescript(sql_init_script)
# update resource table straight away
# for now the logic is to keep existing columns
with sqlite3.connect(db_path) as con:
with sqlite3.connect(self.__db_path) as con:
con.row_factory = sqlite3.Row
cur = con.execute('PRAGMA table_info(resources)')
resource_rows = {x['name']: x for x in cur.fetchall() if x['name'] != 'hwid'}
cur.close()

need_commit = False
for res_def in self.__worker_resource_definitions:
for res_def in config_provider.hardware_resource_definitions():
col_type, col_def = {
int: ('INTEGER', 0),
float: ('INTEGER', 0),
Expand All @@ -146,7 +94,7 @@ def __init__(self, db_path, db_connection_timeout):
if need_commit:
con.commit()

with sqlite3.connect(db_path) as con:
with sqlite3.connect(self.__db_path) as con:
con.row_factory = sqlite3.Row
cur = con.execute('SELECT * FROM lifeblood_metadata')
metadata = cur.fetchone() # there should be exactly one single row.
Expand Down Expand Up @@ -204,12 +152,6 @@ async def create_task(self, newtask: TaskSpawnData, *, con: Optional[aiosqlite.C
new_id = newcur.lastrowid
return new_id

def get_worker_resource_definitions(self) -> Tuple[WorkerResourceDefinition, ...]:
"""
get definitions of generic resources that workers can have
"""
return self.__worker_resource_definitions

async def housekeeping(self):
"""
i don't like this explicit cleanup
Expand Down Expand Up @@ -452,10 +394,10 @@ def set_worker_metadata(self, worker_hwid, data: WorkerMetadata):

def data_connection(self) -> aiosqlite_overlay.ConnectionWithCallbacks:
# TODO: con.row_factory = aiosqlite.Row must be here, ALMOST all places use it anyway, need to prune
return aiosqlite_overlay.connect(self.db_path, timeout=self.db_timeout, pragmas_after_connect=('synchronous=NORMAL',))
return aiosqlite_overlay.connect(self.__db_path, timeout=self.__db_timeout, pragmas_after_connect=('synchronous=NORMAL',))

def lazy_data_transaction(self, key_name: str):
return SharedLazyAiosqliteConnection(None, self.db_path, key_name, timeout=self.db_timeout)
return SharedLazyAiosqliteConnection(None, self.__db_path, key_name, timeout=self.__db_timeout)

async def write_back_cache(self):
self.__logger.info('pinger syncing temporary tables back...')
Expand Down
17 changes: 9 additions & 8 deletions src/lifeblood/scheduler/pinger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from .. import logging
from ..worker_messsage_processor import WorkerControlClient
from ..enums import WorkerState, InvocationState, WorkerPingState, WorkerPingReply
from ..ui_protocol_data import TaskDelta
from .scheduler_component_base import SchedulerComponentBase
from ..config import get_config
from ..net_messages.address import AddressChain
from ..net_messages.exceptions import MessageTransferError, MessageTransferTimeoutError

Expand All @@ -17,15 +15,18 @@


class Pinger(SchedulerComponentBase):
def __init__(self, scheduler: "Scheduler"):
def __init__(
self,
scheduler: "Scheduler",
):
super().__init__(scheduler)
self.__pinger_logger = logging.get_logger('scheduler.worker_pinger')
config = get_config('scheduler')

self.__ping_interval = config.get_option_noasync('scheduler.pinger.ping_interval', 10) # interval for active workers (workers doing work)
self.__ping_idle_interval = config.get_option_noasync('scheduler.pinger.ping_idle_interval', 30) # interval for idle workers
self.__ping_off_interval = config.get_option_noasync('scheduler.pinger.ping_off_interval', 60) # interval for off/errored workers (not really used since workers need to report back first)
self.__dormant_mode_ping_interval_multiplier = config.get_option_noasync('scheduler.pinger.dormant_ping_multiplier', 5)
self.__ping_interval, self.__ping_idle_interval, self.__ping_off_interval, self.__dormant_mode_ping_interval_multiplier = self.scheduler.config_provider.ping_intervals()
#config.get_option_noasync('scheduler.pinger.ping_interval', 10) # interval for active workers (workers doing work)
#config.get_option_noasync('scheduler.pinger.ping_idle_interval', 30) # interval for idle workers
#config.get_option_noasync('scheduler.pinger.ping_off_interval', 60) # interval for off/errored workers (not really used since workers need to report back first)
#config.get_option_noasync('scheduler.pinger.dormant_ping_multiplier', 5)
self.__ping_interval_mult = 1

def _main_task(self):
Expand Down
Loading

0 comments on commit ecb1b79

Please sign in to comment.