Skip to content

Commit

Permalink
[SPARK-45302][PYTHON] Remove PID communication between Python workers…
Browse files Browse the repository at this point in the history
… when no demon is used

### What changes were proposed in this pull request?

This PR removes the legacy workaround for JDK 8 in `PythonWorkerFactory`.

### Why are the changes needed?

No need to manually send the PID around through the socket.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

There are existing unittests for the daemon disabled.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #43087 from HyukjinKwon/SPARK-45302.

Lead-authored-by: Hyukjin Kwon <gurwls223@apache.org>
Co-authored-by: Hyukjin Kwon <gurwls223@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon and HyukjinKwon committed Sep 27, 2023
1 parent 17881eb commit 17430fe
Show file tree
Hide file tree
Showing 10 changed files with 18 additions and 29 deletions.
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class SparkEnv (
pythonExec: String,
workerModule: String,
daemonModule: String,
envVars: Map[String, String]): (PythonWorker, Option[Int]) = {
envVars: Map[String, String]): (PythonWorker, Option[Long]) = {
synchronized {
val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, envVars)
pythonWorkers.getOrElseUpdate(key,
Expand All @@ -139,7 +139,7 @@ class SparkEnv (
private[spark] def createPythonWorker(
pythonExec: String,
workerModule: String,
envVars: Map[String, String]): (PythonWorker, Option[Int]) = {
envVars: Map[String, String]): (PythonWorker, Option[Long]) = {
createPythonWorker(
pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ private object BasePythonRunner {

private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler")

private def faultHandlerLogPath(pid: Int): Path = {
private def faultHandlerLogPath(pid: Long): Path = {
new File(faultHandlerLogDir, pid.toString).toPath
}
}
Expand Down Expand Up @@ -200,7 +200,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](

envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default"))

val (worker: PythonWorker, pid: Option[Int]) = env.createPythonWorker(
val (worker: PythonWorker, pid: Option[Long]) = env.createPythonWorker(
pythonExec, workerModule, daemonModule, envVars.asScala.toMap)
// Whether is the worker released into idle pool or closed. When any codes try to release or
// close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make
Expand Down Expand Up @@ -253,7 +253,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
startTime: Long,
env: SparkEnv,
worker: PythonWorker,
pid: Option[Int],
pid: Option[Long],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[OUT]

Expand Down Expand Up @@ -463,7 +463,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
startTime: Long,
env: SparkEnv,
worker: PythonWorker,
pid: Option[Int],
pid: Option[Long],
releasedOrClosed: AtomicBoolean,
context: TaskContext)
extends Iterator[OUT] {
Expand Down Expand Up @@ -838,7 +838,7 @@ private[spark] class PythonRunner(
startTime: Long,
env: SparkEnv,
worker: PythonWorker,
pid: Option[Int],
pid: Option[Long],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[Array[Byte]] = {
new ReaderIterator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ private[spark] class PythonWorkerFactory(
@GuardedBy("self")
private var daemonPort: Int = 0
@GuardedBy("self")
private val daemonWorkers = new mutable.WeakHashMap[PythonWorker, Int]()
private val daemonWorkers = new mutable.WeakHashMap[PythonWorker, Long]()
@GuardedBy("self")
private val idleWorkers = new mutable.Queue[PythonWorker]()
@GuardedBy("self")
Expand All @@ -91,7 +91,7 @@ private[spark] class PythonWorkerFactory(
envVars.getOrElse("PYTHONPATH", ""),
sys.env.getOrElse("PYTHONPATH", ""))

def create(): (PythonWorker, Option[Int]) = {
def create(): (PythonWorker, Option[Long]) = {
if (useDaemon) {
self.synchronized {
if (idleWorkers.nonEmpty) {
Expand All @@ -111,9 +111,9 @@ private[spark] class PythonWorkerFactory(
* processes itself to avoid the high cost of forking from Java. This currently only works
* on UNIX-based systems.
*/
private def createThroughDaemon(): (PythonWorker, Option[Int]) = {
private def createThroughDaemon(): (PythonWorker, Option[Long]) = {

def createWorker(): (PythonWorker, Option[Int]) = {
def createWorker(): (PythonWorker, Option[Long]) = {
val socketChannel = SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort))
// These calls are blocking.
val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt()
Expand Down Expand Up @@ -153,7 +153,7 @@ private[spark] class PythonWorkerFactory(
/**
* Launch a worker by executing worker.py (by default) directly and telling it to connect to us.
*/
private[spark] def createSimpleWorker(blockingMode: Boolean): (PythonWorker, Option[Int]) = {
private[spark] def createSimpleWorker(blockingMode: Boolean): (PythonWorker, Option[Long]) = {
var serverSocketChannel: ServerSocketChannel = null
try {
serverSocketChannel = ServerSocketChannel.open()
Expand Down Expand Up @@ -189,8 +189,7 @@ private[spark] class PythonWorkerFactory(
try {
val socketChannel = serverSocketChannel.accept()
authHelper.authClient(socketChannel.socket())
// TODO: When we drop JDK 8, we can just use workerProcess.pid()
val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt()
val pid = workerProcess.toHandle.pid()
if (pid < 0) {
throw new IllegalStateException("Python failed to launch worker with code " + pid)
}
Expand Down Expand Up @@ -386,7 +385,7 @@ private[spark] class PythonWorkerFactory(
daemonWorkers.get(worker).foreach { pid =>
// tell daemon to kill worker by pid
val output = new DataOutputStream(daemon.getOutputStream)
output.writeInt(pid)
output.writeLong(pid)
output.flush()
daemon.getOutputStream.flush()
}
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from socket import AF_INET, AF_INET6, SOCK_STREAM, SOMAXCONN
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT

from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer
from pyspark.serializers import read_long, write_int, write_with_length, UTF8Deserializer

if len(sys.argv) > 1:
import importlib
Expand Down Expand Up @@ -139,7 +139,7 @@ def handle_sigterm(*args):

if 0 in ready_fds:
try:
worker_pid = read_int(stdin_bin)
worker_pid = read_long(stdin_bin)
except EOFError:
# Spark told us to exit by closing stdin
shutdown(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,4 @@ def process(df_id, batch_id): # type: ignore[no-untyped-def]
(sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
# There could be a long time between each micro batch.
sock.settimeout(None)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,4 @@ def process(listener_event_str, listener_event_type): # type: ignore[no-untyped
(sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
# There could be a long time between each listener event.
sock.settimeout(None)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
3 changes: 0 additions & 3 deletions python/pyspark/sql/worker/analyze_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,4 @@ def main(infile: IO, outfile: IO) -> None:
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
# TODO: Remove the following two lines and use `Process.pid()` when we drop JDK 8.
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
3 changes: 0 additions & 3 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,4 @@ def process():
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
# TODO: Remove the following two lines and use `Process.pid()` when we drop JDK 8.
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[
startTime: Long,
env: SparkEnv,
worker: PythonWorker,
pid: Option[Int],
pid: Option[Long],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[OUT] = {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ abstract class BasePythonUDFRunner(
startTime: Long,
env: SparkEnv,
worker: PythonWorker,
pid: Option[Int],
pid: Option[Long],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[Array[Byte]] = {
new ReaderIterator(
Expand Down

0 comments on commit 17430fe

Please sign in to comment.