1717
1818import os
1919import sys
20+ import uuid
2021import itertools
2122from multiprocessing .pool import ThreadPool
2223from typing import (
7576]
7677
7778
79+ _SPARKML_TUNING_TEMP_DFS_PATH = "SPARKML_TUNING_TEMP_DFS_PATH"
80+
81+
82+ def _get_temp_dfs_path ():
83+ return os .environ .get (_SPARKML_TUNING_TEMP_DFS_PATH )
84+
85+
86+
87+
7888def _parallelFitTasks (
7989 est : Estimator ,
8090 train : DataFrame ,
@@ -847,9 +857,20 @@ def _fit(self, dataset: DataFrame) -> "CrossValidatorModel":
847857 subModels = [[None for j in range (numModels )] for i in range (nFolds )]
848858
849859 datasets = self ._kFold (dataset )
860+
861+ tmp_dfs_path = _get_temp_dfs_path ()
850862 for i in range (nFolds ):
851- validation = datasets [i ][1 ].cache ()
852- train = datasets [i ][0 ].cache ()
863+ validation = datasets [i ][1 ]
864+ train = datasets [i ][0 ]
865+
866+ if tmp_dfs_path :
867+ validation_tmp_path = os .path .join (tmp_dfs_path , uuid .uuid4 ().hex )
868+ validation .write .save (validation_tmp_path )
869+ train_tmp_path = os .path .join (tmp_dfs_path , uuid .uuid4 ().hex )
870+ train .write .save (train_tmp_path )
871+ else :
872+ validation .cache ()
873+ train .cache ()
853874
854875 tasks = map (
855876 inheritable_thread_target (dataset .sparkSession ),
@@ -861,8 +882,17 @@ def _fit(self, dataset: DataFrame) -> "CrossValidatorModel":
861882 assert subModels is not None
862883 subModels [i ][j ] = subModel
863884
864- validation .unpersist ()
865- train .unpersist ()
885+ if tmp_dfs_path :
886+ # TODO: Spark does not have FS API to delete a path on Distributed storage,
887+ # this is a workaround to delete the data inside the temporary directory.
888+ # we can improve it once Spark adds FS deletion API.
889+ spark_session = SparkSession .getActiveSession ()
890+ empty_df = spark_session .range (0 )
891+ empty_df .write .mode ("overwrite" ).save (validation_tmp_path )
892+ empty_df .write .mode ("overwrite" ).save (train_tmp_path )
893+ else :
894+ validation .unpersist ()
895+ train .unpersist ()
866896
867897 metrics , std_metrics = CrossValidator ._gen_avg_and_std_metrics (metrics_all )
868898
@@ -1475,8 +1505,19 @@ def _fit(self, dataset: DataFrame) -> "TrainValidationSplitModel":
14751505 randCol = self .uid + "_rand"
14761506 df = dataset .select ("*" , F .rand (seed ).alias (randCol ))
14771507 condition = df [randCol ] >= tRatio
1478- validation = df .filter (condition ).cache ()
1479- train = df .filter (~ condition ).cache ()
1508+
1509+ validation = df .filter (condition )
1510+ train = df .filter (~ condition )
1511+
1512+ tmp_dfs_path = _get_temp_dfs_path ()
1513+ if tmp_dfs_path :
1514+ validation_tmp_path = os .path .join (tmp_dfs_path , uuid .uuid4 ().hex )
1515+ validation .write .save (validation_tmp_path )
1516+ train_tmp_path = os .path .join (tmp_dfs_path , uuid .uuid4 ().hex )
1517+ train .write .save (train_tmp_path )
1518+ else :
1519+ validation .cache ()
1520+ train .cache ()
14801521
14811522 subModels = None
14821523 collectSubModelsParam = self .getCollectSubModels ()
@@ -1495,8 +1536,17 @@ def _fit(self, dataset: DataFrame) -> "TrainValidationSplitModel":
14951536 assert subModels is not None
14961537 subModels [j ] = subModel
14971538
1498- train .unpersist ()
1499- validation .unpersist ()
1539+ if tmp_dfs_path :
1540+ # TODO: Spark does not have FS API to delete a path on Distributed storage,
1541+ # this is a workaround to delete the data inside the temporary directory.
1542+ # we can improve it once Spark adds FS deletion API.
1543+ spark_session = SparkSession .getActiveSession ()
1544+ empty_df = spark_session .range (0 )
1545+ empty_df .write .mode ("overwrite" ).save (validation_tmp_path )
1546+ empty_df .write .mode ("overwrite" ).save (train_tmp_path )
1547+ else :
1548+ train .unpersist ()
1549+ validation .unpersist ()
15001550
15011551 if eva .isLargerBetter ():
15021552 bestIndex = np .argmax (cast (List [float ], metrics ))
0 commit comments