Skip to content

Commit ae86e44

Browse files
committed
Fusion updates; Stages updates
1 parent 48e99fb commit ae86e44

File tree

7 files changed

+80
-30
lines changed

7 files changed

+80
-30
lines changed

docs/src/api.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,8 @@ methods and attributes.
240240

241241
Stages
242242
Stages.open
243-
Stages.download
243+
Stages.download_file
244+
Stages.download_folder
244245
Stages.upload_file
245246
Stages.upload_folder
246247
Stages.info

singlestoredb/config.py

+6
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,12 @@
214214
environ='SINGLESTOREDB_TRACK_ENV',
215215
)
216216

217+
register_option(
218+
'fusion.enabled', 'bool', check_bool, False,
219+
'Should Fusion SQL queries be enabled?',
220+
environ='SINGLESTOREDB_FUSION_ENABLED',
221+
)
222+
217223
#
218224
# Query results options
219225
#

singlestoredb/fusion/handlers/stages.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class DownloadStageFileHandler(SQLHandler):
108108
def run(self, params: Dict[str, Any]) -> Optional[FusionSQLResult]:
109109
wg = get_workspace_group(params)
110110

111-
out = wg.stages.download(
111+
out = wg.stages.download_file(
112112
params['stage_path'],
113113
local_path=params['local_path'] or None,
114114
overwrite=params['overwrite'],

singlestoredb/fusion/registry.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python3
2-
import os
32
import re
43
from typing import Any
54
from typing import Dict
@@ -11,9 +10,9 @@
1110

1211
from . import result
1312
from .. import connection
13+
from ..config import get_option
1414
from .handler import SQLHandler
1515

16-
_enabled = ('1', 'yes', 'on', 'enabled', 'true')
1716
_handlers: Dict[str, Type[SQLHandler]] = {}
1817
_handlers_re: Optional[Any] = None
1918

@@ -64,7 +63,7 @@ def get_handler(sql: Union[str, bytes]) -> Optional[Type[SQLHandler]]:
6463
None - if no matching handler could be found
6564
6665
"""
67-
if not os.environ.get('SINGLESTOREDB_ENABLE_FUSION', '').lower() in _enabled:
66+
if not get_option('fusion.enabled'):
6867
return None
6968

7069
if isinstance(sql, (bytes, bytearray)):
@@ -103,7 +102,7 @@ def execute(
103102
FusionSQLResult
104103
105104
"""
106-
if not os.environ.get('SINGLESTOREDB_ENABLE_FUSION', '').lower() in _enabled:
105+
if not get_option('fusion.enabled'):
107106
raise RuntimeError('management API queries have not been enabled')
108107

109108
if handler is None:

singlestoredb/management/workspace.py

+53-9
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,13 @@ def download(
198198
msg='No Stages object is associated with this object.',
199199
)
200200

201-
return self._stages.download(
201+
return self._stages.download_file(
202202
self.path, local_path=local_path,
203203
overwrite=overwrite, encoding=encoding,
204204
)
205205

206+
download_file = download
207+
206208
def remove(self) -> None:
207209
"""Delete the stage file."""
208210
if self._stages is None:
@@ -230,7 +232,7 @@ def removedirs(self) -> None:
230232

231233
self._stages.removedirs(self.path)
232234

233-
def rename(self, new_path: PathLike, *, overwrite: bool = False) -> StagesObject:
235+
def rename(self, new_path: PathLike, *, overwrite: bool = False) -> None:
234236
"""
235237
Move the stage file to a new location.
236238
@@ -246,7 +248,10 @@ def rename(self, new_path: PathLike, *, overwrite: bool = False) -> StagesObject
246248
raise ManagementError(
247249
msg='No Stages object is associated with this object.',
248250
)
249-
return self._stages.rename(self.path, new_path, overwrite=overwrite)
251+
out = self._stages.rename(self.path, new_path, overwrite=overwrite)
252+
self.name = out.name
253+
self.path = out.path
254+
return None
250255

251256
def exists(self) -> bool:
252257
"""Does the file / folder exist?"""
@@ -384,7 +389,7 @@ def open(
384389
return StagesObjectTextWriter('', self, stage_path)
385390

386391
if 'r' in mode:
387-
content = self.download(stage_path)
392+
content = self.download_file(stage_path)
388393
if isinstance(content, bytes):
389394
if 'b' in mode:
390395
return StagesObjectBytesReader(content)
@@ -400,7 +405,7 @@ def open(
400405

401406
def upload_file(
402407
self,
403-
local_path: PathLike,
408+
local_path: Union[PathLike, TextIO, BinaryIO],
404409
stage_path: PathLike,
405410
*,
406411
overwrite: bool = False,
@@ -410,15 +415,17 @@ def upload_file(
410415
411416
Parameters
412417
----------
413-
local_path : Path or str
414-
Path to the local file
418+
local_path : Path or str or file-like
419+
Path to the local file or an open file object
415420
stage_path : Path or str
416421
Path to the stage file
417422
overwrite : bool, optional
418423
Should the ``stage_path`` be overwritten if it exists already?
419424
420425
"""
421-
if not os.path.isfile(local_path):
426+
if isinstance(local_path, (TextIO, BinaryIO)):
427+
pass
428+
elif not os.path.isfile(local_path):
422429
raise IsADirectoryError(f'local path is not a file: {local_path}')
423430

424431
if self.exists(stage_path):
@@ -427,6 +434,8 @@ def upload_file(
427434

428435
self.remove(stage_path)
429436

437+
if isinstance(local_path, (TextIO, BinaryIO)):
438+
return self._upload(local_path, stage_path, overwrite=overwrite)
430439
return self._upload(open(local_path, 'rb'), stage_path, overwrite=overwrite)
431440

432441
def upload_folder(
@@ -727,7 +736,7 @@ def listdir(
727736

728737
raise NotADirectoryError(f'stage path is not a directory: {stage_path}')
729738

730-
def download(
739+
def download_file(
731740
self,
732741
stage_path: PathLike,
733742
local_path: Optional[PathLike] = None,
@@ -774,6 +783,41 @@ def download(
774783

775784
return out
776785

786+
def download_folder(
787+
self,
788+
stage_path: PathLike,
789+
local_path: PathLike = '.',
790+
*,
791+
overwrite: bool = False,
792+
) -> None:
793+
"""
794+
Download a Stages folder to a local directory.
795+
796+
Parameters
797+
----------
798+
stage_path : Path or str
799+
Path to the stage file
800+
local_path : Path or str
801+
Path to local directory target location
802+
overwrite : bool, optional
803+
Should an existing directory / files be overwritten if they exist?
804+
805+
"""
806+
if local_path is not None and not overwrite and os.path.exists(local_path):
807+
raise OSError(
808+
'target directory already exists; '
809+
'use overwrite=True to replace',
810+
)
811+
if not self.is_dir(stage_path):
812+
raise NotADirectoryError(f'stage path is not a directory: {stage_path}')
813+
814+
for f in self.listdir(stage_path, recursive=True):
815+
if self.is_dir(f):
816+
continue
817+
target = os.path.normpath(os.path.join(local_path, f))
818+
os.makedirs(os.path.dirname(target), exist_ok=True)
819+
self.download_file(f, target, overwrite=overwrite)
820+
777821
def remove(self, stage_path: PathLike) -> None:
778822
"""
779823
Delete a stage location.

singlestoredb/tests/test_fusion.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,16 @@ def tearDownClass(cls):
3131
utils.drop_database(cls.dbname)
3232

3333
def setUp(self):
34-
self.enabled = os.environ.get('SINGLESTOREDB_ENABLE_FUSION')
35-
os.environ['SINGLESTOREDB_ENABLE_FUSION'] = '1'
34+
self.enabled = os.environ.get('SINGLESTOREDB_FUSION_ENABLED')
35+
os.environ['SINGLESTOREDB_FUSION_ENABLED'] = '1'
3636
self.conn = s2.connect(database=type(self).dbname, local_infile=True)
3737
self.cur = self.conn.cursor()
3838

3939
def tearDown(self):
4040
if self.enabled:
41-
os.environ['SINGLESTOREDB_ENABLE_FUSION'] = self.enabled
41+
os.environ['SINGLESTOREDB_FUSION_ENABLED'] = self.enabled
4242
else:
43-
del os.environ['SINGLESTOREDB_ENABLE_FUSION']
43+
del os.environ['SINGLESTOREDB_FUSION_ENABLED']
4444

4545
try:
4646
if self.cur is not None:
@@ -57,17 +57,17 @@ def tearDown(self):
5757
pass
5858

5959
def test_env_var(self):
60-
os.environ['SINGLESTOREDB_ENABLE_FUSION'] = '0'
60+
os.environ['SINGLESTOREDB_FUSION_ENABLED'] = '0'
6161

6262
with self.assertRaises(s2.ProgrammingError):
6363
self.cur.execute('show fusion commands')
6464

65-
del os.environ['SINGLESTOREDB_ENABLE_FUSION']
65+
del os.environ['SINGLESTOREDB_FUSION_ENABLED']
6666

6767
with self.assertRaises(s2.ProgrammingError):
6868
self.cur.execute('show fusion commands')
6969

70-
os.environ['SINGLESTOREDB_ENABLE_FUSION'] = 'yes'
70+
os.environ['SINGLESTOREDB_FUSION_ENABLED'] = 'yes'
7171

7272
self.cur.execute('show fusion commands')
7373
assert list(self.cur)
@@ -132,16 +132,16 @@ def tearDownClass(cls):
132132
cls.workspace_groups.pop().terminate(force=True)
133133

134134
def setUp(self):
135-
self.enabled = os.environ.get('SINGLESTOREDB_ENABLE_FUSION')
136-
os.environ['SINGLESTOREDB_ENABLE_FUSION'] = '1'
135+
self.enabled = os.environ.get('SINGLESTOREDB_FUSION_ENABLED')
136+
os.environ['SINGLESTOREDB_FUSION_ENABLED'] = '1'
137137
self.conn = s2.connect(database=type(self).dbname, local_infile=True)
138138
self.cur = self.conn.cursor()
139139

140140
def tearDown(self):
141141
if self.enabled:
142-
os.environ['SINGLESTOREDB_ENABLE_FUSION'] = self.enabled
142+
os.environ['SINGLESTOREDB_FUSION_ENABLED'] = self.enabled
143143
else:
144-
del os.environ['SINGLESTOREDB_ENABLE_FUSION']
144+
del os.environ['SINGLESTOREDB_FUSION_ENABLED']
145145

146146
try:
147147
if self.cur is not None:

singlestoredb/tests/test_management.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def test_open(self):
474474
with st.open('open_test.sql', 'w') as wfile:
475475
wfile.write(open(TEST_DIR / 'test2.sql').read())
476476

477-
txt = st.download('open_test.sql', encoding='utf-8')
477+
txt = st.download_file('open_test.sql', encoding='utf-8')
478478

479479
assert txt == open(TEST_DIR / 'test2.sql').read()
480480

@@ -484,7 +484,7 @@ def test_open(self):
484484
wfile.write(line)
485485
wfile.close()
486486

487-
txt = st.download('open_raw_test.sql', encoding='utf-8')
487+
txt = st.download_file('open_raw_test.sql', encoding='utf-8')
488488

489489
assert txt == open(TEST_DIR / 'test.sql').read()
490490

@@ -524,7 +524,7 @@ def test_obj_open(self):
524524
wfile.write(line)
525525
wfile.close()
526526

527-
txt = st.download(f.path, encoding='utf-8')
527+
txt = st.download_file(f.path, encoding='utf-8')
528528

529529
assert txt == open(TEST_DIR / 'test.sql').read()
530530

@@ -786,7 +786,7 @@ def test_stages_object(self):
786786
# rename
787787
assert st.exists('obj_test.sql')
788788
assert not st.exists('obj_test_2.sql')
789-
f1 = f1.rename('obj_test_2.sql')
789+
f1.rename('obj_test_2.sql')
790790
assert not st.exists('obj_test.sql')
791791
assert st.exists('obj_test_2.sql')
792792
assert f1.abspath() == 'obj_test_2.sql'

0 commit comments

Comments
 (0)