Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prepare for Nx 0.5 #471

Merged
merged 4 commits into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/axon/activations.ex
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ defmodule Axon.Activations do
custom_grad(
Nx.max(x, 0),
[x],
fn g -> [{x, Nx.select(Nx.greater(x, 0), g, Nx.broadcast(0, g))}] end
fn g -> [Nx.select(Nx.greater(x, 0), g, Nx.broadcast(0, g))] end
)
end

Expand Down
4 changes: 2 additions & 2 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -889,9 +889,9 @@ defmodule Axon.Compiler do

if event? and mode? do
if on_event == :backward do
Nx.Defn.Kernel.custom_grad(expr, fn _ans, g ->
Nx.Defn.Kernel.custom_grad(expr, [expr], fn g ->
hooked_g = Nx.Defn.Kernel.hook(g, hook_fn)
[{expr, hooked_g}]
[hooked_g]
end)
else
Nx.Defn.Kernel.hook(expr, hook_fn)
Expand Down
14 changes: 7 additions & 7 deletions lib/axon/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -952,13 +952,13 @@ defmodule Axon.Layers do
norm = opts[:norm]

input
|> Nx.power(norm)
|> Nx.pow(norm)
|> Nx.window_sum(window_dimensions,
strides: strides,
padding: padding,
window_dilations: dilations
)
|> Nx.power(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm))
|> Nx.pow(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm))
end

