From 1bfa52d74e249a3a345b847bb73c273a35073d23 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Thu, 1 Dec 2022 18:25:06 -0500 Subject: [PATCH] Unify event handlers in loops (#405) --- lib/axon/loop.ex | 193 ++++++++++++++++++++++++---------------- test/axon/loop_test.exs | 32 ++++++- 2 files changed, 144 insertions(+), 81 deletions(-) diff --git a/lib/axon/loop.ex b/lib/axon/loop.ex index 2686861d..5c66189a 100644 --- a/lib/axon/loop.ex +++ b/lib/axon/loop.ex @@ -313,22 +313,26 @@ defmodule Axon.Loop do {init_optimizer_fn, update_optimizer_fn} = build_optimizer_fns(optimizer) {init_loss_scale, scale_loss, unscale_grads} = build_loss_scale_fns(loss_scale) - init_fn = fn {inp, _}, init_model_state -> - model_state = init_model_fn.(inp, init_model_state) - optimizer_state = init_optimizer_fn.(model_state) - loss_scale_state = init_loss_scale.() + init_fn = fn + {inp, _}, %{} = init_model_state -> + model_state = init_model_fn.(inp, init_model_state) + optimizer_state = init_optimizer_fn.(model_state) + loss_scale_state = init_loss_scale.() - %{ - i: Nx.tensor(0), - y_true: Nx.tensor(0.0), - y_pred: Nx.tensor(0.0), - loss: Nx.tensor(0.0), - gradient_step: Nx.tensor(0), - model_state: model_state, - gradient_state: zeros_like(model_state), - optimizer_state: optimizer_state, - loss_scale_state: loss_scale_state - } + %{ + i: Nx.tensor(0), + y_true: Nx.tensor(0.0), + y_pred: Nx.tensor(0.0), + loss: Nx.tensor(0.0), + gradient_step: Nx.tensor(0), + model_state: model_state, + gradient_state: zeros_like(model_state), + optimizer_state: optimizer_state, + loss_scale_state: loss_scale_state + } + + data, state -> + raise_bad_training_inputs!(data, state) end # TODO: We should probably compute in same compute policy as MP @@ -345,59 +349,63 @@ defmodule Axon.Loop do {model_out, loss} end - step_fn = fn {inp, tar}, state -> - %{ - i: i, - gradient_step: gradient_step, - loss_scale_state: loss_scale_state, - gradient_state: gradient_state, - model_state: model_state, - optimizer_state: optimizer_state, - loss: loss - } = state - - {{model_out, batch_loss}, gradients} = - Nx.Defn.value_and_grad( - model_state, - &objective_fn.(&1, loss_scale_state, inp, tar), - fn x -> elem(x, 1) end - ) - - {gradients, new_loss_scale_state} = unscale_grads.(gradients, loss_scale_state) + step_fn = fn + {inp, tar}, %{} = state -> + %{ + i: i, + gradient_step: gradient_step, + loss_scale_state: loss_scale_state, + gradient_state: gradient_state, + model_state: model_state, + optimizer_state: optimizer_state, + loss: loss + } = state + + {{model_out, batch_loss}, gradients} = + Nx.Defn.value_and_grad( + model_state, + &objective_fn.(&1, loss_scale_state, inp, tar), + fn x -> elem(x, 1) end + ) - preds = model_out.prediction - new_state = model_out.state + {gradients, new_loss_scale_state} = unscale_grads.(gradients, loss_scale_state) - new_loss = - loss - |> Nx.multiply(i) - |> Nx.add(Nx.multiply(batch_loss, steps)) - |> Nx.divide(Nx.add(i, 1)) + preds = model_out.prediction + new_state = model_out.state - {new_model_state, new_optimizer_state, new_gradient_state, new_gradient_step} = - if Nx.greater_equal(gradient_step, steps - 1) do - {updates, new_optimizer_state} = - update_optimizer_fn.(gradients, optimizer_state, model_state) + new_loss = + loss + |> Nx.multiply(i) + |> Nx.add(Nx.multiply(batch_loss, steps)) + |> Nx.divide(Nx.add(i, 1)) + + {new_model_state, new_optimizer_state, new_gradient_state, new_gradient_step} = + if Nx.greater_equal(gradient_step, steps - 1) do + {updates, new_optimizer_state} = + update_optimizer_fn.(gradients, optimizer_state, model_state) + + new_gradient_state = zeros_like(model_state) + new_model_state = Axon.Updates.apply_updates(model_state, updates, new_state) + {new_model_state, new_optimizer_state, new_gradient_state, 0} + else + {model_state, optimizer_state, gradient_state + gradients, gradient_step + 1} + end - new_gradient_state = zeros_like(model_state) - new_model_state = Axon.Updates.apply_updates(model_state, updates, new_state) - {new_model_state, new_optimizer_state, new_gradient_state, 0} - else - {model_state, optimizer_state, gradient_state + gradients, gradient_step + 1} - end + %{ + state + | i: Nx.add(i, 1), + gradient_step: new_gradient_step, + y_true: tar, + y_pred: preds, + loss: new_loss, + model_state: new_model_state, + gradient_state: new_gradient_state, + optimizer_state: new_optimizer_state, + loss_scale_state: new_loss_scale_state + } - %{ - state - | i: Nx.add(i, 1), - gradient_step: new_gradient_step, - y_true: tar, - y_pred: preds, - loss: new_loss, - model_state: new_model_state, - gradient_state: new_gradient_state, - optimizer_state: new_optimizer_state, - loss_scale_state: new_loss_scale_state - } + data, state -> + raise_bad_training_inputs!(data, state) end { @@ -406,6 +414,21 @@ defmodule Axon.Loop do } end + defp raise_bad_training_inputs!(data, state) do + raise ArgumentError, + "invalid arguments given to train-step initialization," <> + " this usually happens when you pass a invalid parameters" <> + " to Axon.Loop.run with a loop constructed using Axon.Loop.trainer" <> + " or Axon.Loop.evaluator, supervised training and evaluation loops" + + " expect a stream or enumerable of inputs" <> + " of the form {x_train, y_train} where x_train and y_train" <> + " are batches of tensors, you must also provide an initial model" <> + " state such as an empty map: Axon.Loop.run(loop, data, %{}), got" <> + " input data: #{inspect(data)} and initial model state: " <> + " #{inspect(state)}" + end + @doc """ Creates a supervised evaluation step from a model and model state. @@ -425,12 +448,16 @@ defmodule Axon.Loop do } end - step_fn = fn {inp, tar}, %{model_state: model_state} -> - %{ - model_state: model_state, - y_true: tar, - y_pred: forward_model_fn.(model_state, inp) - } + step_fn = fn + {inp, tar}, %{model_state: model_state} -> + %{ + model_state: model_state, + y_true: tar, + y_pred: forward_model_fn.(model_state, inp) + } + + data, state -> + raise_bad_training_inputs!(data, state) end { @@ -587,8 +614,11 @@ defmodule Axon.Loop do if log_interval > 0 do loop - |> log(:iteration_completed, &supervised_log_message_fn/1, :stdio, every: log_interval) - |> log(:epoch_completed, fn _ -> "\n" end, :stdio) + |> log(&supervised_log_message_fn/1, + event: :iteration_completed, + filter: [every: log_interval] + ) + |> log(fn _ -> "\n" end, event: :epoch_completed) else loop end @@ -655,7 +685,7 @@ defmodule Axon.Loop do output_transform = fn state -> state.metrics end loop(step_fn, init_fn, output_transform) - |> log(:iteration_completed, &supervised_log_message_fn(&1, false), :stdio) + |> log(&supervised_log_message_fn(&1, false), event: :iteration_completed) end @doc """ @@ -829,8 +859,12 @@ defmodule Axon.Loop do `message_fn` should take the loop state and return a binary representing the message to be written to the IO device. """ - def log(%Loop{} = loop, event, message_fn, device \\ :stdio, filter \\ :always) - when is_function(message_fn, 1) do + def log(%Loop{} = loop, message_fn, opts \\ []) when is_function(message_fn, 1) do + opts = Keyword.validate!(opts, event: :iteration_completed, filter: :always, device: :stdio) + event = opts[:event] || :iteration_completed + filter = opts[:filter] || :always + device = opts[:device] || :stdio + log_fn = fn %State{} = state -> try do msg = message_fn.(state) @@ -888,16 +922,19 @@ defmodule Axon.Loop do model |> Axon.Loop.trainer(:mean_squared_error, :sgd) |> Axon.Loop.metric(:mean_absolute_error) - |> Axon.Loop.validate(model, validation_data, :iteration_completed, every: 10_000) + |> Axon.Loop.validate(model, validation_data, event: :iteration_completed, filter: [every: 10_000]) |> Axon.Loop.metric(:binary_cross_entropy) """ def validate( %Loop{metrics: metric_fns} = loop, model, validation_data, - event \\ :epoch_completed, - filter \\ :always + opts \\ [] ) do + opts = Keyword.validate!(opts, event: :epoch_completed, filter: :always) + event = opts[:event] || :epoch_completed + filter = opts[:filter] || :always + validation_loop = fn %State{metrics: metrics, step_state: step_state} = state -> %{model_state: model_state} = step_state @@ -909,7 +946,7 @@ defmodule Axon.Loop do metric(loop, v, k) end) ) - |> log(:completed, fn _ -> "\n" end) + |> log(fn _ -> "\n" end, event: :completed) |> run(validation_data, model_state) |> Access.get(0) |> Map.new(fn {k, v} -> diff --git a/test/axon/loop_test.exs b/test/axon/loop_test.exs index 4b7bfac2..12da7dff 100644 --- a/test/axon/loop_test.exs +++ b/test/axon/loop_test.exs @@ -198,9 +198,9 @@ defmodule Axon.LoopTest do assert %Loop{} = loop = Loop.evaluator(model) assert %Loop{} = loop = Loop.metric(loop, :mean_absolute_error) - ExUnit.CaptureIO.capture_io(fn -> - assert %{0 => %{"mean_absolute_error" => _}} = Loop.run(loop, data, model_state) - end) + assert ExUnit.CaptureIO.capture_io(fn -> + assert %{0 => %{"mean_absolute_error" => _}} = Loop.run(loop, data, model_state) + end) =~ "Batch" end test "eval_step/1 evalutes model on a single batch" do @@ -431,6 +431,32 @@ defmodule Axon.LoopTest do end end + describe "trainer" do + test "returns clear error on bad inputs" do + model = Axon.input("input") + data = Stream.repeatedly(fn -> Nx.tensor(5) end) + + assert_raise ArgumentError, ~r/invalid arguments/, fn -> + model + |> Axon.Loop.trainer(:categorical_cross_entropy, :adam) + |> Axon.Loop.run(data, %{}) + end + end + end + + describe "evaluator" do + test "returns clear error on bad inputs" do + model = Axon.input("input") + data = Stream.repeatedly(fn -> Nx.tensor(5) end) + + assert_raise ArgumentError, ~r/invalid arguments/, fn -> + model + |> Axon.Loop.evaluator() + |> Axon.Loop.run(data, %{}) + end + end + end + describe "serialization" do test "serialize_state/deserialize_state preserve loop state" do model = Axon.input("input", shape: {nil, 1}) |> Axon.dense(2)