Skip to content

Commit

Permalink
Use double instead of float
Browse files Browse the repository at this point in the history
Using real made it really slow and that much precision wasn't required anyway
  • Loading branch information
Marenz committed Sep 18, 2016
1 parent 396548e commit e97f07d
Showing 1 changed file with 49 additions and 48 deletions.
97 changes: 49 additions & 48 deletions source/app.d
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@ auto sigmoid ( T ) ( T x )
}

// Create a two dimensional array filled with random values
auto randomSlice ( Lengths... ) ( real min, real max, Lengths lengths )
auto randomSlice ( Lengths... ) ( double min, double max, Lengths lengths )
{
import std.random;

auto matrix = slice!real(lengths);
auto matrix = slice!double(lengths);

matrix.ndEach!((ref a) => a = uniform(min, max));

return matrix;
}

alias Matrix2d = Slice!(2, real*);
alias Vector = Slice!(1, real*);
alias Matrix2d = Slice!(2, double*);
alias Vector = Slice!(1, double*);

class LstmParam
{
Expand Down Expand Up @@ -56,19 +56,19 @@ class LstmParam
this.bf = randomSlice(-0.1, 0.1, mem_cell_ct);
this.bo = randomSlice(-0.1, 0.1, mem_cell_ct);

this.wg_diff = slice([mem_cell_ct, concat_len], cast(real)0.0);
this.wi_diff = slice([mem_cell_ct, concat_len], cast(real)0.0);
this.wf_diff = slice([mem_cell_ct, concat_len], cast(real)0.0);
this.wo_diff = slice([mem_cell_ct, concat_len], cast(real)0.0);
this.wg_diff = slice([mem_cell_ct, concat_len], cast(double)0.0);
this.wi_diff = slice([mem_cell_ct, concat_len], cast(double)0.0);
this.wf_diff = slice([mem_cell_ct, concat_len], cast(double)0.0);
this.wo_diff = slice([mem_cell_ct, concat_len], cast(double)0.0);


this.bg_diff = slice([mem_cell_ct], cast(real)0.0);
this.bi_diff = slice([mem_cell_ct], cast(real)0.0);
this.bf_diff = slice([mem_cell_ct], cast(real)0.0);
this.bo_diff = slice([mem_cell_ct], cast(real)0.0);
this.bg_diff = slice([mem_cell_ct], cast(double)0.0);
this.bi_diff = slice([mem_cell_ct], cast(double)0.0);
this.bf_diff = slice([mem_cell_ct], cast(double)0.0);
this.bo_diff = slice([mem_cell_ct], cast(double)0.0);
}

