Skip to content

Commit

Permalink
Add support for external TVFs
Browse files Browse the repository at this point in the history
  • Loading branch information
kesmit13 committed Jan 15, 2025
1 parent fdf99d9 commit 8875d66
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 71 deletions.
1 change: 1 addition & 0 deletions singlestoredb/functions/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .decorator import tvf # noqa: F401
from .decorator import udf # noqa: F401
162 changes: 124 additions & 38 deletions singlestoredb/functions/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,52 +18,17 @@ def listify(x: Any) -> List[Any]:
return [x]


def udf(
def _func(
func: Optional[Callable[..., Any]] = None,
*,
name: Optional[str] = None,
args: Optional[Union[DataType, List[DataType], Dict[str, DataType]]] = None,
returns: Optional[str] = None,
data_format: Optional[str] = None,
include_masks: bool = False,
function_type: str = 'udf',
) -> Callable[..., Any]:
"""
Apply attributes to a UDF.
Parameters
----------
func : callable, optional
The UDF to apply parameters to
name : str, optional
The name to use for the UDF in the database
args : str | Callable | List[str | Callable] | Dict[str, str | Callable], optional
Specifies the data types of the function arguments. Typically,
the function data types are derived from the function parameter
annotations. These annotations can be overridden. If the function
takes a single type for all parameters, `args` can be set to a
SQL string describing all parameters. If the function takes more
than one parameter and all of the parameters are being manually
defined, a list of SQL strings may be used (one for each parameter).
A dictionary of SQL strings may be used to specify a parameter type
for a subset of parameters; the keys are the names of the
function parameters. Callables may also be used for datatypes. This
is primarily for using the functions in the ``dtypes`` module that
are associated with SQL types with all default options (e.g., ``dt.FLOAT``).
returns : str, optional
Specifies the return data type of the function. If not specified,
the type annotation from the function is used.
data_format : str, optional
The data format of each parameter: python, pandas, arrow, polars
include_masks : bool, optional
Should boolean masks be included with each input parameter to indicate
which elements are NULL? This is only used when a input parameters are
configured to a vector type (numpy, pandas, polars, arrow).
Returns
-------
Callable
"""
"""Generic wrapper for UDF and TVF decorators."""
if args is None:
pass
elif isinstance(args, (list, tuple)):
Expand Down Expand Up @@ -114,6 +79,7 @@ def udf(
returns=returns,
data_format=data_format,
include_masks=include_masks,
function_type=function_type,
).items() if v is not None
}

Expand All @@ -136,7 +102,127 @@ def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]:
return functools.wraps(func)(wrapper)


def udf(
func: Optional[Callable[..., Any]] = None,
*,
name: Optional[str] = None,
args: Optional[Union[DataType, List[DataType], Dict[str, DataType]]] = None,
returns: Optional[str] = None,
data_format: Optional[str] = None,
include_masks: bool = False,
) -> Callable[..., Any]:
"""
Apply attributes to a UDF.
Parameters
----------
func : callable, optional
The UDF to apply parameters to
name : str, optional
The name to use for the UDF in the database
args : str | Callable | List[str | Callable] | Dict[str, str | Callable], optional
Specifies the data types of the function arguments. Typically,
the function data types are derived from the function parameter
annotations. These annotations can be overridden. If the function
takes a single type for all parameters, `args` can be set to a
SQL string describing all parameters. If the function takes more
than one parameter and all of the parameters are being manually
defined, a list of SQL strings may be used (one for each parameter).
A dictionary of SQL strings may be used to specify a parameter type
for a subset of parameters; the keys are the names of the
function parameters. Callables may also be used for datatypes. This
is primarily for using the functions in the ``dtypes`` module that
are associated with SQL types with all default options (e.g., ``dt.FLOAT``).
returns : str, optional
Specifies the return data type of the function. If not specified,
the type annotation from the function is used.
data_format : str, optional
The data format of each parameter: python, pandas, arrow, polars
include_masks : bool, optional
Should boolean masks be included with each input parameter to indicate
which elements are NULL? This is only used when a input parameters are
configured to a vector type (numpy, pandas, polars, arrow).
Returns
-------
Callable
"""
return _func(
func=func,
name=name,
args=args,
returns=returns,
data_format=data_format,
include_masks=include_masks,
function_type='udf',
)


