Skip to content

Commit

Permalink
Prepare for Nx 0.5 (#471)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Jan 31, 2023
1 parent 278f06e commit 899b3bf
Show file tree
Hide file tree
Showing 15 changed files with 351 additions and 428 deletions.
2 changes: 1 addition & 1 deletion lib/axon/activations.ex
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ defmodule Axon.Activations do
custom_grad(
Nx.max(x, 0),
[x],
fn g -> [{x, Nx.select(Nx.greater(x, 0), g, Nx.broadcast(0, g))}] end
fn g -> [Nx.select(Nx.greater(x, 0), g, Nx.broadcast(0, g))] end
)
end

Expand Down
4 changes: 2 additions & 2 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -889,9 +889,9 @@ defmodule Axon.Compiler do

if event? and mode? do
if on_event == :backward do
Nx.Defn.Kernel.custom_grad(expr, fn _ans, g ->
Nx.Defn.Kernel.custom_grad(expr, [expr], fn g ->
hooked_g = Nx.Defn.Kernel.hook(g, hook_fn)
[{expr, hooked_g}]
[hooked_g]
end)
else
Nx.Defn.Kernel.hook(expr, hook_fn)
Expand Down
14 changes: 7 additions & 7 deletions lib/axon/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -952,13 +952,13 @@ defmodule Axon.Layers do
norm = opts[:norm]

input
|> Nx.power(norm)
|> Nx.pow(norm)
|> Nx.window_sum(window_dimensions,
strides: strides,
padding: padding,
window_dilations: dilations
)
|> Nx.power(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm))
|> Nx.pow(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm))
end

