Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix non streaming rpc endpoint #542

Merged
merged 10 commits into from
Oct 6, 2020
Prev Previous commit
Next Next commit
fixed typos in comments, update formating
  • Loading branch information
Navid Yaghmazadeh committed Oct 1, 2020
commit 7e6b6af6fdbc0803a6c41953136b03fb53bf8806
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ private[spark] trait CachedReceiver {
private[client] class CachedEventHubsReceiver private (ehConf: EventHubsConf,
nAndP: NameAndPartition,
startSeqNo: SequenceNumber)
extends Logging {
extends Logging {

type AwaitTimeoutException = java.util.concurrent.TimeoutException

Expand All @@ -89,11 +89,11 @@ private[client] class CachedEventHubsReceiver private (ehConf: EventHubsConf,
receiverOptions.setIdentifier(s"spark-${SparkEnv.get.executorId}-$taskId")
val consumer = retryJava(
EventHubsUtils.createReceiverInner(client,
ehConf.useExclusiveReceiver,
consumerGroup,
nAndP.partitionId.toString,
EventPosition.fromSequenceNumber(seqNo).convert,
receiverOptions),
ehConf.useExclusiveReceiver,
consumerGroup,
nAndP.partitionId.toString,
EventPosition.fromSequenceNumber(seqNo).convert,
receiverOptions),
"CachedReceiver creation."
)
Await.result(consumer, ehConf.internalOperationTimeout)
Expand Down Expand Up @@ -162,7 +162,7 @@ private[client] class CachedEventHubsReceiver private (ehConf: EventHubsConf,
Await.result(lastReceivedOffset(), ehConf.internalOperationTimeout)

if ((lastReceivedSeqNo > -1 && lastReceivedSeqNo + 1 != requestSeqNo) ||
!receiver.getIsOpen) {
!receiver.getIsOpen) {
logInfo(s"(TID $taskId) checkCursor. Recreating a receiver for $nAndP, ${ehConf.consumerGroup.getOrElse(
DefaultConsumerGroup)}. requestSeqNo: $requestSeqNo, lastReceivedSeqNo: $lastReceivedSeqNo, isOpen: ${receiver.getIsOpen}")

Expand Down Expand Up @@ -193,11 +193,11 @@ private[client] class CachedEventHubsReceiver private (ehConf: EventHubsConf,
// The event still isn't present. It must be (2).
val info = Await.result(
retryJava(client.getPartitionRuntimeInformation(nAndP.partitionId.toString),
"partitionRuntime"),
"partitionRuntime"),
ehConf.internalOperationTimeout)

if (requestSeqNo < info.getBeginSequenceNumber &&
movedSeqNo == info.getBeginSequenceNumber) {
movedSeqNo == info.getBeginSequenceNumber) {
Future {
movedEvent
}
Expand Down Expand Up @@ -238,8 +238,8 @@ private[client] class CachedEventHubsReceiver private (ehConf: EventHubsConf,

val theRest = for { i <- 1 until batchCount } yield
awaitReceiveMessage(receiveOne(ehConf.receiverTimeout.getOrElse(DefaultReceiverTimeout),
s"receive; $nAndP; seqNo: ${requestSeqNo + i}"),
requestSeqNo)
s"receive; $nAndP; seqNo: ${requestSeqNo + i}"),
requestSeqNo)
// Combine and sort the data.
val combined = first ++ theRest.flatten
val sorted = combined.toSeq
Expand All @@ -254,10 +254,10 @@ private[client] class CachedEventHubsReceiver private (ehConf: EventHubsConf,
if (ehConf.slowPartitionAdjustment) {
sendPartitionPerformanceToDriver(
PartitionPerformanceMetric(nAndP,
EventHubsUtils.getTaskContextSlim,
requestSeqNo,
batchCount,
elapsedTimeMs))
EventHubsUtils.getTaskContextSlim,
requestSeqNo,
batchCount,
elapsedTimeMs))
}

if (metricPlugin.isDefined) {
Expand All @@ -270,10 +270,10 @@ private[client] class CachedEventHubsReceiver private (ehConf: EventHubsConf,
.getOrElse((0, 0L))
metricPlugin.foreach(
_.onReceiveMetric(EventHubsUtils.getTaskContextSlim,
nAndP,
batchCount,
batchSizeInBytes,
elapsedTimeMs))
nAndP,
batchCount,
batchSizeInBytes,
elapsedTimeMs))
assert(validateSize == batchCount)
} else {
assert(validate.size == batchCount)
Expand Down Expand Up @@ -329,11 +329,11 @@ private[spark] object CachedEventHubsReceiver extends CachedReceiver with Loggin

private[this] val receivers = new MutableMap[String, CachedEventHubsReceiver]()

// RPC endpoint for partition performacne communciation in the executor
// RPC endpoint for partition performance communication in the executor
val partitionPerformanceReceiverRef =
RpcUtils.makeDriverRef(PartitionPerformanceReceiver.ENDPOINT_NAME,
SparkEnv.get.conf,
SparkEnv.get.rpcEnv)
SparkEnv.get.conf,
SparkEnv.get.rpcEnv)

private def key(ehConf: EventHubsConf, nAndP: NameAndPartition): String = {
(ehConf.connectionString + ehConf.consumerGroup + nAndP.partitionId).toLowerCase
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ import org.apache.spark.eventhubs.rdd.{ EventHubsRDD, OffsetRange }
import org.apache.spark.eventhubs.utils.ThrottlingStatusPlugin
import org.apache.spark.eventhubs.{ EventHubsConf, NameAndPartition, _ }
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.SparkEnv
import org.apache.spark.sql.execution.streaming.{
HDFSMetadataLog,
Offset,
Expand Down Expand Up @@ -88,9 +86,11 @@ private[spark] class EventHubsSource private[eventhubs] (sqlContext: SQLContext,
private val sc = sqlContext.sparkContext

private val maxOffsetsPerTrigger: Option[Long] =
Option(parameters.get(MaxEventsPerTriggerKey).map(_.toLong).getOrElse(
parameters.get(MaxEventsPerTriggerKeyAlias).map(_.toLong).getOrElse(
partitionCount * 1000)))
Option(parameters
.get(MaxEventsPerTriggerKey)
.map(_.toLong)
.getOrElse(
parameters.get(MaxEventsPerTriggerKeyAlias).map(_.toLong).getOrElse(partitionCount * 1000)))

// set slow partition adjustment flag and static values in the tracker
private val slowPartitionAdjustment: Boolean =
Expand Down Expand Up @@ -148,22 +148,25 @@ private[spark] class EventHubsSource private[eventhubs] (sqlContext: SQLContext,
text.substring(1, text.length).toInt
} catch {
case _: NumberFormatException =>
throw new IllegalStateException(s"Log file was malformed: failed to read correct log " +
s"version from $text.")
throw new IllegalStateException(
s"Log file was malformed: failed to read correct log " +
s"version from $text.")
}
if (version > 0) {
if (version > maxSupportedVersion) {
throw new IllegalStateException(s"UnsupportedLogVersion: maximum supported log version " +
s"is v${maxSupportedVersion}, but encountered v$version. The log file was produced " +
s"by a newer version of Spark and cannot be read by this version. Please upgrade.")
throw new IllegalStateException(
s"UnsupportedLogVersion: maximum supported log version " +
s"is v${maxSupportedVersion}, but encountered v$version. The log file was produced " +
s"by a newer version of Spark and cannot be read by this version. Please upgrade.")
} else {
return version
}
}
}
// reaching here means we failed to read the correct log version
throw new IllegalStateException(s"Log file was malformed: failed to read correct log " +
s"version from $text.")
throw new IllegalStateException(
s"Log file was malformed: failed to read correct log " +
s"version from $text.")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ private[sql] class EventHubsSourceProvider
}

private[sql] object EventHubsSourceProvider extends Serializable {
// RPC endpoint for partition performacne communciation in the driver
// RPC endpoint for partition performance communication in the driver
val partitionsStatusTracker = PartitionsStatusTracker.getPartitionStatusTracker
val partitionPerformanceReceiver: PartitionPerformanceReceiver =
new PartitionPerformanceReceiver(SparkEnv.get.rpcEnv, partitionsStatusTracker)
Expand Down Expand Up @@ -177,32 +177,32 @@ private[sql] object EventHubsSourceProvider extends Serializable {
new java.sql.Timestamp(ed.getSystemProperties.getEnqueuedTime.toEpochMilli)),
UTF8String.fromString(ed.getSystemProperties.getPublisher),
UTF8String.fromString(ed.getSystemProperties.getPartitionKey),
ArrayBasedMapData(
ed.getProperties.asScala
.mapValues {
case b: Binary =>
val buf = b.asByteBuffer()
val arr = new Array[Byte](buf.remaining)
buf.get(arr)
arr.asInstanceOf[AnyRef]
case d128: Decimal128 => d128.asBytes.asInstanceOf[AnyRef]
case d32: Decimal32 => d32.getBits.asInstanceOf[AnyRef]
case d64: Decimal64 => d64.getBits.asInstanceOf[AnyRef]
case s: Symbol => s.toString.asInstanceOf[AnyRef]
case ub: UnsignedByte => ub.toString.asInstanceOf[AnyRef]
case ui: UnsignedInteger => ui.toString.asInstanceOf[AnyRef]
case ul: UnsignedLong => ul.toString.asInstanceOf[AnyRef]
case us: UnsignedShort => us.toString.asInstanceOf[AnyRef]
case c: Character => c.toString.asInstanceOf[AnyRef]
case d: DescribedType => d.getDescribed
case default => default
ArrayBasedMapData(ed.getProperties.asScala
.mapValues {
case b: Binary =>
val buf = b.asByteBuffer()
val arr = new Array[Byte](buf.remaining)
buf.get(arr)
arr.asInstanceOf[AnyRef]
case d128: Decimal128 => d128.asBytes.asInstanceOf[AnyRef]
case d32: Decimal32 => d32.getBits.asInstanceOf[AnyRef]
case d64: Decimal64 => d64.getBits.asInstanceOf[AnyRef]
case s: Symbol => s.toString.asInstanceOf[AnyRef]
case ub: UnsignedByte => ub.toString.asInstanceOf[AnyRef]
case ui: UnsignedInteger => ui.toString.asInstanceOf[AnyRef]
case ul: UnsignedLong => ul.toString.asInstanceOf[AnyRef]
case us: UnsignedShort => us.toString.asInstanceOf[AnyRef]
case c: Character => c.toString.asInstanceOf[AnyRef]
case d: DescribedType => d.getDescribed
case default => default
}
.map { p =>
p._2 match {
case s: String => UTF8String.fromString(p._1) -> UTF8String.fromString(s)
case default =>
UTF8String.fromString(p._1) -> UTF8String.fromString(Serialization.write(p._2))
}
.map { p =>
p._2 match {
case s: String => UTF8String.fromString(p._1) -> UTF8String.fromString(s)
case default => UTF8String.fromString(p._1) -> UTF8String.fromString(Serialization.write(p._2))
}
}),
}),
ArrayBasedMapData(
// Don't duplicate offset, enqueued time, and seqNo
(ed.getSystemProperties.asScala -- Seq(OffsetAnnotation,
Expand All @@ -218,8 +218,10 @@ private[sql] object EventHubsSourceProvider extends Serializable {
}
.map { p =>
p._2 match {
case s: String => UTF8String.fromString(p._1) -> UTF8String.fromString(s)
case default => UTF8String.fromString(p._1) -> UTF8String.fromString(Serialization.write(p._2))
case s: String => UTF8String.fromString(p._1) -> UTF8String.fromString(s)
case default =>
UTF8String.fromString(p._1) -> UTF8String.fromString(
Serialization.write(p._2))
}
})
)
Expand Down