Skip to content

Commit

Permalink
Merged PR 21151: Cleaning up fp16 behavior
Browse files Browse the repository at this point in the history
This PR improves clipping and pruning behavior of NaNs and Infs during fp16 training, ultimately avoiding the underflow problems that we were facing so far.
  • Loading branch information
emjotde committed Oct 26, 2021
1 parent 7f06f3c commit 1404201
Show file tree
Hide file tree
Showing 13 changed files with 233 additions and 141 deletions.
4 changes: 2 additions & 2 deletions src/common/aliases.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ void ConfigParser::addAliases(cli::CLIWrapper& cli) {
cli.alias("fp16", "true", [&](YAML::Node& config) {
if(mode_ == cli::mode::training) {
config["precision"] = std::vector<std::string>({"float16", "float32"}); // inference type, optimization type, save type
// scaling factor (power of 2), frequency, multiplier at increase, tolerance, range, minium factor
config["cost-scaling"] = std::vector<std::string>({"0", "1000", "2", "0.05", "10", "1e-5"});
// scaling factor, frequency, multiplier at increase, minium scaling factor
config["cost-scaling"] = std::vector<std::string>({"256.f", "1000", "2.f", "256.f"});
} else {
config["precision"] = std::vector<std::string>({"float16"}); // for inference we do not need the other types
}
Expand Down
6 changes: 3 additions & 3 deletions src/common/config_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,15 +522,15 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
// mixed precision training
cli.add<bool>("--fp16",
"Shortcut for mixed precision training with float16 and cost-scaling, "
"corresponds to: --precision float16 float32 --cost-scaling 0 1000 2 0.05 10 1e-5f");
"corresponds to: --precision float16 float32 --cost-scaling 256.f 1000 2.f 256.f");
cli.add<std::vector<std::string>>("--precision",
"Mixed precision training for forward/backward pass and optimizaton. "
"Defines types for: forward/backward pass, optimization.",
{"float32", "float32"});
cli.add<std::vector<std::string>>("--cost-scaling",
"Dynamic cost scaling for mixed precision training: "
"power of 2, scaling window, scaling factor, tolerance, range, minimum factor")
->implicit_val("0.f 1000 2.f 0.05f 10 1e-5f");
"scaling factor, frequency, multiplier, minimum factor")
->implicit_val("256.f 1000 2.f 256.f");
cli.add<size_t>("--gradient-norm-average-window",
"Window size over which the exponential average of the gradient norm is recorded (for logging and scaling). "
"After this many updates about 90% of the mass of the exponential average comes from these updates",
Expand Down
10 changes: 5 additions & 5 deletions src/common/definitions.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,24 +106,24 @@ using Weak = std::weak_ptr<T>;
/** @brief Creates shared_ptr of any type, passes all arguments to any available
* constructor */
template <class T, typename... Args>
Ptr<T> New(Args&&... args) {
return Ptr<T>(new T(std::forward<Args>(args)...));
inline Ptr<T> New(Args&&... args) {
return std::make_shared<T>(std::forward<Args>(args)...);
}

template <class T>
Ptr<T> New(Ptr<T> p) {
inline Ptr<T> New(Ptr<T> p) {
return Ptr<T>(p);
}

/** @brief Creates InstrusivePtr of any type, passes all arguments to any available
* constructor */
template <class T, typename... Args>
IPtr<T> INew(Args&&... args) {
inline IPtr<T> INew(Args&&... args) {
return IPtr<T>(new T(std::forward<Args>(args)...));
}

template <class T>
IPtr<T> INew(Ptr<T> p) {
inline IPtr<T> INew(Ptr<T> p) {
return IPtr<T>(p);
}

Expand Down
15 changes: 10 additions & 5 deletions src/models/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,7 @@ class Transformer : public EncoderOrDecoderBase {

int dimDepth = dimModel / dimHeads;

auto output
= reshape(input, {dimBatch * dimBeam, dimSteps, dimHeads, dimDepth});
auto output = reshape(input, {dimBatch * dimBeam, dimSteps, dimHeads, dimDepth});

return transpose(output, {0, 2, 1, 3}); // [dimBatch*dimBeam, dimHeads, dimSteps, dimDepth]
}
Expand Down Expand Up @@ -361,9 +360,9 @@ class Transformer : public EncoderOrDecoderBase {

Expr LayerAttention(std::string prefix,
Expr input, // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim]
const Expr& keys, // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
const Expr& values, // ...?
const Expr& mask, // [-4: batch size, -3: num heads broadcast=1, -2: max length broadcast=1, -1: max length]
Expr keys, // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
Expr values, // ...?
Expr mask, // [-4: batch size, -3: num heads broadcast=1, -2: max length broadcast=1, -1: max length]
int dimHeads,
bool cache = false,
bool saveAttentionWeights = false) {
Expand All @@ -373,6 +372,12 @@ class Transformer : public EncoderOrDecoderBase {
auto opsPre = opt<std::string>("transformer-preprocess");
auto output = preProcess(prefix + "_Wo", opsPre, input, dropProb);

// fixes missing norm for keys and values in self-attention with pre-norm
if(input == keys)
keys = output;
if(input == values)
values = output;

// multi-head self-attention over previous input
output = MultiHead(prefix, dimModel, dimHeads, output, keys, values, mask, cache, saveAttentionWeights);

Expand Down
4 changes: 4 additions & 0 deletions src/tensors/cpu/tensor_operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ void IsNaN(const Tensor /*in*/, Ptr<Allocator> /*allocator*/, bool& /*isNaN*/, b
ABORT("Not implemented");
}

bool SanitizeGradient(marian::Tensor /*in*/, Ptr<Allocator> /*allocator*/, bool /*pruneNaN*/, bool /*clipInf*/) {
ABORT("Not implemented");
}

template <bool add, typename To, typename From>
void CopyCastTo(To* out, const From* in, int length) {
for(int i = 0; i < length; ++i)
Expand Down
12 changes: 4 additions & 8 deletions src/tensors/gpu/element.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ __global__ void gElement(
indices[i] = tensors[i].shape().bindex(dims);
}

tensors[0].data()[index] = functional::apply(functor, tensors, indices);
// This performs the internal application of the functor in float32 regardless of the input type.
// It seems there are no speed penalties but improved precision.
tensors[0].data()[index] = (T)functional::applyWithCast<float>(functor, tensors, indices);
}
}
}
Expand Down Expand Up @@ -65,13 +67,7 @@ void Element(Functor functor, Tensor out, Tensors... tensors) {
ElementTyped<float>(functor, out, tensors...);
} else if(out->type() == Type::float16) {
#if COMPILE_FP16
std::vector<marian::Tensor> ts({out, tensors...});
bool div2 = std::all_of(ts.cbegin(), ts.cend(), [](marian::Tensor t){ return t->shape()[-1] % 2 == 0; });
if(div2) {
ElementTyped<halfx2>(functor, out, tensors...);
} else {
ElementTyped<half>(functor, out, tensors...);
}
ElementTyped<half>(functor, out, tensors...);
#else
ABORT("FP16 not supported with chosen current hardware or CUDA version");
#endif
Expand Down
147 changes: 110 additions & 37 deletions src/tensors/gpu/tensor_operators.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,12 @@ namespace gpu {
namespace atomics {

static inline __device__ void atomicAdd(float *address, float val) {
//*address += val;
::atomicAdd(address, val);
}

#if COMPILE_FP16
// @TODO: copied from CuTorch, adapt this better, give credit.
static inline __device__ void atomicAdd(half *address, half val) {
//*address += val;

#if __CUDA_ARCH__ >= 700 && CUDA_VERSION >= 10000 // compute capability 70 and higher with CUDA 10
::atomicAdd(address, val);
#else // __CUDA_ARCH__ < 700
Expand All @@ -50,7 +47,8 @@ static inline __device__ void atomicAdd(half *address, half val) {
} while (assumed != old);
#endif // __CUDA_ARCH__
}
#endif
#endif // COMPILE_FP16


}

Expand Down Expand Up @@ -96,6 +94,81 @@ void IsNaN(const Tensor in, Ptr<Allocator> allocator, bool& isNaN, bool& isInf)
cudaStreamSynchronize(0);
}

template <typename T>
__global__ void gSanitizeGradient(T* in, int length,
bool* isNaN, bool* isInf,
bool pruneNaN, bool clipInf,
float forNaN = 0.f, float forInf = 65504.f, float forInfNeg = -65504.f) {
for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) {
int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
if(index < length) {
float v = (float)in[index];
// handle NaN
if(isnan(v)) {
if(pruneNaN) {
in[index] = (T)forNaN;
} else {
*isNaN = true;
}
}
// handle +/- Inf
if(isinf(v)) {
if(clipInf) {
in[index] = v > 0 ? (T)forInf : (T)forInfNeg;
} else {
*isInf = true;
}
}
}
}
}

// This function is meant to clean gradients, i.e. clip infinities and prune NaNs if required.
// If all NaNs and Infs have been removed we return `true` for indicating a sane gradient.
// If `clipInf` is set, infinities are replaced with the maximum/minimum non-inf value for the tensor.
// In that case infinities do not result in a bad gradient, since they get clipped.
// If `pruneNaN` is set, NaNs are replaced with 0. Since NaNs get removed now they do not result
// in a bad gradient.
// If NaNs or infinities are detected but not removed (either because of `pruneNaN=false` or `clipInf=false`),
// we return `false` indicating a bad gradient.
bool SanitizeGradient(marian::Tensor in, Ptr<Allocator> allocator, bool pruneNaN, bool clipInf) {
cudaSetDevice(in->getDeviceId().no);

int length = in->size();

int threads = std::min(MAX_THREADS, length);
int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));

auto mem = allocator->alloc<bool>(2);
bool* dIsNaN = &mem->data<bool>()[0];
bool* dIsInf = &mem->data<bool>()[1];
fill(in->getBackend(), dIsNaN, dIsNaN + 2, false);

float forNaN = 0.f;
float forInf = NumericLimits<float>(in->type()).max;
float forInfNeg = NumericLimits<float>(in->type()).lowest;

if(in->type() == Type::float32) {
gSanitizeGradient<<<blocks, threads>>>(in->data<float>(), length, dIsNaN, dIsInf, pruneNaN, clipInf, forNaN, forInf, forInfNeg);
#if COMPILE_FP16
} else if(in->type() == Type::float16) {
gSanitizeGradient<<<blocks, threads>>>(in->data<half>(), length, dIsNaN, dIsInf, pruneNaN, clipInf, forNaN, forInf, forInfNeg);
#endif
} else {
ABORT("gSanitizeGradient for type {} not implemented", in->type());
}

bool isNaN, isInf;
CudaCopy(dIsNaN, dIsNaN + 1, &isNaN);
CudaCopy(dIsInf, dIsInf + 1, &isInf);

allocator->free(mem);

cudaStreamSynchronize(0);

return !isNaN && !isInf;
}

template <bool add, typename To, typename From>
__global__ void gCopyCastTo(To* out, const From* in, int length) {
for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) {
Expand Down Expand Up @@ -1090,7 +1163,7 @@ void PasteRows(Tensor out,
size_t rowsToCopy = indices->size();

int threads = std::min(MAX_THREADS, (int)cols);
#if 1 // @TODO: make this configurable with a 'deterministic' flag
#if 0 // @TODO: make this configurable with a 'deterministic' flag
// If we only use one block, then each core operates on a different column,
// hence the summation becomes deterministic.
// However, we only use e.g. 512 cores out of possibly 3000+, so this will be
Expand Down Expand Up @@ -1355,7 +1428,7 @@ __global__ void gGRUFastForward(T* out,
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
T m = !mask || mask[j];
float m = !mask || mask[j];
T* rowOut = out + j * cols;
const T* rowState = state + j * cols;

Expand All @@ -1365,21 +1438,21 @@ __global__ void gGRUFastForward(T* out,
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols) {
T r = functional::Ops<T>::sigmoid(xWrow[i] + sUrow[i] + b[i]);
float r = functional::Ops<float>::sigmoid((float)xWrow[i] + (float)sUrow[i] + (float)b[i]);

int k = i + cols;

T z = functional::Ops<T>::sigmoid(xWrow[k] + sUrow[k] + b[k]);
float z = functional::Ops<float>::sigmoid((float)xWrow[k] + (float)sUrow[k] + (float)b[k]);

int l = i + 2 * cols;
T h;
float h;
if(final)
h = functional::Ops<T>::tanh(xWrow[l] + (sUrow[l] + b[l]) * r);
h = functional::Ops<float>::tanh((float)xWrow[l] + ((float)sUrow[l] + (float)b[l]) * r);
else
h = functional::Ops<T>::tanh(xWrow[l] + sUrow[l] * r + b[l]);
h = functional::Ops<float>::tanh((float)xWrow[l] + (float)sUrow[l] * r + (float)b[l]);

T out = ((T)1.f - z) * h + z * rowState[i];
rowOut[i] = m * out + ((T)1.f - m) * rowState[i];
float out = (1.f - z) * h + z * (float)rowState[i];
rowOut[i] = (T)(m * out + (1.f - m) * (float)rowState[i]);
}
}
}
Expand Down Expand Up @@ -1441,7 +1514,7 @@ __global__ void gGRUFastBackward(T* outState,
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
T m = !mask || mask[j];
float m = !mask || mask[j];

T* rowOutState = outState + j * cols;
T* rowOutXW = outXW + j * cols * 3;
Expand All @@ -1459,56 +1532,56 @@ __global__ void gGRUFastBackward(T* outState,
int k = i + cols;
int l = i + 2 * cols;

T r = functional::Ops<T>::sigmoid(rowXW[i] + rowSU[i] + b[i]);
T z = functional::Ops<T>::sigmoid(rowXW[k] + rowSU[k] + b[k]);
float r = functional::Ops<float>::sigmoid((float)rowXW[i] + (float)rowSU[i] + (float)b[i]);
float z = functional::Ops<float>::sigmoid((float)rowXW[k] + (float)rowSU[k] + (float)b[k]);

T h;
float h;
if(final)
h = functional::Ops<T>::tanh(rowXW[l] + (rowSU[l] + b[l]) * r);
h = functional::Ops<float>::tanh((float)rowXW[l] + ((float)rowSU[l] + (float)b[l]) * r);
else
h = functional::Ops<T>::tanh(rowXW[l] + rowSU[l] * r + b[l]);
h = functional::Ops<float>::tanh((float)rowXW[l] + (float)rowSU[l] * r + (float)b[l]);

T adj = rowAdj[i];
float adj = rowAdj[i];

T t = ((T)1.f - z) * ((T)1.f - h * h);
float t = (1.f - z) * (1.f - h * h);

// df/ds
if(outState)
rowOutState[i] += (m * z - m + (T)1.f) * adj;
rowOutState[i] += (T)((m * z - m + 1.f) * adj);

// df/d(xW_r) ...
T dfdxW_r = m * r * ((T)1.f - r) * t * adj;
float dfdxW_r = m * r * (1.f - r) * t * adj;
if(final)
dfdxW_r *= rowSU[l] + b[l];
dfdxW_r *= (float)rowSU[l] + (float)b[l];
else
dfdxW_r *= rowSU[l];
dfdxW_r *= (float)rowSU[l];
if(outXW)
rowOutXW[i] += dfdxW_r;
rowOutXW[i] += (T)dfdxW_r;
if(outSU)
rowOutSU[i] += dfdxW_r;
rowOutSU[i] += (T)dfdxW_r;
if(outB)
rowOutB[i] += dfdxW_r;
rowOutB[i] += (T)dfdxW_r;

// df/d(xW_z) ...
T dfdxW_z = m * ((T)1.f - z) * z * (rowState[i] - h) * adj;
float dfdxW_z = m * (1.f - z) * z * ((float)rowState[i] - h) * adj;
if(outXW)
rowOutXW[k] += dfdxW_z;
rowOutXW[k] += (T)dfdxW_z;
if(outSU)
rowOutSU[k] += dfdxW_z;
rowOutSU[k] += (T)dfdxW_z;
if(outB)
rowOutB[k] += dfdxW_z;
rowOutB[k] += (T)dfdxW_z;

// df/d(xW_x) ...
T dfdxW_x = m * t * adj;
float dfdxW_x = m * t * adj;
if(outXW)
rowOutXW[l] += dfdxW_x;
rowOutXW[l] += (T)dfdxW_x;
if(outSU)
rowOutSU[l] += dfdxW_x * r;
rowOutSU[l] += (T)(dfdxW_x * r);
if(outB)
if(final)
rowOutB[l] += dfdxW_x * r;
rowOutB[l] += (T)(dfdxW_x * r);
else
rowOutB[l] += dfdxW_x;
rowOutB[l] += (T)dfdxW_x;
}
}
}
Expand Down
Loading

0 comments on commit 1404201

Please sign in to comment.