diff --git a/nvflare/private/fed/client/client_executor.py b/nvflare/private/fed/client/client_executor.py index 6465122805b..1bec0bf30e7 100644 --- a/nvflare/private/fed/client/client_executor.py +++ b/nvflare/private/fed/client/client_executor.py @@ -13,7 +13,6 @@ # limitations under the License. import logging -import math import os import shlex import subprocess @@ -26,6 +25,7 @@ from nvflare.apis.shareable import Shareable, make_reply from nvflare.apis.utils.common_utils import get_open_ports from nvflare.fuel.utils.pipe.file_pipe import FilePipe + from .client_status import ClientStatus, get_status_message @@ -57,25 +57,22 @@ def start_train(self, client, args, app_root, app_custom_folder, listen_port): """ pass - def check_status(self, client): + def check_status(self, client) -> str: """To check the status of the running client. Args: client: the FL client object - Returns: running FL client status message - + Returns: + A client status message """ pass def abort_train(self, client): - """To abort the running client. + """To abort the client training. Args: client: the FL client object - - Returns: N/A - """ pass @@ -84,34 +81,28 @@ def abort_task(self, client): Args: client: the FL client object - - Returns: N/A - """ pass - def get_run_info(self): - """To get the run_info from the InfoCollector. - - Returns: current run info + def get_run_info(self) -> dict: + """Get the run information. + Returns: + A dict of run information. """ pass def get_errors(self): - """To get the error_info from the InfoCollector. + """Get the error information. - Returns: current errors + Returns: + A dict of error information. """ pass def reset_errors(self): - """To reset the error_info for the InfoCollector. - - Returns: N/A - - """ + """Reset the error information.""" pass def send_aux_command(self, shareable: Shareable): @@ -119,14 +110,11 @@ def send_aux_command(self, shareable: Shareable): Args: shareable: aux message Shareable - - Returns: N/A - """ pass def cleanup(self): - """Finalize cleanup.""" + """Cleanup.""" self.pipe.clear() @@ -141,11 +129,10 @@ def __init__(self, uid, startup): startup: startup folder """ ClientExecutor.__init__(self, uid, startup) - # self.client = client + self.startup = startup self.conn_client = None - # self.pool = None self.listen_port = get_open_ports(1)[0] @@ -166,14 +153,6 @@ def create_pipe(self): return pipe def start_train(self, client, args, app_root, app_custom_folder, listen_port): - # self.pool = multiprocessing.Pool(processes=1) - # result = self.pool.apply_async(_start_client, (client, args, app_root)) - - # self.conn_client, child_conn = mp.Pipe() - # process = multiprocessing.Process(target=_start_client, args=(client, args, app_root, child_conn, self.pipe)) - # # process = multiprocessing.Process(target=_start_new) - # process.start() - self.listen_port = listen_port new_env = os.environ.copy() @@ -213,12 +192,12 @@ def check_status(self, client): data = {"command": AdminCommandNames.CHECK_STATUS, "data": {}} self.conn_client.send(data) status_message = self.conn_client.recv() - print("check status from process listener......") + self.logger.debug("check status from process listener......") return status_message else: return get_status_message(client.status) except: - self.logger.error("Check_status execution exception.") + self.logger.error("check_status() execution exception.") return "execution exception. Please try again." def get_run_info(self): @@ -275,19 +254,9 @@ def send_aux_command(self, shareable: Shareable): return make_reply(ReturnCode.EXECUTION_EXCEPTION) def abort_train(self, client): - # if client.status == ClientStatus.CROSS_SITE_VALIDATION: - # # Only aborts cross site validation. - # client.abort() - # elif client.status == ClientStatus.TRAINING_STARTED: if client.status == ClientStatus.STARTED: with self.lock: if client.process: - # if client.platform == 'PT' and client.multi_gpu: - # # kill the sub-process group directly - # os.killpg(os.getpgid(client.process.pid), 9) - # else: - # client.process.terminate() - # kill the sub-process group directly if self.conn_client: data = {"command": AdminCommandNames.ABORT, "data": {}} @@ -305,8 +274,6 @@ def abort_train(self, client): client.process.terminate() self.logger.debug("terminated") - # if self.pool: - # self.pool.terminate() if self.conn_client: self.conn_client.close() self.conn_client = None @@ -346,14 +313,8 @@ def wait_training_process_finish(self, client, args, app_root, app_custom_folder self.conn_client.close() self.conn_client = None - # # result.get() - # self.pool.close() - # self.pool.join() - # self.pool.terminate() - - # Not to run cross_validation in a new process any more. + # Not to run cross_validation in a new process anymore client.cross_site_validate = False - client.status = ClientStatus.STOPPED def close(self): @@ -362,57 +323,3 @@ def close(self): self.conn_client.send(data) self.conn_client = None self.cleanup() - - -# class ThreadExecutor(ClientExecutor): -# def __init__(self, client, executor): -# self.client = client -# self.executor = executor - -# def start_train(self, client, args, app_root, app_custom_folder, listen_port): -# future = self.executor.submit(lambda p: _start_client(*p), [client, args, app_root]) - -# def start_mgpu_train(self, client, args, app_root, gpu_number, app_custom_folder, listen_port): -# self.start_train(client, args, app_root) - -# def check_status(self, client): -# return get_status_message(self.client.status) - -# def abort_train(self, client): -# self.client.train_end = True -# self.client.fitter.train_ctx.ask_to_stop_immediately() -# self.client.fitter.train_ctx.set_prop("early_end", True) -# # self.client.model_manager.close() -# # self.client.status = ClientStatus.TRAINING_STOPPED -# return "Aborting the client..." - - -def update_client_properties(client, trainer): - # servers = [{t['name']: t['service']} for t in trainer.server_config] - retry_timeout = 30 - # if trainer.client_config['retry_timeout']: - # retry_timeout = trainer.client_config['retry_timeout'] - client.client_args = trainer.client_config - # client.servers = sorted(servers)[0] - # client.model_manager.federated_meta = {task_name: list() for task_name in tuple(client.servers)} - exclude_vars = trainer.client_config.get("exclude_vars", "dummy") - # client.model_manager.exclude_vars = re.compile(exclude_vars) if exclude_vars else None - # client.model_manager.privacy_policy = trainer.privacy - # client.model_manager.model_reader_writer = trainer.model_reader_writer - # client.model_manager.model_validator = trainer.model_validator - # client.pool = ThreadPool(len(client.servers)) - # client.communicator.ssl_args = trainer.client_config - # client.communicator.secure_train = trainer.secure_train - # client.communicator.model_manager = client.model_manager - client.communicator.should_stop = False - client.communicator.retry = int(math.ceil(float(retry_timeout) / 5)) - # client.communicator.outbound_filters = trainer.outbound_filters - # client.communicator.inbound_filters = trainer.inbound_filters - client.handlers = trainer.handlers - # client.inbound_filters = trainer.inbound_filters - client.executors = trainer.executors - # client.task_inbound_filters = trainer.task_inbound_filters - # client.task_outbound_filters = trainer.task_outbound_filters - # client.secure_train = trainer.secure_train - client.heartbeat_done = False - # client.fl_ctx = FLContext() diff --git a/nvflare/private/fed/client/client_runner.py b/nvflare/private/fed/client/client_runner.py index ef256f640dc..e7292b1b6ec 100644 --- a/nvflare/private/fed/client/client_runner.py +++ b/nvflare/private/fed/client/client_runner.py @@ -244,10 +244,10 @@ def _try_run(self): # reset to default fetch interval task_fetch_interval = self.task_fetch_interval - self.log_info(fl_ctx, "fetching task from server ...") + self.log_debug(fl_ctx, "fetching task from server ...") task = self.engine.get_task_assignment(fl_ctx) if not task: - self.log_info(fl_ctx, "no task received - will try in {} secs".format(task_fetch_interval)) + self.log_debug(fl_ctx, "no task received - will try in {} secs".format(task_fetch_interval)) continue if task.name == SpecialTaskName.END_RUN: @@ -258,7 +258,9 @@ def _try_run(self): task_data = task.data if task_data and isinstance(task_data, Shareable): task_fetch_interval = task_data.get(TaskConstant.WAIT_TIME, self.task_fetch_interval) - self.log_info(fl_ctx, "server asked to try again - will try in {} secs".format(task_fetch_interval)) + self.log_debug( + fl_ctx, "server asked to try again - will try in {} secs".format(task_fetch_interval) + ) continue self.log_info(fl_ctx, "got task assignment: name={}, id={}".format(task.name, task.task_id)) diff --git a/nvflare/private/fed/client/communicator.py b/nvflare/private/fed/client/communicator.py index 324dc179a3c..53b5db0d97b 100644 --- a/nvflare/private/fed/client/communicator.py +++ b/nvflare/private/fed/client/communicator.py @@ -27,6 +27,7 @@ from nvflare.apis.fl_constant import FLContextKey from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import FLCommunicationError +from nvflare.private.defs import SpecialTaskName from nvflare.private.fed.utils.fed_utils import make_context_data, make_shareeable_data, shareable_to_modeldata @@ -68,7 +69,8 @@ def set_up_channel(self, channel_dict, token=None): channel_dict: grpc channel parameters token: client token - Returns: an initialised grpc channel + Returns: + An initialised grpc channel """ if self.secure_train: @@ -83,7 +85,7 @@ def set_up_channel(self, channel_dict, token=None): certificate_chain=certificate_chain, private_key=private_key, root_certificates=trusted_certs ) - # make sure that all headers are in lowecase, + # make sure that all headers are in lowercase, # otherwise grpc throws an exception call_credentials = grpc.metadata_call_credentials( lambda context, callback: callback((("x-custom-token", token),), None) @@ -108,7 +110,8 @@ def get_client_state(self, project_name, token, ssid, fl_ctx: FLContext): token: FL client token fl_ctx: FLContext - Returns: a ClientState message + Returns: + A ClientState message """ state_message = fed_msg.ClientState(token=token, ssid=ssid) @@ -148,14 +151,13 @@ def client_registration(self, client_name, servers, project_name): servers: FL servers project_name: FL study project name - Returns: FL token + Returns: + The client's token """ local_ip = self.get_client_ip() login_message = fed_msg.ClientLogin(client_name=client_name, client_ip=local_ip) - # login_message = fed_msg.ClientLogin( - # client_id=None, token=None, client_ip=local_ip) login_message.meta.project.name = project_name result, retry = None, self.retry @@ -194,15 +196,16 @@ def client_registration(self, client_name, servers, project_name): return token, ssid def getTask(self, servers, project_name, token, ssid, fl_ctx: FLContext): - """Get registered with the remote server via channel, and fetch the server's model parameters. + """Get a task from server. Args: servers: FL servers project_name: FL study project name - token: FL client token + token: client token fl_ctx: FLContext - Returns: a CurrentTask message from server + Returns: + A CurrentTask message from server """ global_model, retry = None, self.retry @@ -218,10 +221,6 @@ def getTask(self, servers, project_name, token, ssid, fl_ctx: FLContext): self.should_stop = False end_time = time.time() - self.logger.info( - f"Received from {project_name} server " - f" ({global_model.ByteSize()} Bytes). getTask time: {end_time - start_time} seconds" - ) task = fed_msg.CurrentTask() task.meta.CopyFrom(global_model.meta) @@ -229,6 +228,16 @@ def getTask(self, servers, project_name, token, ssid, fl_ctx: FLContext): task.data.CopyFrom(global_model.data) task.task_name = global_model.task_name + if global_model.task_name == SpecialTaskName.TRY_AGAIN: + self.logger.debug( + f"Received from {project_name} server " + f" ({global_model.ByteSize()} Bytes). getTask time: {end_time - start_time} seconds" + ) + else: + self.logger.info( + f"Received from {project_name} server " + f" ({global_model.ByteSize()} Bytes). getTask time: {end_time - start_time} seconds" + ) return task except grpc.RpcError as grpc_error: self.grpc_error_handler( @@ -256,8 +265,8 @@ def submitUpdate(self, servers, project_name, token, ssid, shareable: execution task result shareable execute_task_name: execution task name - Returns: server message from the server - + Returns: + A FederatedSummary message from the server. """ client_state = self.get_client_state(project_name, token, ssid, fl_ctx) client_state.client_name = client_name @@ -295,7 +304,7 @@ def submitUpdate(self, servers, project_name, token, ssid, def auxCommunicate(self, servers, project_name, token, ssid, fl_ctx: FLContext, client_name, shareable, topic, timeout): - """To send the aux communication message to the server. + """Send the auxiliary communication message to the server. Args: servers: FL servers @@ -307,7 +316,8 @@ def auxCommunicate(self, servers, project_name, token, ssid, topic: aux message topic timeout: aux communication timeout - Returns: server response message + Returns: + An AuxReply message from server """ client_state = self.get_client_state(project_name, token, ssid, fl_ctx) @@ -327,7 +337,7 @@ def auxCommunicate(self, servers, project_name, token, ssid, while retry > 0: try: start_time = time.time() - self.logger.info(f"Send AuxMessage to {project_name} server") + self.logger.debug(f"Send AuxMessage to {project_name} server") server_msg = stub.AuxCommunicate(aux_message, timeout=timeout) # Clear the stopping flag # if the connection to server recovered. @@ -351,7 +361,8 @@ def quit_remote(self, servers, task_name, token, ssid, fl_ctx: FLContext): token: FL client token fl_ctx: FLContext - Returns: server's reply to the last message + Returns: + server's reply to the last message """ server_message, retry = None, self.retry @@ -428,9 +439,6 @@ def grpc_error_handler(self, service, grpc_error, action, start_time, retry, ver start_time: communication start time retry: retry number verbose: verbose to error print out - - Returns: N/A - """ status_code = None if isinstance(grpc_error, grpc.Call): diff --git a/nvflare/private/fed/client/fed_client.py b/nvflare/private/fed/client/fed_client.py index e0875388333..b2873307d9d 100644 --- a/nvflare/private/fed/client/fed_client.py +++ b/nvflare/private/fed/client/fed_client.py @@ -23,9 +23,10 @@ from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable +from nvflare.private.defs import SpecialTaskName from nvflare.private.event import fire_event +from nvflare.private.fed.utils.numproto import proto_to_bytes -from ..utils.numproto import proto_to_bytes from .fed_client_base import FederatedClientBase @@ -83,7 +84,10 @@ def fetch_task(self, fl_ctx: FLContext): pull_success, task_name, remote_tasks = self.pull_task(fl_ctx) fire_event(EventType.AFTER_PULL_TASK, self.handlers, fl_ctx) - self.logger.info(f"pull_task completed. Task name:{task_name} Status:{pull_success} ") + if task_name == SpecialTaskName.TRY_AGAIN: + self.logger.debug(f"pull_task completed. Task name:{task_name} Status:{pull_success} ") + else: + self.logger.info(f"pull_task completed. Task name:{task_name} Status:{pull_success} ") return pull_success, task_name, remote_tasks def extract_shareable(self, responses, fl_ctx: FLContext): diff --git a/nvflare/private/fed/client/fed_client_base.py b/nvflare/private/fed/client/fed_client_base.py index 2d7fecfa20a..f79a1e8e369 100644 --- a/nvflare/private/fed/client/fed_client_base.py +++ b/nvflare/private/fed/client/fed_client_base.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""The client of the federated training process.""" - import logging import threading from functools import partial @@ -35,7 +33,7 @@ class FederatedClientBase: - """Federated client-side base implementation. + """The client-side base implementation of federated learning. This class provide the tools function which will be used in both FedClient and FedClientLite. """ @@ -91,6 +89,7 @@ def __init__( self.engine = None self.status = ClientStatus.NOT_STARTED + self.remote_tasks = None self.sp_established = False self.overseer_agent = overseer_agent @@ -141,13 +140,10 @@ def _switch_ssid(self): self.logger.info(f"Primary SP switched to new SSID: {self.ssid}") def client_register(self, project_name): - """Register the client to the FL server and get the FL token. + """Register the client to the FL server. Args: - project_name: FL server project name - - Returns: N/A - + project_name: FL study project name. """ if not self.token: try: @@ -157,46 +153,42 @@ def client_register(self, project_name): self.fl_ctx.set_prop(FLContextKey.CLIENT_NAME, self.client_name, private=False) self.fl_ctx.set_prop(EngineConstant.FL_TOKEN, self.token, private=False) self.logger.info( - "Successfully registered client:{} for {}. Token:{} SSID:{}".format( + "Successfully registered client:{} for project {}. Token:{} SSID:{}".format( self.client_name, project_name, self.token, self.ssid ) ) - # fire_event(EventType.CLIENT_REGISTER, self.handlers, self.fl_ctx) - except FLCommunicationError as e: + except FLCommunicationError: self.communicator.heartbeat_done = True def fetch_execute_task(self, project_name, fl_ctx: FLContext): - """Get registered with the remote server via channel, and fetch the server's model parameters. + """Fetch a task from the server. Args: project_name: FL study project name fl_ctx: FLContext - Returns: a CurrentTask message from server - + Returns: + A CurrentTask message from server """ try: - self.logger.info("Starting to fetch execute task.") + self.logger.debug("Starting to fetch execute task.") task = self.communicator.getTask(self.servers, project_name, self.token, self.ssid, fl_ctx) return task except FLCommunicationError as e: self.logger.info(e) - # self.communicator.heartbeat_done = True def push_execute_result(self, project_name, shareable: Shareable, fl_ctx: FLContext): - """Read local model and push to self.server[task_name] channel. - - This function makes and sends a Contribution Message. + """Submit execution results of a task to server. Args: project_name: FL study project name shareable: Shareable object fl_ctx: FLContext - Returns: reply message - + Returns: + A FederatedSummary message from the server. """ try: self.logger.info("Starting to push execute result.") @@ -209,12 +201,9 @@ def push_execute_result(self, project_name, shareable: Shareable, fl_ctx: FLCont return message except FLCommunicationError as e: self.logger.info(e) - # self.communicator.heartbeat_done = True def send_aux_message(self, project_name, topic: str, shareable: Shareable, timeout: float, fl_ctx: FLContext): - """Read local model and push to self.server[task_name] channel. - - This function makes and sends a Contribution Message. + """Send auxiliary message to the server. Args: project_name: FL study project name @@ -223,11 +212,11 @@ def send_aux_message(self, project_name, topic: str, shareable: Shareable, timeo timeout: communication timeout fl_ctx: FLContext - Returns: reply message - + Returns: + A reply message """ try: - self.logger.info("Starting to send aux messagee.") + self.logger.debug("Starting to send aux message.") message = self.communicator.auxCommunicate( self.servers, project_name, self.token, self.ssid, fl_ctx, self.client_name, shareable, topic, timeout) @@ -235,7 +224,6 @@ def send_aux_message(self, project_name, topic: str, shareable: Shareable, timeo return message except FLCommunicationError as e: self.logger.info(e) - # self.communicator.heartbeat_done = True def send_heartbeat(self, project_name): try: diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index 4973718f86f..e0d4e90d155 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -366,13 +366,7 @@ def Quit(self, request, context): return fed_msg.FederatedSummary(comment="Removed client") def GetTask(self, request, context): - """Process client's request.""" - # # fl_ctx = self.fl_ctx.clone_sticky() - # if not self.run_manager: - # context.abort(grpc.StatusCode.OUT_OF_RANGE, "Server training stopped") - - # if self.server_runner is None: - # context.abort(grpc.StatusCode.OUT_OF_RANGE, "Server has stopped") + """Process client's get task request.""" with self.engine.new_context() as fl_ctx: state_check = self.server_state.get_task(fl_ctx) @@ -383,25 +377,20 @@ def GetTask(self, request, context): if client is None: context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Client not valid.") - self.logger.info(f"Fetch task requested from client: {client.name} ({client.get_token()})") + self.logger.debug(f"Fetch task requested from client: {client.name} ({client.get_token()})") token = client.get_token() engine = fl_ctx.get_engine() - # shared_fl_ctx = FLContext() - # shared_fl_ctx.set_run_number(request.meta.run_number) shared_fl_ctx = pickle.loads(proto_to_bytes(request.context["fl_context"])) fl_ctx.set_peer_context(shared_fl_ctx) with self.lock: - # shareable = self.model_manager.get_shareable(self.fl_ctx) - if self.server_runner is None or engine is None or self.engine.run_manager is None: self.logger.info("server has no current run - asked client to end the run") taskname = SpecialTaskName.END_RUN task_id = "" shareable = None else: - # taskname, task_id, shareable = self.controller.process_task_request(client, fl_ctx) taskname, task_id, shareable = self.server_runner.process_task_request(client, fl_ctx) if shareable is None: @@ -421,32 +410,24 @@ def GetTask(self, request, context): current_model = shareable_to_modeldata(shareable, fl_ctx) task.data.CopyFrom(current_model) - self.logger.info(f"Return task:{taskname} to client:{client.name} --- ({token}) ") - - # self.fl_ctx.merge_sticky(fl_ctx) + if taskname == SpecialTaskName.TRY_AGAIN: + self.logger.debug(f"GetTask: Return task: {taskname} to client: {client.name} ({token}) ") + else: + self.logger.info(f"GetTask: Return task: {taskname} to client: {client.name} ({token}) ") return task def SubmitUpdate(self, request, context): """Handle client's submission of the federated updates.""" - # if not self.run_manager: - # context.abort(grpc.StatusCode.OUT_OF_RANGE, "Server has stopped") - if self.server_runner is None or self.engine.run_manager is None: - # context.abort(grpc.StatusCode.OUT_OF_RANGE, "Server has stopped") self.logger.info("ignored result submission since Server Engine isn't ready") context.abort(grpc.StatusCode.OUT_OF_RANGE, "Server has stopped") - # fl_ctx = self.fl_ctx.clone_sticky() with self.engine.new_context() as fl_ctx: state_check = self.server_state.submit_result(fl_ctx) self._handle_state_check(context, state_check) self._ssid_check(request.client, context) - # if self.status == ServerStatus.TRAINING_STOPPED or self.status == ServerStatus.TRAINING_NOT_STARTED: - # context.abort(grpc.StatusCode.OUT_OF_RANGE, "Server training stopped") - # return - contribution = request client = self.client_manager.validate_client(contribution.client, context) @@ -461,7 +442,6 @@ def SubmitUpdate(self, request, context): shareable = shareable.from_bytes(proto_to_bytes(request.data.params["data"])) shared_fl_context = pickle.loads(proto_to_bytes(request.data.params["fl_context"])) - # fl_ctx.set_prop(FLContextKey.PEER_CONTEXT, shared_fl_context) fl_ctx.set_peer_context(shared_fl_context) shared_fl_context.set_prop(FLContextKey.SHAREABLE, shareable, private=False) @@ -482,14 +462,9 @@ def SubmitUpdate(self, request, context): time_seconds or "less than 1", ) - # fire_event(EventType.BEFORE_PROCESS_SUBMISSION, self.handlers, fl_ctx) - - # task_id = shared_fl_context.get_cookie(FLContextKey.TASK_ID) task_id = shareable.get_cookie(FLContextKey.TASK_ID) self.server_runner.process_submission(client, contribution_task_name, task_id, shareable, fl_ctx) - # fire_event(EventType.AFTER_PROCESS_SUBMISSION, self.handlers, fl_ctx) - response_comment = "Received from {} ({} Bytes, {} seconds)".format( contribution.client.client_name, contribution.ByteSize(), @@ -498,17 +473,10 @@ def SubmitUpdate(self, request, context): summary_info = fed_msg.FederatedSummary(comment=response_comment) summary_info.meta.CopyFrom(self.task_meta_info) - # with self.lock: - # self.fl_ctx.merge_sticky(fl_ctx) - return summary_info def AuxCommunicate(self, request, context): """Handle auxiliary channel communication.""" - # if not self.run_manager: - # context.abort(grpc.StatusCode.OUT_OF_RANGE, "Server has stopped") - - # fl_ctx = self.fl_ctx.clone_sticky() with self.engine.new_context() as fl_ctx: state_check = self.server_state.aux_communicate(fl_ctx) self._handle_state_check(context, state_check) @@ -541,8 +509,6 @@ def AuxCommunicate(self, request, context): shared_fl_context.set_prop(FLContextKey.SHAREABLE, shareable, private=False) topic = shareable.get_header(ReservedHeaderKey.TOPIC) - # aux_runner = fl_ctx.get_aux_runner() - # assert isinstance(aux_runner, ServerAuxRunner) reply = self.engine.dispatch(topic=topic, request=shareable, fl_ctx=fl_ctx) aux_reply = fed_msg.AuxReply() @@ -576,9 +542,7 @@ def Retrieve(self, request, context): messages = self.admin_server.get_outgoing_requests(client_token=client_name) if self.admin_server else [] response = admin_msg.Messages() - # response.message.CopyFrom(messages) for m in messages: - # message = response.message.add() response.message.append(message_to_proto(m)) return response @@ -602,8 +566,6 @@ def SendResult(self, request, context): return response def start_run(self, run_number, run_root, conf, args, snapshot): - # self.status = ServerStatus.STARTING - # Create the FL Engine workspace = Workspace(args.workspace, "server", args.config_folder) self.run_manager = RunManager( @@ -628,9 +590,6 @@ def start_run(self, run_number, run_root, conf, args, snapshot): if snapshot: self.engine.restore_components(snapshot=snapshot, fl_ctx=FLContext()) - # with open(os.path.join(run_root, env_config)) as file: - # env = json.load(file) - fl_ctx.set_prop(FLContextKey.APP_ROOT, run_root, sticky=True) fl_ctx.set_prop(FLContextKey.CURRENT_RUN, run_number, private=False, sticky=True) fl_ctx.set_prop(FLContextKey.WORKSPACE_ROOT, args.workspace, private=True, sticky=True) @@ -642,17 +601,10 @@ def start_run(self, run_number, run_root, conf, args, snapshot): self.run_manager.add_handler(self.server_runner) self.run_manager.add_component("_Server_Runner", self.server_runner) - # self.controller.initialize_run(self.fl_ctx) - - # return super().start() - # self.status = ServerStatus.STARTED - # self.run_engine() engine_thread = threading.Thread(target=self.run_engine) - # heartbeat_thread.daemon = True engine_thread.start() while self.engine.engine_info.status != MachineStatus.STOPPED: - # self.remove_dead_clients() if self.engine.asked_to_stop: self.engine.abort_app_on_server() diff --git a/nvflare/private/fed/server/server_runner.py b/nvflare/private/fed/server/server_runner.py index 2968293880c..d02b88dbb8b 100644 --- a/nvflare/private/fed/server/server_runner.py +++ b/nvflare/private/fed/server/server_runner.py @@ -46,7 +46,7 @@ def __init__( task_request_interval (int): Task request interval in seconds workflows (list): A list of workflow task_data_filters (dict): A dict of {task_name: list of filters apply to data (pre-process)} - task_result_filters (dict): A dict of {task_name: list of filters apply to result (post-process} + task_result_filters (dict): A dict of {task_name: list of filters apply to result (post-process)} handlers (list, optional): A list of event handlers components (dict, optional): A dict of extra python objects {id: object} """ @@ -204,10 +204,10 @@ def process_task_request(self, client: Client, fl_ctx: FLContext) -> (str, str, if not isinstance(engine, ServerEngineSpec): raise TypeError("engine must be ServerEngineSpec but got {}".format(type(engine))) - self.log_info(fl_ctx, "got task request from client") + self.log_debug(fl_ctx, "process task request from client") if self.status == "init": - self.log_info(fl_ctx, "server runner still initializing - asked client to try again later") + self.log_debug(fl_ctx, "server runner still initializing - asked client to try again later") return self._task_try_again() if self.status == "done": @@ -216,25 +216,25 @@ def process_task_request(self, client: Client, fl_ctx: FLContext) -> (str, str, peer_ctx = fl_ctx.get_peer_context() if not isinstance(peer_ctx, FLContext): - self.log_error(fl_ctx, "invalid task request: no peer context, asked client to try again.") + self.log_debug(fl_ctx, "invalid task request: no peer context - asked client to try again later") return self._task_try_again() peer_run_num = peer_ctx.get_run_number() if not peer_run_num or peer_run_num != self.run_num: - # the client is on a different RUN - self.log_info(fl_ctx, "invalid task request: not the same run_number, asked client to end run") + # the client is in a different RUN + self.log_info(fl_ctx, "invalid task request: not the same run_number - asked client to end the run") return SpecialTaskName.END_RUN, "", None try: with self.wf_lock: if self.current_wf is None: - self.log_info(fl_ctx, "There's no current workflow - asked client to try again later") + self.log_info(fl_ctx, "no current workflow - asked client to try again later") return self._task_try_again() task_name, task_id, task_data = self.current_wf.responder.process_task_request(client, fl_ctx) if not task_name or task_name == SpecialTaskName.TRY_AGAIN: - self.log_info(fl_ctx, "no task currently for client - asked client to try again later") + self.log_debug(fl_ctx, "no task currently for client - asked client to try again later") return self._task_try_again() if task_data: @@ -242,39 +242,7 @@ def process_task_request(self, client: Client, fl_ctx: FLContext) -> (str, str, self.log_error( fl_ctx, "bad task data generated by workflow {}: must be Shareable but got {}".format( - type(self.current_wf.id), type(task_data) - ), - ) - return self._task_try_again() - else: - task_data = Shareable() - - task_data.set_header(ReservedHeaderKey.TASK_ID, task_id) - task_data.set_header(ReservedHeaderKey.TASK_NAME, task_name) - task_data.add_cookie(ReservedHeaderKey.WORKFLOW, self.current_wf.id) - - if task_data: - if not isinstance(task_data, Shareable): - self.log_error( - fl_ctx, - "bad task data generated by workflow {}: must be Shareable but got {}".format( - type(self.current_wf.id), type(task_data) - ), - ) - return self._task_try_again() - else: - task_data = Shareable() - - task_data.set_header(ReservedHeaderKey.TASK_ID, task_id) - task_data.set_header(ReservedHeaderKey.TASK_NAME, task_name) - task_data.add_cookie(ReservedHeaderKey.WORKFLOW, self.current_wf.id) - - if task_data: - if not isinstance(task_data, Shareable): - self.log_error( - fl_ctx, - "bad task data generated by workflow {}: must be Shareable but got {}".format( - type(self.current_wf.id), type(task_data) + self.current_wf.id, type(task_data) ), ) return self._task_try_again() @@ -342,37 +310,23 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul # set the reply prop so log msg context could include RC from it fl_ctx.set_prop(FLContextKey.REPLY, result, private=True, sticky=False) - if not isinstance(result, Shareable): - self.log_error(fl_ctx, "invalid result submission: must be Shareable but got {}".format(type(result))) - return - - # set the reply prop so log msg context could include RC from it - fl_ctx.set_prop(FLContextKey.REPLY, result, private=True, sticky=False) - - if not isinstance(result, Shareable): - self.log_error(fl_ctx, "invalid result submission: must be Shareable but got {}".format(type(result))) - return - - # set the reply prop so log msg context could include RC from it - fl_ctx.set_prop(FLContextKey.REPLY, result, private=True, sticky=False) - fl_ctx.set_prop(FLContextKey.TASK_NAME, value=task_name, private=True, sticky=False) fl_ctx.set_prop(FLContextKey.TASK_RESULT, value=result, private=True, sticky=False) fl_ctx.set_prop(FLContextKey.TASK_ID, value=task_id, private=True, sticky=False) if self.status != "started": - self.log_info(fl_ctx, "ignored result submission since server runner is {}".format(self.status)) + self.log_info(fl_ctx, "ignored result submission since server runner's status is {}".format(self.status)) return peer_ctx = fl_ctx.get_peer_context() if not isinstance(peer_ctx, FLContext): - self.log_error(fl_ctx, "invalid result submission: no peer context; dropped.") + self.log_info(fl_ctx, "invalid result submission: no peer context - dropped") return peer_run_num = peer_ctx.get_run_number() if not peer_run_num or peer_run_num != self.run_num: # the client is on a different RUN - self.log_error(fl_ctx, "invalid result submission: not the same run number; dropped") + self.log_info(fl_ctx, "invalid result submission: not the same run number - dropped") return result.set_header(ReservedHeaderKey.TASK_NAME, task_name) @@ -397,7 +351,7 @@ def process_submission(self, client: Client, task_name: str, task_id: str, resul with self.wf_lock: try: if self.current_wf is None: - self.log_info(fl_ctx, "There's no current workflow - dropped submission.") + self.log_info(fl_ctx, "no current workflow - dropped submission.") return wf_id = result.get_cookie(ReservedHeaderKey.WORKFLOW, None)