void applyDiff ( real lr = 1.0L )
void applyDiff ( double lr = 1.0L )
{
void apply (S) ( ref S a, ref S diff )
{
Expand All @@ -93,10 +93,10 @@ class LstmParam
apply(this.bo, this.bo_diff);


this.wg_diff[] = cast(real)0.0;
this.wi_diff[] = cast(real)0.0;
this.wf_diff[] = cast(real)0.0;
this.wo_diff[] = cast(real)0.0;
this.wg_diff[] = cast(double)0.0;
this.wi_diff[] = cast(double)0.0;
this.wf_diff[] = cast(double)0.0;
this.wo_diff[] = cast(double)0.0;

this.bg_diff[] = 0.0;
this.bi_diff[] = 0.0;
Expand All @@ -113,16 +113,16 @@ class LstmState

this ( long mem_cell_ct, long x_dim )
{
this.g = slice([mem_cell_ct], cast(real)0.0);
this.i = slice([mem_cell_ct], 0.0L);
this.f = slice([mem_cell_ct], 0.0L);
this.o = slice([mem_cell_ct], 0.0L);
this.s = slice([mem_cell_ct], 0.0L);
this.h = slice([mem_cell_ct], 0.0L);

this.bottom_diff_h = slice([this.h.length], 0.0L);
this.bottom_diff_s = slice([this.s.length], 0.0L);
this.bottom_diff_x = slice([x_dim], 0.0L);
this.g = slice([mem_cell_ct], cast(double)0.0);
this.i = slice([mem_cell_ct], 0.0);
this.f = slice([mem_cell_ct], 0.0);
this.o = slice([mem_cell_ct], 0.0);
this.s = slice([mem_cell_ct], 0.0);
this.h = slice([mem_cell_ct], 0.0);

this.bottom_diff_h = slice([this.h.length], 0.0);
this.bottom_diff_s = slice([this.s.length], 0.0);
this.bottom_diff_x = slice([x_dim], 0.0);
}
}

Expand All @@ -146,16 +146,16 @@ class LstmNode
Vector h_prev = Vector() )
{
// If this is the first lstm node in the network
if (s_prev.length == 0) s_prev = slice([this.state.s.length], 0.0L);
if (h_prev.length == 0) h_prev = slice([this.state.h.length], 0.0L);
if (s_prev.length == 0) s_prev = slice([this.state.s.length], 0.0);
if (h_prev.length == 0) h_prev = slice([this.state.h.length], 0.0);

// save data for use in backprop
this.s_prev = s_prev;
this.h_prev = h_prev;

import std.range;

real[] merged = x.ptr[0 .. x.elementsCount] ~
double[] merged = x.ptr[0 .. x.elementsCount] ~
s_prev.ptr[0 .. s_prev.elementsCount];

// concatenate x(t) and h(t-1)
Expand All @@ -172,9 +172,10 @@ class LstmNode
// assigns result of func(pw.dot(xc) + pb) to state
void apply ( alias func ) ( ref Vector state, Matrix2d pw, Vector pb )
{
state = pw.dot(xc).perElement!add(pb).ndMap!func.slice;
state = pw.dot(xc).perElement!add(pb).ndMap!(a =>
cast(double)func(a)).slice;

//gemv!(real, real, real)(&glas, 1.0L, pw, xc, 0.0L, state);
//gemv!(double, double, double)(&glas, 1.0L, pw, xc, 0.0, state);
//auto zip = assumeSameStructure!("a", "b")(state, pb);
//zip.ndEach!(z => z.a = func(z.a + z.b))();
}
Expand Down Expand Up @@ -238,8 +239,8 @@ class LstmNode
this.param.bg_diff = this.param.bi_diff.perElement!add(dg_input);

// compute bottom diff
//typeof(this.xc) dxc; dxc[] = 0.0L;
auto dxc = slice!real(this.xc.shape, 0.0L);
//typeof(this.xc) dxc; dxc[] = 0.0;
auto dxc = slice!double(this.xc.shape, 0.0);

import mir.ndslice.iteration;

Expand All @@ -259,8 +260,8 @@ class LstmNetwork
{
interface LossLayer
{
real loss ( Vector pred, real label );
Vector bottomDiff ( Vector pred, real label );
double loss ( Vector pred, double label );
Vector bottomDiff ( Vector pred, double label );
}

LstmParam param;
Expand All @@ -279,7 +280,7 @@ class LstmNetwork
Will *NOT* update parameters. To update parameters,
call this.lstm_param.applyDiff()
*/
real yListIs ( Vector y_list, LossLayer loss_layer )
double yListIs ( Vector y_list, LossLayer loss_layer )
{
assert(x_list.length == y_list.length);
assert(x_list.length > 0);
Expand All @@ -293,7 +294,7 @@ class LstmNetwork

// here s is not affecting loss due to h(t+1), hence we set equal to
// zero
auto diff_s = slice([this.param.mem_cell_ct], 0.0L);
auto diff_s = slice([this.param.mem_cell_ct], 0.0);
this.node_list[idx].topDiffIs(diff_h, diff_s);

idx -= 1;
Expand All @@ -320,12 +321,12 @@ class LstmNetwork

void xListClear ( )
{
this.x_list = slice!real([0, 0]);
this.x_list = slice!double([0, 0]);
}

void xListAdd ( Vector x )
{
auto new_matrix = slice!real([this.x_list.length + 1, x.length]);
auto new_matrix = slice!double([this.x_list.length + 1, x.length]);

//writefln("matrix created, lens: %s", new_matrix.shape());
//writefln("vec shape: %s", x.shape());
Expand Down Expand Up @@ -375,16 +376,16 @@ auto dot ( Matrix2d mat, Vector vec )

GlasContext glas;

Vector result = slice!real(mat.length);
Vector result = slice!double(mat.length);


gemv!(real, real, real)(&glas, 1.0L, mat, vec, 0.0L, result);
gemv!(double, double, double)(&glas, 1.0L, mat, vec, 0.0, result);

return result;
}

// Per forms per-element operation of the constant on the vector
auto perElement ( alias op ) ( real a, Vector b )
auto perElement ( alias op ) ( double a, Vector b )
{
return b.ndMap!(x => op(a, x)).slice;
}
Expand All @@ -400,7 +401,7 @@ auto perElement ( alias op, T ) ( T a, T b )
// Calculates the outer product of a and b
auto outer ( Vector a, Vector b )
{
auto result = slice!real(a.length, b.length);
auto result = slice!double(a.length, b.length);

size_t idx_a, idx_b;

Expand All @@ -422,13 +423,13 @@ auto outer ( Vector a, Vector b )
// Computes square loss with first element of hidden layer array
class ToyLossLayer : LstmNetwork.LossLayer
{
override real loss ( Vector pred, real label )
override double loss ( Vector pred, double label )
{
return (pred[0] - label) * (pred[0] - label);
}
override Vector bottomDiff ( Vector pred, real label )
override Vector bottomDiff ( Vector pred, double label )
{
auto diff = slice!real([pred.length], 0.0L);
auto diff = slice!double([pred.length], 0.0);
diff[0] = 2 * (pred[0] - label);
return diff;
}
Expand All @@ -445,7 +446,7 @@ void main()
auto concat_len = x_dim + mem_cell_ct;
auto param = new LstmParam(mem_cell_ct, x_dim);
auto net = new LstmNetwork(param);
auto y_list = slice!real([4]);
auto y_list = slice!double([4]);
y_list[] = [-0.5, 0.2, 0.1, -0.5];

auto input_val_arr = randomSlice(-1.0L, 1.0L, y_list.length, x_dim);
Expand Down

0 comments on commit e97f07d

Please sign in to comment.