Skip to content

Commit

Permalink
Unify event handlers in loops (#405)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Dec 1, 2022
1 parent e0bc997 commit 1bfa52d
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 81 deletions.
193 changes: 115 additions & 78 deletions lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -313,22 +313,26 @@ defmodule Axon.Loop do
{init_optimizer_fn, update_optimizer_fn} = build_optimizer_fns(optimizer)
{init_loss_scale, scale_loss, unscale_grads} = build_loss_scale_fns(loss_scale)

init_fn = fn {inp, _}, init_model_state ->
model_state = init_model_fn.(inp, init_model_state)
optimizer_state = init_optimizer_fn.(model_state)
loss_scale_state = init_loss_scale.()
init_fn = fn
{inp, _}, %{} = init_model_state ->
model_state = init_model_fn.(inp, init_model_state)
optimizer_state = init_optimizer_fn.(model_state)
loss_scale_state = init_loss_scale.()

%{
i: Nx.tensor(0),
y_true: Nx.tensor(0.0),
y_pred: Nx.tensor(0.0),
loss: Nx.tensor(0.0),
gradient_step: Nx.tensor(0),
model_state: model_state,
gradient_state: zeros_like(model_state),
optimizer_state: optimizer_state,
loss_scale_state: loss_scale_state
}
%{
i: Nx.tensor(0),
y_true: Nx.tensor(0.0),
y_pred: Nx.tensor(0.0),
loss: Nx.tensor(0.0),
gradient_step: Nx.tensor(0),
model_state: model_state,
gradient_state: zeros_like(model_state),
optimizer_state: optimizer_state,
loss_scale_state: loss_scale_state
}

data, state ->
raise_bad_training_inputs!(data, state)
end

# TODO: We should probably compute in same compute policy as MP
Expand All @@ -345,59 +349,63 @@ defmodule Axon.Loop do
{model_out, loss}
end

step_fn = fn {inp, tar}, state ->
%{
i: i,
gradient_step: gradient_step,
loss_scale_state: loss_scale_state,
gradient_state: gradient_state,
model_state: model_state,
optimizer_state: optimizer_state,
loss: loss
} = state

{{model_out, batch_loss}, gradients} =
Nx.Defn.value_and_grad(
model_state,
&objective_fn.(&1, loss_scale_state, inp, tar),
fn x -> elem(x, 1) end
)

{gradients, new_loss_scale_state} = unscale_grads.(gradients, loss_scale_state)
step_fn = fn
{inp, tar}, %{} = state ->
%{
i: i,
gradient_step: gradient_step,
loss_scale_state: loss_scale_state,
gradient_state: gradient_state,
model_state: model_state,
optimizer_state: optimizer_state,
loss: loss
} = state

{{model_out, batch_loss}, gradients} =
Nx.Defn.value_and_grad(
model_state,
&objective_fn.(&1, loss_scale_state, inp, tar),
fn x -> elem(x, 1) end
)

preds = model_out.prediction
new_state = model_out.state
{gradients, new_loss_scale_state} = unscale_grads.(gradients, loss_scale_state)

new_loss =
loss
|> Nx.multiply(i)
|> Nx.add(Nx.multiply(batch_loss, steps))
|> Nx.divide(Nx.add(i, 1))
preds = model_out.prediction
new_state = model_out.state

{new_model_state, new_optimizer_state, new_gradient_state, new_gradient_step} =
if Nx.greater_equal(gradient_step, steps - 1) do
{updates, new_optimizer_state} =
update_optimizer_fn.(gradients, optimizer_state, model_state)
new_loss =
loss
|> Nx.multiply(i)
|> Nx.add(Nx.multiply(batch_loss, steps))
|> Nx.divide(Nx.add(i, 1))

{new_model_state, new_optimizer_state, new_gradient_state, new_gradient_step} =
if Nx.greater_equal(gradient_step, steps - 1) do
{updates, new_optimizer_state} =
update_optimizer_fn.(gradients, optimizer_state, model_state)

new_gradient_state = zeros_like(model_state)
new_model_state = Axon.Updates.apply_updates(model_state, updates, new_state)
{new_model_state, new_optimizer_state, new_gradient_state, 0}
else
{model_state, optimizer_state, gradient_state + gradients, gradient_step + 1}
end

new_gradient_state = zeros_like(model_state)
new_model_state = Axon.Updates.apply_updates(model_state, updates, new_state)
{new_model_state, new_optimizer_state, new_gradient_state, 0}
else
{model_state, optimizer_state, gradient_state + gradients, gradient_step + 1}
end
%{
state
| i: Nx.add(i, 1),
gradient_step: new_gradient_step,
y_true: tar,
y_pred: preds,
loss: new_loss,
model_state: new_model_state,
gradient_state: new_gradient_state,
optimizer_state: new_optimizer_state,
loss_scale_state: new_loss_scale_state
}

%{
state
| i: Nx.add(i, 1),
gradient_step: new_gradient_step,
y_true: tar,
y_pred: preds,
loss: new_loss,
model_state: new_model_state,
gradient_state: new_gradient_state,
optimizer_state: new_optimizer_state,
loss_scale_state: new_loss_scale_state
}
data, state ->
raise_bad_training_inputs!(data, state)
end

{
Expand All @@ -406,6 +414,21 @@ defmodule Axon.Loop do
}
end

defp raise_bad_training_inputs!(data, state) do
raise ArgumentError,
"invalid arguments given to train-step initialization," <>
" this usually happens when you pass a invalid parameters" <>
" to Axon.Loop.run with a loop constructed using Axon.Loop.trainer" <>
" or Axon.Loop.evaluator, supervised training and evaluation loops"

" expect a stream or enumerable of inputs" <>
" of the form {x_train, y_train} where x_train and y_train" <>
" are batches of tensors, you must also provide an initial model" <>
" state such as an empty map: Axon.Loop.run(loop, data, %{}), got" <>
" input data: #{inspect(data)} and initial model state: " <>
" #{inspect(state)}"
end

@doc """
Creates a supervised evaluation step from a model and model state.
Expand All @@ -425,12 +448,16 @@ defmodule Axon.Loop do
}
end

step_fn = fn {inp, tar}, %{model_state: model_state} ->
%{
model_state: model_state,
y_true: tar,
y_pred: forward_model_fn.(model_state, inp)
}
step_fn = fn
{inp, tar}, %{model_state: model_state} ->
%{
model_state: model_state,
y_true: tar,
y_pred: forward_model_fn.(model_state, inp)
}

data, state ->
raise_bad_training_inputs!(data, state)
end

{
Expand Down Expand Up @@ -587,8 +614,11 @@ defmodule Axon.Loop do

if log_interval > 0 do
loop
|> log(:iteration_completed, &supervised_log_message_fn/1, :stdio, every: log_interval)
|> log(:epoch_completed, fn _ -> "\n" end, :stdio)
|> log(&supervised_log_message_fn/1,
event: :iteration_completed,
filter: [every: log_interval]
)
|> log(fn _ -> "\n" end, event: :epoch_completed)
else
loop
end
Expand Down Expand Up @@ -655,7 +685,7 @@ defmodule Axon.Loop do
output_transform = fn state -> state.metrics end

loop(step_fn, init_fn, output_transform)
|> log(:iteration_completed, &supervised_log_message_fn(&1, false), :stdio)
|> log(&supervised_log_message_fn(&1, false), event: :iteration_completed)
end

@doc """
Expand Down Expand Up @@ -829,8 +859,12 @@ defmodule Axon.Loop do
`message_fn` should take the loop state and return a binary
representing the message to be written to the IO device.
"""
def log(%Loop{} = loop, event, message_fn, device \\ :stdio, filter \\ :always)
when is_function(message_fn, 1) do
def log(%Loop{} = loop, message_fn, opts \\ []) when is_function(message_fn, 1) do
opts = Keyword.validate!(opts, event: :iteration_completed, filter: :always, device: :stdio)
event = opts[:event] || :iteration_completed
filter = opts[:filter] || :always
device = opts[:device] || :stdio

log_fn = fn %State{} = state ->
try do
msg = message_fn.(state)
Expand Down Expand Up @@ -888,16 +922,19 @@ defmodule Axon.Loop do
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(:mean_absolute_error)
|> Axon.Loop.validate(model, validation_data, :iteration_completed, every: 10_000)
|> Axon.Loop.validate(model, validation_data, event: :iteration_completed, filter: [every: 10_000])
|> Axon.Loop.metric(:binary_cross_entropy)
"""
def validate(
%Loop{metrics: metric_fns} = loop,
model,
validation_data,
event \\ :epoch_completed,
filter \\ :always
opts \\ []
) do
opts = Keyword.validate!(opts, event: :epoch_completed, filter: :always)
event = opts[:event] || :epoch_completed
filter = opts[:filter] || :always

validation_loop = fn %State{metrics: metrics, step_state: step_state} = state ->
%{model_state: model_state} = step_state

Expand All @@ -909,7 +946,7 @@ defmodule Axon.Loop do
metric(loop, v, k)
end)
)
|> log(:completed, fn _ -> "\n" end)
|> log(fn _ -> "\n" end, event: :completed)
|> run(validation_data, model_state)
|> Access.get(0)
|> Map.new(fn {k, v} ->
Expand Down
32 changes: 29 additions & 3 deletions test/axon/loop_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,9 @@ defmodule Axon.LoopTest do
assert %Loop{} = loop = Loop.evaluator(model)
assert %Loop{} = loop = Loop.metric(loop, :mean_absolute_error)

ExUnit.CaptureIO.capture_io(fn ->
assert %{0 => %{"mean_absolute_error" => _}} = Loop.run(loop, data, model_state)
end)
assert ExUnit.CaptureIO.capture_io(fn ->
assert %{0 => %{"mean_absolute_error" => _}} = Loop.run(loop, data, model_state)
end) =~ "Batch"
end

test "eval_step/1 evalutes model on a single batch" do
Expand Down Expand Up @@ -431,6 +431,32 @@ defmodule Axon.LoopTest do
end
end

describe "trainer" do
test "returns clear error on bad inputs" do
model = Axon.input("input")
data = Stream.repeatedly(fn -> Nx.tensor(5) end)

assert_raise ArgumentError, ~r/invalid arguments/, fn ->
model
|> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
|> Axon.Loop.run(data, %{})
end
end
end

describe "evaluator" do
test "returns clear error on bad inputs" do
model = Axon.input("input")
data = Stream.repeatedly(fn -> Nx.tensor(5) end)

assert_raise ArgumentError, ~r/invalid arguments/, fn ->
model
|> Axon.Loop.evaluator()
|> Axon.Loop.run(data, %{})
end
end
end

describe "serialization" do
test "serialize_state/deserialize_state preserve loop state" do
model = Axon.input("input", shape: {nil, 1}) |> Axon.dense(2)
Expand Down

0 comments on commit 1bfa52d

Please sign in to comment.