Skip to content

Commit

Permalink
Fix non streaming rpc endpoint (Azure#542)
Browse files Browse the repository at this point in the history
  • Loading branch information
nyaghma authored and Navid Yaghmazadeh committed Dec 8, 2020
1 parent a1e4c67 commit 5feb961
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ private[spark] object CachedEventHubsReceiver extends CachedReceiver with Loggin

private[this] val receivers = new MutableMap[String, CachedEventHubsReceiver]()

// RPC endpoint for partition performacne communciation in the executor
// RPC endpoint for partition performance communication in the executor
val partitionPerformanceReceiverRef =
RpcUtils.makeDriverRef(PartitionPerformanceReceiver.ENDPOINT_NAME,
SparkEnv.get.conf,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ import org.apache.spark.eventhubs.rdd.{ EventHubsRDD, OffsetRange }
import org.apache.spark.eventhubs.utils.ThrottlingStatusPlugin
import org.apache.spark.eventhubs.{ EventHubsConf, NameAndPartition, _ }
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.SparkEnv
import org.apache.spark.sql.execution.streaming.{
HDFSMetadataLog,
Offset,
Expand Down Expand Up @@ -77,6 +75,7 @@ private[spark] class EventHubsSource private[eventhubs] (sqlContext: SQLContext,

import EventHubsConf._
import EventHubsSource._
import EventHubsSourceProvider._

private lazy val ehClient = EventHubsSourceProvider.clientFactory(parameters)(ehConf)
private lazy val partitionCount: Int = ehClient.partitionCount
Expand All @@ -87,9 +86,11 @@ private[spark] class EventHubsSource private[eventhubs] (sqlContext: SQLContext,
private val sc = sqlContext.sparkContext

private val maxOffsetsPerTrigger: Option[Long] =
Option(parameters.get(MaxEventsPerTriggerKey).map(_.toLong).getOrElse(
parameters.get(MaxEventsPerTriggerKeyAlias).map(_.toLong).getOrElse(
partitionCount * 1000)))
Option(parameters
.get(MaxEventsPerTriggerKey)
.map(_.toLong)
.getOrElse(
parameters.get(MaxEventsPerTriggerKeyAlias).map(_.toLong).getOrElse(partitionCount * 1000)))

// set slow partition adjustment flag and static values in the tracker
private val slowPartitionAdjustment: Boolean =
Expand Down Expand Up @@ -147,22 +148,25 @@ private[spark] class EventHubsSource private[eventhubs] (sqlContext: SQLContext,
text.substring(1, text.length).toInt
} catch {
case _: NumberFormatException =>
throw new IllegalStateException(s"Log file was malformed: failed to read correct log " +
s"version from $text.")
throw new IllegalStateException(
s"Log file was malformed: failed to read correct log " +
s"version from $text.")
}
if (version > 0) {
if (version > maxSupportedVersion) {
throw new IllegalStateException(s"UnsupportedLogVersion: maximum supported log version " +
s"is v${maxSupportedVersion}, but encountered v$version. The log file was produced " +
s"by a newer version of Spark and cannot be read by this version. Please upgrade.")
throw new IllegalStateException(
s"UnsupportedLogVersion: maximum supported log version " +
s"is v${maxSupportedVersion}, but encountered v$version. The log file was produced " +
s"by a newer version of Spark and cannot be read by this version. Please upgrade.")
} else {
return version
}
}
}
// reaching here means we failed to read the correct log version
throw new IllegalStateException(s"Log file was malformed: failed to read correct log " +
s"version from $text.")
throw new IllegalStateException(
s"Log file was malformed: failed to read correct log " +
s"version from $text.")
}
}
val defaultSeqNos = ehClient
Expand Down Expand Up @@ -449,14 +453,7 @@ private[eventhubs] object EventHubsSource {
""".stripMargin

private[eventhubs] val VERSION = 1

// RPC endpoint for partition performacne communciation in the driver
private var localBatchId = -1
val partitionsStatusTracker = PartitionsStatusTracker.getPartitionStatusTracker
val partitionPerformanceReceiver: PartitionPerformanceReceiver =
new PartitionPerformanceReceiver(SparkEnv.get.rpcEnv, partitionsStatusTracker)
val partitionPerformanceReceiverRef: RpcEndpointRef = SparkEnv.get.rpcEnv
.setupEndpoint(PartitionPerformanceReceiver.ENDPOINT_NAME, partitionPerformanceReceiver)

def getSortedExecutorList(sc: SparkContext): Array[String] = {
val bm = sc.env.blockManager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.{ AnalysisException, DataFrame, SQLContext, SaveMode }
import org.apache.spark.unsafe.types.UTF8String
import org.json4s.jackson.Serialization
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.SparkEnv

import collection.JavaConverters._

Expand Down Expand Up @@ -140,6 +142,12 @@ private[sql] class EventHubsSourceProvider
}

private[sql] object EventHubsSourceProvider extends Serializable {
// RPC endpoint for partition performance communication in the driver
val partitionsStatusTracker = PartitionsStatusTracker.getPartitionStatusTracker
val partitionPerformanceReceiver: PartitionPerformanceReceiver =
new PartitionPerformanceReceiver(SparkEnv.get.rpcEnv, partitionsStatusTracker)
val partitionPerformanceReceiverRef: RpcEndpointRef = SparkEnv.get.rpcEnv
.setupEndpoint(PartitionPerformanceReceiver.ENDPOINT_NAME, partitionPerformanceReceiver)

def eventHubsSchema: StructType = {
StructType(
Expand Down Expand Up @@ -169,32 +177,32 @@ private[sql] object EventHubsSourceProvider extends Serializable {
new java.sql.Timestamp(ed.getSystemProperties.getEnqueuedTime.toEpochMilli)),
UTF8String.fromString(ed.getSystemProperties.getPublisher),
UTF8String.fromString(ed.getSystemProperties.getPartitionKey),
ArrayBasedMapData(
ed.getProperties.asScala
.mapValues {
case b: Binary =>
val buf = b.asByteBuffer()
val arr = new Array[Byte](buf.remaining)
buf.get(arr)
arr.asInstanceOf[AnyRef]
case d128: Decimal128 => d128.asBytes.asInstanceOf[AnyRef]
case d32: Decimal32 => d32.getBits.asInstanceOf[AnyRef]
case d64: Decimal64 => d64.getBits.asInstanceOf[AnyRef]
case s: Symbol => s.toString.asInstanceOf[AnyRef]
case ub: UnsignedByte => ub.toString.asInstanceOf[AnyRef]
case ui: UnsignedInteger => ui.toString.asInstanceOf[AnyRef]
case ul: UnsignedLong => ul.toString.asInstanceOf[AnyRef]
case us: UnsignedShort => us.toString.asInstanceOf[AnyRef]
case c: Character => c.toString.asInstanceOf[AnyRef]
case d: DescribedType => d.getDescribed
case default => default
ArrayBasedMapData(ed.getProperties.asScala
.mapValues {
case b: Binary =>
val buf = b.asByteBuffer()
val arr = new Array[Byte](buf.remaining)
buf.get(arr)
arr.asInstanceOf[AnyRef]
case d128: Decimal128 => d128.asBytes.asInstanceOf[AnyRef]
case d32: Decimal32 => d32.getBits.asInstanceOf[AnyRef]
case d64: Decimal64 => d64.getBits.asInstanceOf[AnyRef]
case s: Symbol => s.toString.asInstanceOf[AnyRef]
case ub: UnsignedByte => ub.toString.asInstanceOf[AnyRef]
case ui: UnsignedInteger => ui.toString.asInstanceOf[AnyRef]
case ul: UnsignedLong => ul.toString.asInstanceOf[AnyRef]
case us: UnsignedShort => us.toString.asInstanceOf[AnyRef]
case c: Character => c.toString.asInstanceOf[AnyRef]
case d: DescribedType => d.getDescribed
case default => default
}
.map { p =>
p._2 match {
case s: String => UTF8String.fromString(p._1) -> UTF8String.fromString(s)
case default =>
UTF8String.fromString(p._1) -> UTF8String.fromString(Serialization.write(p._2))
}
.map { p =>
p._2 match {
case s: String => UTF8String.fromString(p._1) -> UTF8String.fromString(s)
case default => UTF8String.fromString(p._1) -> UTF8String.fromString(Serialization.write(p._2))
}
}),
}),
ArrayBasedMapData(
// Don't duplicate offset, enqueued time, and seqNo
(ed.getSystemProperties.asScala -- Seq(OffsetAnnotation,
Expand All @@ -210,8 +218,10 @@ private[sql] object EventHubsSourceProvider extends Serializable {
}
.map { p =>
p._2 match {
case s: String => UTF8String.fromString(p._1) -> UTF8String.fromString(s)
case default => UTF8String.fromString(p._1) -> UTF8String.fromString(Serialization.write(p._2))
case s: String => UTF8String.fromString(p._1) -> UTF8String.fromString(s)
case default =>
UTF8String.fromString(p._1) -> UTF8String.fromString(
Serialization.write(p._2))
}
})
)
Expand Down

0 comments on commit 5feb961

Please sign in to comment.