Skip to content

Commit f0b0047

Browse files
authored
add CODE_TAR for trainer worker (#698)
* add code_tar to .sh * complete test * fix Co-authored-by: xiangyuxuan.prs <[email protected]>
1 parent 8826853 commit f0b0047

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed

deploy/scripts/env_to_args.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ pull_code() {
2020
wget $1 -O code.tar.gz
2121
elif [[ $1 == "oss://"* ]]; then
2222
python -c "import tensorflow as tf; import tensorflow_io; open('code.tar.gz', 'wb').write(tf.io.gfile.GFile('$1', 'rb').read())"
23+
elif [[ $1 == "base64://"* ]]; then
24+
python -c "import base64; f = open('code.tar.gz', 'wb'); f.write(base64.b64decode('$1'[9:])); f.close()"
2325
else
2426
cp $1 code.tar.gz
25-
fi
27+
fi
2628
tar -zxvf code.tar.gz
2729
cd $cwd
2830
}

deploy/scripts/trainer/run_trainer_worker.sh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,12 @@ for i in "${WORKER_GROUPS[@]}"; do
4343
done
4444
fi
4545

46-
pull_code ${CODE_KEY} $PWD
46+
if [[ -n "${CODE_KEY}" ]]; then
47+
pull_code ${CODE_KEY} $PWD
48+
else
49+
pull_code ${CODE_TAR} $PWD
50+
fi
51+
4752
cd ${ROLE}
4853

4954
mode=$(normalize_env_to_args "--mode" "$MODE")

web_console_v2/api/test/fedlearner_webconsole/job/yaml_formatter_test.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
# coding: utf-8
1616
import unittest
17-
18-
from fedlearner_webconsole.job.yaml_formatter import format_yaml
17+
import tarfile
18+
import base64
19+
from io import BytesIO
20+
from fedlearner_webconsole.job.yaml_formatter import format_yaml, code_dict_encode
1921

2022

2123
class YamlFormatterTest(unittest.TestCase):
@@ -62,6 +64,22 @@ def test_format_yaml_unknown_ph(self):
6264
format_yaml('$x.y is ${i.j}', x=x)
6365
self.assertEqual(str(cm.exception), 'Unknown placeholder: i.j')
6466

67+
def test_encode_code(self):
68+
test_data = {'test/a.py': 'awefawefawefawefwaef',
69+
'test1/b.py': 'asdfasd',
70+
'c.py': '',
71+
'test/d.py': 'asdf'}
72+
code_base64 = code_dict_encode(test_data)
73+
code_dict = {}
74+
if code_base64.startswith('base64://'):
75+
tar_binary = BytesIO(base64.b64decode(code_base64[9:]))
76+
with tarfile.open(fileobj=tar_binary) as tar:
77+
for file in tar.getmembers():
78+
code_dict[file.name] = str(tar.extractfile(file).read(),
79+
encoding='utf-8')
80+
self.assertEqual(code_dict, test_data)
81+
82+
6583

6684
if __name__ == '__main__':
6785
unittest.main()

0 commit comments

Comments
 (0)