Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Partho Sarthi <psarthi@nvidia.com>
  • Loading branch information
parthosa committed Mar 14, 2024
1 parent 75bd945 commit 8673f2a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,8 @@ object RapidsPluginUtils extends Logging {
throw new RuntimeException(s"Device architecture $gpuArch is unsupported." +
s" Minimum supported architecture: $minSupportedGpuArch.")
}
val supportedMajorGpuArchs = supportedGpuArchs.map(_/10)
val majorGpuArch = gpuArch/10
val supportedMajorGpuArchs = supportedGpuArchs.map(_ / 10)
val majorGpuArch = gpuArch / 10
// Warn the user if the device's major architecture is not available
if (!supportedMajorGpuArchs.contains(majorGpuArch)) {
val supportedMajorArchStr = supportedMajorGpuArchs.toSeq.sorted.mkString(", ")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ class GpuArchitectureTestSuite extends AnyFunSuite {
}

test("test unsupported architecture") {
assertThrows[RuntimeException] {
val jniSupportedGpuArchs = Set(50, 60, 70)
val cudfSupportedGpuArchs = Set(50, 60, 65, 70)
val gpuArch = 40
val jniSupportedGpuArchs = Set(50, 60, 70)
val cudfSupportedGpuArchs = Set(50, 60, 65, 70)
val gpuArch = 40
val exception = intercept[RuntimeException] {
validateGpuArchitectureInternal(gpuArch, jniSupportedGpuArchs, cudfSupportedGpuArchs)
}
assert(exception.getMessage.contains(s"Device architecture $gpuArch is unsupported"))
}

test("test supported major architecture with higher minor version") {
Expand All @@ -51,11 +52,13 @@ class GpuArchitectureTestSuite extends AnyFunSuite {
}

test("test empty supported architecture set") {
assertThrows[IllegalStateException] {
val jniSupportedGpuArchs = Set(50, 60)
val cudfSupportedGpuArchs = Set(70, 80)
val gpuArch = 60
val jniSupportedGpuArchs = Set(50, 60)
val cudfSupportedGpuArchs = Set(70, 80)
val gpuArch = 60
val exception = intercept[IllegalStateException] {
validateGpuArchitectureInternal(gpuArch, jniSupportedGpuArchs, cudfSupportedGpuArchs)
}
assert(exception.getMessage.contains(
s"Compatibility check failed for GPU architecture $gpuArch"))
}
}

0 comments on commit 8673f2a

Please sign in to comment.