Skip to content

Commit

Permalink
Fix comments style to pass scala style check (#461)
Browse files Browse the repository at this point in the history
* fix comments style

* add a empty line to commit signoff

Signed-off-by: yangjie01 <yangjie01@baidu.com>

* revert empty line

Signed-off-by: yangjie01 <yangjie01@baidu.com>
  • Loading branch information
LuciferYang authored Jul 30, 2020
1 parent 53680ac commit d7935c1
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ class UCX(executorId: Int, usingWakeupFeature: Boolean = true) extends AutoClose
val registeredMemory = new ArrayBuffer[UcpMemory]

/**
* Initializes the UCX context and local worker and starts up the worker progress thread.
* UCX worker/endpoint relationship.
*/
* Initializes the UCX context and local worker and starts up the worker progress thread.
* UCX worker/endpoint relationship.
*/
def init(): Unit = {
synchronized {
if (initialized) {
Expand Down Expand Up @@ -183,12 +183,12 @@ class UCX(executorId: Int, usingWakeupFeature: Boolean = true) extends AutoClose
}

/**
* Starts a TCP server to listen for external clients, returning with
* what port it used.
*
* @param mgmtHost String the hostname to bind to
* @return port bound
*/
* Starts a TCP server to listen for external clients, returning with
* what port it used.
*
* @param mgmtHost String the hostname to bind to
* @return port bound
*/
def startManagementPort(mgmtHost: String): Int = {
var portBindAttempts = 100
var portBound = false
Expand Down Expand Up @@ -326,14 +326,15 @@ class UCX(executorId: Int, usingWakeupFeature: Boolean = true) extends AutoClose
private def ucxWorkerAddress: ByteBuffer = worker.getAddress

/**
* Establish a new [[UcpEndpoint]] given a [[WorkerAddress]]. It also
* caches them s.t. at [[close]] time we can release resources.
* @param endpointId presently an executorId, it is used to distinguish between endpoints
* when routing messages outbound
* @param workerAddress the worker address for the remote endpoint (ucx opaque object)
* @return returns a [[UcpEndpoint]] that can later be used to send on (from the
* Establish a new [[UcpEndpoint]] given a [[WorkerAddress]]. It also
* caches them s.t. at [[close]] time we can release resources.
*
* @param endpointId presently an executorId, it is used to distinguish between endpoints
* when routing messages outbound
* @param workerAddress the worker address for the remote endpoint (ucx opaque object)
* @return returns a [[UcpEndpoint]] that can later be used to send on (from the
* progress thread)
*/
*/
private[ucx] def setupEndpoint(endpointId: Long, workerAddress: WorkerAddress): UcpEndpoint = {
logDebug(s"Starting/reusing an endpoint to $workerAddress with id $endpointId")
// create an UCX endpoint using workerAddress
Expand All @@ -347,12 +348,12 @@ class UCX(executorId: Int, usingWakeupFeature: Boolean = true) extends AutoClose
}

/**
* Connect to a remote UCX management port.
*
* @param peerMgmtHost management TCP host
* @param peerMgmtPort management TCP port
* @return Connection object representing this connection
*/
* Connect to a remote UCX management port.
*
* @param peerMgmtHost management TCP host
* @param peerMgmtPort management TCP port
* @return Connection object representing this connection
*/
def getConnection(peerExecutorId: Int,
peerMgmtHost: String,
peerMgmtPort: Int): ClientConnection = {
Expand Down Expand Up @@ -414,11 +415,11 @@ class UCX(executorId: Int, usingWakeupFeature: Boolean = true) extends AutoClose
executorIdToPeerTag.computeIfAbsent(peerExecutorId, _ => peerTag.incrementAndGet())

/**
* Handle an incoming connection on the TCP management port
* This will fetch the [[WorkerAddress]] from the peer, and establish a UcpEndpoint
*
* @param socket an accepted socket to a remote client
*/
* Handle an incoming connection on the TCP management port
* This will fetch the [[WorkerAddress]] from the peer, and establish a UcpEndpoint
*
* @param socket an accepted socket to a remote client
*/
private[ucx] def handleSocket(socket: Socket): Unit = {
val connectionRange =
new NvtxRange(s"UCX Handle Connection from ${socket.getInetAddress}", NvtxColor.RED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ import org.openucx.jucx.ucp.UcpRequest
import org.apache.spark.internal.Logging

/**
* This is a private api used within the ucx package.
* It is used by [[Transaction]] to call into the UCX functions. It adds the tag
* as we use that to track the message and for debugging.
*/
* This is a private api used within the ucx package.
* It is used by [[Transaction]] to call into the UCX functions. It adds the tag
* as we use that to track the message and for debugging.
*/
private[ucx] abstract class UCXTagCallback {
def onError(alt: AddressLengthTag, ucsStatus: Int, errorMsg: String): Unit
def onMessageStarted(ucxMessage: UcpRequest): Unit
Expand Down Expand Up @@ -118,10 +118,10 @@ class UCXConnection(peerExecutorId: Int, ucx: UCX) extends Connection with Loggi
private[this] val pendingTransactions = new ConcurrentHashMap[Long, UCXTransaction]()

/**
* 1) client gets upper 28 bits
* 2) then comes the type, which gets 4 bits
* 3) the remaining 32 bits are used for buffer specific tags
*/
* 1) client gets upper 28 bits
* 2) then comes the type, which gets 4 bits
* 3) the remaining 32 bits are used for buffer specific tags
*/
private val requestMsgType: Long = 0x00000000000000000L
private val responseMsgType: Long = 0x0000000000000000AL
private val bufferMsgType: Long = 0x0000000000000000BL
Expand Down Expand Up @@ -333,12 +333,12 @@ object UCXConnection extends Logging {
}

/**
* Given a java `InputStream`, obtain the peer's `WorkerAddress` and executor id,
* returning them as a pair.
*
* @param is management port input stream
* @return a tuple of worker address and the peer executor id
*/
* Given a java `InputStream`, obtain the peer's `WorkerAddress` and executor id,
* returning them as a pair.
*
* @param is management port input stream
* @return a tuple of worker address and the peer executor id
*/
def readHandshakeHeader(is: InputStream): (WorkerAddress, Int) = {
val maxLen = 1024 * 1024

Expand All @@ -359,15 +359,15 @@ object UCXConnection extends Logging {


/**
* Writes a header that is exchanged in the management port. The header contains:
* - UCP Worker address length (4 bytes)
* - UCP Worker address (variable length)
* - Local executor id (4 bytes)
*
* @param os output stream to write to
* @param workerAddress byte buffer that holds the local UCX worker address
* @param localExecutorId The local executorId
*/
* Writes a header that is exchanged in the management port. The header contains:
* - UCP Worker address length (4 bytes)
* - UCP Worker address (variable length)
* - Local executor id (4 bytes)
*
* @param os output stream to write to
* @param workerAddress byte buffer that holds the local UCX worker address
* @param localExecutorId The local executorId
*/
def writeHandshakeHeader(os: OutputStream,
workerAddress: ByteBuffer,
localExecutorId: Int): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@ import org.apache.spark.internal.Logging
import org.apache.spark.storage.BlockManagerId

/**
* UCXShuffleTransport is the UCX implementation for the `RapidsShuffleTransport`. It provides
* a way to create a `RapidsShuffleServer` and one `RapidsShuffleClient` per peer, that are
* able to send/receive via UCX.
*
* Additionally, this class maintains pools of memory used to limit the cost of memory
* pinning and registration (bounce buffers), a metadata message pool for small flatbuffers used
* to describe shuffled data, and implements a simple throttle mechanism to keep GPU memory
* usage at bay by way of configuration settings.
*
* @param shuffleServerId `BlockManagerId` for this executor
* @param rapidsConf plugin configuration
*/
* UCXShuffleTransport is the UCX implementation for the `RapidsShuffleTransport`. It provides
* a way to create a `RapidsShuffleServer` and one `RapidsShuffleClient` per peer, that are
* able to send/receive via UCX.
*
* Additionally, this class maintains pools of memory used to limit the cost of memory
* pinning and registration (bounce buffers), a metadata message pool for small flatbuffers used
* to describe shuffled data, and implements a simple throttle mechanism to keep GPU memory
* usage at bay by way of configuration settings.
*
* @param shuffleServerId `BlockManagerId` for this executor
* @param rapidsConf plugin configuration
*/
class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsConf)
extends RapidsShuffleTransport
with Logging {
Expand Down Expand Up @@ -115,20 +115,20 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
}

/**
* Initialize the bounce buffer pools that are to be used to send and receive data against UCX
*
* We have 2 pools for the send side, since buffers may come from spilled memory (host),
* or device memory.
*
* We have 1 pool for the receive side, since all receives are targeted for the GPU.
*
* The size of buffers is the same for all pools, since send/receive sizes need to match. The
* Initialize the bounce buffer pools that are to be used to send and receive data against UCX
*
* We have 2 pools for the send side, since buffers may come from spilled memory (host),
* or device memory.
*
* We have 1 pool for the receive side, since all receives are targeted for the GPU.
*
* The size of buffers is the same for all pools, since send/receive sizes need to match. The
* count can be set independently.
*
* @param bounceBufferSize the size for a single bounce buffer
* @param deviceNumBuffers number of buffers to allocate for the device
* @param hostNumBuffers number of buffers to allocate for the host
*/
*
* @param bounceBufferSize the size for a single bounce buffer
* @param deviceNumBuffers number of buffers to allocate for the device
* @param hostNumBuffers number of buffers to allocate for the host
*/
def initBounceBufferPools(
bounceBufferSize: Long,
deviceNumBuffers: Int,
Expand Down Expand Up @@ -335,11 +335,13 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
.setNameFormat(s"shuffle-server-bss-thread-%d")
.setDaemon(true)
.build))

/**
* Construct a server instance
* @param requestHandler used to get metadata info, and acquire tables used in the shuffle.
* @return the server instance
*/
* Construct a server instance
*
* @param requestHandler used to get metadata info, and acquire tables used in the shuffle.
* @return the server instance
*/
override def makeServer(requestHandler: RapidsShuffleRequestHandler): RapidsShuffleServer = {
new RapidsShuffleServer(
this,
Expand All @@ -353,13 +355,13 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
}

/**
* Returns a sequence of bounce buffers if the transport allows for [[neededAmount]] + its
* Returns a sequence of bounce buffers if the transport allows for [[neededAmount]] + its
* inflight tally to be inflight at this time, and bounce buffers are available.
*
* @param neededAmount amount of bytes needed.
* @return optional bounce buffers to be used to for the client to receive if amount of bytes
* needed was allowed into the inflight amount, None otherwise (caller should try again)
*/
*
* @param neededAmount amount of bytes needed.
* @return optional bounce buffers to be used to for the client to receive if amount of bytes
* needed was allowed into the inflight amount, None otherwise (caller should try again)
*/
private def markBytesInFlight(neededAmount: Long)
: Option[Seq[MemoryBuffer]] = inflightMonitor.synchronized {
// if it would fit, or we are sending nothing (protects against the buffer that is bigger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ import org.openucx.jucx.ucp.UcpRequest
import org.apache.spark.internal.Logging

/**
* Helper enum to describe transaction types supported in UCX
* request = is a send and a receive, with the callback happening after the receive
* send = is a send of 1 or more messages (callback at the end)
* receive is a receive of 1 or more messages (callback at the end)
*/
* Helper enum to describe transaction types supported in UCX
* request = is a send and a receive, with the callback happening after the receive
* send = is a send of 1 or more messages (callback at the end)
* receive is a receive of 1 or more messages (callback at the end)
*/
private[ucx] object UCXTransactionType extends Enumeration {
val Request, Send, Receive = Value
}
Expand All @@ -52,10 +52,10 @@ private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long)
def decrementPendingAndGet: Long = pending.decrementAndGet

/**
* This will mark the tag as being cancelled for debugging purposes.
*
* @param tag the cancelled tag
*/
* This will mark the tag as being cancelled for debugging purposes.
*
* @param tag the cancelled tag
*/
def handleTagCancelled(tag: Long): Unit = {
if (registeredByTag.contains(tag)) {
val origBuff = registeredByTag(tag)
Expand All @@ -64,11 +64,11 @@ private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long)
}

/**
* This will mark the tag as having an error for debugging purposes.
*
* @param tag the tag involved in the error
* @param errorMsg error description from UCX
*/
* This will mark the tag as having an error for debugging purposes.
*
* @param tag the tag involved in the error
* @param errorMsg error description from UCX
*/
def handleTagError(tag: Long, errorMsg: String): Unit = {
if (registeredByTag.contains(tag)) {
val origBuff = registeredByTag(tag)
Expand All @@ -78,10 +78,10 @@ private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long)
}

/**
* This will mark the tag as completed for debugging purposes.
*
* @param tag the successful tag
*/
* This will mark the tag as completed for debugging purposes.
*
* @param tag the successful tag
*/
def handleTagCompleted(tag: Long): Unit = {
if (registeredByTag.contains(tag)){
val origBuff = registeredByTag(tag)
Expand Down Expand Up @@ -159,9 +159,9 @@ private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long)
}

/**
* TODO: Note that this does not handle a timeout. We still need to do this, the version of UCX
* we build against can cancel messages.
*/
* TODO: Note that this does not handle a timeout. We still need to do this, the version of UCX
* we build against can cancel messages.
*/
def waitForCompletion(): Unit = {
while (status != TransactionStatus.Complete &&
status != TransactionStatus.Error) {
Expand All @@ -185,9 +185,10 @@ private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long)
}

/**
* Interal function to register a callback against the callback service
* @param cb callback function to call using the callbackService.
*/
* Internal function to register a callback against the callback service
*
* @param cb callback function to call using the callbackService.
*/
private def registerCb(cb: TransactionCallback): Unit = {
txCallback =
(newStatus: TransactionStatus.Value) => {
Expand Down Expand Up @@ -238,8 +239,8 @@ private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long)
}

/**
* Register an [[AddressLengthTag]] for a send transaction
*/
* Register an [[AddressLengthTag]] for a send transaction
*/
def registerForSend(alt: AddressLengthTag): Unit = {
registeredByTag.put(alt.tag, alt)
registered += alt
Expand All @@ -248,8 +249,8 @@ private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long)
}

/**
* Register an [[AddressLengthTag]] for a receive transaction
*/
* Register an [[AddressLengthTag]] for a receive transaction
*/
def registerForReceive(alt: AddressLengthTag): Unit = {
registered += alt
registeredByTag.put(alt.tag, alt)
Expand All @@ -262,11 +263,12 @@ private[ucx] class UCXTransaction(conn: UCXConnection, val txId: Long)
}

/**
* Internal function to kick off a [[Transaction]]
* @param txType a transaction type to be used for debugging purposes
* @param numPending number of messages we expect to see sent/received
* @param cb callback to call when done/errored
*/
* Internal function to kick off a [[Transaction]]
*
* @param txType a transaction type to be used for debugging purposes
* @param numPending number of messages we expect to see sent/received
* @param cb callback to call when done/errored
*/
private[ucx] def start(
txType: UCXTransactionType.Value,
numPending: Long,
Expand Down

0 comments on commit d7935c1

Please sign in to comment.