From fe250670928adc5b75c403b2701390b949644e64 Mon Sep 17 00:00:00 2001
From: Wei Liu <wei.liu@databricks.com>
Date: Fri, 31 Jan 2025 17:32:26 -0800
Subject: [PATCH 1/6] done

---
 .../streaming/worker/foreach_batch_worker.py  | 43 +++++++++++--------
 1 file changed, 24 insertions(+), 19 deletions(-)

diff --git a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
index 0c92de6372b6f..c673e6b3b239d 100644
--- a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
@@ -43,26 +43,33 @@
 
 def main(infile: IO, outfile: IO) -> None:
     global spark
-    check_python_version(infile)
 
-    # Enable Spark Connect Mode
-    os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1"
+    log_name = "Streaming ForeachBatch worker"
 
-    connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"]
-    session_id = utf8_deserializer.loads(infile)
+    def init():
+        check_python_version(infile)
 
-    print(
-        "Streaming foreachBatch worker is starting with "
-        f"url {connect_url} and sessionId {session_id}."
-    )
+        # Enable Spark Connect Mode
+        os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1"
 
-    # To attach to the existing SparkSession, we're setting the session_id in the URL.
-    connect_url = connect_url + ";session_id=" + session_id
-    spark_connect_session = SparkSession.builder.remote(connect_url).getOrCreate()
-    assert spark_connect_session.session_id == session_id
-    spark = spark_connect_session
+        connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"]
+        session_id = utf8_deserializer.loads(infile)
+
+        print(
+            f"{log_name} is starting with "
+            f"url {connect_url} and sessionId {session_id}."
+        )
+
+        # To attach to the existing SparkSession, we're setting the session_id in the URL.
+        connect_url = connect_url + ";session_id=" + session_id
+        spark_connect_session = SparkSession.builder.remote(connect_url).getOrCreate()
+        assert spark_connect_session.session_id == session_id
+        spark = spark_connect_session
+
+        func = worker.read_command(pickle_ser, infile)
+        write_int(0, outfile)
+        outfile.flush()
 
-    log_name = "Streaming ForeachBatch worker"
 
     def process(df_id, batch_id):  # type: ignore[no-untyped-def]
         global spark
@@ -72,10 +79,8 @@ def process(df_id, batch_id):  # type: ignore[no-untyped-def]
         print(f"{log_name} Completed batch {batch_id} with DF id {df_id}")
 
     try:
-        func = worker.read_command(pickle_ser, infile)
-        write_int(0, outfile)
-        outfile.flush()
-
+        init()
+       
         while True:
             df_ref_id = utf8_deserializer.loads(infile)
             batch_id = read_long(infile)

From 7a29f87747a3ae09a2baeba8002680bf7caeaec9 Mon Sep 17 00:00:00 2001
From: Wei Liu <wei.liu@databricks.com>
Date: Fri, 31 Jan 2025 17:32:28 -0800
Subject: [PATCH 2/6] retrigger


From 4280077b7675e15637b4f0d34db9f624371ee70d Mon Sep 17 00:00:00 2001
From: Wei Liu <wei.liu@databricks.com>
Date: Fri, 31 Jan 2025 17:34:38 -0800
Subject: [PATCH 3/6] fmt

---
 .../sql/connect/streaming/worker/foreach_batch_worker.py  | 8 ++------
 1 file changed, 2 insertions(+), 6 deletions(-)

diff --git a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
index c673e6b3b239d..ae0158636da89 100644
--- a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
@@ -55,10 +55,7 @@ def init():
         connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"]
         session_id = utf8_deserializer.loads(infile)
 
-        print(
-            f"{log_name} is starting with "
-            f"url {connect_url} and sessionId {session_id}."
-        )
+        print(f"{log_name} is starting with " f"url {connect_url} and sessionId {session_id}.")
 
         # To attach to the existing SparkSession, we're setting the session_id in the URL.
         connect_url = connect_url + ";session_id=" + session_id
@@ -70,7 +67,6 @@ def init():
         write_int(0, outfile)
         outfile.flush()
 
-
     def process(df_id, batch_id):  # type: ignore[no-untyped-def]
         global spark
         print(f"{log_name} Started batch {batch_id} with DF id {df_id}")
@@ -80,7 +76,7 @@ def process(df_id, batch_id):  # type: ignore[no-untyped-def]
 
     try:
         init()
-       
+
         while True:
             df_ref_id = utf8_deserializer.loads(infile)
             batch_id = read_long(infile)

From 347253c17356e030cf83b1699faf062511efe8b7 Mon Sep 17 00:00:00 2001
From: Wei Liu <wei.liu@databricks.com>
Date: Mon, 3 Feb 2025 12:22:51 -0800
Subject: [PATCH 4/6] fix

