Skip to content

Commit

Permalink
Fix small bug in ShuffleBufferCatalog.hasActiveShuffle (NVIDIA#397)
Browse files Browse the repository at this point in the history
* Fix small bug in ShuffleBufferCatalog.hasActiveShuffle

Signed-off-by: Alessandro Bellina <abellina@nvidia.com>

* When unregistering a shuffle, it is no longer active

Signed-off-by: Alessandro Bellina <abellina@nvidia.com>

* Remove unnecesary imports

Signed-off-by: Alessandro Bellina <abellina@nvidia.com>
  • Loading branch information
abellina authored Jul 21, 2020
1 parent cef27c1 commit 9d65260
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class ShuffleBufferCatalog(
}
}

def hasActiveShuffle(shuffleId: Int): Boolean = activeShuffles.contains(shuffleId)
def hasActiveShuffle(shuffleId: Int): Boolean = activeShuffles.containsKey(shuffleId)

/** Get all the buffer IDs that correspond to a shuffle block identifier. */
def blockIdToBuffersIds(blockId: ShuffleBlockId): Array[ShuffleBufferId] = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import org.scalatest.FunSuite
import org.scalatest.mockito.MockitoSugar

import org.apache.spark.sql.rapids.RapidsDiskBlockManager

class ShuffleBufferCatalogSuite extends FunSuite with MockitoSugar {
test("registered shuffles should be active") {
val catalog = mock[RapidsBufferCatalog]
val rapidsDiskBlockManager = mock[RapidsDiskBlockManager]
val shuffleCatalog = new ShuffleBufferCatalog(catalog, rapidsDiskBlockManager)

assertResult(false)(shuffleCatalog.hasActiveShuffle(123))
shuffleCatalog.registerShuffle(123)
assertResult(true)(shuffleCatalog.hasActiveShuffle(123))
shuffleCatalog.unregisterShuffle(123)
assertResult(false)(shuffleCatalog.hasActiveShuffle(123))
}
}

0 comments on commit 9d65260

Please sign in to comment.