Skip to content

Commit

Permalink
Tests run with PASSED (NVIDIA#427)
Browse files Browse the repository at this point in the history
* Tests run with PASSED

* Server/client/overseer working

* Action triggered

* Log shows correct behavior

* Fix site launcher stop server

* With some hacks, the ha test run passed

* Resolve four conflicts from new dev-2.1 commits

* Rename yaml file

* Make server restart later

* Remove threads

* Test code exits without being stuck

* Resolve admin controller conflicts

* Add back internal tests

* Fix issues after merging with dev-2.1
  • Loading branch information
IsaacYangSLA authored Apr 27, 2022
1 parent 0aa5858 commit d019dbd
Show file tree
Hide file tree
Showing 10 changed files with 792 additions and 77 deletions.
4 changes: 2 additions & 2 deletions nvflare/poc/admin/startup/fed_admin_HA.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
"path": "nvflare.ha.overseer_agent.HttpOverseerAgent",
"args": {
"role": "admin",
"overseer_end_point": "http://localhost:6000/api/v1",
"overseer_end_point": "http://localhost:5000/api/v1",
"project": "example_project",
"name": "admin"
}
}
}
}
}
27 changes: 25 additions & 2 deletions nvflare/poc/client/startup/fed_client_HA.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,33 @@
"overseer_agent": {
"path": "nvflare.ha.overseer_agent.HttpOverseerAgent",
"args": {
"overseer_end_point": "http://localhost:6000/api/v1",
"overseer_end_point": "http://localhost:5000/api/v1",
"project": "example_project",
"role": "client",
"name": "site1"
}
}
},
"components": [
{
"id": "resource_manager",
"path": "nvflare.apis.impl.list_resource_manager.ListResourceManager",
"args": {
"resources": {
"gpu": [
0,
1,
2,
3
]
}
}
},
{
"id": "resource_consumer",
"path": "nvflare.apis.impl.gpu_resource_consumer.GPUResourceConsumer",
"args": {
"gpu_resource_key": "gpu"
}
}
]
}
2 changes: 1 addition & 1 deletion nvflare/poc/overseer/startup/start.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
#!/usr/bin/env bash
FLASK_APP=nvflare.ha.overseer.overseer flask run --host=localhost --port=6000
FLASK_APP=nvflare.ha.overseer.overseer flask run --host=localhost --port=5000
38 changes: 36 additions & 2 deletions nvflare/poc/server/startup/fed_server_HA.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,47 @@
"overseer_agent": {
"path": "nvflare.ha.overseer_agent.HttpOverseerAgent",
"args": {
"overseer_end_point": "http://localhost:6000/api/v1",
"overseer_end_point": "http://localhost:5000/api/v1",
"project": "example_project",
"role": "server",
"name": "localhost",
"fl_port": "8002",
"admin_port": "8003",
"heartbeat_interval": 6
}
}
},
"components": [
{
"id": "job_scheduler",
"path": "nvflare.apis.impl.job_scheduler.DefaultJobScheduler",
"args": {
"max_jobs": 4
}
},
{
"id": "job_manager",
"path": "nvflare.apis.impl.job_def_manager.SimpleJobDefManager",
"args": {
"uri_root": "/tmp/jobs-storage",
"job_store_id": "job_store"
}
},
{
"id": "job_store",
"name": "FilesystemStorage",
"args": {}
},
{
"id": "study_manager",
"path": "nvflare.apis.impl.study_manager.StudyManager",
"args": {
"study_store_id": "study_store"
}
},
{
"id": "study_store",
"path": "nvflare.app_common.storages.filesystem_storage.FilesystemStorage",
"args": {}
}
]
}
231 changes: 228 additions & 3 deletions tests/integration_test/admin_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.

import logging
import re
import time

from nvflare.ha.overseer_agent import HttpOverseerAgent
from nvflare.fuel.hci.client.api_status import APIStatus
from nvflare.fuel.hci.client.fl_admin_api import FLAdminAPI
from nvflare.fuel.hci.client.fl_admin_api_constants import FLDetailKey
Expand All @@ -23,7 +25,7 @@


class AdminController:
def __init__(self, jobs_root_dir, poll_period=10):
def __init__(self, jobs_root_dir, ha, poll_period=10):
"""
This class runs an app on a given server and clients.
"""
Expand All @@ -32,10 +34,17 @@ def __init__(self, jobs_root_dir, poll_period=10):
self.jobs_root_dir = jobs_root_dir
self.poll_period = poll_period

if ha:
overseer_agent = HttpOverseerAgent(
role="admin", overseer_end_point="http://127.0.0.1:5000/api/v1", project="example_project", name="admin"
)
else:
overseer_agent = DummyOverseerAgent(sp_end_point="localhost:8002:8003")

