Skip to content

Commit f81d8a4

Browse files
committed
update
Signed-off-by: Weichen Xu <[email protected]>
1 parent 25358eb commit f81d8a4

File tree

1 file changed

+58
-8
lines changed

1 file changed

+58
-8
lines changed

python/pyspark/ml/tuning.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import os
1919
import sys
20+
import uuid
2021
import itertools
2122
from multiprocessing.pool import ThreadPool
2223
from typing import (
@@ -75,6 +76,15 @@
7576
]
7677

7778

79+
_SPARKML_TUNING_TEMP_DFS_PATH = "SPARKML_TUNING_TEMP_DFS_PATH"
80+
81+
82+
def _get_temp_dfs_path():
83+
return os.environ.get(_SPARKML_TUNING_TEMP_DFS_PATH)
84+
85+
86+
87+
7888
def _parallelFitTasks(
7989
est: Estimator,
8090
train: DataFrame,
@@ -847,9 +857,20 @@ def _fit(self, dataset: DataFrame) -> "CrossValidatorModel":
847857
subModels = [[None for j in range(numModels)] for i in range(nFolds)]
848858

849859
datasets = self._kFold(dataset)
860+
861+
tmp_dfs_path = _get_temp_dfs_path()
850862
for i in range(nFolds):
851-
validation = datasets[i][1].cache()
852-
train = datasets[i][0].cache()
863+
validation = datasets[i][1]
864+
train = datasets[i][0]
865+
866+
if tmp_dfs_path:
867+
validation_tmp_path = os.path.join(tmp_dfs_path, uuid.uuid4().hex)
868+
validation.write.save(validation_tmp_path)
869+
train_tmp_path = os.path.join(tmp_dfs_path, uuid.uuid4().hex)
870+
train.write.save(train_tmp_path)
871+
else:
872+
validation.cache()
873+
train.cache()
853874

854875
tasks = map(
855876
inheritable_thread_target(dataset.sparkSession),
@@ -861,8 +882,17 @@ def _fit(self, dataset: DataFrame) -> "CrossValidatorModel":
861882
assert subModels is not None
862883
subModels[i][j] = subModel
863884

864-
validation.unpersist()
865-
train.unpersist()
885+
if tmp_dfs_path:
886+
# TODO: Spark does not have FS API to delete a path on Distributed storage,
887+
# this is a workaround to delete the data inside the temporary directory.
888+
# we can improve it once Spark adds FS deletion API.
889+
spark_session = SparkSession.getActiveSession()
890+
empty_df = spark_session.range(0)
891+
empty_df.write.mode("overwrite").save(validation_tmp_path)
892+
empty_df.write.mode("overwrite").save(train_tmp_path)
893+
else:
894+
validation.unpersist()
895+
train.unpersist()
866896

867897
metrics, std_metrics = CrossValidator._gen_avg_and_std_metrics(metrics_all)
868898

@@ -1475,8 +1505,19 @@ def _fit(self, dataset: DataFrame) -> "TrainValidationSplitModel":
14751505
randCol = self.uid + "_rand"
14761506
df = dataset.select("*", F.rand(seed).alias(randCol))
14771507
condition = df[randCol] >= tRatio
1478-
validation = df.filter(condition).cache()
1479-
train = df.filter(~condition).cache()
1508+
1509+
validation = df.filter(condition)
1510+
train = df.filter(~condition)
1511+
1512+
tmp_dfs_path = _get_temp_dfs_path()
1513+
if tmp_dfs_path:
1514+
validation_tmp_path = os.path.join(tmp_dfs_path, uuid.uuid4().hex)
1515+
validation.write.save(validation_tmp_path)
1516+
train_tmp_path = os.path.join(tmp_dfs_path, uuid.uuid4().hex)
1517+
train.write.save(train_tmp_path)
1518+
else:
1519+
validation.cache()
1520+
train.cache()
14801521

14811522
subModels = None
14821523
collectSubModelsParam = self.getCollectSubModels()
@@ -1495,8 +1536,17 @@ def _fit(self, dataset: DataFrame) -> "TrainValidationSplitModel":
14951536
assert subModels is not None
14961537
subModels[j] = subModel
14971538

1498-
train.unpersist()
1499-
validation.unpersist()
1539+
if tmp_dfs_path:
1540+
# TODO: Spark does not have FS API to delete a path on Distributed storage,
1541+
# this is a workaround to delete the data inside the temporary directory.
1542+
# we can improve it once Spark adds FS deletion API.
1543+
spark_session = SparkSession.getActiveSession()
1544+
empty_df = spark_session.range(0)
1545+
empty_df.write.mode("overwrite").save(validation_tmp_path)
1546+
empty_df.write.mode("overwrite").save(train_tmp_path)
1547+
else:
1548+
train.unpersist()
1549+
validation.unpersist()
15001550

15011551
if eva.isLargerBetter():
15021552
bestIndex = np.argmax(cast(List[float], metrics))

0 commit comments

Comments
 (0)