Skip to content

Commit

Permalink
Reduced CUDA dropout memory usage (#1005)
Browse files Browse the repository at this point in the history
* Revert "Build Framework (NeoML-master 2.0.204.0): Incrementing version number."

This reverts commit 9a29b52.

* added test to branch

* added rca22824

* started

* revert some changes

* added some new

* reducing dropout memory usage on cuda

Signed-off-by: daniyalaliev <daniial.aliev@abbyy.com>

* design changes have been made

Signed-off-by: daniyalaliev <daniial.aliev@abbyy.com>

---------

Signed-off-by: daniyalaliev <daniial.aliev@abbyy.com>
Co-authored-by: daniyalaliev <daniial.aliev@abbyy.com>
  • Loading branch information
daniyalaliev and daniyalaliev authored Dec 20, 2023
1 parent 6c9023d commit 9cf9652
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 22 deletions.
1 change: 1 addition & 0 deletions NeoMathEngine/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ if((WIN32 OR LINUX) AND CMAKE_SIZEOF_VOID_P EQUAL 8)
GPU/CUDA/CudaCommon.h
GPU/CUDA/CudaDevice.h
GPU/CUDA/CudaMathEngineDnnConvs.h
GPU/CUDA/CudaMathEngineDnnDropout.h
GPU/CUDA/CudaMathEngineDnnPoolings.h
GPU/CUDA/CudaMathEngine.h
GPU/CUDA/CusparseDll.h
Expand Down
47 changes: 29 additions & 18 deletions NeoMathEngine/src/GPU/CUDA/CudaMathEngineDnnDropout.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,37 @@ limitations under the License.
#include <CudaDevice.h>
#include <CudaCommon.h>
#include <MathEngineCommon.h>
#include <MathEngineDnnDropout.h>
#include <MemoryHandleInternal.h>
#include <CudaMathEngineDnnDropout.h>

#include <Kernels/CudaDnnDropoutKernels.h>

namespace NeoML {

void CCudaMathEngine::Dropout( const CDropoutDesc& dropoutDesc, const CFloatHandle& inputData, const CFloatHandle& outputData )
CCudaMathEngineDropoutDesc::CCudaMathEngineDropoutDesc( IMathEngine& mathEngine, float rate, bool isSpatial,
bool isBatchwise, const CBlobDesc& input, const CBlobDesc& output, int seed ) :
Input(input),
Output(output),
ForwardRate(1.f - rate),
IsSpatial(isSpatial),
IsBatchwise(isBatchwise),
seed(seed)
{}

CDropoutDesc* CCudaMathEngine::InitDropout( float rate, bool isSpatial, bool isBatchwise,
const CBlobDesc& input, const CBlobDesc& output, int seed )
{
return new CCudaMathEngineDropoutDesc(mathEngine(), rate, isSpatial, isBatchwise, input, output, seed);
}

void CCudaMathEngine::Dropout( const CDropoutDesc& dropoutDesc,
const CFloatHandle& inputData, const CFloatHandle& outputData )
{
ASSERT_EXPR( inputData.GetMathEngine() == this );
ASSERT_EXPR( outputData.GetMathEngine() == this );
SetCudaDevice( device->DeviceNumber );

const CMathEngineDropoutDesc& desc = static_cast<const CMathEngineDropoutDesc&>( dropoutDesc );
const CCudaMathEngineDropoutDesc& desc = static_cast<const CCudaMathEngineDropoutDesc&>( dropoutDesc );
const CBlobDesc& input = desc.Input;

if( desc.ForwardRate == 1.f ) {
Expand All @@ -47,28 +64,22 @@ void CCudaMathEngine::Dropout( const CDropoutDesc& dropoutDesc, const CFloatHand
const int batchWidth = input.ObjectCount() / batchLength;
const int maskSize = batchWidth * objectSize;

ASSERT_EXPR( desc.Mask.Size() == maskSize );

if( !desc.IsSpatial ) {
MultiplyMatrixByDiagMatrix( inputData, batchLength, maskSize, desc.Mask.GetHandle(),
outputData, desc.Output.BlobSize() );
dim3 blockCount;
dim3 threadCount;

getCudaTaskGrid2D(blockCount, threadCount, batchLength, (maskSize + 3) / 4);
RandomMatrixDropout<<<blockCount, threadCount>>>( GetRaw(inputData), batchLength, maskSize,
GetRaw(outputData), desc.seed, desc.ForwardRate );
return;
}

dim3 blockCount;
dim3 threadCount;

getCudaTaskGrid3D( blockCount, threadCount, input.ObjectCount(), input.ObjectSize() / objectSize,
objectSize );
ChannelLastBlobSpatialDropoutKernel<<<blockCount, threadCount>>>( GetRaw( inputData ),
GetRaw( desc.Mask.GetHandle() ), GetRaw( outputData ), input.ObjectCount(), input.ObjectSize(),
batchWidth, objectSize );
}

CDropoutDesc* CCudaMathEngine::InitDropout( float rate, bool isSpatial, bool isBatchwise,
const CBlobDesc& input, const CBlobDesc& output, int seed )
{
return new CMathEngineDropoutDesc( mathEngine(), rate, isSpatial, isBatchwise, input, output, seed );
getCudaTaskGrid3D( blockCount, threadCount, input.ObjectCount(), input.ObjectSize() / objectSize, objectSize );
RandomSpatialDropout<<<blockCount, threadCount>>>( GetRaw( inputData ), GetRaw( outputData ),
input.ObjectCount(), input.ObjectSize(), batchWidth, objectSize, desc.seed, desc.ForwardRate );
}

} // namespace NeoML
Expand Down
41 changes: 41 additions & 0 deletions NeoMathEngine/src/GPU/CUDA/CudaMathEngineDnnDropout.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/* Copyright © 2017-2023 ABBYY Production LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--------------------------------------------------------------------------------------------------------------*/

#pragma once

#ifdef NEOML_USE_CUDA

#include <NeoMathEngine/NeoMathEngine.h>
#include <NeoMathEngine/CrtAllocatedObject.h>

namespace NeoML {

// DropoutDesc without mask for usage on CUDA
struct CCudaMathEngineDropoutDesc : public CDropoutDesc {
explicit CCudaMathEngineDropoutDesc(IMathEngine& mathEngine, float rate, bool isSpatial, bool isBatchwise,
const CBlobDesc& input, const CBlobDesc& output, int seed);

CBlobDesc Input; // input blob descriptor
CBlobDesc Output; // output blob descriptor
const float ForwardRate; // the probability that an element is not dropped out
const bool IsSpatial; // indicates if whole channels are dropped out
const bool IsBatchwise; // indicates if an element is dropped out of all objects in one batch at the same time
// seed that will be used later in CUDA
const int seed;
};

} // namespace NeoML

#endif // NEOML_USE_CUDA
34 changes: 30 additions & 4 deletions NeoMathEngine/src/GPU/CUDA/Kernels/CudaDnnDropoutKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,46 @@ limitations under the License.
#pragma once

#include <Kernels/CudaGrid.h>
#include <Kernels/CudaRandom.h>

namespace NeoML {

__global__ void ChannelLastBlobSpatialDropoutKernel( const float* __restrict__ input,
const float* __restrict__ mask, float* output, int inputObjectCount, int inputObjectSize, int maskObjectCount,
int maskObjectSize )
__global__ void RandomMatrixDropout( const float* __restrict__ first, int firstHeight,
int firstWidth, float* res, int seed, float forwardRate )
{
const unsigned int threshold = forwardRate * UINT_MAX;
int row;
int col;
if( GetCudaTaskIndex2D( firstHeight, ( firstWidth + 3 ) / 4, row, col ) ) {
CCudaRandom random(seed);
random.Skip(col);
col *= 4;
const int index = row * firstWidth + col;

CIntArray<4> generated = random.Next();
for(int j = 0; j < 4 && col + j < firstWidth; ++j) {
res[index + j] = (generated[j] <= threshold) ? (first[index + j] / forwardRate) : 0.f;
}
}
}

__global__ void RandomSpatialDropout( const float* __restrict__ input, float* res, int inputObjectCount,
int inputObjectSize, int maskObjectCount, int maskObjectSize, int seed, float forwardRate )
{
const unsigned int threshold = forwardRate * UINT_MAX;
int obj;
int row;
int col;
if( GetCudaTaskIndex3D( inputObjectCount, inputObjectSize / maskObjectSize, maskObjectSize, obj, row, col ) ) {
int pack = obj % maskObjectCount;
int index = obj * inputObjectSize + row * maskObjectSize + col;
output[index] = input[index] * mask[maskObjectSize * pack + col];
int numBlock = ( pack * maskObjectSize + col ) / 4;
int numLeft = ( pack * maskObjectSize + col ) % 4;
CCudaRandom random(seed);
random.Skip(numBlock);

CIntArray<4> generated = random.Next();
res[index] = (generated[numLeft] <= threshold) ? (input[index] / forwardRate) : 0.f;
}
}

Expand Down

0 comments on commit 9cf9652

Please sign in to comment.