Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add risingwave engine adapter support for sqlmesh. #3436

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f652202
feat(rw): support basic risingwave dialect
lin0303-siyuan Nov 28, 2024
a99f828
feat(rw): add risingwave integration support
lin0303-siyuan Nov 28, 2024
32c763e
fix(rw): correct py style
lin0303-siyuan Nov 28, 2024
fe78639
fix(rw): risingwave integration test support, only sushi failed
lin0303-siyuan Dec 17, 2024
9661bf6
fix(rw): use postgres engine adapter as base for rw
lin0303-siyuan Dec 18, 2024
3f7fff6
feat(rw): support drop schema cascade
lin0303-siyuan Dec 18, 2024
ba86e19
Merge branch 'main' into feature/rw
lin0303-siyuan Dec 18, 2024
d5b7b83
fix(rw): fix t.Literal type error for rw connection
lin0303-siyuan Dec 18, 2024
f455e40
fix(rw): skip truncate test for risingwave
lin0303-siyuan Dec 18, 2024
2e4e9ea
fix(rw): use latest rw image
lin0303-siyuan Jan 3, 2025
650112f
Merge branch 'main' into feature/rw
lin0303-siyuan Jan 3, 2025
ad0b722
fix(rw): remove unnecessary log and comment
lin0303-siyuan Jan 3, 2025
8bcc106
fix(rw): remove print in test_integration
lin0303-siyuan Jan 3, 2025
1a7f618
fix(rw): correct default integration init and update rw tests
lin0303-siyuan Jan 8, 2025
d0255b5
fix(rw): set flush true in cursor init for rw
lin0303-siyuan Jan 8, 2025
7ffdc42
fix(rw): impl truncate_table use delete from
lin0303-siyuan Jan 9, 2025
c92bffe
fix(rw): use latest rw image and delete drop schema alternative impl
lin0303-siyuan Jan 21, 2025
e8ebed5
fix(rw): don't skip the truncate table test for rw
lin0303-siyuan Jan 22, 2025
e34a716
fix(rw): password is not required for rw adapter
lin0303-siyuan Jan 23, 2025
badcd70
feat(rw): add risingwave docs
lin0303-siyuan Jan 23, 2025
7be81a8
Merge branch 'main' into feature/rw
lin0303-siyuan Jan 23, 2025
50ce26d
feat(rw): add create sink example for rw docs
lin0303-siyuan Jan 23, 2025
4341e8f
fix(rw): remove unnecessary connection args
lin0303-siyuan Jan 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/continue_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ workflows:
- spark
- clickhouse
- clickhouse-cluster
- risingwave
- engine_tests_cloud:
name: cloud_engine_<< matrix.engine >>
context:
Expand Down
4 changes: 4 additions & 0 deletions .circleci/wait-for-db.sh
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ trino_ready() {
docker compose -f tests/core/engine_adapter/integration/docker/compose.trino.yaml exec trino /bin/bash -c '/usr/lib/trino/bin/health-check'
}

risingwave_ready() {
probe_port 4566
}

echo "Waiting for $ENGINE to be ready..."

READINESS_FUNC="${ENGINE}_ready"
Expand Down
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ spark-test: engine-spark-up
trino-test: engine-trino-up
pytest -n auto -x -m "trino or trino_iceberg or trino_delta" --retries 3 --junitxml=test-results/junit-trino.xml

risingwave-test: engine-risingwave-up
pytest -n auto -x -m "risingwave" --retries 3 --junitxml=test-results/junit-risingwave.xml