udf.pandas = functools.partial(udf, data_format='pandas') # type: ignore
udf.polars = functools.partial(udf, data_format='polars') # type: ignore
udf.arrow = functools.partial(udf, data_format='arrow') # type: ignore
udf.numpy = functools.partial(udf, data_format='numpy') # type: ignore


def tvf(
func: Optional[Callable[..., Any]] = None,
*,
name: Optional[str] = None,
args: Optional[Union[DataType, List[DataType], Dict[str, DataType]]] = None,
returns: Optional[str] = None,
data_format: Optional[str] = None,
include_masks: bool = False,
) -> Callable[..., Any]:
"""
Apply attributes to a TVF.
Parameters
----------
func : callable, optional
The TVF to apply parameters to
name : str, optional
The name to use for the TVF in the database
args : str | Callable | List[str | Callable] | Dict[str, str | Callable], optional
Specifies the data types of the function arguments. Typically,
the function data types are derived from the function parameter
annotations. These annotations can be overridden. If the function
takes a single type for all parameters, `args` can be set to a
SQL string describing all parameters. If the function takes more
than one parameter and all of the parameters are being manually
defined, a list of SQL strings may be used (one for each parameter).
A dictionary of SQL strings may be used to specify a parameter type
for a subset of parameters; the keys are the names of the
function parameters. Callables may also be used for datatypes. This
is primarily for using the functions in the ``dtypes`` module that
are associated with SQL types with all default options (e.g., ``dt.FLOAT``).
returns : str, optional
Specifies the return data type of the function. If not specified,
the type annotation from the function is used.
data_format : str, optional
The data format of each parameter: python, pandas, arrow, polars
include_masks : bool, optional
Should boolean masks be included with each input parameter to indicate
which elements are NULL? This is only used when a input parameters are
configured to a vector type (numpy, pandas, polars, arrow).
Returns
-------
Callable
"""
return _func(
func=func,
name=name,
args=args,
returns=returns,
data_format=data_format,
include_masks=include_masks,
function_type='tvf',
)


