diff --git a/core/src/main/scala/org/apache/spark/eventhubs/client/EventHubsClient.scala b/core/src/main/scala/org/apache/spark/eventhubs/client/EventHubsClient.scala index 326c5dbfa..3c73f791f 100644 --- a/core/src/main/scala/org/apache/spark/eventhubs/client/EventHubsClient.scala +++ b/core/src/main/scala/org/apache/spark/eventhubs/client/EventHubsClient.scala @@ -29,6 +29,7 @@ import org.json4s.jackson.Serialization import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import scala.util.{ Failure, Success, Try } /** * Wraps a raw EventHubReceiver to make it easier for unit tests @@ -79,12 +80,20 @@ private[spark] class EventHubsClient(private val ehConf: EventHubsConf) } } - // Note: the EventHubs Java Client will retry this API call on failure + @annotation.tailrec + final def retry[T](n: Int)(fn: => T): T = { + Try { fn } match { + case Success(x) => x + case Failure(e: EventHubException) if e.getIsTransient && n > 1 => + logInfo("Retrying getRunTimeInfo failure.", e) + retry(n - 1)(fn) + case Failure(e) => throw e + } + } + private def getRunTimeInfo(partitionId: PartitionId): PartitionRuntimeInformation = { - try { + retry(RetryCount) { client.getPartitionRuntimeInformation(partitionId.toString).get - } catch { - case e: Exception => throw e } } diff --git a/core/src/main/scala/org/apache/spark/eventhubs/package.scala b/core/src/main/scala/org/apache/spark/eventhubs/package.scala index b7bc0296f..c2fd6d3b1 100644 --- a/core/src/main/scala/org/apache/spark/eventhubs/package.scala +++ b/core/src/main/scala/org/apache/spark/eventhubs/package.scala @@ -40,6 +40,7 @@ package object eventhubs { val DefaultUseSimulatedClient = "false" val StartingSequenceNumber = 0L val DefaultEpoch = 0L + val RetryCount = 3 type PartitionId = Int val PartitionId = Int