@doc """
Expand Down Expand Up @@ -1078,9 +1078,9 @@ defmodule Axon.Layers do
Axon.Shape.adaptive_pool_window_size(input, window_strides, output_size, opts[:channels])

input
|> Nx.power(norm)
|> Nx.pow(norm)
|> Nx.window_sum(window_dimensions, padding: :valid, strides: window_strides)
|> Nx.power(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm))
|> Nx.pow(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm))
end

## Normalization
Expand Down Expand Up @@ -1438,7 +1438,7 @@ defmodule Axon.Layers do

mask = Nx.less(rand, keep_prob)

a = Nx.rsqrt(keep_prob * Nx.power(Nx.tensor(1, type: Nx.type(input)) * alpha_p, 2))
a = Nx.rsqrt(keep_prob * Nx.pow(Nx.tensor(1, type: Nx.type(input)) * alpha_p, 2))
b = -a * alpha_p * rate

x = Nx.select(mask, input, alpha_p)
Expand Down Expand Up @@ -1661,9 +1661,9 @@ defmodule Axon.Layers do
all_but_batch_and_feature = Axon.Shape.global_pool_axes(input, opts[:channels])

input
|> Nx.power(norm)
|> Nx.pow(norm)
|> Nx.sum(axes: all_but_batch_and_feature, keep_axes: opts[:keep_axes])
|> Nx.power(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm))
|> Nx.pow(Nx.divide(Nx.tensor(1, type: Nx.type(input)), norm))
end

## Sparse
Expand Down
3 changes: 0 additions & 3 deletions lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1811,9 +1811,6 @@ defmodule Axon.Loop do
Logger.debug("Axon.Loop finished batch step execution in #{us_to_ms(time)}ms")
end

# Force a garbage collection so any device or copied data is deallocated.
:erlang.garbage_collect()

batch_fn = {:compiled, batch_fn}
state = %{state | step_state: new_step_state, metrics: new_metrics}

Expand Down
2 changes: 1 addition & 1 deletion lib/axon/loss_scale.ex
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ defmodule Axon.LossScale do
precision during the model training process. Each loss-scale
implementation here returns a 3-tuple of the functions:

{init_fn, scale_fn, unscale_fn, adjust_fn} = Axon.LossScale.static(Nx.power(2, 15))
{init_fn, scale_fn, unscale_fn, adjust_fn} = Axon.LossScale.static(Nx.pow(2, 15))

You can use these to scale/unscale loss and gradients as well
as adjust the loss scale state.
Expand Down
2 changes: 1 addition & 1 deletion lib/axon/losses.ex
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ defmodule Axon.Losses do
loss =
y_true
|> Nx.subtract(y_pred)
|> Nx.power(2)
|> Nx.pow(2)
|> Nx.mean(axes: [-1])

reduction(loss, opts[:reduction])
Expand Down
2 changes: 1 addition & 1 deletion lib/axon/optimizers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ defmodule Axon.Optimizers do
end
end

model_params = Nx.random_uniform({784, 10})
{model_params, _key} = Nx.Random.uniform(key, shape: {784, 10})
{init_fn, update_fn} = Axon.Optimizers.adam(0.005)

optimizer_state =
Expand Down
4 changes: 2 additions & 2 deletions lib/axon/schedules.ex
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ defmodule Axon.Schedules do

decayed_value =
rate
|> Nx.power(p)
|> Nx.pow(p)
|> Nx.multiply(init_value)

Nx.select(
Expand Down Expand Up @@ -210,7 +210,7 @@ defmodule Axon.Schedules do
|> Nx.divide(k)
|> Nx.negate()
|> Nx.add(1)
|> Nx.power(p)
|> Nx.pow(p)
|> Nx.multiply(Nx.subtract(init_value, end_value))
|> Nx.add(end_value)
end
Expand Down
35 changes: 21 additions & 14 deletions lib/axon/updates.ex
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ defmodule Axon.Updates do
opts = keyword!(opts, eps: 1.0e-7)
eps = opts[:eps]

sum_of_squares = deep_merge(x, sum_of_squares, fn g, z -> Nx.power(g, 2) + z end)
sum_of_squares = deep_merge(x, sum_of_squares, fn g, z -> Nx.pow(g, 2) + z end)

inv_sqrt_squares = deep_new(sum_of_squares, fn z -> Nx.rsqrt(z + eps) end)

Expand Down Expand Up @@ -402,7 +402,7 @@ defmodule Axon.Updates do

mu_nu =
deep_merge(mu, nu, fn m, n ->
Nx.rsqrt(-Nx.power(m, 2) + n + eps)
Nx.rsqrt(-Nx.pow(m, 2) + n + eps)
end)

x = deep_merge(x, mu_nu, fn g, mn -> g * mn end)
Expand Down Expand Up @@ -499,7 +499,7 @@ defmodule Axon.Updates do
nu = update_moment(x, nu, b2, 2)
count_inc = count + 1

b2t = Nx.power(b2, count_inc)
b2t = Nx.pow(b2, count_inc)
ro = ro_inf - 2 * count_inc * b2t / (1 - b2t)

mu_hat = bias_correction(mu, b1, count + 1)
Expand Down Expand Up @@ -637,7 +637,7 @@ defmodule Axon.Updates do
sum_gs =
deep_reduce(x, Nx.tensor(0.0), fn leaf, acc ->
leaf
|> Nx.power(2)
|> Nx.pow(2)
|> Nx.sum()
|> Nx.add(acc)
end)
Expand Down Expand Up @@ -771,6 +771,9 @@ defmodule Axon.Updates do
Adds random Gaussian noise to the input.

## Options

* `:seed` - Random seed to use. Defaults to the
current system time.

* `:eta` - Controls amount of noise to add.
Defaults to `0.01`.
Expand All @@ -791,22 +794,26 @@ defmodule Axon.Updates do

def add_noise({init_fn, apply_fn} = combinator, opts)
when is_function(init_fn, 1) and is_function(apply_fn, 3) and is_list(opts) do
stateful(combinator, &init_add_noise/1, &apply_add_noise(&1, &2, &3, opts))
{seed, opts} = Keyword.pop_lazy(opts, :seed, fn -> :erlang.system_time() end)
stateful(combinator, &init_add_noise(&1, seed: seed), &apply_add_noise(&1, &2, &3, opts))
end

defnp init_add_noise(_params) do
%{count: Nx.tensor(0)}
defnp init_add_noise(_params, opts \\ []) do
%{count: Nx.tensor(0), key: Nx.Random.key(opts[:seed])}
end

defnp apply_add_noise(x, %{count: count}, _params, opts \\ []) do
defnp apply_add_noise(x, %{count: count, key: key}, _params, opts \\ []) do
opts = keyword!(opts, eta: 0.01, gamma: 0.55)
var = opts[:eta] / Nx.power(count + 1, opts[:gamma])
var = opts[:eta] / Nx.pow(count + 1, opts[:gamma])

noise = deep_new(x, fn z -> Nx.random_normal(z) end)
{noise, key} =
deep_map_reduce(x, key, fn z, key ->
Nx.Random.normal(key, shape: Nx.shape(z), type: Nx.type(z))
end)

updates = deep_merge(x, noise, fn g, n -> g + var * n end)

{updates, %{count: count + 1}}
{updates, %{count: count + 1, key: key}}
end

@doc """
Expand Down Expand Up @@ -869,7 +876,7 @@ defmodule Axon.Updates do

