Skip to content

Commit

Permalink
Update integration test cases (NVIDIA#429)
Browse files Browse the repository at this point in the history
* Update integration test cases

* Fix typo

* Disable example tests before we updated all the examples
  • Loading branch information
YuanTingHsieh authored Apr 23, 2022
1 parent 50e1699 commit 8aa11d8
Show file tree
Hide file tree
Showing 16 changed files with 73 additions and 294 deletions.
4 changes: 2 additions & 2 deletions tests/integration_test/apps/tf/config/config_fed_client.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"train"
],
"executor": {
"path": "trainer.SimpleTrainer",
"path": "tests.integration_test.tf2.trainer.SimpleTrainer",
"args": {
"epochs_per_round": 2
}
Expand All @@ -20,7 +20,7 @@
],
"filters": [
{
"path": "filter.ExcludeVars",
"name": "ExcludeVars",
"args": {
"exclude_vars": [
"flatten"
Expand Down
13 changes: 0 additions & 13 deletions tests/integration_test/apps/tf/custom/__init__.py

This file was deleted.

111 changes: 0 additions & 111 deletions tests/integration_test/apps/tf/custom/filter.py

This file was deleted.

1 change: 1 addition & 0 deletions tests/integration_test/system_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def read_yaml(yaml_file_path):

return data


params = [
# "./test_examples.yml",
"./test_internal.yml"
Expand Down
5 changes: 3 additions & 2 deletions tests/integration_test/test_examples.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
poc: ../../nvflare/poc
n_clients: 2
app_path: ../../examples
jobs_root_dir: ./jobs
apps_root_dir: ../../examples
snapshot_path: /tmp/snapshot-storage
cleanup: True

Expand All @@ -10,7 +11,7 @@ tests:
- tests.integration_test.validators.pt_model_validator.PTModelValidator
- app_name: hello-numpy-cross-val
validators:
- tests.integration_test.validators.cross_val_result_validator.CrossResultValidator
- tests.integration_test.validators.cross_val_result_validator.CrossValResultValidator
- tests.integration_test.validators.sag_result_validator.SAGResultValidator
- app_name: hello-numpy-sag
validators:
Expand Down
42 changes: 21 additions & 21 deletions tests/integration_test/test_internal.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@ cleanup: True


tests:
# - app_name: global_model_eval
# validators:
# - tests.integration_test.validators.global_model_eval_validator.GlobalModelEvalValidator
# - tests.integration_test.validators.sag_result_validator.SAGResultValidator
# - app_name: pt
# validators:
# - tests.integration_test.validators.pt_model_validator.PTModelValidator
# - app_name: cross_val_one_client
# validators:
# - tests.integration_test.validators.cross_val_single_client_validator.CrossValSingleClientValidator
# - tests.integration_test.validators.sag_result_validator.SAGResultValidator
# - app_name: cross_val
# validators:
# - tests.integration_test.validators.cross_val_result_validator.CrossResultValidator
# - tests.integration_test.validators.sag_result_validator.SAGResultValidator
- app_name: global_model_eval
validators:
- tests.integration_test.validators.global_model_eval_validator.GlobalModelEvalValidator
- tests.integration_test.validators.sag_result_validator.SAGResultValidator
- app_name: pt
validators:
- tests.integration_test.validators.pt_model_validator.PTModelValidator
- app_name: cross_val_one_client
validators:
- tests.integration_test.validators.cross_val_result_validator.CrossValSingleClientResultValidator
- tests.integration_test.validators.sag_result_validator.SAGResultValidator
- app_name: cross_val
validators:
- tests.integration_test.validators.cross_val_result_validator.CrossValResultValidator
- tests.integration_test.validators.sag_result_validator.SAGResultValidator
- app_name: sag
validators:
- tests.integration_test.validators.sag_result_validator.SAGResultValidator
Expand All @@ -34,9 +34,9 @@ tests:
- app_name: tb_streaming
validators:
- tests.integration_test.validators.tb_result_validator.TBResultValidator
# - app_name: tf
# validators:
# - tests.integration_test.validators.tf_model_validator.TFModelValidator
# - app_name: cyclic
# validators:
# - tests.integration_test.validators.tf_model_validator.TFModelValidator
- app_name: tf
validators:
- tests.integration_test.validators.tf_model_validator.TFModelValidator
- app_name: cyclic
validators:
- tests.integration_test.validators.tf_model_validator.TFModelValidator
38 changes: 24 additions & 14 deletions tests/integration_test/validators/cross_val_result_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,34 @@

import os

from .app_result_validator import AppResultValidator
from nvflare.app_common.app_constant import AppConstants

from .app_result_validator import AppResultValidator

def check_cross_validation_result(server_data, client_data, run_data):

run_number = run_data["run_number"]
server_dir = server_data["server_path"]
client_names = list(client_data["client_names"])
def check_cross_validation_result(server_data, client_data, run_data, n_clients=-1):
if n_clients != -1:
client_names = [client_data["client_names"][i] for i in range(n_clients)]
else:
client_names = list(client_data["client_names"])

server_run_dir = os.path.join(server_dir, "run_" + str(run_number))
server_run_dir = os.path.join(server_data["server_path"], run_data["job_id"])

if not os.path.exists(server_run_dir):
print(f"check_cross_validation_result: server run dir {server_run_dir} doesn't exist.")
return False

cross_val_dir = os.path.join(server_run_dir, "cross_site_val")
cross_val_dir = os.path.join(server_run_dir, AppConstants.CROSS_VAL_DIR)
if not os.path.exists(cross_val_dir):
print(f"check_cross_validation_result: models dir {cross_val_dir} doesn't exist.")
return False

model_shareable_dir = os.path.join(cross_val_dir, "model_shareables")
model_shareable_dir = os.path.join(cross_val_dir, AppConstants.CROSS_VAL_MODEL_DIR_NAME)
if not os.path.exists(model_shareable_dir):
print(f"check_cross_validation_result: model shareable directory {model_shareable_dir} doesn't exist.")
return False

result_shareable_dir = os.path.join(cross_val_dir, "result_shareables")
result_shareable_dir = os.path.join(cross_val_dir, AppConstants.CROSS_VAL_RESULTS_DIR_NAME)
if not os.path.exists(result_shareable_dir):
print(f"check_cross_validation_result: result shareable directory {result_shareable_dir} doesn't exist.")
return False
Expand Down Expand Up @@ -73,12 +75,8 @@ def check_cross_validation_result(server_data, client_data, run_data):
return True


class CrossResultValidator(AppResultValidator):
def __init__(self):
super(CrossResultValidator, self).__init__()

class CrossValResultValidator(AppResultValidator):
def validate_results(self, server_data, client_data, run_data) -> bool:

cross_val_result = check_cross_validation_result(server_data, client_data, run_data)

print(f"CrossVal Result: {cross_val_result}")
Expand All @@ -87,3 +85,15 @@ def validate_results(self, server_data, client_data, run_data) -> bool:
raise ValueError("Cross val failed.")

return cross_val_result


class CrossValSingleClientResultValidator(AppResultValidator):
def validate_results(self, server_data, client_data, run_data) -> bool:
cross_val_result = check_cross_validation_result(server_data, client_data, run_data, n_clients=1)

print(f"CrossVal Result: {cross_val_result}")

if not cross_val_result:
raise ValueError("Cross val failed.")

return cross_val_result

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@

class FiltersResultValidator(AppResultValidator):
def validate_results(self, server_data, client_data, run_data) -> bool:
server_dir = server_data["server_path"]

server_run_dir = os.path.join(server_dir, run_data["job_id"])
server_run_dir = os.path.join(server_data["server_path"], run_data["job_id"])

if not os.path.exists(server_run_dir):
print(f"FiltersResultValidator: server run dir {server_run_dir} doesn't exist.")
Expand Down
Loading

0 comments on commit 8aa11d8

Please sign in to comment.