Skip to content

Commit

Permalink
Improve option handling
Browse files Browse the repository at this point in the history
  • Loading branch information
aragilar committed Jul 18, 2018
1 parent bcbaa6c commit beae526
Showing 1 changed file with 32 additions and 6 deletions.
38 changes: 32 additions & 6 deletions spaceplot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from matplotlib.pyplot import figure as _figure


def spaceplots(inputs, outputs, input_names=None, output_names=None, **kwargs):
def spaceplots(
inputs, outputs, input_names=None, output_names=None, limits=None, **kwargs
):
num_samples, num_inputs = inputs.shape
if input_names is not None:
if len(input_names) != num_inputs:
Expand All @@ -23,10 +25,21 @@ def spaceplots(inputs, outputs, input_names=None, output_names=None, **kwargs):
else:
output_names = [None] * num_outputs

if limits is not None:
if limits.shape[1] != 2:
raise RuntimeError(
"There must be a upper and lower limit for each output"
)
elif limits.shape[0] != num_outputs:
raise RuntimeError("Output data and limits don't match")
else:
limits = [[None, None]] * num_outputs

for out_index in range(num_outputs):
yield _subspace_plot(
inputs, outputs[:, out_index], input_names=input_names,
output_name=output_names[out_index], **kwargs
output_name=output_names[out_index], min_output=limits[out_index][0],
max_output=limits[out_index][1], **kwargs
)


Expand Down Expand Up @@ -79,7 +92,18 @@ def _setup_axes(
return fig, axes, grid


def _subspace_plot(inputs, output, *, input_names, output_name, **kwargs):
def _subspace_plot(
inputs, output, *, input_names, output_name, scatter_args=None,
histogram_args=None, min_output=None, max_output=None
):
if scatter_args is None:
scatter_args = {}
if histogram_args is None:
histogram_args = {}
if min_output is None:
min_output = min(output)
if max_output is None:
max_output = max(output)

# see https://matplotlib.org/examples/pylab_examples/multi_image.html
_, num_inputs = inputs.shape
Expand All @@ -89,11 +113,13 @@ def _subspace_plot(inputs, output, *, input_names, output_name, **kwargs):
if output_name is not None:
fig.suptitle(output_name)

norm = _Normalize(min(output), max(output)) # TODO: get from user if needed
norm = _Normalize(min_output, max_output)

hist_plots = []
for i in range(num_inputs):
hist_plots.append(_plot_hist(inputs[:, i], axis=axes[i][i]))
hist_plots.append(_plot_hist(
inputs[:, i], axis=axes[i][i], **histogram_args
))

scatter_plots = []
scatter_plots_grid = []
Expand All @@ -103,7 +129,7 @@ def _subspace_plot(inputs, output, *, input_names, output_name, **kwargs):
sc_plot = _plot_scatter(
x=inputs[:, x_index], y=inputs[:, y_index], z=output,
axis=axes[y_index][x_index], # check order
norm=norm
norm=norm, **scatter_args
)
scatter_plots.append(sc_plot)
scatter_plots_grid[y_index].append(sc_plot)
Expand Down

0 comments on commit beae526

Please sign in to comment.