@doc """
Expand Down Expand Up @@ -1078,9 +1078,9 @@ defmodule Axon.Layers do
Axon.Shape.adaptive_pool_window_size(input, window_strides, output_size, opts[:channels])

input
|> Nx.power(norm)
|> Nx.pow(norm)
|> Nx.window_sum(window_dimensions, padding: :valid, strides: window_strides)
|> Nx.power(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm))
|> Nx.pow(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm))
end

## Normalization
Expand Down Expand Up @@ -1438,7 +1438,7 @@ defmodule Axon.Layers do

mask = Nx.less(rand, keep_prob)

a = Nx.rsqrt(keep_prob * Nx.power(Nx.tensor(1, type: Nx.type(input)) * alpha_p, 2))
a = Nx.rsqrt(keep_prob * Nx.pow(Nx.tensor(1, type: Nx.type(input)) * alpha_p, 2))
b = -a * alpha_p * rate

x = Nx.select(mask, input, alpha_p)
Expand Down Expand Up @@ -1661,9 +1661,9 @@ defmodule Axon.Layers do
all_but_batch_and_feature = Axon.Shape.global_pool_axes(input, opts[:channels])

input
|> Nx.power(norm)
|> Nx.pow(norm)
|> Nx.sum(axes: all_but_batch_and_feature, keep_axes: opts[:keep_axes])
|> Nx.power(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm))
|> Nx.pow(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm))
end

## Sparse
Expand Down
3 changes: 0 additions & 3 deletions lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1811,9 +1811,6 @@ defmodule Axon.Loop do
Logger.debug("Axon.Loop finished batch step execution in #{us_to_ms(time)}ms")
end

# Force a garbage collection so any device or copied data is deallocated.
:erlang.garbage_collect()

batch_fn = {:compiled, batch_fn}
state = %{state | step_state: new_step_state, metrics: new_metrics}

Expand Down
2 changes: 1 addition & 1 deletion lib/axon/loss_scale.ex
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ defmodule Axon.LossScale do
precision during the model training process. Each loss-scale
implementation here returns a 3-tuple of the functions:
{init_fn, scale_fn, unscale_fn, adjust_fn} = Axon.LossScale.static(Nx.power(2, 15))
{init_fn, scale_fn, unscale_fn, adjust_fn} = Axon.LossScale.static(Nx.pow(2, 15))
You can use these to scale/unscale loss and gradients as well
as adjust the loss scale state.
Expand Down
2 changes: 1 addition & 1 deletion lib/axon/losses.ex
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ defmodule Axon.Losses do
loss =
y_true
|> Nx.subtract(y_pred)
|> Nx.power(2)
|> Nx.pow(2)
|> Nx.mean(axes: [-1])

reduction(loss, opts[:reduction])
Expand Down
2 changes: 1 addition & 1 deletion lib/axon/optimizers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ defmodule Axon.Optimizers do
end
end
model_params = Nx.random_uniform({784, 10})
{model_params, _key} = Nx.Random.uniform(key, shape: {784, 10})
{init_fn, update_fn} = Axon.Optimizers.adam(0.005)
optimizer_state =
Expand Down
4 changes: 2 additions & 2 deletions lib/axon/schedules.ex
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ defmodule Axon.Schedules do

decayed_value =
rate
|> Nx.power(p)
|> Nx.pow(p)
|> Nx.multiply(init_value)

Nx.select(
Expand Down Expand Up @@ -210,7 +210,7 @@ defmodule Axon.Schedules do
|> Nx.divide(k)
|> Nx.negate()
|> Nx.add(1)
|> Nx.power(p)
|> Nx.pow(p)
|> Nx.multiply(Nx.subtract(init_value, end_value))
|> Nx.add(end_value)
end
Expand Down
35 changes: 21 additions & 14 deletions lib/axon/updates.ex
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ defmodule Axon.Updates do
opts = keyword!(opts, eps: 1.0e-7)
eps = opts[:eps]

sum_of_squares = deep_merge(x, sum_of_squares, fn g, z -> Nx.power(g, 2) + z end)
sum_of_squares = deep_merge(x, sum_of_squares, fn g, z -> Nx.pow(g, 2) + z end)

inv_sqrt_squares = deep_new(sum_of_squares, fn z -> Nx.rsqrt(z + eps) end)

Expand Down Expand Up @@ -402,7 +402,7 @@ defmodule Axon.Updates do

mu_nu =
deep_merge(mu, nu, fn m, n ->
Nx.rsqrt(-Nx.power(m, 2) + n + eps)
Nx.rsqrt(-Nx.pow(m, 2) + n + eps)
end)

x = deep_merge(x, mu_nu, fn g, mn -> g * mn end)
Expand Down Expand Up @@ -499,7 +499,7 @@ defmodule Axon.Updates do
nu = update_moment(x, nu, b2, 2)
count_inc = count + 1

b2t = Nx.power(b2, count_inc)
b2t = Nx.pow(b2, count_inc)
ro = ro_inf - 2 * count_inc * b2t / (1 - b2t)

mu_hat = bias_correction(mu, b1, count + 1)
Expand Down Expand Up @@ -637,7 +637,7 @@ defmodule Axon.Updates do
sum_gs =
deep_reduce(x, Nx.tensor(0.0), fn leaf, acc ->
leaf
|> Nx.power(2)
|> Nx.pow(2)
|> Nx.sum()
|> Nx.add(acc)
end)
Expand Down Expand Up @@ -771,6 +771,9 @@ defmodule Axon.Updates do
Adds random Gaussian noise to the input.
## Options
* `:seed` - Random seed to use. Defaults to the
current system time.
* `:eta` - Controls amount of noise to add.
Defaults to `0.01`.
Expand All @@ -791,22 +794,26 @@ defmodule Axon.Updates do

def add_noise({init_fn, apply_fn} = combinator, opts)
when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do
stateful(combinator, &init_add_noise/1, &apply_add_noise(&1, &2, &3, opts))
{seed, opts} = Keyword.pop_lazy(opts, :seed, fn -> :erlang.system_time() end)
stateful(combinator, &init_add_noise(&1, seed: seed), &apply_add_noise(&1, &2, &3, opts))
end

defnp init_add_noise(_params) do
%{count: Nx.tensor(0)}
defnp init_add_noise(_params, opts \\ []) do
%{count: Nx.tensor(0), key: Nx.Random.key(opts[:seed])}
end

defnp apply_add_noise(x, %{count: count}, _params, opts \\ []) do
defnp apply_add_noise(x, %{count: count, key: key}, _params, opts \\ []) do
opts = keyword!(opts, eta: 0.01, gamma: 0.55)
var = opts[:eta] / Nx.power(count + 1, opts[:gamma])
var = opts[:eta] / Nx.pow(count + 1, opts[:gamma])

noise = deep_new(x, fn z -> Nx.random_normal(z) end)
{noise, key} =
deep_map_reduce(x, key, fn z, key ->
Nx.Random.normal(key, shape: Nx.shape(z), type: Nx.type(z))
end)

updates = deep_merge(x, noise, fn g, n -> g + var * n end)

{updates, %{count: count + 1}}
{updates, %{count: count + 1, key: key}}
end

@doc """
Expand Down Expand Up @@ -869,7 +876,7 @@ defmodule Axon.Updates do

