Skip to content

Commit

Permalink
Add Tensorboard ability to read from different folders (microsoft#112)
Browse files Browse the repository at this point in the history
* Update run_tensorboard to work with logs in different folders

* fix flake8 complaints

* fix test and changelog

* move types-requests to test_requirements for running mypy

* update test

* move tensorboard example into docs

* shorten command line arg names

* Address PR comments
  • Loading branch information
mebristo committed Sep 21, 2021
1 parent 19b081a commit 46b2d3f
Show file tree
Hide file tree
Showing 17 changed files with 250 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check_changelog.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
if [[ "$changed_files" =~ CHANGELOG\.md ]]
then
changelog=$(cat ./CHANGELOG.md)
if [[ "$changelog" =~ https://github\.com/microsoft/InnerEye-DeepLearning/pull/$PR_NUMBER ]]
if [[ "$changelog" =~ https://github\.com/microsoft/hi-ml/pull/$PR_NUMBER ]]
then
echo "Changelog has been updated and contains the PR number."
else
Expand Down
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ created.
## Upcoming

### Added
- ([#111](https://github.com/microsoft/InnerEye-DeepLearning/pull/111)) Adding changelog. Displaying changelog in sphinx docu. Ensure changelog is updated.
- ([#111](https://github.com/microsoft/hi-ml/pull/111)) Adding changelog. Displaying changelog in sphinx docu. Ensure changelog is updated.

### Changed

- ([#112](https://github.com/microsoft/hi-ml/pull/112)) Update himl_tensorboard to work with files not in 'logs' directory
### Fixed

### Removed
Expand Down
26 changes: 14 additions & 12 deletions docs/source/commandline_tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,37 @@ From the command line, run the command
```himl-tb```

specifying one of
`[--experiment_name] [--latest_run_file] [--run_recovery_ids]`
`[--experiment] [--latest_run_file] [--run_recovery_ids] [--run_ids]`

This will start a TensorBoard session, by default running on port 6006. To use an alternative port, specify this with `--port`.

If `--experiment_name` is provided, the most recent Run from this experiment will be visualised.
If `--experiment` is provided, the most recent Run from this experiment will be visualised.
If `--latest_run_file` is provided, the script will expect to find a RunId in this file.
Alternatively you can specify the Runs to visualise via `--run_recovery_ids` or `--run_ids`.
You can specify the location where TensorBoard logs will be stored, using the `--run_logs_dir` argument.

If you choose to specify `--experiment_name`, you can also specify `--num_runs` to view and/or `--tags` to filter by.
By default, this tool expects that your TensorBoard logs live in a folder named 'logs' and will create a similarly named folder in your root directory. If your TensorBoard logs are stored elsewhere, you can specify this with the `--log_dir` argument.

If your AML config path is not ROOT_DIR/config.json, you must also specify `--config_path`.
If you choose to specify `--experiment`, you can also specify `--num_runs` to view and/or `--tags` to filter by.

If your AML config path is not ROOT_DIR/config.json, you must also specify `--config_file`.

To see an example of how to create TensorBoard logs using PyTorch on AML, see the
[AML submitting script](examples/9/aml_sample.rst) which submits the following [pytorch sample script](examples/9/pytorch_sample.rst). Note that to run this, you'll need to create an environment with pytorch and tensorboard as dependencies, as a minimum. See an [example conda environemnt](examples/9/tensorboard_env.rst). This will create an experiment named 'tensorboard_test' on your Workspace, with a single run. Go to outputs + logs -> outputs to see the tensorboard events file.
## Download files from AML Runs

From the command line, run the command

```himl-download```

specifying one of
`[--experiment_name] [--latest_run_file] [--run_recovery_ids] [--run_ids]`
`[--experiment] [--latest_run_file] [--run_recovery_ids] [--run_ids]`

If `--experiment_name` is provided, the most recent Run from this experiment will be downloaded.
If `--experiment` is provided, the most recent Run from this experiment will be downloaded.
If `--latest_run_file` is provided, the script will expect to find a RunId in this file.
Alternatively you can specify the Runs to download via `--run_recovery_ids` or `--run_ids`.

The files associated with your Run(s) will be downloaded to the location specified with `--output_dir` (by default ROOT_DIR/outputs)
Alternatively you can specify the Run to download via `--run_recovery_ids` or `--run_ids`.

If you choose to specify `--experiment_name`, you can also specify `--num_runs` to view and/or `--tags` to filter by.
The files associated with your Run will be downloaded to the location specified with `--output_dir` (by default ROOT_DIR/outputs)

If your AML config path is not `ROOT_DIR/config.json`, you must also specify `--config_path`.
If you choose to specify `--experiment`, you can also specify `--tags` to filter by.

If your AML config path is not `ROOT_DIR/config.json`, you must also specify `--config_file`.
24 changes: 24 additions & 0 deletions docs/source/examples/9/aml_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from azureml.core import Environment, Experiment, ScriptRunConfig, Workspace


def main() -> None:
ws = Workspace.from_config("config.json")
experiment = Experiment(ws, "tensorboard_test")
config = ScriptRunConfig(
source_directory='.',
script="pytorch_sample.py",
compute_target="<name of compute target>"
)
env = Environment.from_conda_specification("TensorboardTestEnv", "tensorboard_env.yml")
config.run_config.environment = env

run = experiment.submit(config)
run.wait_for_completion()


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions docs/source/examples/9/aml_sample.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.. literalinclude:: aml_sample.py
:language: python
32 changes: 32 additions & 0 deletions docs/source/examples/9/pytorch_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# ------------------------------------------------------------------------------------------
# Adapted from the example at https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html
# ------------------------------------------------------------------------------------------
from pathlib import Path
import torch
from torch.utils.tensorboard import SummaryWriter


def main() -> None:
log_dir = Path("outputs")
log_dir.mkdir(exist_ok=True)
writer = SummaryWriter(log_dir=str(log_dir))

x = torch.arange(-20, 20, 0.1).view(-1, 1)
y = -2 * x + 0.1 * torch.randn(x.size())

model = torch.nn.Linear(1, 1)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

for epoch in range(10):
y1 = model(x)
loss = criterion(y1, y)
writer.add_scalar("Loss/train", loss, epoch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
writer.flush()


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions docs/source/examples/9/pytorch_sample.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.. literalinclude:: pytorch_sample.py
:language: python
2 changes: 2 additions & 0 deletions docs/source/examples/9/tensorboard_env.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.. literalinclude:: tensorboard_env.yml
:language: yaml
10 changes: 10 additions & 0 deletions docs/source/examples/9/tensorboard_env.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: TensorboardTestEnv
channels:
- defaults
- pytorch
dependencies:
- pip=20.1.1
- python=3.7.3
- pytorch=1.4.0
- pip:
- tensorboard==2.2.1
4 changes: 2 additions & 2 deletions src/health/azure/azure_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,13 +438,13 @@ def determine_run_id_source(args: Namespace) -> AzureRunIdSource:
"""
if "latest_run_file" in args and args.latest_run_file is not None:
return AzureRunIdSource.LATEST_RUN_FILE
if "experiment_name" in args and args.experiment_name is not None:
if "experiment" in args and args.experiment is not None:
return AzureRunIdSource.EXPERIMENT_LATEST
if "run_recovery_ids" in args and args.run_recovery_ids is not None:
return AzureRunIdSource.RUN_RECOVERY_ID
if "run_ids" in args and args.run_ids is not None:
return AzureRunIdSource.RUN_ID
raise ValueError("One of latest_run_file, experiment_name, run_recovery_ids or run_ids must be provided")
raise ValueError("One of latest_run_file, experiment, run_recovery_ids or run_ids must be provided")


def get_aml_runs_from_latest_run_file(args: Namespace, workspace: Workspace) -> List[Run]:
Expand Down
4 changes: 2 additions & 2 deletions src/health/azure/himl_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def determine_output_dir_name(args: Namespace, run_id_source: AzureRunIdSource,
:return: The path in which to store the AML Run files
"""
if run_id_source == AzureRunIdSource.EXPERIMENT_LATEST:
output_path = output_dir / args.experiment_name
output_path = output_dir / args.experiment
elif run_id_source == AzureRunIdSource.LATEST_RUN_FILE:
output_path = output_dir / Path(args.latest_run_file).stem
elif run_id_source == AzureRunIdSource.RUN_RECOVERY_ID:
Expand Down Expand Up @@ -57,7 +57,7 @@ def main() -> None: # pragma: no cover
help="Optional path to most_recent_run.txt where the ID of the latest run is stored"
)
parser.add_argument(
"--experiment_name",
"--experiment",
type=str,
required=False,
help="The name of the AML Experiment that you wish to download Run files from"
Expand Down
113 changes: 103 additions & 10 deletions src/health/azure/himl_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,99 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import os
import sys
import logging
from argparse import ArgumentParser
from pathlib import Path
from requests import Session
from typing import Any, Optional

from azureml._run_impl.run_watcher import RunWatcher
from azureml.tensorboard import Tensorboard

from health.azure.azure_util import get_aml_runs, determine_run_id_source
from health.azure.himl import get_workspace

from concurrent.futures import ThreadPoolExecutor
from subprocess import PIPE, Popen
from threading import Event

ROOT_DIR = Path.cwd()
OUTPUT_DIR = ROOT_DIR / "outputs"
TENSORBOARD_DIR = ROOT_DIR / "tensorboard_logs"


class WrappedTensorboard(Tensorboard):
def __init__(self, remote_root: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.remote_root = remote_root

def start(self) -> Optional[str]:
"""
Start the Tensorboard instance, and begin processing logs.
:return: The URL for accessing the Tensorboard instance.
"""
self._tb_proc: Optional[Popen]
if self._tb_proc is not None:
return None

self._executor = ThreadPoolExecutor()
self._event = Event()
self._session = Session()

# Make a run watcher for each run we are monitoring
self._run_watchers = []
local_log_dirs = []
for run in self._runs:
run_local_root = os.path.join(self._local_root, run.id)
local_log_dirs.append(f"{run.id}:{run_local_root}")
run_watcher = RunWatcher(
run,
local_root=run_local_root,
remote_root=self.remote_root,
executor=self._executor,
event=self._event,
session=self._session)
self._run_watchers.append(run_watcher)

for w in self._run_watchers:
self._executor.submit(w.refresh_requeue)

# We use sys.executable here to ensure that we can import modules from the same environment
# as the current process.
# (using just "python" results in the global environment, which might not have a Tensorboard module)
# sometimes, sys.executable might not give us what we want (i.e. in a notebook), and then we just have to hope
# that "python" will give us something useful
python_binary = sys.executable or "python"
python_command = [
python_binary, "-m", "tensorboard.main",
"--port", str(self._port)
]
if len(local_log_dirs) > 1:
# logdir_spec is not recommended but it is the only working way to display multiple dirs
logdir_str = ','.join(local_log_dirs)
python_command.append("--logdir_spec")
logging.info("Loading tensorboard files for > 1 run. You may notice reduced functionality as noted "
"here: https://github.com/tensorflow/tensorboard#logdir--logdir_spec-legacy-mode ")
else:
logdir_str = run_local_root
python_command.append("--logdir")

python_command.append(logdir_str)

self._tb_proc = Popen(
python_command,
stderr=PIPE, stdout=PIPE, universal_newlines=True)
if os.name == "nt":
self._win32_kill_subprocess_on_exit(self._tb_proc)

url = self._wait_for_url()
# in notebooks, this shows as a clickable link (whereas the returned value is not parsed in output)
logging.info(f"Tensorboard running at: {url}")

return url


def main() -> None: # pragma: no cover
Expand All @@ -33,11 +115,11 @@ def main() -> None: # pragma: no cover
help="The port to run Tensorboard on"
)
parser.add_argument(
"--run_logs_dir",
"--log_dir",
type=str,
default="tensorboard_logs",
default="outputs",
required=False,
help="Path to directory in which to store Tensorboard logs"
help="Path to directory in which Tensorboard files (summarywriter and TB logs) are stored"
)
parser.add_argument(
"--latest_run_file",
Expand All @@ -46,7 +128,7 @@ def main() -> None: # pragma: no cover
help="Optional path to most_recent_run.txt where details on latest run are stored"
)
parser.add_argument(
"--experiment_name",
"--experiment",
type=str,
required=False,
help="The name of the AML Experiment that you wish to view Runs from"
Expand All @@ -65,6 +147,12 @@ def main() -> None: # pragma: no cover
required=False,
help="Optional run recovery ids of the runs to plot"
)
parser.add_argument(
"--run_ids",
default=[],
nargs="+",
help="Optional run ids of the runs to plot"
)

args = parser.parse_args()

Expand All @@ -73,21 +161,26 @@ def main() -> None: # pragma: no cover
raise ValueError(
"You must provide a config.json file in the root folder to connect"
"to an AML workspace. This can be downloaded from your AML workspace (see README.md)"
)
)

workspace = get_workspace(aml_workspace=None, workspace_config_path=config_path)

run_id_source = determine_run_id_source(args)
runs = get_aml_runs(args, workspace, run_id_source)

print(f"Runs:\n{runs}")
if len(runs) == 0:
raise ValueError("No runs were found")

# start Tensorboard
print(f"runs: {runs}")
local_logs_dir = ROOT_DIR / args.log_dir
local_logs_dir.mkdir(exist_ok=True, parents=True)

remote_logs_dir = local_logs_dir.relative_to(ROOT_DIR)

run_logs_dir = OUTPUT_DIR / args.run_logs_dir
run_logs_dir.mkdir(exist_ok=True)
ts = Tensorboard(runs=runs, local_root=str(run_logs_dir), port=args.port)
ts = WrappedTensorboard(remote_root=str(remote_logs_dir) + '/',
runs=runs,
local_root=str(local_logs_dir),
port='6006')

ts.start()
print("=============================================================================\n\n")
Expand Down
1 change: 1 addition & 0 deletions test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ pylint==2.9.5
pycobertura==2.0.1
pytest==6.2.2
pytest-cov==2.11.1
types-requests==2.25.6
4 changes: 2 additions & 2 deletions testhiml/health/azure/test_azure_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def __init__(self, run_id: str = 'run1234') -> None:
def test_determine_run_id_source(tmp_path: Path) -> None:
parser = ArgumentParser()
parser.add_argument("--latest_run_file", type=str)
parser.add_argument("--experiment_name", type=str)
parser.add_argument("--experiment", type=str)
parser.add_argument("--run_recovery_ids", type=str)
parser.add_argument("--run_ids", type=str)

Expand All @@ -440,7 +440,7 @@ def test_determine_run_id_source(tmp_path: Path) -> None:
assert util.determine_run_id_source(mock_args) == util.AzureRunIdSource.LATEST_RUN_FILE

# If experiment name is provided, expect source to be experiment
mock_args = parser.parse_args(["--experiment_name", "fake_experiment"])
mock_args = parser.parse_args(["--experiment", "fake_experiment"])
assert util.determine_run_id_source(mock_args) == util.AzureRunIdSource.EXPERIMENT_LATEST

# If run recovery id is provided, expect source to be that
Expand Down
Binary file not shown.
Loading

0 comments on commit 46b2d3f

Please sign in to comment.