nu =
deep_merge(x, nu, fn g, v ->
v - (1 - b2) * Nx.sign(v - Nx.power(g, 2)) * Nx.power(g, 2)
v - (1 - b2) * Nx.sign(v - Nx.pow(g, 2)) * Nx.pow(g, 2)
end)

mu_hat = bias_correction(mu, b1, count + 1)
Expand Down Expand Up @@ -998,11 +1005,11 @@ defmodule Axon.Updates do
## Helpers

defnp update_moment(x, moment, decay, order) do
deep_merge(x, moment, fn g, z -> (1 - decay) * Nx.power(g, order) + decay * z end)
deep_merge(x, moment, fn g, z -> (1 - decay) * Nx.pow(g, order) + decay * z end)
end

defnp bias_correction(moment, decay, count) do
deep_new(moment, fn z -> z / (1 - Nx.power(decay, count)) end)
deep_new(moment, fn z -> z / (1 - Nx.pow(decay, count)) end)
end

defnp safe_norm(g, min_norm) do
Expand Down
6 changes: 3 additions & 3 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,23 @@ defmodule Axon.MixProject do
if path = System.get_env("AXON_NX_PATH") do
[path: path, override: true]
else
[]
[github: "elixir-nx/nx", sparse: "nx", override: true]
end
end

defp exla_opts do
if path = System.get_env("AXON_EXLA_PATH") do
[path: path]
else
[]
[github: "elixir-nx/nx", sparse: "exla", override: true]
end
end

defp torchx_opts do
if path = System.get_env("AXON_TORCHX_PATH") do
[path: path]
else
[]
[github: "elixir-nx/nx", sparse: "torchx", override: true]
end
end