nu =
deep_merge(x, nu, fn g, v ->
v - (1 - b2) * Nx.sign(v - Nx.power(g, 2)) * Nx.power(g, 2)
v - (1 - b2) * Nx.sign(v - Nx.pow(g, 2)) * Nx.pow(g, 2)
end)

mu_hat = bias_correction(mu, b1, count + 1)
Expand Down Expand Up @@ -998,11 +1005,11 @@ defmodule Axon.Updates do
## Helpers

defnp update_moment(x, moment, decay, order) do
deep_merge(x, moment, fn g, z -> (1 - decay) * Nx.power(g, order) + decay * z end)
deep_merge(x, moment, fn g, z -> (1 - decay) * Nx.pow(g, order) + decay * z end)
end

defnp bias_correction(moment, decay, count) do
deep_new(moment, fn z -> z / (1 - Nx.power(decay, count)) end)
deep_new(moment, fn z -> z / (1 - Nx.pow(decay, count)) end)
end

defnp safe_norm(g, min_norm) do
Expand Down
6 changes: 3 additions & 3 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,23 @@ defmodule Axon.MixProject do
if path = System.get_env("AXON_NX_PATH") do
[path: path, override: true]
else
[]
[github: "elixir-nx/nx", sparse: "nx", override: true]
end
end

defp exla_opts do
if path = System.get_env("AXON_EXLA_PATH") do
[path: path]
else
[]
[github: "elixir-nx/nx", sparse: "exla", override: true]
end
end

defp torchx_opts do
if path = System.get_env("AXON_TORCHX_PATH") do
[path: path]
else
[]
[github: "elixir-nx/nx", sparse: "torchx", override: true]
end
end

