Skip to content

Commit

Permalink
Change log levels for some logs in private (NVIDIA#289)
Browse files Browse the repository at this point in the history
* Change log levels for some logs in private

* Fix issues

* Undo remove
  • Loading branch information
YuanTingHsieh authored Mar 11, 2022
1 parent e10d62c commit b6bab46
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 280 deletions.
129 changes: 18 additions & 111 deletions nvflare/private/fed/client/client_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import logging
import math
import os
import shlex
import subprocess
Expand All @@ -26,6 +25,7 @@
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.utils.common_utils import get_open_ports
from nvflare.fuel.utils.pipe.file_pipe import FilePipe

from .client_status import ClientStatus, get_status_message


Expand Down Expand Up @@ -57,25 +57,22 @@ def start_train(self, client, args, app_root, app_custom_folder, listen_port):
"""
pass

def check_status(self, client):
def check_status(self, client) -> str:
"""To check the status of the running client.
Args:
client: the FL client object
Returns: running FL client status message
Returns:
A client status message
"""
pass

def abort_train(self, client):
"""To abort the running client.
"""To abort the client training.
Args:
client: the FL client object
Returns: N/A
"""
pass

Expand All @@ -84,49 +81,40 @@ def abort_task(self, client):
Args:
client: the FL client object
Returns: N/A
"""
pass

def get_run_info(self):
"""To get the run_info from the InfoCollector.
Returns: current run info
def get_run_info(self) -> dict:
"""Get the run information.
Returns:
A dict of run information.
"""
pass

def get_errors(self):
"""To get the error_info from the InfoCollector.
"""Get the error information.
Returns: current errors
Returns:
A dict of error information.
"""
pass

def reset_errors(self):
"""To reset the error_info for the InfoCollector.
Returns: N/A
"""
"""Reset the error information."""
pass

def send_aux_command(self, shareable: Shareable):
"""To send the aux command to child process.
Args:
shareable: aux message Shareable
Returns: N/A
"""
pass

def cleanup(self):
"""Finalize cleanup."""
"""Cleanup."""
self.pipe.clear()


Expand All @@ -141,11 +129,10 @@ def __init__(self, uid, startup):
startup: startup folder
"""
ClientExecutor.__init__(self, uid, startup)
# self.client = client

self.startup = startup

self.conn_client = None
# self.pool = None

self.listen_port = get_open_ports(1)[0]

Expand All @@ -166,14 +153,6 @@ def create_pipe(self):
return pipe

def start_train(self, client, args, app_root, app_custom_folder, listen_port):
# self.pool = multiprocessing.Pool(processes=1)
# result = self.pool.apply_async(_start_client, (client, args, app_root))

# self.conn_client, child_conn = mp.Pipe()
# process = multiprocessing.Process(target=_start_client, args=(client, args, app_root, child_conn, self.pipe))
# # process = multiprocessing.Process(target=_start_new)
# process.start()

self.listen_port = listen_port

new_env = os.environ.copy()
Expand Down Expand Up @@ -213,12 +192,12 @@ def check_status(self, client):
data = {"command": AdminCommandNames.CHECK_STATUS, "data": {}}
self.conn_client.send(data)
status_message = self.conn_client.recv()
print("check status from process listener......")
self.logger.debug("check status from process listener......")
return status_message
else:
return get_status_message(client.status)
except:
self.logger.error("Check_status execution exception.")
self.logger.error("check_status() execution exception.")
return "execution exception. Please try again."

def get_run_info(self):
Expand Down Expand Up @@ -275,19 +254,9 @@ def send_aux_command(self, shareable: Shareable):
return make_reply(ReturnCode.EXECUTION_EXCEPTION)

def abort_train(self, client):
# if client.status == ClientStatus.CROSS_SITE_VALIDATION:
# # Only aborts cross site validation.
# client.abort()
# elif client.status == ClientStatus.TRAINING_STARTED:
if client.status == ClientStatus.STARTED:
with self.lock:
if client.process:
# if client.platform == 'PT' and client.multi_gpu:
# # kill the sub-process group directly
# os.killpg(os.getpgid(client.process.pid), 9)
# else:
# client.process.terminate()

# kill the sub-process group directly
if self.conn_client:
data = {"command": AdminCommandNames.ABORT, "data": {}}
Expand All @@ -305,8 +274,6 @@ def abort_train(self, client):
client.process.terminate()
self.logger.debug("terminated")

# if self.pool:
# self.pool.terminate()
if self.conn_client:
self.conn_client.close()
self.conn_client = None
Expand Down Expand Up @@ -346,14 +313,8 @@ def wait_training_process_finish(self, client, args, app_root, app_custom_folder
self.conn_client.close()
self.conn_client = None

# # result.get()
# self.pool.close()
# self.pool.join()
# self.pool.terminate()

# Not to run cross_validation in a new process any more.
# Not to run cross_validation in a new process anymore
client.cross_site_validate = False

client.status = ClientStatus.STOPPED

def close(self):
Expand All @@ -362,57 +323,3 @@ def close(self):
self.conn_client.send(data)
self.conn_client = None
self.cleanup()


# class ThreadExecutor(ClientExecutor):
# def __init__(self, client, executor):
# self.client = client
# self.executor = executor

# def start_train(self, client, args, app_root, app_custom_folder, listen_port):
# future = self.executor.submit(lambda p: _start_client(*p), [client, args, app_root])

# def start_mgpu_train(self, client, args, app_root, gpu_number, app_custom_folder, listen_port):
# self.start_train(client, args, app_root)

# def check_status(self, client):
# return get_status_message(self.client.status)

# def abort_train(self, client):
# self.client.train_end = True
# self.client.fitter.train_ctx.ask_to_stop_immediately()
# self.client.fitter.train_ctx.set_prop("early_end", True)
# # self.client.model_manager.close()
# # self.client.status = ClientStatus.TRAINING_STOPPED
# return "Aborting the client..."


def update_client_properties(client, trainer):
# servers = [{t['name']: t['service']} for t in trainer.server_config]
retry_timeout = 30
# if trainer.client_config['retry_timeout']:
# retry_timeout = trainer.client_config['retry_timeout']
client.client_args = trainer.client_config
# client.servers = sorted(servers)[0]
# client.model_manager.federated_meta = {task_name: list() for task_name in tuple(client.servers)}
exclude_vars = trainer.client_config.get("exclude_vars", "dummy")
# client.model_manager.exclude_vars = re.compile(exclude_vars) if exclude_vars else None
# client.model_manager.privacy_policy = trainer.privacy
# client.model_manager.model_reader_writer = trainer.model_reader_writer
# client.model_manager.model_validator = trainer.model_validator
# client.pool = ThreadPool(len(client.servers))
# client.communicator.ssl_args = trainer.client_config
# client.communicator.secure_train = trainer.secure_train
# client.communicator.model_manager = client.model_manager
client.communicator.should_stop = False
client.communicator.retry = int(math.ceil(float(retry_timeout) / 5))
# client.communicator.outbound_filters = trainer.outbound_filters
# client.communicator.inbound_filters = trainer.inbound_filters
client.handlers = trainer.handlers
# client.inbound_filters = trainer.inbound_filters
client.executors = trainer.executors
# client.task_inbound_filters = trainer.task_inbound_filters
# client.task_outbound_filters = trainer.task_outbound_filters
# client.secure_train = trainer.secure_train
client.heartbeat_done = False
# client.fl_ctx = FLContext()
8 changes: 5 additions & 3 deletions nvflare/private/fed/client/client_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,10 @@ def _try_run(self):

# reset to default fetch interval
task_fetch_interval = self.task_fetch_interval
self.log_info(fl_ctx, "fetching task from server ...")
self.log_debug(fl_ctx, "fetching task from server ...")
task = self.engine.get_task_assignment(fl_ctx)
if not task:
self.log_info(fl_ctx, "no task received - will try in {} secs".format(task_fetch_interval))
self.log_debug(fl_ctx, "no task received - will try in {} secs".format(task_fetch_interval))
continue

if task.name == SpecialTaskName.END_RUN:
Expand All @@ -258,7 +258,9 @@ def _try_run(self):
task_data = task.data
if task_data and isinstance(task_data, Shareable):
task_fetch_interval = task_data.get(TaskConstant.WAIT_TIME, self.task_fetch_interval)
self.log_info(fl_ctx, "server asked to try again - will try in {} secs".format(task_fetch_interval))
self.log_debug(
fl_ctx, "server asked to try again - will try in {} secs".format(task_fetch_interval)
)
continue

self.log_info(fl_ctx, "got task assignment: name={}, id={}".format(task.name, task.task_id))
Expand Down
Loading

0 comments on commit b6bab46

Please sign in to comment.