Skip to content

Commit

Permalink
Add possibility to use LOAD DATA with in-memory stream
Browse files Browse the repository at this point in the history
  • Loading branch information
pmishchenko-ua committed Jan 3, 2024
1 parent 5d42aca commit a1d1e7e
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 26 deletions.
72 changes: 48 additions & 24 deletions singlestoredb/mysql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def cursor(self):
return self.cursorclass(self)

# The following methods are INTERNAL USE ONLY (called from Cursor)
def query(self, sql, unbuffered=False):
def query(self, sql, unbuffered=False, data_input_stream=None):
"""
Run a query on the server.
Expand All @@ -773,7 +773,9 @@ def query(self, sql, unbuffered=False):
if isinstance(sql, str):
sql = sql.encode(self.encoding, 'surrogateescape')
self._execute_command(COMMAND.COM_QUERY, sql)
self._affected_rows = self._read_query_result(unbuffered=unbuffered)
self._affected_rows = self._read_query_result(
unbuffered=unbuffered,
data_input_stream=data_input_stream)
return self._affected_rows

def next_result(self, unbuffered=False):
Expand Down Expand Up @@ -1120,13 +1122,13 @@ def _write_bytes(self, data):
CR.CR_SERVER_GONE_ERROR, f'MySQL server has gone away ({e!r})',
)

def _read_query_result(self, unbuffered=False):
def _read_query_result(self, unbuffered=False, data_input_stream=None):
self._result = None
if unbuffered:
result = self.resultclass(self, unbuffered=unbuffered)
else:
result = self.resultclass(self)
result.read()
result.read(data_input_stream)
self._result = result
if result.server_status is not None:
self.server_status = result.server_status
Expand Down Expand Up @@ -1532,14 +1534,14 @@ def __del__(self):
if self.unbuffered_active:
self._finish_unbuffered_query()

def read(self):
def read(self, data_input_stream=None):
try:
first_packet = self.connection._read_packet()

if first_packet.is_ok_packet():
self._read_ok_packet(first_packet)
elif first_packet.is_load_local_packet():
self._read_load_local_packet(first_packet)
self._read_load_local_packet(first_packet, data_input_stream)
else:
self._read_result_packet(first_packet)
finally:
Expand Down Expand Up @@ -1584,13 +1586,13 @@ def _read_ok_packet(self, first_packet):
self.message = ok_packet.message
self.has_next = ok_packet.has_next

def _read_load_local_packet(self, first_packet):
def _read_load_local_packet(self, first_packet, data_input_stream=None):
if not self.connection._local_infile:
raise RuntimeError(
'**WARN**: Received LOAD_LOCAL packet but local_infile option is false.',
)
load_packet = LoadLocalPacketWrapper(first_packet)
sender = LoadLocalFile(load_packet.filename, self.connection)
sender = LoadLocalFile(load_packet.filename, self.connection, data_input_stream)
try:
sender.send_data()
except Exception:
Expand Down Expand Up @@ -1765,32 +1767,54 @@ def __init__(self, connection, unbuffered=False):

class LoadLocalFile:

def __init__(self, filename, connection):
def __init__(self, filename, connection, data_input_stream):
self.filename = filename
self.connection = connection
self.input_stream = data_input_stream

def send_data(self):
"""Send data packets from the local file to the server"""
if not self.connection._sock:
raise err.InterfaceError(0, '')
conn = self.connection

try:
with open(self.filename, 'rb') as open_file:
if self.input_stream is None:
try:
with open(self.filename, 'rb') as open_file:
packet_size = min(
conn.max_allowed_packet, 16 * 1024,
) # 16KB is efficient enough
while True:
chunk = open_file.read(packet_size)
if not chunk:
break
conn.write_packet(chunk)
except OSError:
raise err.OperationalError(
ER.FILE_NOT_FOUND,
f"Can't find file '{self.filename}'",
)
finally:
if not conn._closed:
# send the empty packet to signify we are done sending data
conn.write_packet(b'')
else:
try:
packet_size = min(
conn.max_allowed_packet, 16 * 1024,
) # 16KB is efficient enough
conn.max_allowed_packet, 256 * 1024,
) # 256KB
self.input_stream.seek(0)
while True:
chunk = open_file.read(packet_size)
chunk = self.input_stream.read(packet_size)
if not chunk:
break
conn.write_packet(chunk)
except OSError:
raise err.OperationalError(
ER.FILE_NOT_FOUND,
f"Can't find file '{self.filename}'",
)
finally:
if not conn._closed:
# send the empty packet to signify we are done sending data
conn.write_packet(b'')
except AttributeError:
raise err.OperationalError(
ER.CHECKREAD,
"Can't read the stream attached to LOAD DATA statement. "
"Make sure that the object implements read(size: int) -> bytes",
)
finally:
if not conn._closed:
# send the empty packet to signify we are done sending data
conn.write_packet(b'')
40 changes: 38 additions & 2 deletions singlestoredb/mysql/cursors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# type: ignore
import re
from collections import namedtuple
import io

from . import err
from ..connection import Cursor as BaseCursor
Expand Down Expand Up @@ -277,6 +278,41 @@ def _do_execute_many(
self.rowcount = rows
return rows

def load_data(self, query, data_input_stream: io.BytesIO, args=None):
"""
Execute a LOAD DATA query, reading the data from data_input_stream
instead of the file.
If args is a list or tuple, :1, :2, etc. can be used as a
placeholder in the query. If args is a dict, :name can be used
as a placeholder in the query.
Parameters
----------
query : str
Query to execute. File name is ignored
data_input_stream : AVRO-formatted io.BytesIO object.
TODO: accept an iterable of rows
args : Sequnce[Any], optional
Sequence of sequences or mappings. It is used as parameter.
Returns
-------
int : Number of affected rows.
"""
while self.nextset():
pass

log_query(query, args)

query = self.mogrify(query, args)

result = self._query(query, data_input_stream)
self._executed = query
return result

def callproc(self, procname, args=()):
"""
Execute stored procedure procname with args.
Expand Down Expand Up @@ -380,10 +416,10 @@ def scroll(self, value, mode='relative'):
raise IndexError('out of range')
self._rownumber = r

def _query(self, q):
def _query(self, q, data_input_stream=None):
conn = self._get_db()
self._clear_result()
conn.query(q)
conn.query(q, data_input_stream=data_input_stream)
self._do_get_result()
return self.rowcount

Expand Down

0 comments on commit a1d1e7e

Please sign in to comment.