Expand Down
10 changes: 5 additions & 5 deletions mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@
"cc_precompiler": {:hex, :cc_precompiler, "0.1.5", "ac3ef86f31ab579b856192a948e956cc3e4bb5006e303c4ab4b24958108e218a", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "ee5b2e56eb03798231a3d322579fff509139a534ef54205d04c188e18cab1f57"},
"complex": {:hex, :complex, "0.4.3", "84db4aad241099a8785446ac6eacf498bf3a60634a0e45c7745d875714ddbf98", [:mix], [], "hexpm", "2ceda96ebddcc22697974f1a2666d4cc5dfdd34f8cd8c4f9dced037bcb41eeb5"},
"dll_loader_helper": {:hex, :dll_loader_helper, "0.1.10", "ba85d66f82c1748513dbaee71aa9d0593bb9a65dba246b980753c4d683b0a07b", [:make, :mix], [{:castore, ">= 0.0.0", [hex: :castore, repo: "hexpm", optional: false]}, {:cc_precompiler, "~> 0.1", [hex: :cc_precompiler, repo: "hexpm", optional: false]}], "hexpm", "c0d02a2d8cd0085252f7551a343f89060bb7beb3f303d991e46a7370ed257485"},
"earmark_parser": {:hex, :earmark_parser, "1.4.29", "149d50dcb3a93d9f3d6f3ecf18c918fb5a2d3c001b5d3305c926cddfbd33355b", [:mix], [], "hexpm", "4902af1b3eb139016aed210888748db8070b8125c2342ce3dcae4f38dcc63503"},
"earmark_parser": {:hex, :earmark_parser, "1.4.30", "0b938aa5b9bafd455056440cdaa2a79197ca5e693830b4a982beada840513c5f", [:mix], [], "hexpm", "3b5385c2d36b0473d0b206927b841343d25adb14f95f0110062506b300cd5a1b"},
"elixir_make": {:hex, :elixir_make, "0.7.3", "c37fdae1b52d2cc51069713a58c2314877c1ad40800a57efb213f77b078a460d", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "24ada3e3996adbed1fa024ca14995ef2ba3d0d17b678b0f3f2b1f66e6ce2b274"},
"ex_doc": {:hex, :ex_doc, "0.29.1", "b1c652fa5f92ee9cf15c75271168027f92039b3877094290a75abcaac82a9f77", [:mix], [{:earmark_parser, "~> 1.4.19", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "b7745fa6374a36daf484e2a2012274950e084815b936b1319aeebcf7809574f6"},
"exla": {:hex, :exla, "0.4.2", "7d5008c36c942de75efddffe4a4e6aac98da722261b7188b23b1363282a146a8", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.4.2", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.4.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "c7c5d70073c30ca4fee3981d992d27f2a2c1d8333b012ab8d0f7330c3624ee79"},
"kino": {:hex, :kino, "0.8.0", "07603a32c111959ed48f08ac3808a0dda05433d28f8d2f06d65b25b255966649", [:mix], [{:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "736568d4de9eb56d8903bae6fe08b7c06db44efe37bb883165e755e623881c51"},
"exla": {:git, "https://github.com/elixir-nx/nx.git", "952dd193f8be7041a1ee835bbe58753baa5460cc", [sparse: "exla"]},
"kino": {:hex, :kino, "0.8.1", "da3b2cba121b7542146cffdb8af055fa0129395fa67aead9e7e3df93aed1f107", [:mix], [{:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "da45dd141db30db18973de0e3398bda3ab8cb0b5da58d6a0debbe5b864aba295"},
"kino_vega_lite": {:hex, :kino_vega_lite, "0.1.7", "c93fdfe6e35c4c5a4f8afd51a89786b2187e5a7da4595b13ea02a4329d9f0976", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}, {:vega_lite, "~> 0.1.4", [hex: :vega_lite, repo: "hexpm", optional: false]}], "hexpm", "59ee442f0532266749d15dc9af4e2875bec61ccfa1b07636bc396ee63dfde8e7"},
"makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"},
"makeup_elixir": {:hex, :makeup_elixir, "0.16.0", "f8c570a0d33f8039513fbccaf7108c5d750f47d8defd44088371191b76492b0b", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "28b2cbdc13960a46ae9a8858c4bebdec3c9a6d7b4b9e7f4ed1502f8159f338e7"},
"makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"},
"nimble_parsec": {:hex, :nimble_parsec, "1.2.3", "244836e6e3f1200c7f30cb56733fd808744eca61fd182f731eac4af635cc6d0b", [:mix], [], "hexpm", "c8d789e39b9131acf7b99291e93dae60ab48ef14a7ee9d58c6964f59efb570b0"},
"nx": {:hex, :nx, "0.4.2", "444e9cc1b1e95edf8c9d9d9f22635349a0cd60cb6a07d4954f3016b2d6d178d7", [:mix], [{:complex, "~> 0.4.3", [hex: :complex, repo: "hexpm", optional: false]}], "hexpm", "9d8f110cf733c4bbc86f0a5fe08f6537e106c39bbcb6dfabc7ef33f14f12edb3"},
"nx": {:git, "https://github.com/elixir-nx/nx.git", "952dd193f8be7041a1ee835bbe58753baa5460cc", [sparse: "nx"]},
"table": {:hex, :table, "0.1.2", "87ad1125f5b70c5dea0307aa633194083eb5182ec537efc94e96af08937e14a8", [:mix], [], "hexpm", "7e99bc7efef806315c7e65640724bf165c3061cdc5d854060f74468367065029"},
"table_rex": {:hex, :table_rex, "3.1.1", "0c67164d1714b5e806d5067c1e96ff098ba7ae79413cc075973e17c38a587caa", [:mix], [], "hexpm", "678a23aba4d670419c23c17790f9dcd635a4a89022040df7d5d772cb21012490"},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
"torchx": {:hex, :torchx, "0.4.2", "eefeed48f2808f1b29858a1a86458b13bc6ef44966727ec7a68e036d5b3d2960", [:make, :mix], [{:dll_loader_helper, "~> 0.1.0", [hex: :dll_loader_helper, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.4.1", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "9c409d393bce3a26c8fc6ba5d4e3466d25e2d57a667a09b42cf65ca8bcb405d3"},
"torchx": {:git, "https://github.com/elixir-nx/nx.git", "952dd193f8be7041a1ee835bbe58753baa5460cc", [sparse: "torchx"]},
"vega_lite": {:hex, :vega_lite, "0.1.6", "145ab4908bc890b02cef3526e890e9b899528eaa7aa9d6fa642b52a8a2c682c6", [:mix], [{:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "078c0d8cd9a8eca4ae8f9527c45c01d69cefb6b2235fd5179a227ac2f031d7ac"},
"xla": {:hex, :xla, "0.4.3", "cf6201aaa44d990298996156a83a16b9a87c5fbb257758dbf4c3e83c5e1c4b96", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "caae164b56dcaec6fbcabcd7dea14303afde07623b0cfa4a3cd2576b923105f5"},
}
Loading

0 comments on commit 899b3bf

Please sign in to comment.