tvf.pandas = functools.partial(tvf, data_format='pandas') # type: ignore
tvf.polars = functools.partial(tvf, data_format='polars') # type: ignore
tvf.arrow = functools.partial(tvf, data_format='arrow') # type: ignore
tvf.numpy = functools.partial(tvf, data_format='numpy') # type: ignore
126 changes: 99 additions & 27 deletions singlestoredb/functions/ext/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,32 +158,91 @@ def make_func(
attrs = getattr(func, '_singlestoredb_attrs', {})
data_format = attrs.get('data_format') or 'python'
include_masks = attrs.get('include_masks', False)
function_type = attrs.get('function_type', 'udf').lower()
info: Dict[str, Any] = {}

if data_format == 'python':
async def do_func(
row_ids: Sequence[int],
rows: Sequence[Sequence[Any]],
) -> Tuple[
Sequence[int],
List[Tuple[Any]],
]:
'''Call function on given rows of data.'''
return row_ids, list(zip(func_map(func, rows)))
if function_type == 'tvf':
if data_format == 'python':
async def do_func(
row_ids: Sequence[int],
rows: Sequence[Sequence[Any]],
) -> Tuple[
Sequence[int],
List[Tuple[Any]],
]:
'''Call function on given rows of data.'''
out_ids: List[int] = []
out = []
for i, res in zip(row_ids, func_map(func, rows)):
out.extend(res)
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
return out_ids, out

else:
# Vector formats use the same function wrapper
async def do_func( # type: ignore
row_ids: Sequence[int],
cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
'''Call function on given cols of data.'''
if include_masks:
out = func(*cols)
assert isinstance(out, tuple)
return row_ids, [out]

out_ids, out = [], []
res = func(*[x[0] for x in cols])
for vec in res:
# C extension only supports Python objects as strings
if data_format == 'numpy' and str(vec.dtype)[:2] in ['<U', '<S']:
vec = vec.astype(object)
out.append((vec, None))

# NOTE: There is no way to determine which row ID belongs to
# each result row, so we just have to use the same
# row ID for all rows in the result.
if data_format == 'numpy':
import numpy as np
out_ids = np.array([row_ids[0]] * len(out[0][0]))
elif data_format == 'polars':
import polars as pl
out_ids = pl.Series([row_ids[0]] * len(out[0][0]))
elif data_format == 'arrow':
import pyarrow as pa
out_ids = pa.array([row_ids[0]] * len(out[0][0]))
elif data_format == 'pandas':
import pandas as pd
out_ids = pd.Series([row_ids[0]] * len(out[0][0]))

return out_ids, out

else:
# Vector formats use the same function wrapper
async def do_func( # type: ignore
row_ids: Sequence[int],
cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
'''Call function on given cols of data.'''
# TODO: only supports a single return value
if include_masks:
out = func(*cols)
assert isinstance(out, tuple)
return row_ids, [out]
return row_ids, [(func(*[x[0] for x in cols]), None)]
if data_format == 'python':
async def do_func(
row_ids: Sequence[int],
rows: Sequence[Sequence[Any]],
) -> Tuple[
Sequence[int],
List[Tuple[Any]],
]:
'''Call function on given rows of data.'''
return row_ids, list(zip(func_map(func, rows)))

else:
# Vector formats use the same function wrapper
async def do_func( # type: ignore
row_ids: Sequence[int],
cols: Sequence[Tuple[Sequence[Any], Optional[Sequence[bool]]]],
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
'''Call function on given cols of data.'''
if include_masks:
out = func(*cols)
assert isinstance(out, tuple)
return row_ids, [out]
out = func(*[x[0] for x in cols])
if isinstance(out, tuple):
return row_ids, [(x, None) for x in out]
return row_ids, [(out, None)]

do_func.__name__ = name
do_func.__doc__ = func.__doc__
Expand All @@ -196,6 +255,9 @@ async def do_func( # type: ignore
# Set data format
info['data_format'] = data_format

# Set function type
info['function_type'] = function_type

# Setup argument types for rowdat_1 parser
colspec = []
for x in sig['args']:
Expand All @@ -205,11 +267,21 @@ async def do_func( # type: ignore
colspec.append((x['name'], rowdat_1_type_map[dtype]))
info['colspec'] = colspec

def parse_return_type(s: str) -> List[str]:
if s.startswith('tuple['):
return s[6:-1].split(',')
if s.startswith('array[tuple['):
return s[12:-2].split(',')
return [s]

# Setup return type
dtype = sig['returns']['dtype'].replace('?', '')
if dtype not in rowdat_1_type_map:
raise TypeError(f'no data type mapping for {dtype}')
info['returns'] = [rowdat_1_type_map[dtype]]
returns = []
for x in parse_return_type(sig['returns']['dtype']):
dtype = x.replace('?', '')
if dtype not in rowdat_1_type_map:
raise TypeError(f'no data type mapping for {dtype}')
returns.append(rowdat_1_type_map[dtype])
info['returns'] = returns

return do_func, info

Expand All @@ -233,7 +305,7 @@ class Application(object):
* Function aliases : <pkg1>.[<func1@alias1,func2@alias2,...]
* Multiple packages : <pkg1>.<func1>:<pkg2>.<func2>
app_mode : str, optional
The mode of operation for the application: remote or collocated
The mode of operation for the application: remote, managed, or collocated
url : str, optional
The URL of the function API
data_format : str, optional
Expand Down
Loading

0 comments on commit 8875d66

Please sign in to comment.