Skip to content

Commit

Permalink
Add tests for initializers and fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Dec 1, 2022
1 parent c1115e4 commit e0bc997
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 52 deletions.
120 changes: 68 additions & 52 deletions lib/axon/initializers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -606,47 +606,70 @@ defmodule Axon.Initializers do
opts =
keyword!(opts, [:shape, type: {:f, 32}, scale: 1.0, mode: :fan_in, distribution: :normal])

fans = transform(opts[:shape], &compute_fans/1)
fans = compute_fans(opts[:shape])
denominator = compute_denominator(fans, opts[:mode])

denominator =
transform(
{fans, opts[:mode]},
fn
{{fan_in, _}, :fan_in} ->
fan_in
variance = Nx.divide(Nx.tensor(opts[:scale], type: opts[:type]), Nx.max(denominator, 1.0))

{{_, fan_out}, :fan_out} ->
fan_out
apply_distribution(key, opts[:distribution], variance, shape: opts[:shape], type: opts[:type])
end

{{fan_in, fan_out}, :fan_avg} ->
(fan_in + fan_out) / 2.0
deftransformp compute_fans(shape) do
rank = Nx.rank(shape)

{{_, _}, mode} ->
raise ArgumentError, "invalid mode #{inspect(mode)} passed to variance_scaling/1"
end
)
{in_size, out_size} =
cond do
rank < 1 ->
{1, 1}

variance = Nx.divide(Nx.tensor(opts[:scale], type: opts[:type]), Nx.max(denominator, 1.0))
rank == 1 ->
{elem(shape, 0), elem(shape, 0)}

rank == 2 ->
{elem(shape, 0), elem(shape, 1)}

var_opts = transform(opts, &Keyword.take(&1, [:shape, :type]))
true ->
{elem(shape, rank - 2), elem(shape, rank - 1)}
end

transform(
{key, opts[:distribution], variance, var_opts},
fn
{key, :normal, variance, opts} ->
var_normal(key, variance, opts)
receptive_field_size = Nx.size(shape) / in_size / out_size
fan_in = in_size * receptive_field_size
fan_out = out_size * receptive_field_size

{key, :uniform, variance, opts} ->
var_uniform(key, variance, opts)
{fan_in, fan_out}
end

{key, :truncated_normal, variance, opts} ->
var_truncated(key, variance, opts)
deftransformp compute_denominator(fans, mode) do
case {fans, mode} do
{{fan_in, _}, :fan_in} ->
fan_in

{_, dist, _, _} ->
raise ArgumentError,
"invalid distribution #{inspect(dist)} passed to variance_scaling/1"
end
)
{{_, fan_out}, :fan_out} ->
fan_out

{{fan_in, fan_out}, :fan_avg} ->
(fan_in + fan_out) / 2.0

{{_, _}, mode} ->
raise ArgumentError, "invalid mode #{inspect(mode)} passed to variance_scaling/1"
end
end

deftransformp apply_distribution(key, distribution, variance, opts) do
case distribution do
:normal ->
var_normal(key, variance, opts)

:uniform ->
var_uniform(key, variance, opts)

:truncated_normal ->
var_truncated(key, variance, opts)

dist ->
raise ArgumentError,
"invalid distribution #{inspect(dist)} passed to variance_scaling/1"
end
end