Expand Down
10 changes: 5 additions & 5 deletions mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@
"cc_precompiler": {:hex, :cc_precompiler, "0.1.5", "ac3ef86f31ab579b856192a948e956cc3e4bb5006e303c4ab4b24958108e218a", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "ee5b2e56eb03798231a3d322579fff509139a534ef54205d04c188e18cab1f57"},
"complex": {:hex, :complex, "0.4.3", "84db4aad241099a8785446ac6eacf498bf3a60634a0e45c7745d875714ddbf98", [:mix], [], "hexpm", "2ceda96ebddcc22697974f1a2666d4cc5dfdd34f8cd8c4f9dced037bcb41eeb5"},
"dll_loader_helper": {:hex, :dll_loader_helper, "0.1.10", "ba85d66f82c1748513dbaee71aa9d0593bb9a65dba246b980753c4d683b0a07b", [:make, :mix], [{:castore, ">= 0.0.0", [hex: :castore, repo: "hexpm", optional: false]}, {:cc_precompiler, "~> 0.1", [hex: :cc_precompiler, repo: "hexpm", optional: false]}], "hexpm", "c0d02a2d8cd0085252f7551a343f89060bb7beb3f303d991e46a7370ed257485"},
"earmark_parser": {:hex, :earmark_parser, "1.4.29", "149d50dcb3a93d9f3d6f3ecf18c918fb5a2d3c001b5d3305c926cddfbd33355b", [:mix], [], "hexpm", "4902af1b3eb139016aed210888748db8070b8125c2342ce3dcae4f38dcc63503"},
"earmark_parser": {:hex, :earmark_parser, "1.4.30", "0b938aa5b9bafd455056440cdaa2a79197ca5e693830b4a982beada840513c5f", [:mix], [], "hexpm", "3b5385c2d36b0473d0b206927b841343d25adb14f95f0110062506b300cd5a1b"},
"elixir_make": {:hex, :elixir_make, "0.7.3", "c37fdae1b52d2cc51069713a58c2314877c1ad40800a57efb213f77b078a460d", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "24ada3e3996adbed1fa024ca14995ef2ba3d0d17b678b0f3f2b1f66e6ce2b274"},
"ex_doc": {:hex, :ex_doc, "0.29.1", "b1c652fa5f92ee9cf15c75271168027f92039b3877094290a75abcaac82a9f77", [:mix], [{:earmark_parser, "~> 1.4.19", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "b7745fa6374a36daf484e2a2012274950e084815b936b1319aeebcf7809574f6"},
"exla": {:hex, :exla, "0.4.2", "7d5008c36c942de75efddffe4a4e6aac98da722261b7188b23b1363282a146a8", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.4.2", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.4.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "c7c5d70073c30ca4fee3981d992d27f2a2c1d8333b012ab8d0f7330c3624ee79"},
"kino": {:hex, :kino, "0.8.0", "07603a32c111959ed48f08ac3808a0dda05433d28f8d2f06d65b25b255966649", [:mix], [{:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "736568d4de9eb56d8903bae6fe08b7c06db44efe37bb883165e755e623881c51"},
"exla": {:git, "https://github.com/elixir-nx/nx.git", "952dd193f8be7041a1ee835bbe58753baa5460cc", [sparse: "exla"]},
"kino": {:hex, :kino, "0.8.1", "da3b2cba121b7542146cffdb8af055fa0129395fa67aead9e7e3df93aed1f107", [:mix], [{:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "da45dd141db30db18973de0e3398bda3ab8cb0b5da58d6a0debbe5b864aba295"},
"kino_vega_lite": {:hex, :kino_vega_lite, "0.1.7", "c93fdfe6e35c4c5a4f8afd51a89786b2187e5a7da4595b13ea02a4329d9f0976", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}, {:vega_lite, "~> 0.1.4", [hex: :vega_lite, repo: "hexpm", optional: false]}], "hexpm", "59ee442f0532266749d15dc9af4e2875bec61ccfa1b07636bc396ee63dfde8e7"},
"makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"},
"makeup_elixir": {:hex, :makeup_elixir, "0.16.0", "f8c570a0d33f8039513fbccaf7108c5d750f47d8defd44088371191b76492b0b", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "28b2cbdc13960a46ae9a8858c4bebdec3c9a6d7b4b9e7f4ed1502f8159f338e7"},
"makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"},
"nimble_parsec": {:hex, :nimble_parsec, "1.2.3", "244836e6e3f1200c7f30cb56733fd808744eca61fd182f731eac4af635cc6d0b", [:mix], [], "hexpm", "c8d789e39b9131acf7b99291e93dae60ab48ef14a7ee9d58c6964f59efb570b0"},
"nx": {:hex, :nx, "0.4.2", "444e9cc1b1e95edf8c9d9d9f22635349a0cd60cb6a07d4954f3016b2d6d178d7", [:mix], [{:complex, "~> 0.4.3", [hex: :complex, repo: "hexpm", optional: false]}], "hexpm", "9d8f110cf733c4bbc86f0a5fe08f6537e106c39bbcb6dfabc7ef33f14f12edb3"},
"nx": {:git, "https://github.com/elixir-nx/nx.git", "952dd193f8be7041a1ee835bbe58753baa5460cc", [sparse: "nx"]},
"table": {:hex, :table, "0.1.2", "87ad1125f5b70c5dea0307aa633194083eb5182ec537efc94e96af08937e14a8", [:mix], [], "hexpm", "7e99bc7efef806315c7e65640724bf165c3061cdc5d854060f74468367065029"},
"table_rex": {:hex, :table_rex, "3.1.1", "0c67164d1714b5e806d5067c1e96ff098ba7ae79413cc075973e17c38a587caa", [:mix], [], "hexpm", "678a23aba4d670419c23c17790f9dcd635a4a89022040df7d5d772cb21012490"},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
"torchx": {:hex, :torchx, "0.4.2", "eefeed48f2808f1b29858a1a86458b13bc6ef44966727ec7a68e036d5b3d2960", [:make, :mix], [{:dll_loader_helper, "~> 0.1.0", [hex: :dll_loader_helper, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.4.1", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "9c409d393bce3a26c8fc6ba5d4e3466d25e2d57a667a09b42cf65ca8bcb405d3"},
"torchx": {:git, "https://github.com/elixir-nx/nx.git", "952dd193f8be7041a1ee835bbe58753baa5460cc", [sparse: "torchx"]},
"vega_lite": {:hex, :vega_lite, "0.1.6", "145ab4908bc890b02cef3526e890e9b899528eaa7aa9d6fa642b52a8a2c682c6", [:mix], [{:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "078c0d8cd9a8eca4ae8f9527c45c01d69cefb6b2235fd5179a227ac2f031d7ac"},
"xla": {:hex, :xla, "0.4.3", "cf6201aaa44d990298996156a83a16b9a87c5fbb257758dbf4c3e83c5e1c4b96", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "caae164b56dcaec6fbcabcd7dea14303afde07623b0cfa4a3cd2576b923105f5"},
}
Loading