self.admin_api: FLAdminAPI = FLAdminAPI(
upload_dir=self.jobs_root_dir,
download_dir=self.jobs_root_dir,
overseer_agent=DummyOverseerAgent(sp_end_point="localhost:8002:8003"),
overseer_agent=overseer_agent,
poc=True,
debug=False,
user_name="admin",
Expand Down Expand Up @@ -73,6 +82,9 @@ def get_run_data(self):

return run_data

def get_stats(self, target):
return self.admin_api.show_stats(self.job_id, target)

def ensure_clients_started(self, num_clients):
if not self.admin_api:
return False
Expand Down Expand Up @@ -132,7 +144,7 @@ def submit_job(self, job_name) -> bool:
if response["status"] != APIStatus.SUCCESS:
raise RuntimeError(f"submit_job failed: {response}")
self.job_id = response["details"]["job_id"]

self.last_job_name = job_name
return True

def wait_for_job_done(self):
Expand All @@ -152,6 +164,219 @@ def wait_for_job_done(self):
continue
training_done = True

def run_app_ha(self, site_launcher, ha_test):
run_state = {"workflow": None, "task": None, "round_number": None, "run_finished": None}

last_read_line = 0
event = 0
ha_events = ha_test["events"]
event_test_status = [False for _ in range(len(ha_events))] # whether event has been successfully triggered

i = 0
training_done = False
while not training_done:
i += 1

server_logs = self.admin_api.cat_target(TargetType.SERVER, file="log.txt")["details"][
"message"
].splitlines()[last_read_line:]
last_read_line = len(server_logs) + last_read_line
server_logs_string = "\n".join(server_logs)

stats = self.get_stats(TargetType.SERVER)

# update run_state
changed, wfs, run_state = self.process_stats(stats, run_state)

if changed or i % (10 / self.poll_period) == 0:
i = 0
print("STATS: ", stats)
self.print_state(ha_test, run_state)

# check if event is triggered -> then execute the corresponding actions
if event <= len(ha_events) - 1 and not event_test_status[event]:
event_trigger = []

if isinstance(ha_events[event]["trigger"], dict):
for k, v in ha_events[event]["trigger"].items():
if k == "workflow":
print(run_state)
print(wfs)
event_trigger.append(run_state[k] == wfs[v][0])
else:
event_trigger.append(run_state[k] == v)
elif isinstance(ha_events[event]["trigger"], str) and ha_events[event]["trigger"] in server_logs_string:
event_trigger.append(True)

if event_trigger and all(event_trigger):
print(f"EVENT TRIGGERED: {ha_events[event]['trigger']}")
event_test_status[event] = True
self.execute_actions(site_launcher, ha_events[event]["actions"])
continue

response = self.admin_api.check_status(target_type=TargetType.SERVER)
if response and "status" in response and response["status"] != APIStatus.SUCCESS:
print("NO ACTIVE SERVER!")

elif (
response and "status" in response and "details" in response and response["status"] == APIStatus.SUCCESS
):

# compare run_state to expected result_state from the test case
if event <= len(ha_events) - 1 and event_test_status[event] and response["status"] == APIStatus.SUCCESS:
result_state = ha_events[event]["result_state"]
if any(list(run_state.values())):
if result_state == "unchanged":
result_state = ha_events[event]["trigger"]
for k, v in result_state.items():
if k == "workflow":
print(f"ASSERT Current {k}: {run_state[k]} == Expected {k}: {wfs[v][0]}")
assert run_state[k] == wfs[v][0]
else:
print(f"ASSERT Current {k}: {run_state[k]} == Expected {k}: {v}")
assert run_state[k] == v
print("\n")
event += 1

# check if run is stopped
if (
FLDetailKey.SERVER_ENGINE_STATUS in response["details"]
and response["details"][FLDetailKey.SERVER_ENGINE_STATUS] == "stopped"
):
response = self.admin_api.check_status(target_type=TargetType.CLIENT)
if response["status"] != APIStatus.SUCCESS:
print(f"CHECK status failed: {response}")
for row in response["details"]["client_statuses"]:
if row[3] != "stopped":
continue
training_done = True
time.sleep(self.poll_period)

assert all(event_test_status), "Test failed: not all test events were triggered"

def execute_actions(self, site_launcher, actions):
for action in actions:
tokens = action.split(" ")
command = tokens[0]
args = tokens[1:]

print(f"ACTION: {action}")

if command == "sleep":
time.sleep(int(args[0]))

elif command == "kill":
if args[0] == "server":
active_server_id = site_launcher.get_active_server_id(self.admin_api.port)
site_launcher.stop_server(active_server_id)
elif args[0] == "overseer":
site_launcher.stop_overseer()
elif args[0] == "client": # TODO fix client kill & restart during run
if len(args) == 2:
client_id = int(args[1])
else:
client_id = list(site_launcher.client_properties.keys())[0]
self.admin_api.remove_client([site_launcher.client_properties[client_id]["name"]])
site_launcher.stop_client(client_id)

elif command == "restart":
if args[0] == "server":
if len(args) == 2:
server_id = int(args[1])
else:
print(site_launcher.server_properties)
server_id = list(site_launcher.server_properties.keys())[0]
site_launcher.start_server()
elif args[0] == "overseer":
site_launcher.start_overseer()
elif args[0] == "client": # TODO fix client kill & restart during run
if len(args) == 2:
client_id = int(args[1])
else:
client_id = list(site_launcher.client_properties.keys())[0]
site_launcher.start_client(client_id)

def print_state(self, ha_test, state):
print("\n")
print(f"Job name: {self.last_job_name}")
print(f"HA test case: {ha_test['name']}")
print("-" * 30)
for k, v in state.items():
print(f"{k}: {v}")
print("-" * 30 + "\n")

def process_stats(self, stats, run_state):
# extract run_state from stats
# {'status': <APIStatus.SUCCESS: 'SUCCESS'>,
# 'details': {
# 'message': {
# 'ScatterAndGather': {
# 'tasks': {'train': []},
# 'phase': 'train',
# 'current_round': 0,
# 'num_rounds': 2},
# 'CrossSiteModelEval':
# {'tasks': {}},
# 'ServerRunner': {
# 'run_number': 1,
# 'status': 'started',
# 'workflow': 'scatter_and_gather'
# }
# }
# },
# 'raw': {'time': '2022-04-04 15:13:09.367350', 'data': [{'type': 'dict', 'data': {'ScatterAndGather': {'tasks': {'train': []}, 'phase': 'train', 'current_round': 0, 'num_rounds': 2}, 'CrossSiteModelEval': {'tasks': {}}, 'ServerRunner': {'run_number': 1, 'status': 'started', 'workflow': 'scatter_and_gather'}}}], 'status': <APIStatus.SUCCESS: 'SUCCESS'>}}
wfs = {}
prev_run_state = run_state.copy()
print(f"run_state {prev_run_state}", flush=True)
print(f"stats {stats}", flush=True)
if stats and "status" in stats and "details" in stats and stats["status"] == APIStatus.SUCCESS:
if "message" in stats["details"]:
wfs = stats["details"]["message"]
if wfs:
run_state["workflow"], run_state["task"], run_state["round_number"] = None, None, None
for item in wfs:
if wfs[item].get("tasks"):
run_state["workflow"] = item
run_state["task"] = list(wfs[item].get("tasks").keys())[0]
if "current_round" in wfs[item]:
run_state["round_number"] = wfs[item]["current_round"]
# if wfs[item].get("workflow"):
# workflow = wfs[item].get("workflow")

if stats["status"] == APIStatus.SUCCESS:
run_state["run_finished"] = "ServerRunner" not in wfs.keys()
else:
run_state["run_finished"] = False

wfs = [wf for wf in list(wfs.items()) if "tasks" in wf[1]]

return run_state != prev_run_state, wfs, run_state

def process_logs(self, logs, run_state):
# regex to extract run_state from logs

prev_run_state = run_state.copy()

# matches latest instance of "wf={workflow}," or "wf={workflow}]"
match = re.search("(wf=)([^\,\]]+)(\,|\])(?!.*(wf=)([^\,\]]+)(\,|\]))", logs)
if match:
run_state["workflow"] = match.group(2)

# matches latest instance of "task_name={validate}, or "task_name={validate}"
match = re.search("(task_name=)([^\,\]]+)(\,|\])(?!.*(task_name=)([^\,\]]+)(\,|\]))", logs)
if match:
run_state["task"] = match.group(2)

# matches latest instance of "Round {0-999} started."
match = re.search(
"Round ([0-9]|[1-9][0-9]|[1-9][0-9][0-9]) started\.(?!.*Round ([0-9]|[1-9][0-9]|[1-9][0-9][0-9]) started\.)",
logs,
)
if match:
run_state["round_number"] = int(match.group(1))

return run_state != prev_run_state, run_state

def finalize(self):
self.admin_api.overseer_agent.end()
self.admin_api.shutdown(target_type=TargetType.ALL)
Loading

0 comments on commit d019dbd

Please sign in to comment.