#################
# Cloud Engines #
#################
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@
"sse-starlette>=0.2.2",
"pyarrow",
],
"risingwave": [
"psycopg2",
],
},
classifiers=[
"Intended Audience :: Developers",
Expand Down
42 changes: 42 additions & 0 deletions sqlmesh/core/config/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1730,6 +1730,48 @@ def get_catalog(self) -> t.Optional[str]:
return self.catalog_name


class RisingwaveConnectionConfig(ConnectionConfig):
erindru marked this conversation as resolved.
Show resolved Hide resolved
host: str
user: str
password: str
port: int
database: str
keepalives_idle: t.Optional[int] = None
connect_timeout: int = 10
role: t.Optional[str] = None
sslmode: t.Optional[str] = None

concurrent_tasks: int = 4
register_comments: bool = True
pre_ping: bool = True

type_: t.Literal["risingwave"] = Field(alias="type", default="risingwave")

@property
def _connection_kwargs_keys(self) -> t.Set[str]:
return {
"host",
"user",
"password",
"port",
"database",
"keepalives_idle",
"connect_timeout",
"role",
"sslmode",
}

@property
def _engine_adapter(self) -> t.Type[EngineAdapter]:
return engine_adapter.RisingwaveEngineAdapter

@property
def _connection_factory(self) -> t.Callable:
from psycopg2 import connect

return connect


CONNECTION_CONFIG_TO_TYPE = {
# Map all subclasses of ConnectionConfig to the value of their `type_` field.
tpe.all_field_infos()["type_"].default: tpe
Expand Down
2 changes: 2 additions & 0 deletions sqlmesh/core/engine_adapter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter
from sqlmesh.core.engine_adapter.trino import TrinoEngineAdapter
from sqlmesh.core.engine_adapter.athena import AthenaEngineAdapter
from sqlmesh.core.engine_adapter.risingwave import RisingwaveEngineAdapter

DIALECT_TO_ENGINE_ADAPTER = {
"hive": SparkEngineAdapter,
Expand All @@ -33,6 +34,7 @@
"mssql": MSSQLEngineAdapter,
"trino": TrinoEngineAdapter,
"athena": AthenaEngineAdapter,
"risingwave": RisingwaveEngineAdapter,
}

DIALECT_ALIASES = {
Expand Down
99 changes: 99 additions & 0 deletions sqlmesh/core/engine_adapter/risingwave.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from __future__ import annotations

import logging
import typing as t


from sqlglot import Dialect, exp

from sqlmesh.core.engine_adapter.postgres import PostgresEngineAdapter
from sqlmesh.core.engine_adapter.shared import (
set_catalog,
CatalogSupport,
CommentCreationView,
DataObjectType,
CommentCreationTable,
)


if t.TYPE_CHECKING:
from sqlmesh.core._typing import SessionProperties
from sqlmesh.core._typing import SchemaName

logger = logging.getLogger(__name__)


@set_catalog()
class RisingwaveEngineAdapter(PostgresEngineAdapter):
DIALECT = "risingwave"
DEFAULT_BATCH_SIZE = 400
CATALOG_SUPPORT = CatalogSupport.SINGLE_CATALOG_ONLY
COMMENT_CREATION_TABLE = CommentCreationTable.COMMENT_COMMAND_ONLY
COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED
SUPPORTS_MATERIALIZED_VIEWS = True
# Temporarily set this because integration test: test_transaction uses truncate table operation, which is not supported in risingwave.
erindru marked this conversation as resolved.
Show resolved Hide resolved
SUPPORTS_TRANSACTIONS = False

def _set_flush(self) -> None:
sql = "SET RW_IMPLICIT_FLUSH TO true;"
erindru marked this conversation as resolved.
Show resolved Hide resolved
self._execute(sql)

def __init__(
self,
connection_factory: t.Callable[[], t.Any],
dialect: str = "",
sql_gen_kwargs: t.Optional[t.Dict[str, Dialect | bool | str]] = None,
multithreaded: bool = False,
cursor_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
cursor_init: t.Optional[t.Callable[[t.Any], None]] = None,
default_catalog: t.Optional[str] = None,
execute_log_level: int = logging.DEBUG,
register_comments: bool = True,
pre_ping: bool = False,
**kwargs: t.Any,
):
super().__init__(
connection_factory,
dialect,
sql_gen_kwargs,
multithreaded,
cursor_kwargs,
cursor_init,
default_catalog,
execute_log_level,
register_comments,
pre_ping,
**kwargs,
)
if hasattr(self, "cursor"):
self._set_flush()

def _begin_session(self, properties: SessionProperties) -> t.Any:
"""Begin a new session."""
self._set_flush()

def drop_schema(
self,
schema_name: SchemaName,
ignore_if_not_exists: bool = True,
cascade: bool = False,
**drop_args: t.Dict[str, exp.Expression],
) -> None:
"""
Risingwave doesn't support CASCADE clause and drops schemas unconditionally so far.
erindru marked this conversation as resolved.
Show resolved Hide resolved
If cascade is supported later, this logic could be discarded.
"""
if cascade:
objects = self._get_data_objects(schema_name)
for obj in objects:
if obj.type == DataObjectType.VIEW:
self.drop_view(
".".join([obj.schema_name, obj.name]),
ignore_if_not_exists=ignore_if_not_exists,
)
else:
self.drop_table(
".".join([obj.schema_name, obj.name]),
exists=ignore_if_not_exists,
)
super().drop_schema(schema_name, ignore_if_not_exists=ignore_if_not_exists, cascade=False)
37 changes: 36 additions & 1 deletion tests/core/engine_adapter/integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def supports_merge(self) -> bool:
if self.dialect == "athena":
return "hive" not in self.mark

if self.dialect == "risingwave":
return False

return True

@property
Expand All @@ -177,7 +180,7 @@ def get_metadata_results(self, schema: t.Optional[str] = None) -> MetadataResult

def _init_engine_adapter(self) -> None:
schema = self.schema(TEST_SCHEMA)
self.engine_adapter.drop_schema(schema, ignore_if_not_exists=True, cascade=True)
self.engine_adapter.drop_schema(schema, ignore_if_not_exists=True, cascade=False)
erindru marked this conversation as resolved.
Show resolved Hide resolved
self.engine_adapter.create_schema(schema)

def _format_df(self, data: pd.DataFrame, to_datetime: bool = True) -> pd.DataFrame:
Expand Down Expand Up @@ -348,6 +351,20 @@ def get_table_comment(
"""
elif self.dialect == "clickhouse":
query = f"SELECT name, comment FROM system.tables WHERE database = '{schema_name}' AND name = '{table_name}'"
elif self.dialect == "risingwave":
query = f"""
SELECT
c.relname,
d.description
FROM pg_class c
INNER JOIN pg_description d ON c.oid = d.objoid AND d.objsubid = 0
INNER JOIN pg_namespace n ON c.relnamespace = n.oid
WHERE
c.relname = '{table_name}'
AND n.nspname= '{schema_name}'
AND c.relkind = '{'v' if table_kind == "VIEW" else 'r'}'
;
"""

result = self.engine_adapter.fetchall(query)

Expand Down Expand Up @@ -439,6 +456,24 @@ def get_column_comments(
schema_name = '{schema_name}'
AND table_name = '{table_name}'
"""
elif self.dialect == "risingwave":
query = f"""
SELECT
a.attname AS column_name, d.description
FROM
pg_class c
INNER JOIN pg_namespace n ON c.relnamespace = n.oid
INNER JOIN pg_attribute a ON c.oid = a.attrelid
INNER JOIN pg_description d
ON
a.attnum = d.objsubid
AND d.objoid = c.oid
WHERE
n.nspname = '{schema_name}'
AND c.relname = '{table_name}'
AND c.relkind = '{'v' if table_kind == "VIEW" else 'r'}'
;
"""

result = self.engine_adapter.fetchall(query)

Expand Down
8 changes: 8 additions & 0 deletions tests/core/engine_adapter/integration/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ gateways:
cluster: cluster1
state_connection:
type: duckdb
inttest_risingwave:
erindru marked this conversation as resolved.
Show resolved Hide resolved
connection:
type: risingwave
user: root
password: risingwave
database: dev
host: {{ env_var('DOCKER_HOSTNAME', 'localhost') }}
port: 4566


# Cloud databases
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
services:
risingwave:
image: risingwavelabs/risingwave:nightly-20250101
ports:
- "4566:4566"
13 changes: 12 additions & 1 deletion tests/core/engine_adapter/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,14 @@ def test_type(request):
pytest.mark.athena,
],
),
pytest.param(
"risingwave",
marks=[
pytest.mark.docker,
pytest.mark.engine,
pytest.mark.risingwave,
],
),
]
)
def mark_gateway(request) -> t.Tuple[str, str]:
Expand Down Expand Up @@ -370,7 +378,7 @@ def test_create_table(ctx: TestContext):
column_descriptions={"id": "test id column description"},
table_format=ctx.default_table_format,
)
results = ctx.get_metadata_results()
results = ctx.get_metadata_results(schema=table.db)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has changed it for all adapters?