---
 .../api/python/PythonWorkerFactory.scala      |  1 +
 .../api/python/StreamingPythonRunner.scala    |  2 ++
 .../streaming/worker/foreach_batch_worker.py  | 19 ++++++++-----------
 3 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 3221a4900f6ad..47a475842dcbd 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -219,6 +219,7 @@ private[spark] class PythonWorkerFactory(
         if (!blockingMode) {
           socketChannel.configureBlocking(false)
         }
+//        socketChannel.setOption()
         val selector = Selector.open()
         val selectionKey = if (blockingMode) {
           null
diff --git a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
index ce933337afc35..7df8a630f894e 100644
--- a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
@@ -82,6 +82,7 @@ private[spark] class StreamingPythonRunner(
       pythonWorker.get.channel.socket().getOutputStream, bufferSize)
     val dataOut = new DataOutputStream(stream)
 
+//    sock.setSoTimeout(10000)
     PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
 
     // Send sessionId
@@ -104,6 +105,7 @@ private[spark] class StreamingPythonRunner(
     logInfo(log"Runner initialization succeeded (returned" +
       log" ${MDC(PYTHON_WORKER_RESPONSE, resFromPython)}).")
 
+    sock.setSoTimeout(0)
     (dataOut, dataIn)
   }
 
diff --git a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
index ae0158636da89..f144ac49e5bb1 100644
--- a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
@@ -46,7 +46,14 @@ def main(infile: IO, outfile: IO) -> None:
 
     log_name = "Streaming ForeachBatch worker"
 
-    def init():
+    def process(df_id, batch_id):  # type: ignore[no-untyped-def]
+        global spark
+        print(f"{log_name} Started batch {batch_id} with DF id {df_id}")
+        batch_df = spark_connect_session._create_remote_dataframe(df_id)
+        func(batch_df, batch_id)
+        print(f"{log_name} Completed batch {batch_id} with DF id {df_id}")
+
+    try:
         check_python_version(infile)
 
         # Enable Spark Connect Mode
@@ -67,16 +74,6 @@ def init():
         write_int(0, outfile)
         outfile.flush()
 
-    def process(df_id, batch_id):  # type: ignore[no-untyped-def]
-        global spark
-        print(f"{log_name} Started batch {batch_id} with DF id {df_id}")
-        batch_df = spark_connect_session._create_remote_dataframe(df_id)
-        func(batch_df, batch_id)
-        print(f"{log_name} Completed batch {batch_id} with DF id {df_id}")
-
-    try:
-        init()
-
         while True:
             df_ref_id = utf8_deserializer.loads(infile)
             batch_id = read_long(infile)

From c8476ce8e9a1760ed5d3980c811ef3792219cbcb Mon Sep 17 00:00:00 2001
From: Wei Liu <wei.liu@databricks.com>
Date: Mon, 3 Feb 2025 14:51:24 -0800
Subject: [PATCH 5/6] retrigger

---
 .../org/apache/spark/api/python/StreamingPythonRunner.scala     | 2 --
 1 file changed, 2 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
index 7df8a630f894e..ce933337afc35 100644
--- a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
@@ -82,7 +82,6 @@ private[spark] class StreamingPythonRunner(
       pythonWorker.get.channel.socket().getOutputStream, bufferSize)
     val dataOut = new DataOutputStream(stream)
 
-//    sock.setSoTimeout(10000)
     PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
 
     // Send sessionId
@@ -105,7 +104,6 @@ private[spark] class StreamingPythonRunner(
     logInfo(log"Runner initialization succeeded (returned" +
       log" ${MDC(PYTHON_WORKER_RESPONSE, resFromPython)}).")
 
-    sock.setSoTimeout(0)
     (dataOut, dataIn)
   }
 

From 153f635a42ca56d22dfe97a0d51d1ad5fe69c415 Mon Sep 17 00:00:00 2001
From: Wei Liu <wei.liu@databricks.com>
Date: Mon, 3 Feb 2025 14:53:18 -0800
Subject: [PATCH 6/6] retrigger

---
 .../scala/org/apache/spark/api/python/PythonWorkerFactory.scala  | 1 -
 1 file changed, 1 deletion(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 47a475842dcbd..3221a4900f6ad 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -219,7 +219,6 @@ private[spark] class PythonWorkerFactory(
         if (!blockingMode) {
           socketChannel.configureBlocking(false)
         }
-//        socketChannel.setOption()
         val selector = Selector.open()
         val selectionKey = if (blockingMode) {
           null