@doc """
Expand Down Expand Up @@ -761,33 +784,26 @@ defmodule Axon.Initializers do
variance
|> Nx.sqrt()
|> Nx.divide(0.87962566103423978)
|> Nx.as_type(type)

rand = Nx.Random.normal_split(key, 0.0, sigma, shape: shape, type: type)
Nx.clip(rand, -2, 2)
truncated_normal(key, -2, 2, shape: shape, type: type) * sigma
end

defp compute_fans(shape) do
rank = Nx.rank(shape)

{fan_in, fan_out} =
cond do
rank < 1 ->
{1, 1}

rank == 1 ->
{elem(shape, 0), elem(shape, 0)}
defnp truncated_normal(key, lower, upper, opts \\ []) do
opts = keyword!(opts, [:shape, type: {:f, 32}])
shape = opts[:shape]
type = opts[:type]

rank == 2 ->
{elem(shape, 0), elem(shape, 1)}
sqrt2 = Nx.sqrt(2) |> Nx.as_type(type)
lower = Nx.as_type(lower, type)
upper = Nx.as_type(upper, type)

true ->
receptive_field_size = Nx.size(shape) / elem(shape, 0) / elem(shape, 1)
a = Nx.erf(lower / sqrt2)
b = Nx.erf(upper / sqrt2)

fan_in = elem(shape, 0) * receptive_field_size
fan_out = elem(shape, 1) * receptive_field_size
{fan_in, fan_out}
end
u = Nx.Random.uniform_split(key, a, b, shape: shape, type: type)
out = sqrt2 * Nx.erf_inv(u)

{fan_in, fan_out}
Nx.clip(out, lower, upper)
end
end
120 changes: 120 additions & 0 deletions test/axon/initializers_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,126 @@ defmodule Axon.InitializersTest do

doctest Axon.Initializers

describe "lecun_uniform/1" do
test "matches jax with defaults" do
init_fn = Axon.Initializers.lecun_uniform()

actual = init_fn.({6, 4}, :f32, Nx.Random.key(0))

expected =
Nx.tensor([
[0.20278636, -0.44988236, -0.67542195, 0.09019998],
[0.07444432, -0.5715227, -0.22269602, -0.65556777],
[-0.5834403, 0.41140956, -0.20922656, 0.04757705],
[-0.6660935, -0.11757841, 0.11348342, 0.58670807],
[-0.31940702, -0.4950906, 0.6199206, 0.02958001],
[0.01707217, 0.57443, 0.3266003, 0.64393777]
])

assert_all_close(expected, actual)
end
end

describe "lecun_normal/1" do
test "matches jax with defaults" do
init_fn = Axon.Initializers.lecun_normal()

actual = init_fn.({6, 4}, :f32, Nx.Random.key(0))

expected =
Nx.tensor([
[0.16248588, -0.39667854, -0.7911283, 0.07110167],
[0.05860865, -0.5588784, -0.17921554, -0.73135567],
[-0.5787071, 0.35475218, -0.16787331, 0.03739766],
[-0.7614683, -0.09293934, 0.08966116, 0.58433205],
[-0.26443285, -0.4505209, 0.6471739, 0.02323576],
[0.01340684, 0.56362104, 0.27110383, 0.7013128]
])

assert_all_close(expected, actual)
end
end

describe "glorot_uniform/1" do
test "matches jax with defaults" do
init_fn = Axon.Initializers.glorot_uniform()

actual = init_fn.({6, 4}, :f32, Nx.Random.key(0))

expected =
Nx.tensor([
[0.22214133, -0.49282146, -0.7398877, 0.09880914],
[0.08154967, -0.6260718, -0.24395128, -0.7181385],
[-0.6391269, 0.4506766, -0.22919622, 0.05211805],
[-0.7296689, -0.1288007, 0.12431487, 0.6427065],
[-0.34989288, -0.54234457, 0.67908907, 0.03240328],
[0.01870163, 0.6292566, 0.35777274, 0.7053985]
])

assert_all_close(expected, actual)
end
end

describe "glorot_normal/1" do
test "matches jax with defaults" do
init_fn = Axon.Initializers.glorot_normal()

actual = init_fn.({6, 4}, :f32, Nx.Random.key(0))

expected =
Nx.tensor([
[0.17799434, -0.43453953, -0.8666375, 0.07788797],
[0.06420256, -0.6122206, -0.19632077, -0.8011599],
[-0.63394177, 0.3886115, -0.18389598, 0.04096708],
[-0.8341466, -0.10180994, 0.09821887, 0.64010364],
[-0.28967166, -0.49352086, 0.7089434, 0.0254535],
[0.01468645, 0.61741585, 0.29697934, 0.7682496]
])

assert_all_close(expected, actual)
end
end

describe "he_uniform/1" do
test "matches jax with defaults" do
init_fn = Axon.Initializers.he_uniform()

actual = init_fn.({6, 4}, :f32, Nx.Random.key(0))

expected =
Nx.tensor([
[0.28678322, -0.63622975, -0.9551909, 0.12756205],
[0.10528016, -0.8082552, -0.31493974, -0.9271128],
[-0.82510924, 0.58182096, -0.29589105, 0.06728411],
[-0.9419985, -0.16628098, 0.1604898, 0.8297305],
[-0.45170975, -0.70016384, 0.87670016, 0.04183245],
[0.0241437, 0.8123667, 0.4618826, 0.9106655]
])

assert_all_close(expected, actual)
end
end

describe "he_normal/1" do
test "matches jax with defaults" do
init_fn = Axon.Initializers.he_normal()

actual = init_fn.({6, 4}, :f32, Nx.Random.key(0))

expected =
Nx.tensor([
[0.22978972, -0.5609881, -1.1188242, 0.10055294],
[0.08288515, -0.7903734, -0.25344902, -1.034293],
[-0.81841534, 0.5016953, -0.2374087, 0.05288828],
[-1.0768787, -0.13143606, 0.12680002, 0.8263703],
[-0.37396452, -0.6371327, 0.915242, 0.03286032],
[0.01896013, 0.79708046, 0.38339868, 0.991806]
])

assert_all_close(expected, actual)
end
end

describe "orthogonal/1" do
test "property" do
init_fn = Axon.Initializers.orthogonal()
Expand Down

0 comments on commit e0bc997

Please sign in to comment.