assert len(results.tables) == 1
assert len(results.views) == 0
assert len(results.materialized_views) == 0
Expand Down Expand Up @@ -1261,6 +1269,9 @@ def test_truncate_table(ctx: TestContext):
if ctx.test_type != "query":
pytest.skip("Truncate table test does not change based on input data type")

if ctx.dialect == "risingwave":
pytest.skip("Risingwave doesn't support truncate table")
erindru marked this conversation as resolved.
Show resolved Hide resolved

table = ctx.table("test_table")

ctx.engine_adapter.create_table(
Expand Down
49 changes: 49 additions & 0 deletions tests/core/engine_adapter/test_risingwave.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# type: ignore
import typing as t
from unittest.mock import call

import pytest
from sqlglot import parse_one
from sqlmesh.core.engine_adapter.risingwave import RisingwaveEngineAdapter

pytestmark = [pytest.mark.engine, pytest.mark.postgres, pytest.mark.risingwave]


def test_create_view(make_mocked_engine_adapter: t.Callable):
adapter = make_mocked_engine_adapter(RisingwaveEngineAdapter)

adapter.create_view("db.view", parse_one("SELECT 1"), replace=True)
adapter.create_view("db.view", parse_one("SELECT 1"), replace=False)

adapter.cursor.execute.assert_has_calls(
[
# 1st call
call('DROP VIEW IF EXISTS "db"."view" CASCADE'),
call('CREATE VIEW "db"."view" AS SELECT 1'),
# 2nd call
call('CREATE VIEW "db"."view" AS SELECT 1'),
]
)


def test_drop_view(make_mocked_engine_adapter: t.Callable):
adapter = make_mocked_engine_adapter(RisingwaveEngineAdapter)

adapter.SUPPORTS_MATERIALIZED_VIEWS = True
erindru marked this conversation as resolved.
Show resolved Hide resolved

adapter.drop_view("db.view")

adapter.drop_view("db.view", materialized=True)

adapter.drop_view("db.view", cascade=False)

adapter.cursor.execute.assert_has_calls(
[
# 1st call
call('DROP VIEW IF EXISTS "db"."view" CASCADE'),
# 2nd call
call('DROP MATERIALIZED VIEW IF EXISTS "db"."view" CASCADE'),
# 3rd call
call('DROP VIEW IF EXISTS "db"."view"'),
]
)