From beae5262df13d46c5da08aa7e193d7fd5bffa1ca Mon Sep 17 00:00:00 2001 From: James Tocknell Date: Wed, 18 Jul 2018 15:57:25 +1000 Subject: [PATCH] Improve option handling --- spaceplot/__init__.py | 38 ++++++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/spaceplot/__init__.py b/spaceplot/__init__.py index f9090e9..0e6c2d3 100644 --- a/spaceplot/__init__.py +++ b/spaceplot/__init__.py @@ -6,7 +6,9 @@ from matplotlib.pyplot import figure as _figure -def spaceplots(inputs, outputs, input_names=None, output_names=None, **kwargs): +def spaceplots( + inputs, outputs, input_names=None, output_names=None, limits=None, **kwargs +): num_samples, num_inputs = inputs.shape if input_names is not None: if len(input_names) != num_inputs: @@ -23,10 +25,21 @@ def spaceplots(inputs, outputs, input_names=None, output_names=None, **kwargs): else: output_names = [None] * num_outputs + if limits is not None: + if limits.shape[1] != 2: + raise RuntimeError( + "There must be a upper and lower limit for each output" + ) + elif limits.shape[0] != num_outputs: + raise RuntimeError("Output data and limits don't match") + else: + limits = [[None, None]] * num_outputs + for out_index in range(num_outputs): yield _subspace_plot( inputs, outputs[:, out_index], input_names=input_names, - output_name=output_names[out_index], **kwargs + output_name=output_names[out_index], min_output=limits[out_index][0], + max_output=limits[out_index][1], **kwargs ) @@ -79,7 +92,18 @@ def _setup_axes( return fig, axes, grid -def _subspace_plot(inputs, output, *, input_names, output_name, **kwargs): +def _subspace_plot( + inputs, output, *, input_names, output_name, scatter_args=None, + histogram_args=None, min_output=None, max_output=None +): + if scatter_args is None: + scatter_args = {} + if histogram_args is None: + histogram_args = {} + if min_output is None: + min_output = min(output) + if max_output is None: + max_output = max(output) # see https://matplotlib.org/examples/pylab_examples/multi_image.html _, num_inputs = inputs.shape @@ -89,11 +113,13 @@ def _subspace_plot(inputs, output, *, input_names, output_name, **kwargs): if output_name is not None: fig.suptitle(output_name) - norm = _Normalize(min(output), max(output)) # TODO: get from user if needed + norm = _Normalize(min_output, max_output) hist_plots = [] for i in range(num_inputs): - hist_plots.append(_plot_hist(inputs[:, i], axis=axes[i][i])) + hist_plots.append(_plot_hist( + inputs[:, i], axis=axes[i][i], **histogram_args + )) scatter_plots = [] scatter_plots_grid = [] @@ -103,7 +129,7 @@ def _subspace_plot(inputs, output, *, input_names, output_name, **kwargs): sc_plot = _plot_scatter( x=inputs[:, x_index], y=inputs[:, y_index], z=output, axis=axes[y_index][x_index], # check order - norm=norm + norm=norm, **scatter_args ) scatter_plots.append(sc_plot) scatter_plots_grid[y_index].append(sc_plot)