diff --git a/Cargo.lock b/Cargo.lock index c0df2e452c3e..02c1b404597e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6059,12 +6059,14 @@ dependencies = [ "futures", "inquire", "itertools 0.10.5", + "madsim-etcd-client", "madsim-tokio", "regex", "risingwave_common", "risingwave_connector", "risingwave_frontend", "risingwave_hummock_sdk", + "risingwave_meta", "risingwave_object_store", "risingwave_pb", "risingwave_rpc_client", diff --git a/ci/scripts/connector-node-integration-test.sh b/ci/scripts/connector-node-integration-test.sh index 0e3a030e1dc7..8baadbfef0e1 100755 --- a/ci/scripts/connector-node-integration-test.sh +++ b/ci/scripts/connector-node-integration-test.sh @@ -86,8 +86,9 @@ tar xf java-binding-integration-test.tar.zst bin echo "--- prepare integration tests" cd ${RISINGWAVE_ROOT}/java/connector-node -pip3 install grpcio grpcio-tools psycopg2 psycopg2-binary pyspark==3.3 -cd python-client && bash gen-stub.sh +pip3 install grpcio grpcio-tools psycopg2 psycopg2-binary pyspark==3.3 black +cd python-client && bash gen-stub.sh && bash format-python.sh --check +export PYTHONPATH=proto echo "--- running streamchunk data format integration tests" cd ${RISINGWAVE_ROOT}/java/connector-node/python-client diff --git a/ci/scripts/release.sh b/ci/scripts/release.sh index cc9f434b9d6b..b222e49c0826 100755 --- a/ci/scripts/release.sh +++ b/ci/scripts/release.sh @@ -11,7 +11,7 @@ if [ "${BUILDKITE_SOURCE}" != "schedule" ] && [ "${BUILDKITE_SOURCE}" != "webhoo fi echo "--- Install java and maven" -yum install -y java-11-openjdk wget python3 +yum install -y java-11-openjdk wget python3 cyrus-sasl-devel pip3 install toml-cli wget https://ci-deps-dist.s3.amazonaws.com/apache-maven-3.9.3-bin.tar.gz && tar -zxvf apache-maven-3.9.3-bin.tar.gz export PATH="${REPO_ROOT}/apache-maven-3.9.3/bin:$PATH" diff --git a/ci/workflows/integration-tests.yml b/ci/workflows/integration-tests.yml index 585e12ba1fed..6aaee9fbf687 100644 --- a/ci/workflows/integration-tests.yml +++ b/ci/workflows/integration-tests.yml @@ -93,7 +93,7 @@ steps: testcase: - "twitter" - "twitter-pulsar" - - "debezium-mongo" + # - "debezium-mongo" - "debezium-postgres" - "tidb-cdc-sink" - "debezium-sqlserver" @@ -105,10 +105,10 @@ steps: testcase: "twitter-pulsar" format: "protobuf" skip: true - - with: - testcase: "debezium-mongo" - format: "protobuf" - skip: true + # - with: + # testcase: "debezium-mongo" + # format: "protobuf" + # skip: true - with: testcase: "debezium-postgres" format: "protobuf" diff --git a/e2e_test/streaming/temporal_join.slt b/e2e_test/streaming/temporal_join/temporal_join.slt similarity index 100% rename from e2e_test/streaming/temporal_join.slt rename to e2e_test/streaming/temporal_join/temporal_join.slt diff --git a/e2e_test/streaming/temporal_join/temporal_join_with_index.slt b/e2e_test/streaming/temporal_join/temporal_join_with_index.slt new file mode 100644 index 000000000000..f714cefcc51b --- /dev/null +++ b/e2e_test/streaming/temporal_join/temporal_join_with_index.slt @@ -0,0 +1,84 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +statement ok +create table stream(id1 int, a1 int, b1 int) APPEND ONLY; + +statement ok +create table version(id2 int, a2 int, b2 int, primary key (id2)); + +statement ok +create index idx on version (a2); + +statement ok +create materialized view v as select id1, a1, id2, a2 from stream left join idx FOR SYSTEM_TIME AS OF PROCTIME() on b1 = b2 and a1 = a2; + +statement ok +insert into stream values(1, 11, 111); + +statement ok +insert into version values(1, 11, 111); + +statement ok +insert into version values(9, 11, 111); + +statement ok +insert into stream values(1, 11, 111); + +statement ok +delete from version; + +query IIII rowsort +select * from v; +---- +1 11 1 11 +1 11 9 11 +1 11 NULL NULL + +statement ok +insert into version values(2, 22, 222); + +statement ok +insert into stream values(2, 22, 222); + +statement ok +insert into version values(8, 22, 222); + +statement ok +insert into stream values(2, 22, 222); + +query IIII rowsort +select * from v; +---- +1 11 1 11 +1 11 9 11 +1 11 NULL NULL +2 22 2 22 +2 22 2 22 +2 22 8 22 + +statement ok +update version set b2 = 333 where id2 = 2; + +statement ok +insert into stream values(2, 22, 222); + +query IIII rowsort +select * from v; +---- +1 11 1 11 +1 11 9 11 +1 11 NULL NULL +2 22 2 22 +2 22 2 22 +2 22 8 22 +2 22 8 22 + +statement ok +drop materialized view v; + +statement ok +drop table stream; + +statement ok +drop table version; diff --git a/java/connector-node/README.md b/java/connector-node/README.md index 6f8cf3c6430c..8c21d68dc4b1 100644 --- a/java/connector-node/README.md +++ b/java/connector-node/README.md @@ -62,13 +62,17 @@ Navigate to the `python-client` directory and run the following command: ``` bash build-venv.sh bash gen-stub.sh -python3 integration_tests.py +PYTHONPATH=proto python3 integration_tests.py ``` Or you can use conda and install the necessary package `grpcio grpcio-tools psycopg2 psycopg2-binary`. The connector service is the server and Python integration test is a client, which will send gRPC request and get response from the connector server. So when running integration_tests, remember to launch the connector service in advance. You can get the gRPC response and check messages or errors in client part. And check the detailed exception information on server side. +### Python file format + +We use `black` as the python file formatter. We can run `format-python.sh` to format the python files. + ### JDBC test We have integration tests that involve the use of several sinks, including file sink, jdbc sink, iceberg sink, and deltalake sink. If you wish to run these tests locally, you will need to configure both MinIO and PostgreSQL. diff --git a/java/connector-node/python-client/.gitignore b/java/connector-node/python-client/.gitignore index 322e260731af..600d2d33badf 100644 --- a/java/connector-node/python-client/.gitignore +++ b/java/connector-node/python-client/.gitignore @@ -1,3 +1 @@ -*.py - -!integration_tests.py \ No newline at end of file +.vscode \ No newline at end of file diff --git a/java/connector-node/python-client/build-venv.sh b/java/connector-node/python-client/build-venv.sh index 2e9ebc90e0cc..0fa387ebc962 100755 --- a/java/connector-node/python-client/build-venv.sh +++ b/java/connector-node/python-client/build-venv.sh @@ -1,3 +1,3 @@ virtualenv sink-client-venv source sink-client-venv/bin/activate -pip3 install grpcio grpcio-tools psycopg2 psycopg2-binary +pip3 install grpcio grpcio-tools psycopg2 psycopg2-binary black diff --git a/java/connector-node/python-client/format-python.sh b/java/connector-node/python-client/format-python.sh new file mode 100644 index 000000000000..fc7d5a2710a0 --- /dev/null +++ b/java/connector-node/python-client/format-python.sh @@ -0,0 +1,2 @@ +set -ex +black $@ integration_tests.py pyspark-util.py \ No newline at end of file diff --git a/java/connector-node/python-client/gen-stub.sh b/java/connector-node/python-client/gen-stub.sh index 66d135f0aff1..c7d23075cee2 100755 --- a/java/connector-node/python-client/gen-stub.sh +++ b/java/connector-node/python-client/gen-stub.sh @@ -1 +1 @@ -python3 -m grpc_tools.protoc -I../../../proto/ --python_out=. --grpc_python_out=. ../../../proto/*.proto +python3 -m grpc_tools.protoc -I../../../proto/ --python_out=./proto --grpc_python_out=./proto ../../../proto/*.proto diff --git a/java/connector-node/python-client/integration_tests.py b/java/connector-node/python-client/integration_tests.py index 2bb0063efae6..5481bbbc1ad8 100644 --- a/java/connector-node/python-client/integration_tests.py +++ b/java/connector-node/python-client/integration_tests.py @@ -28,11 +28,13 @@ def make_mock_schema(): schema = connector_service_pb2.TableSchema( columns=[ connector_service_pb2.TableSchema.Column( - name="id", data_type=data_pb2.DataType(type_name=2)), + name="id", data_type=data_pb2.DataType(type_name=2) + ), connector_service_pb2.TableSchema.Column( - name="name", data_type=data_pb2.DataType(type_name=7)) + name="name", data_type=data_pb2.DataType(type_name=7) + ), ], - pk_indices=[0] + pk_indices=[0], ) return schema @@ -41,88 +43,122 @@ def make_mock_schema_stream_chunk(): schema = connector_service_pb2.TableSchema( columns=[ connector_service_pb2.TableSchema.Column( - name="v1", data_type=data_pb2.DataType(type_name=1)), + name="v1", data_type=data_pb2.DataType(type_name=1) + ), connector_service_pb2.TableSchema.Column( - name="v2", data_type=data_pb2.DataType(type_name=2)), + name="v2", data_type=data_pb2.DataType(type_name=2) + ), connector_service_pb2.TableSchema.Column( - name="v3", data_type=data_pb2.DataType(type_name=3)), + name="v3", data_type=data_pb2.DataType(type_name=3) + ), connector_service_pb2.TableSchema.Column( - name="v4", data_type=data_pb2.DataType(type_name=4)), + name="v4", data_type=data_pb2.DataType(type_name=4) + ), connector_service_pb2.TableSchema.Column( - name="v5", data_type=data_pb2.DataType(type_name=5)), + name="v5", data_type=data_pb2.DataType(type_name=5) + ), connector_service_pb2.TableSchema.Column( - name="v6", data_type=data_pb2.DataType(type_name=6)), + name="v6", data_type=data_pb2.DataType(type_name=6) + ), connector_service_pb2.TableSchema.Column( - name="v7", data_type=data_pb2.DataType(type_name=7)), + name="v7", data_type=data_pb2.DataType(type_name=7) + ), ], - pk_indices=[0] + pk_indices=[0], ) return schema def load_input(input_file): - with open(input_file, 'r') as file: + with open(input_file, "r") as file: sink_input = json.load(file) return sink_input def load_binary_input(input_file): - with open(input_file, 'rb') as file: + with open(input_file, "rb") as file: sink_input = file.read() return sink_input -def construct_payload(input_file, use_json): +def load_json_payload(input_file): + sink_input = load_input(input_file) payloads = [] - if use_json: - sink_input = load_input(input_file) - for batch in sink_input: - row_ops = [] - for row in batch: - row_ops.append(connector_service_pb2.SinkStreamRequest.WriteBatch.JsonPayload.RowOp( - op_type=row['op_type'], line=str(row['line']))) - payloads.append(connector_service_pb2.SinkStreamRequest.WriteBatch.JsonPayload( - row_ops=row_ops)) - else: - sink_input = load_binary_input(input_file) - payloads.append(connector_service_pb2.SinkStreamRequest.WriteBatch.StreamChunkPayload( - binary_data=sink_input)) + for batch in sink_input: + row_ops = [] + for row in batch: + row_ops.append( + connector_service_pb2.SinkStreamRequest.WriteBatch.JsonPayload.RowOp( + op_type=row["op_type"], line=str(row["line"]) + ) + ) + + payloads.append( + { + "json_payload": connector_service_pb2.SinkStreamRequest.WriteBatch.JsonPayload( + row_ops=row_ops + ) + } + ) + return payloads + + +def load_stream_chunk_payload(input_file): + payloads = [] + sink_input = load_binary_input(input_file) + payloads.append( + { + "stream_chunk_payload": connector_service_pb2.SinkStreamRequest.WriteBatch.StreamChunkPayload( + binary_data=sink_input + ) + } + ) return payloads -def test_sink(type, prop, payload_input, use_json, table_schema=make_mock_schema()): +def test_sink(prop, format, payload_input, table_schema): # read input, Add StartSink request - request_list = [connector_service_pb2.SinkStreamRequest(start=connector_service_pb2.SinkStreamRequest.StartSink( - format=connector_service_pb2.SinkPayloadFormat.JSON if use_json else connector_service_pb2.SinkPayloadFormat.STREAM_CHUNK, - sink_config=connector_service_pb2.SinkConfig( - connector_type=type, - properties=prop, - table_schema=table_schema + request_list = [ + connector_service_pb2.SinkStreamRequest( + start=connector_service_pb2.SinkStreamRequest.StartSink( + format=format, + sink_config=connector_service_pb2.SinkConfig( + connector_type=prop["connector"], + properties=prop, + table_schema=table_schema, + ), + ) ) - ))] + ] - with grpc.insecure_channel('localhost:50051') as channel: + with grpc.insecure_channel("localhost:50051") as channel: stub = connector_service_pb2_grpc.ConnectorServiceStub(channel) - epoch = 0 + epoch = 1 batch_id = 1 # construct request for payload in payload_input: - request_list.append(connector_service_pb2.SinkStreamRequest( - start_epoch=connector_service_pb2.SinkStreamRequest.StartEpoch(epoch=epoch))) - request_list.append(connector_service_pb2.SinkStreamRequest(write=connector_service_pb2.SinkStreamRequest.WriteBatch( - json_payload=payload, - batch_id=batch_id, - epoch=epoch - )) if use_json else - connector_service_pb2.SinkStreamRequest(write=connector_service_pb2.SinkStreamRequest.WriteBatch( - stream_chunk_payload=payload, - batch_id=batch_id, - epoch=epoch - )) + request_list.append( + connector_service_pb2.SinkStreamRequest( + start_epoch=connector_service_pb2.SinkStreamRequest.StartEpoch( + epoch=epoch + ) + ) + ) + request_list.append( + connector_service_pb2.SinkStreamRequest( + write=connector_service_pb2.SinkStreamRequest.WriteBatch( + batch_id=batch_id, + epoch=epoch, + **payload, + ) + ) ) - request_list.append(connector_service_pb2.SinkStreamRequest( - sync=connector_service_pb2.SinkStreamRequest.SyncBatch(epoch=epoch))) + request_list.append( + connector_service_pb2.SinkStreamRequest( + sync=connector_service_pb2.SinkStreamRequest.SyncBatch(epoch=epoch) + ) + ) epoch += 1 batch_id += 1 # send request @@ -135,153 +171,195 @@ def test_sink(type, prop, payload_input, use_json, table_schema=make_mock_schema print("Integration test failed: ", e) exit(1) + def validate_jdbc_sink(input_file): conn = psycopg2.connect( - "dbname=test user=test password=connector host=localhost port=5432") + "dbname=test user=test password=connector host=localhost port=5432" + ) cur = conn.cursor() cur.execute("SELECT * FROM test") rows = cur.fetchall() - expected = [list(row.values()) - for batch in load_input(input_file) for row in batch] + expected = [list(row.values()) for batch in load_input(input_file) for row in batch] def convert(b): - return [(item[1]['id'], item[1]['name']) for item in b] + return [(item[1]["id"], item[1]["name"]) for item in b] + expected = convert(expected) if len(rows) != len(expected): - print("Integration test failed: expected {} rows, but got {}".format( - len(expected), len(rows))) + print( + "Integration test failed: expected {} rows, but got {}".format( + len(expected), len(rows) + ) + ) exit(1) for i in range(len(rows)): if len(rows[i]) != len(expected[i]): - print("Integration test failed: expected {} columns, but got {}".format( - len(expected[i]), len(rows[i]))) + print( + "Integration test failed: expected {} columns, but got {}".format( + len(expected[i]), len(rows[i]) + ) + ) exit(1) for j in range(len(rows[i])): if rows[i][j] != expected[i][j]: print( - "Integration test failed: expected {} at row {}, column {}, but got {}".format(expected[i][j], i, j, - rows[i][j])) + "Integration test failed: expected {} at row {}, column {}, but got {}".format( + expected[i][j], i, j, rows[i][j] + ) + ) exit(1) -def test_file_sink(file_name, use_json): - type = "file" - prop = {"output.path": "/tmp/connector", } - test_sink(type, prop, construct_payload( - file_name, use_json), use_json) +def test_file_sink(param): + prop = { + "connector": "file", + "output.path": "/tmp/connector", + } + test_sink(prop, **param) -def test_jdbc_sink(input_file, input_binary_file, use_json): - type = "jdbc" - prop = {"jdbc.url": "jdbc:postgresql://localhost:5432/test?user=test&password=connector", - "table.name": "test", "type" : "upsert"} - file_name = input_file if use_json else input_binary_file - test_sink(type, prop, construct_payload( - file_name, use_json), use_json) +def test_jdbc_sink(input_file, param): + prop = { + "connector": "jdbc", + "jdbc.url": "jdbc:postgresql://localhost:5432/test?user=test&password=connector", + "table.name": "test", + "type": "upsert", + } + test_sink(prop, **param) # validate results validate_jdbc_sink(input_file) -def test_elasticsearch_sink(file_name, use_json): - prop = {"url": "http://127.0.0.1:9200", - "index": "test"} - type = "elasticsearch-7" - test_sink(type, prop, construct_payload( - file_name, use_json), use_json) +def test_elasticsearch_sink(param): + prop = { + "connector": "elasticsearch-7", + "url": "http://127.0.0.1:9200", + "index": "test", + } + test_sink(prop, **param) -def test_iceberg_sink(file_name, use_json): - prop = {"type": "append-only", - "warehouse.path": "s3a://bucket", - "s3.endpoint": "http://127.0.0.1:9000", - "s3.access.key": "minioadmin", - "s3.secret.key": "minioadmin", - "database.name": "demo_db", - "table.name": "demo_table"} - type = "iceberg" - test_sink(type, prop, construct_payload( - file_name, use_json), use_json) - - -def test_upsert_iceberg_sink(file_name, use_json): - prop = {"type": "upsert", - "warehouse.path": "s3a://bucket", - "s3.endpoint": "http://127.0.0.1:9000", - "s3.access.key": "minioadmin", - "s3.secret.key": "minioadmin", - "database.name": "demo_db", - "table.name": "demo_table"} +def test_iceberg_sink(param): + prop = { + "connector": "iceberg", + "type": "append-only", + "warehouse.path": "s3a://bucket", + "s3.endpoint": "http://127.0.0.1:9000", + "s3.access.key": "minioadmin", + "s3.secret.key": "minioadmin", + "database.name": "demo_db", + "table.name": "demo_table", + } + test_sink(prop, **param) + + +def test_upsert_iceberg_sink(param): + prop = { + "connector": "iceberg", + "type": "upsert", + "warehouse.path": "s3a://bucket", + "s3.endpoint": "http://127.0.0.1:9000", + "s3.access.key": "minioadmin", + "s3.secret.key": "minioadmin", + "database.name": "demo_db", + "table.name": "demo_table", + } type = "iceberg" # need to make sure all ops as Insert - test_sink(type, prop, construct_payload( - file_name, use_json), use_json) + test_sink(prop, **param) -def test_deltalake_sink(file_name, use_json): +def test_deltalake_sink(param): prop = { + "connector": "deltalake", "location": "s3a://bucket/delta", "s3.access.key": "minioadmin", "s3.secret.key": "minioadmin", "s3.endpoint": "127.0.0.1:9000", } - type = "deltalake" - test_sink(type, prop, construct_payload( - file_name, use_json), use_json) + test_sink(prop, **param) -def test_stream_chunk_data_format(file_name): - type = "file" - prop = {"output.path": "/tmp/connector", } - test_sink(type, prop, construct_payload( - file_name, use_json), use_json, make_mock_schema_stream_chunk()) +def test_stream_chunk_data_format(param): + prop = { + "connector": "file", + "output.path": "/tmp/connector", + } + test_sink(prop, **param) if __name__ == "__main__": parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--file_sink', action='store_true', - help="run file sink test") - parser.add_argument('--jdbc_sink', action='store_true', - help="run jdbc sink test") - parser.add_argument('--stream_chunk_format_test', action='store_true', - help="run print stream chunk sink test") - parser.add_argument('--iceberg_sink', action='store_true', - help="run iceberg sink test") - parser.add_argument('--upsert_iceberg_sink', - action='store_true', help="run upsert iceberg sink test") - parser.add_argument('--deltalake_sink', action='store_true', - help="run deltalake sink test") + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--file_sink", action="store_true", help="run file sink test") + parser.add_argument("--jdbc_sink", action="store_true", help="run jdbc sink test") parser.add_argument( - '--input_file', default="./data/sink_input.json", help="input data to run tests") - parser.add_argument('--input_binary_file', default="./data/sink_input", - help="input stream chunk data to run tests") - parser.add_argument('--es_sink', action='store_true', - help='run elasticsearch sink test') + "--stream_chunk_format_test", + action="store_true", + help="run print stream chunk sink test", + ) parser.add_argument( - '--data_format_use_json', default=True, help="choose json or streamchunk") + "--iceberg_sink", action="store_true", help="run iceberg sink test" + ) + parser.add_argument( + "--upsert_iceberg_sink", + action="store_true", + help="run upsert iceberg sink test", + ) + parser.add_argument( + "--deltalake_sink", action="store_true", help="run deltalake sink test" + ) + parser.add_argument( + "--input_file", default="./data/sink_input.json", help="input data to run tests" + ) + parser.add_argument( + "--input_binary_file", + default="./data/sink_input", + help="input stream chunk data to run tests", + ) + parser.add_argument( + "--es_sink", action="store_true", help="run elasticsearch sink test" + ) + parser.add_argument( + "--data_format_use_json", default=True, help="choose json or streamchunk" + ) args = parser.parse_args() - use_json = args.data_format_use_json == True or args.data_format_use_json == 'True' + use_json = args.data_format_use_json == True or args.data_format_use_json == "True" if use_json: - file_name = args.input_file + payload = load_json_payload(args.input_file) + format = connector_service_pb2.SinkPayloadFormat.JSON else: - file_name = args.input_binary_file + payload = load_stream_chunk_payload(args.input_binary_file) + format = connector_service_pb2.SinkPayloadFormat.STREAM_CHUNK # stream chunk format if args.stream_chunk_format_test: - test_stream_chunk_data_format(file_name) + param = { + "format": format, + "payload_input": payload, + "table_schema": make_mock_schema_stream_chunk(), + } + test_stream_chunk_data_format(param) + + param = { + "format": format, + "payload_input": payload, + "table_schema": make_mock_schema(), + } if args.file_sink: - test_file_sink(file_name, use_json) + test_file_sink(param) if args.jdbc_sink: - test_jdbc_sink(args.input_file, args.input_binary_file, use_json) + test_jdbc_sink(args.input_file, param) if args.iceberg_sink: - test_iceberg_sink(file_name, use_json) + test_iceberg_sink(param) if args.deltalake_sink: - test_deltalake_sink(file_name, use_json) + test_deltalake_sink(param) if args.es_sink: - test_elasticsearch_sink(file_name, use_json) + test_elasticsearch_sink(param) # json format if args.upsert_iceberg_sink: - test_upsert_iceberg_sink(file_name, use_json) + test_upsert_iceberg_sink(param) diff --git a/java/connector-node/python-client/proto/.gitignore b/java/connector-node/python-client/proto/.gitignore new file mode 100644 index 000000000000..9956645154d0 --- /dev/null +++ b/java/connector-node/python-client/proto/.gitignore @@ -0,0 +1,3 @@ +__pycache__ +*.py +!__init__.py \ No newline at end of file diff --git a/java/connector-node/python-client/proto/__init__.py b/java/connector-node/python-client/proto/__init__.py new file mode 100644 index 000000000000..847ef40cb6c4 --- /dev/null +++ b/java/connector-node/python-client/proto/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023 RisingWave Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/java/connector-node/python-client/pyspark-util.py b/java/connector-node/python-client/pyspark-util.py index c05f2d9487e6..3252fe5ecce5 100644 --- a/java/connector-node/python-client/pyspark-util.py +++ b/java/connector-node/python-client/pyspark-util.py @@ -17,52 +17,77 @@ import json from pyspark.sql import SparkSession, Row + def init_iceberg_spark(): - return SparkSession.builder.master("local").config( - 'spark.jars.packages', 'org.apache.iceberg:iceberg-spark-runtime-3.2_2.12:1.0.0,org.apache.hadoop:hadoop-aws:3.3.2').config( - 'spark.sql.catalog.demo', 'org.apache.iceberg.spark.SparkCatalog').config( - 'spark.sql.catalog.demo.type', 'hadoop').config( - 'spark.sql.catalog.demo.warehouse', 's3a://bucket/').config( - 'spark.sql.catalog.demo.hadoop.fs.s3a.endpoint', 'http://127.0.0.1:9000').config( - 'spark.sql.catalog.demo.hadoop.fs.s3a.access.key', 'minioadmin').config( - 'spark.sql.catalog.demo.hadoop.fs.s3a.secret.key', 'minioadmin').getOrCreate() + return ( + SparkSession.builder.master("local") + .config( + "spark.jars.packages", + "org.apache.iceberg:iceberg-spark-runtime-3.2_2.12:1.0.0,org.apache.hadoop:hadoop-aws:3.3.2", + ) + .config("spark.sql.catalog.demo", "org.apache.iceberg.spark.SparkCatalog") + .config("spark.sql.catalog.demo.type", "hadoop") + .config("spark.sql.catalog.demo.warehouse", "s3a://bucket/") + .config( + "spark.sql.catalog.demo.hadoop.fs.s3a.endpoint", "http://127.0.0.1:9000" + ) + .config("spark.sql.catalog.demo.hadoop.fs.s3a.access.key", "minioadmin") + .config("spark.sql.catalog.demo.hadoop.fs.s3a.secret.key", "minioadmin") + .getOrCreate() + ) + def init_deltalake_spark(): - return SparkSession.builder.master("local").config( - 'spark.jars.packages', 'io.delta:delta-core_2.12:2.2.0,org.apache.hadoop:hadoop-aws:3.3.2').config( - 'spark.sql.extensions', 'io.delta.sql.DeltaSparkSessionExtension').config( - 'spark.sql.catalog.spark_catalog', 'org.apache.spark.sql.delta.catalog.DeltaCatalog').config( - 'spark.hadoop.fs.s3a.endpoint', 'http://127.0.0.1:9000').config( - 'spark.hadoop.fs.s3a.access.key', 'minioadmin').config( - 'spark.hadoop.fs.s3a.secret.key', 'minioadmin').getOrCreate() + return ( + SparkSession.builder.master("local") + .config( + "spark.jars.packages", + "io.delta:delta-core_2.12:2.2.0,org.apache.hadoop:hadoop-aws:3.3.2", + ) + .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + .config( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.delta.catalog.DeltaCatalog", + ) + .config("spark.hadoop.fs.s3a.endpoint", "http://127.0.0.1:9000") + .config("spark.hadoop.fs.s3a.access.key", "minioadmin") + .config("spark.hadoop.fs.s3a.secret.key", "minioadmin") + .getOrCreate() + ) + def create_iceberg_table(): spark = init_iceberg_spark() - spark.sql("create table demo.demo_db.demo_table(id int, name string) TBLPROPERTIES ('format-version'='2');") + spark.sql( + "create table demo.demo_db.demo_table(id int, name string) TBLPROPERTIES ('format-version'='2');" + ) print("Table demo.demo_db.demo_table(id int, name string) created") + def drop_iceberg_table(): spark = init_iceberg_spark() spark.sql("drop table demo.demo_db.demo_table;") print("Table demo.demo_db.demo_table dropped") + def read_iceberg_table(): spark = init_iceberg_spark() spark.sql("select * from demo.demo_db.demo_table;").show() + def test_table(input_file, actual_list): actual = [] for row in actual_list: actual.append(row.asDict()) - actual = sorted(actual, key = lambda ele: sorted(ele.items())) + actual = sorted(actual, key=lambda ele: sorted(ele.items())) - with open(input_file, 'r') as file: + with open(input_file, "r") as file: sink_input = json.load(file) expected = [] for batch in sink_input: for row in batch: - expected.append(row['line']) - expected = sorted(expected, key = lambda ele: sorted(ele.items())) + expected.append(row["line"]) + expected = sorted(expected, key=lambda ele: sorted(ele.items())) if actual == expected: print("Test passed") @@ -70,38 +95,40 @@ def test_table(input_file, actual_list): print("Expected:", expected, "\nActual:", actual) raise Exception("Test failed") + def test_iceberg_table(input_file): spark = init_iceberg_spark() list = spark.sql("select * from demo.demo_db.demo_table;").collect() test_table(input_file, list) - + + def test_upsert_iceberg_table(input_file): spark = init_iceberg_spark() list = spark.sql("select * from demo.demo_db.demo_table;").collect() actual = [] for row in list: actual.append(row.asDict()) - actual = sorted(actual, key = lambda ele: sorted(ele.items())) + actual = sorted(actual, key=lambda ele: sorted(ele.items())) - with open(input_file, 'r') as file: + with open(input_file, "r") as file: sink_input = json.load(file) expected = [] for batch in sink_input: for row in batch: - match row['op_type']: + match row["op_type"]: case 1: - expected.append(row['line']) + expected.append(row["line"]) case 2: - expected.remove(row['line']) + expected.remove(row["line"]) case 3: - expected.append(row['line']) + expected.append(row["line"]) case 4: - expected.remove(row['line']) + expected.remove(row["line"]) case _: raise Exception("Unknown op_type") - expected = sorted(expected, key = lambda ele: sorted(ele.items())) + expected = sorted(expected, key=lambda ele: sorted(ele.items())) if actual == expected: print("Test passed") @@ -109,14 +136,17 @@ def test_upsert_iceberg_table(input_file): print("Expected:", expected, "\nActual:", actual) raise Exception("Test failed") + def read_deltalake_table(): spark = init_deltalake_spark() spark.sql("select * from delta.`s3a://bucket/delta`;").show() + def create_deltalake_table(): spark = init_deltalake_spark() spark.sql( - "create table IF NOT EXISTS delta.`s3a://bucket/delta`(id int, name string) using delta;") + "create table IF NOT EXISTS delta.`s3a://bucket/delta`(id int, name string) using delta;" + ) print("Table delta.`s3a://bucket/delta`(id int, name string) created") @@ -125,15 +155,23 @@ def delete_deltalake_table_data(): spark.sql("DELETE FROM delta.`s3a://bucket/delta`") print("Table delta.`s3a://bucket/delta` dropped") + def test_deltalake_table(input_file): spark = init_deltalake_spark() list = spark.sql("select * from delta.`s3a://bucket/delta`;").collect() test_table(input_file, list) + if __name__ == "__main__": - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('operation', help="operation on table: read, create, drop, test or test_upsert") - parser.add_argument('--input_file', default="./data/sink_input.json", help="input data to run tests") + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "operation", help="operation on table: read, create, drop, test or test_upsert" + ) + parser.add_argument( + "--input_file", default="./data/sink_input.json", help="input data to run tests" + ) args = parser.parse_args() match args.operation: case "read_iceberg": diff --git a/src/ctl/Cargo.toml b/src/ctl/Cargo.toml index 73f7dbb0968a..96c86d8a2c89 100644 --- a/src/ctl/Cargo.toml +++ b/src/ctl/Cargo.toml @@ -19,6 +19,7 @@ bytes = "1" chrono = "0.4" clap = { version = "4", features = ["derive"] } comfy-table = "6" +etcd-client = { version = "0.2", package = "madsim-etcd-client" } futures = { version = "0.3", default-features = false, features = ["alloc"] } inquire = "0.6.2" itertools = "0.10" @@ -27,6 +28,7 @@ risingwave_common = { path = "../common" } risingwave_connector = { path = "../connector" } risingwave_frontend = { path = "../frontend" } risingwave_hummock_sdk = { path = "../storage/hummock_sdk" } +risingwave_meta = { path = "../meta" } risingwave_object_store = { path = "../object_store" } risingwave_pb = { path = "../prost" } risingwave_rpc_client = { path = "../rpc_client" } diff --git a/src/ctl/src/cmd_impl.rs b/src/ctl/src/cmd_impl.rs index 4343cfb60513..97a8f2c24a3f 100644 --- a/src/ctl/src/cmd_impl.rs +++ b/src/ctl/src/cmd_impl.rs @@ -14,6 +14,7 @@ pub mod bench; pub mod compute; +pub mod debug; pub mod hummock; pub mod meta; pub mod profile; diff --git a/src/ctl/src/cmd_impl/debug.rs b/src/ctl/src/cmd_impl/debug.rs new file mode 100644 index 000000000000..2a8a3fa8ff22 --- /dev/null +++ b/src/ctl/src/cmd_impl/debug.rs @@ -0,0 +1,17 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod meta_store; + +pub use meta_store::*; diff --git a/src/ctl/src/cmd_impl/debug/meta_store.rs b/src/ctl/src/cmd_impl/debug/meta_store.rs new file mode 100644 index 000000000000..1e9b4768661c --- /dev/null +++ b/src/ctl/src/cmd_impl/debug/meta_store.rs @@ -0,0 +1,136 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::BTreeSet; +use std::sync::Arc; + +use etcd_client::ConnectOptions; +use risingwave_meta::model::{MetadataModel, TableFragments, Worker}; +use risingwave_meta::storage::meta_store::MetaStore; +use risingwave_meta::storage::{EtcdMetaStore, WrappedEtcdClient}; +use risingwave_pb::user::UserInfo; +use serde_yaml::Value; + +use crate::{DebugCommon, DebugCommonKind, DebugCommonOutputFormat}; + +const KIND_KEY: &str = "kind"; +const ITEM_KEY: &str = "item"; + +macro_rules! yaml_arm { + ($kind:tt, $item:expr) => {{ + let mut mapping = serde_yaml::Mapping::new(); + mapping.insert( + Value::String(KIND_KEY.to_string()), + Value::String($kind.to_string()), + ); + mapping.insert( + Value::String(ITEM_KEY.to_string()), + serde_yaml::to_value($item.to_protobuf()).unwrap(), + ); + serde_yaml::Value::Mapping(mapping) + }}; +} + +macro_rules! json_arm { + ($kind:tt, $item:expr) => {{ + let mut mapping = serde_json::Map::new(); + mapping.insert( + KIND_KEY.to_string(), + serde_json::Value::String($kind.to_string()), + ); + mapping.insert( + ITEM_KEY.to_string(), + serde_json::to_value($item.to_protobuf()).unwrap(), + ); + serde_json::Value::Object(mapping) + }}; +} + +enum Item { + Worker(Worker), + User(UserInfo), + Table(TableFragments), +} + +pub async fn dump(common: DebugCommon) -> anyhow::Result<()> { + let DebugCommon { + etcd_endpoints, + etcd_username, + etcd_password, + enable_etcd_auth, + kinds, + format, + } = common; + + let client = if enable_etcd_auth { + let options = ConnectOptions::default().with_user( + etcd_username.unwrap_or_default(), + etcd_password.unwrap_or_default(), + ); + WrappedEtcdClient::connect(etcd_endpoints, Some(options), true).await? + } else { + WrappedEtcdClient::connect(etcd_endpoints, None, false).await? + }; + + let meta_store = Arc::new(EtcdMetaStore::new(client)); + let snapshot = meta_store.snapshot().await; + let kinds: BTreeSet<_> = kinds.into_iter().collect(); + + let mut items = vec![]; + for kind in kinds { + match kind { + DebugCommonKind::Worker => Worker::list_at_snapshot::(&snapshot) + .await? + .into_iter() + .for_each(|worker| items.push(Item::Worker(worker))), + DebugCommonKind::User => UserInfo::list_at_snapshot::(&snapshot) + .await? + .into_iter() + .for_each(|user| items.push(Item::User(user))), + DebugCommonKind::Table => TableFragments::list_at_snapshot::(&snapshot) + .await? + .into_iter() + .for_each(|table| items.push(Item::Table(table))), + }; + } + + let writer = std::io::stdout(); + + match format { + DebugCommonOutputFormat::Yaml => { + let mut seq = serde_yaml::Sequence::new(); + for item in items { + seq.push(match item { + Item::Worker(worker) => yaml_arm!("worker", worker), + Item::User(user) => yaml_arm!("user", user), + Item::Table(table) => yaml_arm!("table", table), + }); + } + serde_yaml::to_writer(writer, &seq).unwrap(); + } + DebugCommonOutputFormat::Json => { + let mut seq = vec![]; + for item in items { + seq.push(match item { + Item::Worker(worker) => json_arm!("worker", worker), + Item::User(user) => json_arm!("user", user), + Item::Table(table) => json_arm!("table", table), + }); + } + serde_json::to_writer_pretty(writer, &seq).unwrap(); + } + } + + Ok(()) +} diff --git a/src/ctl/src/lib.rs b/src/ctl/src/lib.rs index ccb59b21ca47..06f1dc3b5472 100644 --- a/src/ctl/src/lib.rs +++ b/src/ctl/src/lib.rs @@ -64,6 +64,9 @@ enum Commands { /// Commands for Benchmarks #[clap(subcommand)] Bench(BenchCommands), + /// Commands for Debug + #[clap(subcommand)] + Debug(DebugCommands), /// Commands for tracing the compute nodes Trace, // TODO(yuhao): profile other nodes @@ -74,6 +77,55 @@ enum Commands { }, } +#[derive(clap::ValueEnum, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] +enum DebugCommonKind { + Worker, + User, + Table, +} + +#[derive(clap::ValueEnum, Clone, Debug)] +enum DebugCommonOutputFormat { + Json, + Yaml, +} + +#[derive(clap::Args, Debug, Clone)] +pub struct DebugCommon { + /// The address of the etcd cluster + #[clap(long, value_delimiter = ',', default_value = "localhost:2388")] + etcd_endpoints: Vec, + + /// The username for etcd authentication, used if `--enable-etcd-auth` is set + #[clap(long)] + etcd_username: Option, + + /// The password for etcd authentication, used if `--enable-etcd-auth` is set + #[clap(long)] + etcd_password: Option, + + /// Whether to enable etcd authentication + #[clap(long, default_value_t = false, requires_all = &["etcd_username", "etcd_password"])] + enable_etcd_auth: bool, + + /// Kinds of debug info to dump + #[clap(value_enum, value_delimiter = ',')] + kinds: Vec, + + /// The output format + #[clap(value_enum, long = "output", short = 'o', default_value_t = DebugCommonOutputFormat::Yaml)] + format: DebugCommonOutputFormat, +} + +#[derive(Subcommand, Clone, Debug)] +pub enum DebugCommands { + /// Dump debug info from the raw state store + Dump { + #[command(flatten)] + common: DebugCommon, + }, +} + #[derive(Subcommand)] enum ComputeCommands { /// Show all the configuration parameters on compute node @@ -509,6 +561,7 @@ pub async fn start_impl(opts: CliOpts, context: &CtlContext) -> Result<()> { cmd_impl::scale::update_schedulability(context, workers, Schedulability::Schedulable) .await? } + Commands::Debug(DebugCommands::Dump { common }) => cmd_impl::debug::dump(common).await?, } Ok(()) } diff --git a/src/frontend/planner_test/tests/testdata/input/temporal_join.yaml b/src/frontend/planner_test/tests/testdata/input/temporal_join.yaml index e42cec784158..eeb19aba2548 100644 --- a/src/frontend/planner_test/tests/testdata/input/temporal_join.yaml +++ b/src/frontend/planner_test/tests/testdata/input/temporal_join.yaml @@ -88,3 +88,35 @@ join version2 FOR SYSTEM_TIME AS OF PROCTIME() on stream.id2 = version2.id2 where a1 < 10; expected_outputs: - stream_plan +- name: temporal join with an index (distribution key size = 1) + sql: | + create table stream(id1 int, a1 int, b1 int) APPEND ONLY; + create table version(id2 int, a2 int, b2 int, primary key (id2)); + create index idx2 on version (a2, b2) distributed by (a2); + select id1, a1, id2, a2 from stream left join idx2 FOR SYSTEM_TIME AS OF PROCTIME() on a1 = a2 and b1 = b2; + expected_outputs: + - stream_plan +- name: temporal join with an index (distribution key size = 2) + sql: | + create table stream(id1 int, a1 int, b1 int) APPEND ONLY; + create table version(id2 int, a2 int, b2 int, primary key (id2)); + create index idx2 on version (a2, b2); + select id1, a1, id2, a2 from stream left join idx2 FOR SYSTEM_TIME AS OF PROCTIME() on a1 = a2 and b1 = b2; + expected_outputs: + - stream_plan +- name: temporal join with an index (index column size = 1) + sql: | + create table stream(id1 int, a1 int, b1 int) APPEND ONLY; + create table version(id2 int, a2 int, b2 int, primary key (id2)); + create index idx2 on version (b2); + select id1, a1, id2, a2 from stream left join idx2 FOR SYSTEM_TIME AS OF PROCTIME() on a1 = a2 and b1 = b2; + expected_outputs: + - stream_plan +- name: temporal join with singleton table + sql: | + create table t (a int) append only; + create materialized view v as select count(*) from t; + select * from t left join v FOR SYSTEM_TIME AS OF PROCTIME() on a = count; + expected_outputs: + - stream_plan + diff --git a/src/frontend/planner_test/tests/testdata/output/temporal_join.yaml b/src/frontend/planner_test/tests/testdata/output/temporal_join.yaml index c1c920e6f97c..2b93cce79725 100644 --- a/src/frontend/planner_test/tests/testdata/output/temporal_join.yaml +++ b/src/frontend/planner_test/tests/testdata/output/temporal_join.yaml @@ -152,3 +152,55 @@ │ └─StreamTableScan { table: version1, columns: [version1.id1, version1.x1], pk: [version1.id1], dist: UpstreamHashShard(version1.id1) } └─StreamExchange [no_shuffle] { dist: UpstreamHashShard(version2.id2) } └─StreamTableScan { table: version2, columns: [version2.id2, version2.x2], pk: [version2.id2], dist: UpstreamHashShard(version2.id2) } +- name: temporal join with an index (distribution key size = 1) + sql: | + create table stream(id1 int, a1 int, b1 int) APPEND ONLY; + create table version(id2 int, a2 int, b2 int, primary key (id2)); + create index idx2 on version (a2, b2) distributed by (a2); + select id1, a1, id2, a2 from stream left join idx2 FOR SYSTEM_TIME AS OF PROCTIME() on a1 = a2 and b1 = b2; + stream_plan: |- + StreamMaterialize { columns: [id1, a1, id2, a2, stream._row_id(hidden), stream.b1(hidden)], stream_key: [stream._row_id, id2, a1, stream.b1], pk_columns: [stream._row_id, id2, a1, stream.b1], pk_conflict: NoCheck } + └─StreamTemporalJoin { type: LeftOuter, predicate: stream.a1 = idx2.a2 AND stream.b1 = idx2.b2, output: [stream.id1, stream.a1, idx2.id2, idx2.a2, stream._row_id, stream.b1] } + ├─StreamExchange { dist: HashShard(stream.a1) } + │ └─StreamTableScan { table: stream, columns: [stream.id1, stream.a1, stream.b1, stream._row_id], pk: [stream._row_id], dist: UpstreamHashShard(stream._row_id) } + └─StreamExchange [no_shuffle] { dist: UpstreamHashShard(idx2.a2) } + └─StreamTableScan { table: idx2, columns: [idx2.a2, idx2.b2, idx2.id2], pk: [idx2.id2], dist: UpstreamHashShard(idx2.a2) } +- name: temporal join with an index (distribution key size = 2) + sql: | + create table stream(id1 int, a1 int, b1 int) APPEND ONLY; + create table version(id2 int, a2 int, b2 int, primary key (id2)); + create index idx2 on version (a2, b2); + select id1, a1, id2, a2 from stream left join idx2 FOR SYSTEM_TIME AS OF PROCTIME() on a1 = a2 and b1 = b2; + stream_plan: |- + StreamMaterialize { columns: [id1, a1, id2, a2, stream._row_id(hidden), stream.b1(hidden)], stream_key: [stream._row_id, id2, a1, stream.b1], pk_columns: [stream._row_id, id2, a1, stream.b1], pk_conflict: NoCheck } + └─StreamTemporalJoin { type: LeftOuter, predicate: stream.a1 = idx2.a2 AND stream.b1 = idx2.b2, output: [stream.id1, stream.a1, idx2.id2, idx2.a2, stream._row_id, stream.b1] } + ├─StreamExchange { dist: HashShard(stream.a1, stream.b1) } + │ └─StreamTableScan { table: stream, columns: [stream.id1, stream.a1, stream.b1, stream._row_id], pk: [stream._row_id], dist: UpstreamHashShard(stream._row_id) } + └─StreamExchange [no_shuffle] { dist: UpstreamHashShard(idx2.a2, idx2.b2) } + └─StreamTableScan { table: idx2, columns: [idx2.a2, idx2.b2, idx2.id2], pk: [idx2.id2], dist: UpstreamHashShard(idx2.a2, idx2.b2) } +- name: temporal join with an index (index column size = 1) + sql: | + create table stream(id1 int, a1 int, b1 int) APPEND ONLY; + create table version(id2 int, a2 int, b2 int, primary key (id2)); + create index idx2 on version (b2); + select id1, a1, id2, a2 from stream left join idx2 FOR SYSTEM_TIME AS OF PROCTIME() on a1 = a2 and b1 = b2; + stream_plan: |- + StreamMaterialize { columns: [id1, a1, id2, a2, stream._row_id(hidden), stream.b1(hidden)], stream_key: [stream._row_id, id2, stream.b1, a1], pk_columns: [stream._row_id, id2, stream.b1, a1], pk_conflict: NoCheck } + └─StreamTemporalJoin { type: LeftOuter, predicate: stream.b1 = idx2.b2 AND (stream.a1 = idx2.a2), output: [stream.id1, stream.a1, idx2.id2, idx2.a2, stream._row_id, stream.b1] } + ├─StreamExchange { dist: HashShard(stream.b1) } + │ └─StreamTableScan { table: stream, columns: [stream.id1, stream.a1, stream.b1, stream._row_id], pk: [stream._row_id], dist: UpstreamHashShard(stream._row_id) } + └─StreamExchange [no_shuffle] { dist: UpstreamHashShard(idx2.b2) } + └─StreamTableScan { table: idx2, columns: [idx2.b2, idx2.id2, idx2.a2], pk: [idx2.id2], dist: UpstreamHashShard(idx2.b2) } +- name: temporal join with singleton table + sql: | + create table t (a int) append only; + create materialized view v as select count(*) from t; + select * from t left join v FOR SYSTEM_TIME AS OF PROCTIME() on a = count; + stream_plan: |- + StreamMaterialize { columns: [a, count, t._row_id(hidden), $expr1(hidden)], stream_key: [t._row_id, $expr1], pk_columns: [t._row_id, $expr1], pk_conflict: NoCheck } + └─StreamTemporalJoin { type: LeftOuter, predicate: AND ($expr1 = v.count), output: [t.a, v.count, t._row_id, $expr1] } + ├─StreamExchange { dist: Single } + │ └─StreamProject { exprs: [t.a, t.a::Int64 as $expr1, t._row_id] } + │ └─StreamTableScan { table: t, columns: [t.a, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } + └─StreamExchange [no_shuffle] { dist: Single } + └─StreamTableScan { table: v, columns: [v.count], pk: [], dist: Single } diff --git a/src/frontend/src/optimizer/plan_node/eq_join_predicate.rs b/src/frontend/src/optimizer/plan_node/eq_join_predicate.rs index cd35ffd83e15..4539d17b2acd 100644 --- a/src/frontend/src/optimizer/plan_node/eq_join_predicate.rs +++ b/src/frontend/src/optimizer/plan_node/eq_join_predicate.rs @@ -14,6 +14,7 @@ use std::fmt; +use itertools::Itertools; use risingwave_common::catalog::Schema; use crate::expr::{ @@ -270,6 +271,43 @@ impl EqJoinPredicate { ) } + /// Retain the prefix of `eq_keys` based on the `prefix_len`. The other part is moved to the + /// other condition. + pub fn retain_prefix_eq_key(self, prefix_len: usize) -> Self { + assert!(prefix_len <= self.eq_keys.len()); + let (retain_eq_key, other_eq_key) = self.eq_keys.split_at(prefix_len); + let mut new_other_conjunctions = self.other_cond.conjunctions; + new_other_conjunctions.extend( + other_eq_key + .iter() + .cloned() + .map(|(l, r, null_safe)| { + FunctionCall::new( + if null_safe { + ExprType::IsNotDistinctFrom + } else { + ExprType::Equal + }, + vec![l.into(), r.into()], + ) + .unwrap() + .into() + }) + .collect_vec(), + ); + + let new_other_cond = Condition { + conjunctions: new_other_conjunctions, + }; + + Self::new( + new_other_cond, + retain_eq_key.to_owned(), + self.left_cols_num, + self.right_cols_num, + ) + } + pub fn rewrite_exprs(&self, rewriter: &mut (impl ExprRewriter + ?Sized)) -> Self { let mut new = self.clone(); new.other_cond = new.other_cond.rewrite_expr(rewriter); diff --git a/src/frontend/src/optimizer/plan_node/logical_join.rs b/src/frontend/src/optimizer/plan_node/logical_join.rs index 6911bc3f5869..0d72c1d5fb1c 100644 --- a/src/frontend/src/optimizer/plan_node/logical_join.rs +++ b/src/frontend/src/optimizer/plan_node/logical_join.rs @@ -366,19 +366,19 @@ impl LogicalJoin { .expect("dist_key must in order_key"); dist_key_in_order_key_pos.push(pos); } - // The at least prefix of order key that contains distribution key. - let at_least_prefix_len = dist_key_in_order_key_pos + // The shortest prefix of order key that contains distribution key. + let shortest_prefix_len = dist_key_in_order_key_pos .iter() .max() .map_or(0, |pos| pos + 1); // Distributed lookup join can't support lookup table with a singleton distribution. - if at_least_prefix_len == 0 { + if shortest_prefix_len == 0 { return None; } // Reorder the join equal predicate to match the order key. - let mut reorder_idx = Vec::with_capacity(at_least_prefix_len); + let mut reorder_idx = Vec::with_capacity(shortest_prefix_len); for order_col_id in order_col_ids { let mut found = false; for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() { @@ -392,7 +392,7 @@ impl LogicalJoin { break; } } - if reorder_idx.len() < at_least_prefix_len { + if reorder_idx.len() < shortest_prefix_len { return None; } let lookup_prefix_len = reorder_idx.len(); @@ -966,18 +966,6 @@ impl LogicalJoin { ) -> Result { assert!(predicate.has_eq()); - let left = self.left().to_stream_with_dist_required( - &RequiredDist::shard_by_key(self.left().schema().len(), &predicate.left_eq_indexes()), - ctx, - )?; - - if !left.append_only() { - return Err(RwError::from(ErrorCode::NotSupported( - "Temporal join requires an append-only left input".into(), - "Please ensure your left input is append-only".into(), - ))); - } - let right = self.right(); let Some(logical_scan) = right.as_logical_scan() else { return Err(RwError::from(ErrorCode::NotSupported( @@ -994,30 +982,76 @@ impl LogicalJoin { } let table_desc = logical_scan.table_desc(); + let output_column_ids = logical_scan.output_column_ids(); - // Verify that right join key columns are the primary key of the lookup table. + // Verify that the right join key columns are the the prefix of the primary key and + // also contain the distribution key. let order_col_ids = table_desc.order_column_ids(); - let order_col_ids_len = order_col_ids.len(); - let output_column_ids = logical_scan.output_column_ids(); + let order_key = table_desc.order_column_indices(); + let dist_key = table_desc.distribution_key.clone(); + + let mut dist_key_in_order_key_pos = vec![]; + for d in dist_key { + let pos = order_key + .iter() + .position(|&x| x == d) + .expect("dist_key must in order_key"); + dist_key_in_order_key_pos.push(pos); + } + // The shortest prefix of order key that contains distribution key. + let shortest_prefix_len = dist_key_in_order_key_pos + .iter() + .max() + .map_or(0, |pos| pos + 1); // Reorder the join equal predicate to match the order key. - let mut reorder_idx = vec![]; + let mut reorder_idx = Vec::with_capacity(shortest_prefix_len); for order_col_id in order_col_ids { + let mut found = false; for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() { if order_col_id == output_column_ids[eq_idx] { reorder_idx.push(i); + found = true; break; } } + if !found { + break; + } } - if order_col_ids_len != predicate.eq_keys().len() || reorder_idx.len() < order_col_ids_len { + if reorder_idx.len() < shortest_prefix_len { + // TODO: support index selection for temporal join and refine this error message. return Err(RwError::from(ErrorCode::NotSupported( "Temporal join requires the lookup table's primary key contained exactly in the equivalence condition".into(), "Please add the primary key of the lookup table to the join condition and remove any other conditions".into(), ))); } + let lookup_prefix_len = reorder_idx.len(); let predicate = predicate.reorder(&reorder_idx); + let left = if dist_key_in_order_key_pos.is_empty() { + self.left() + .to_stream_with_dist_required(&RequiredDist::single(), ctx)? + } else { + let left_eq_indexes = predicate.left_eq_indexes(); + let left_dist_key = dist_key_in_order_key_pos + .iter() + .map(|pos| left_eq_indexes[*pos]) + .collect_vec(); + + self.left().to_stream_with_dist_required( + &RequiredDist::shard_by_key(self.left().schema().len(), &left_dist_key), + ctx, + )? + }; + + if !left.append_only() { + return Err(RwError::from(ErrorCode::NotSupported( + "Temporal join requires an append-only left input".into(), + "Please ensure your left input is append-only".into(), + ))); + } + // Extract the predicate from logical scan. Only pure scan is supported. let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up(); // Construct output column to require column mapping @@ -1090,6 +1124,8 @@ impl LogicalJoin { new_join_output_indices, ); + let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len); + Ok(StreamTemporalJoin::new(new_logical_join, new_predicate).into()) } diff --git a/src/frontend/src/optimizer/plan_node/stream_temporal_join.rs b/src/frontend/src/optimizer/plan_node/stream_temporal_join.rs index c3a8b4ab7b1d..f9fb325b8af8 100644 --- a/src/frontend/src/optimizer/plan_node/stream_temporal_join.rs +++ b/src/frontend/src/optimizer/plan_node/stream_temporal_join.rs @@ -21,7 +21,6 @@ use risingwave_pb::stream_plan::TemporalJoinNode; use super::utils::{childless_record, watermark_pretty, Distill}; use super::{generic, ExprRewritable, PlanBase, PlanRef, PlanTreeNodeBinary, StreamNode}; use crate::expr::{Expr, ExprRewriter}; -use crate::optimizer::plan_node::generic::GenericPlanRef; use crate::optimizer::plan_node::plan_tree_node::PlanTreeNodeUnary; use crate::optimizer::plan_node::stream::StreamPlanRef; use crate::optimizer::plan_node::utils::IndicesDisplay; @@ -42,7 +41,6 @@ impl StreamTemporalJoin { pub fn new(logical: generic::Join, eq_join_predicate: EqJoinPredicate) -> Self { assert!(logical.join_type == JoinType::Inner || logical.join_type == JoinType::LeftOuter); assert!(logical.left.append_only()); - assert!(logical.right.logical_pk() == eq_join_predicate.right_eq_indexes()); let right = logical.right.clone(); let exchange: &StreamExchange = right .as_stream_exchange() diff --git a/src/meta/src/barrier/progress.rs b/src/meta/src/barrier/progress.rs index d7d3e715a497..d706aa5bf47f 100644 --- a/src/meta/src/barrier/progress.rs +++ b/src/meta/src/barrier/progress.rs @@ -55,6 +55,9 @@ struct Progress { /// Upstream mvs total key count. upstream_total_key_count: u64, + /// Consumed rows + consumed_rows: u64, + /// DDL definition definition: String, } @@ -80,6 +83,7 @@ impl Progress { creating_mv_id, upstream_mv_count, upstream_total_key_count, + consumed_rows: 0, definition, } } @@ -87,15 +91,25 @@ impl Progress { /// Update the progress of `actor`. fn update(&mut self, actor: ActorId, new_state: ChainState, upstream_total_key_count: u64) { self.upstream_total_key_count = upstream_total_key_count; - match self.states.get_mut(&actor).unwrap() { - state @ (ChainState::Init | ChainState::ConsumingUpstream(_, _)) => { - if matches!(new_state, ChainState::Done) { - self.done_count += 1; + match self.states.remove(&actor).unwrap() { + ChainState::Init => {} + ChainState::ConsumingUpstream(_, old_consumed_rows) => { + if !matches!(new_state, ChainState::Done) { + self.consumed_rows -= old_consumed_rows; } - *state = new_state; } ChainState::Done => panic!("should not report done multiple times"), - } + }; + match &new_state { + ChainState::Init => {} + ChainState::ConsumingUpstream(_, new_consumed_rows) => { + self.consumed_rows += new_consumed_rows; + } + ChainState::Done => { + self.done_count += 1; + } + }; + self.states.insert(actor, new_state); self.calculate_progress(); } @@ -110,26 +124,16 @@ impl Progress { self.states.keys().cloned() } - /// `progress` = `done_ratio` + (1 - `done_ratio`) * (`consumed_rows` / `remaining_rows`). + /// `progress` = `consumed_rows` / `upstream_total_key_count` fn calculate_progress(&self) -> f64 { if self.is_done() || self.states.is_empty() { return 1.0; } - let done_ratio: f64 = (self.done_count) as f64 / self.states.len() as f64; - let mut remaining_rows = self.upstream_total_key_count as f64 * (1_f64 - done_ratio); - if remaining_rows == 0.0 { - remaining_rows = 1.0; + let mut upstream_total_key_count = self.upstream_total_key_count as f64; + if upstream_total_key_count == 0.0 { + upstream_total_key_count = 1.0 } - let consumed_rows: u64 = self - .states - .values() - .map(|x| match x { - ChainState::ConsumingUpstream(_, rows) => *rows, - _ => 0, - }) - .sum(); - let mut progress = - done_ratio + (1_f64 - done_ratio) * consumed_rows as f64 / remaining_rows; + let mut progress = self.consumed_rows as f64 / upstream_total_key_count; if progress >= 1.0 { progress = 0.99; } diff --git a/src/meta/src/lib.rs b/src/meta/src/lib.rs index 8f858472cb6f..0cbe5a1209c4 100644 --- a/src/meta/src/lib.rs +++ b/src/meta/src/lib.rs @@ -43,7 +43,7 @@ mod dashboard; mod error; pub mod hummock; pub mod manager; -mod model; +pub mod model; mod rpc; pub(crate) mod serving; pub mod storage; diff --git a/src/stream/src/executor/temporal_join.rs b/src/stream/src/executor/temporal_join.rs index 5b96193614b4..dc12b6ccdef0 100644 --- a/src/stream/src/executor/temporal_join.rs +++ b/src/stream/src/executor/temporal_join.rs @@ -13,22 +13,29 @@ // limitations under the License. use std::alloc::Global; +use std::collections::HashMap; +use std::ops::{Deref, DerefMut}; use std::pin::pin; use std::sync::Arc; use either::Either; use futures::stream::{self, PollNext}; -use futures::{StreamExt, TryStreamExt}; +use futures::{pin_mut, StreamExt, TryStreamExt}; use futures_async_stream::try_stream; use local_stats_alloc::{SharedStatsAlloc, StatsAlloc}; use lru::DefaultHasher; use risingwave_common::array::{Op, StreamChunk}; use risingwave_common::catalog::Schema; +use risingwave_common::estimate_size::{EstimateSize, KvSize}; +use risingwave_common::hash::{HashKey, NullBitmap}; use risingwave_common::row::{OwnedRow, Row, RowExt}; -use risingwave_common::util::iter_util::ZipEqFast; +use risingwave_common::types::DataType; +use risingwave_common::util::iter_util::ZipEqDebug; use risingwave_expr::expr::BoxedExpression; use risingwave_hummock_sdk::{HummockEpoch, HummockReadEpoch}; +use risingwave_storage::store::PrefetchOptions; use risingwave_storage::table::batch_table::storage_table::StorageTable; +use risingwave_storage::table::TableIter; use risingwave_storage::StateStore; use super::{Barrier, Executor, Message, MessageStream, StreamExecutorError, StreamExecutorResult}; @@ -39,11 +46,11 @@ use crate::executor::monitor::StreamingMetrics; use crate::executor::{ActorContextRef, BoxedExecutor, JoinType, JoinTypePrimitive, PkIndices}; use crate::task::AtomicU64Ref; -pub struct TemporalJoinExecutor { +pub struct TemporalJoinExecutor { ctx: ActorContextRef, left: BoxedExecutor, right: BoxedExecutor, - right_table: TemporalSide, + right_table: TemporalSide, left_join_keys: Vec, right_join_keys: Vec, null_safe: Vec, @@ -58,20 +65,86 @@ pub struct TemporalJoinExecutor { metrics: Arc, } -struct TemporalSide { +#[derive(Default)] +pub struct JoinEntry { + /// pk -> row + cached: HashMap, + kv_heap_size: KvSize, +} + +impl EstimateSize for JoinEntry { + fn estimated_heap_size(&self) -> usize { + // TODO: Add internal size. + // https://github.com/risingwavelabs/risingwave/issues/9713 + self.kv_heap_size.size() + } +} + +impl JoinEntry { + /// Insert into the cache. + pub fn insert(&mut self, key: OwnedRow, value: OwnedRow) { + self.kv_heap_size.add(&key, &value); + self.cached.try_insert(key, value).unwrap(); + } + + /// Delete from the cache. + pub fn remove(&mut self, pk: &OwnedRow) { + if let Some(value) = self.cached.remove(pk) { + self.kv_heap_size.sub(pk, &value); + } else { + panic!("pk {:?} should be in the cache", pk); + } + } + + pub fn is_empty(&self) -> bool { + self.cached.is_empty() + } +} + +struct JoinEntryWrapper(Option); + +impl EstimateSize for JoinEntryWrapper { + fn estimated_heap_size(&self) -> usize { + self.0.estimated_heap_size() + } +} + +impl JoinEntryWrapper { + const MESSAGE: &str = "the state should always be `Some`"; + + /// Take the value out of the wrapper. Panic if the value is `None`. + pub fn take(&mut self) -> JoinEntry { + self.0.take().expect(Self::MESSAGE) + } +} + +impl Deref for JoinEntryWrapper { + type Target = JoinEntry; + + fn deref(&self) -> &Self::Target { + self.0.as_ref().expect(Self::MESSAGE) + } +} + +impl DerefMut for JoinEntryWrapper { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.as_mut().expect(Self::MESSAGE) + } +} + +struct TemporalSide { source: StorageTable, + pk: Vec, table_output_indices: Vec, - cache: ManagedLruCache, DefaultHasher, SharedStatsAlloc>, + cache: ManagedLruCache>, ctx: ActorContextRef, + join_key_data_types: Vec, } -impl TemporalSide { - async fn lookup( - &mut self, - key: impl Row, - epoch: HummockEpoch, - ) -> StreamExecutorResult> { - let key = key.into_owned_row(); +impl TemporalSide { + /// Lookup the temporal side table and return a `JoinEntry` which could be empty if there are no + /// matched records. + async fn lookup(&mut self, key: &K, epoch: HummockEpoch) -> StreamExecutorResult { let table_id_str = self.source.table_id().to_string(); let actor_id_str = self.ctx.id.to_string(); self.ctx @@ -79,36 +152,71 @@ impl TemporalSide { .temporal_join_total_query_cache_count .with_label_values(&[&table_id_str, &actor_id_str]) .inc(); - Ok(match self.cache.get(&key) { - Some(res) => res.clone(), - None => { - // cache miss - self.ctx - .streaming_metrics - .temporal_join_cache_miss_count - .with_label_values(&[&table_id_str, &actor_id_str]) - .inc(); - let res = self - .source - .get_row(key.clone(), HummockReadEpoch::NoWait(epoch)) - .await? - .map(|row| row.project(&self.table_output_indices).into_owned_row()); - self.cache.put(key, res.clone()); - res + + let res = if self.cache.contains(key) { + let mut state = self.cache.peek_mut(key).unwrap(); + state.take() + } else { + // cache miss + self.ctx + .streaming_metrics + .temporal_join_cache_miss_count + .with_label_values(&[&table_id_str, &actor_id_str]) + .inc(); + + let pk_prefix = key.deserialize(&self.join_key_data_types)?; + + let iter = self + .source + .batch_iter_with_pk_bounds( + HummockReadEpoch::NoWait(epoch), + &pk_prefix, + .., + false, + PrefetchOptions::new_for_exhaust_iter(), + ) + .await?; + + let mut entry = JoinEntry::default(); + + pin_mut!(iter); + while let Some(row) = iter.next_row().await? { + entry.insert( + row.as_ref().project(&self.pk).into_owned_row(), + row.project(&self.table_output_indices).into_owned_row(), + ); } - }) + + entry + }; + + Ok(res) } - fn update(&mut self, payload: Vec, join_keys: &[usize]) { - payload.iter().flat_map(|c| c.rows()).for_each(|(op, row)| { - let key = row.project(join_keys).into_owned_row(); - if let Some(mut value) = self.cache.get_mut(&key) { - match op { - Op::Insert | Op::UpdateInsert => *value = Some(row.into_owned_row()), - Op::Delete | Op::UpdateDelete => *value = None, - }; + fn update( + &mut self, + chunks: Vec, + join_keys: &[usize], + ) -> StreamExecutorResult<()> { + for chunk in chunks { + let keys = K::build(join_keys, chunk.data_chunk())?; + for ((op, row), key) in chunk.rows().zip_eq_debug(keys.into_iter()) { + if self.cache.contains(&key) { + // Update cache + let mut entry = self.cache.get_mut(&key).unwrap(); + let pk = row.project(&self.pk).into_owned_row(); + match op { + Op::Insert | Op::UpdateInsert => entry.insert(pk, row.into_owned_row()), + Op::Delete | Op::UpdateDelete => entry.remove(&pk), + }; + } } - }); + } + Ok(()) + } + + pub fn insert_back(&mut self, key: K, state: JoinEntry) { + self.cache.put(key, JoinEntryWrapper(Some(state))); } } @@ -184,7 +292,7 @@ async fn align_input(left: Box, right: Box) { } } -impl TemporalJoinExecutor { +impl TemporalJoinExecutor { #[allow(clippy::too_many_arguments)] pub fn new( ctx: ActorContextRef, @@ -202,6 +310,7 @@ impl TemporalJoinExecutor { watermark_epoch: AtomicU64Ref, metrics: Arc, chunk_size: usize, + join_key_data_types: Vec, ) -> Self { let schema_fields = [left.schema().fields.clone(), right.schema().fields.clone()].concat(); @@ -226,15 +335,19 @@ impl TemporalJoinExecutor { alloc, ); + let pk = table.pk_in_output_indices().unwrap(); + Self { ctx: ctx.clone(), left, right, right_table: TemporalSide { source: table, + pk, table_output_indices, cache, ctx, + join_key_data_types, }, left_join_keys, right_join_keys, @@ -257,6 +370,8 @@ impl TemporalJoinExecutor { self.right.schema().len(), ); + let null_matched = K::Bitmap::from_bool_vec(self.null_safe); + let mut prev_epoch = None; let table_id_str = self.right_table.source.table_id().to_string(); @@ -271,6 +386,8 @@ impl TemporalJoinExecutor { .set(self.right_table.cache.len() as i64); match msg? { InternalMessage::Chunk(chunk) => { + // Compact chunk, otherwise the following keys and chunk rows might fail to zip. + let chunk = chunk.compact(); let mut builder = StreamChunkBuilder::new( self.chunk_size, &self.schema.data_types(), @@ -278,33 +395,33 @@ impl TemporalJoinExecutor { right_map.clone(), ); let epoch = prev_epoch.expect("Chunk data should come after some barrier."); - for (op, left_row) in chunk.rows() { - let key = left_row.project(&self.left_join_keys); - if key - .iter() - .zip_eq_fast(self.null_safe.iter()) - .any(|(datum, can_null)| datum.is_none() && !*can_null) - { - continue; - } - if let Some(right_row) = self.right_table.lookup(key, epoch).await? { - // check join condition - let ok = if let Some(ref mut cond) = self.condition { - let concat_row = left_row.chain(&right_row).into_owned_row(); - cond.eval_row_infallible(&concat_row, |err| { - self.ctx.on_compute_error(err, self.identity.as_str()) - }) - .await - .map(|s| *s.as_bool()) - .unwrap_or(false) - } else { - true - }; - if ok { - if let Some(chunk) = builder.append_row(op, left_row, &right_row) { - yield Message::Chunk(chunk); + let keys = K::build(&self.left_join_keys, chunk.data_chunk())?; + for ((op, left_row), key) in chunk.rows().zip_eq_debug(keys.into_iter()) { + if key.null_bitmap().is_subset(&null_matched) + && let join_entry = self.right_table.lookup(&key, epoch).await? + && !join_entry.is_empty() { + for right_row in join_entry.cached.values() { + // check join condition + let ok = if let Some(ref mut cond) = self.condition { + let concat_row = left_row.chain(&right_row).into_owned_row(); + cond.eval_row_infallible(&concat_row, |err| { + self.ctx.on_compute_error(err, self.identity.as_str()) + }) + .await + .map(|s| *s.as_bool()) + .unwrap_or(false) + } else { + true + }; + + if ok { + if let Some(chunk) = builder.append_row(op, left_row, right_row) { + yield Message::Chunk(chunk); + } } } + // Insert back the state taken from ht. + self.right_table.insert_back(key.clone(), join_entry); } else if T == JoinType::LeftOuter { if let Some(chunk) = builder.append_row_update(op, left_row) { yield Message::Chunk(chunk); @@ -324,7 +441,7 @@ impl TemporalJoinExecutor { } } self.right_table.cache.update_epoch(barrier.epoch.curr); - self.right_table.update(updates, &self.right_join_keys); + self.right_table.update(updates, &self.right_join_keys)?; prev_epoch = Some(barrier.epoch.curr); yield Message::Barrier(barrier) } @@ -333,7 +450,9 @@ impl TemporalJoinExecutor { } } -impl Executor for TemporalJoinExecutor { +impl Executor + for TemporalJoinExecutor +{ fn execute(self: Box) -> super::BoxedMessageStream { self.into_stream().boxed() } diff --git a/src/stream/src/from_proto/temporal_join.rs b/src/stream/src/from_proto/temporal_join.rs index 4c7b8695066b..19df99d2d66e 100644 --- a/src/stream/src/from_proto/temporal_join.rs +++ b/src/stream/src/from_proto/temporal_join.rs @@ -15,6 +15,8 @@ use std::sync::Arc; use risingwave_common::catalog::{ColumnDesc, TableId, TableOption}; +use risingwave_common::hash::{HashKey, HashKeyDispatcher}; +use risingwave_common::types::DataType; use risingwave_common::util::sort_util::OrderType; use risingwave_expr::expr::{build_from_prost, BoxedExpression}; use risingwave_pb::plan_common::{JoinType as JoinTypeProto, StorageTableDesc}; @@ -141,6 +143,11 @@ impl ExecutorBuilder for TemporalJoinExecutorBuilder { .map(|&x| x as usize) .collect_vec(); + let join_key_data_types = left_join_keys + .iter() + .map(|idx| source_l.schema().fields[*idx].data_type()) + .collect_vec(); + let dispatcher_args = TemporalJoinExecutorDispatcherArgs { ctx: params.actor_context, left: source_l, @@ -158,6 +165,7 @@ impl ExecutorBuilder for TemporalJoinExecutorBuilder { chunk_size: params.env.config().developer.chunk_size, metrics: params.executor_stats, join_type_proto: node.get_join_type()?, + join_key_data_types, }; dispatcher_args.dispatch() @@ -181,31 +189,38 @@ struct TemporalJoinExecutorDispatcherArgs { chunk_size: usize, metrics: Arc, join_type_proto: JoinTypeProto, + join_key_data_types: Vec, } -impl TemporalJoinExecutorDispatcherArgs { - pub fn dispatch(self) -> StreamResult { +impl HashKeyDispatcher for TemporalJoinExecutorDispatcherArgs { + type Output = StreamResult; + + fn dispatch_impl(self) -> Self::Output { + /// This macro helps to fill the const generic type parameter. macro_rules! build { ($join_type:ident) => { - Ok(Box::new( - TemporalJoinExecutor::::new( - self.ctx, - self.left, - self.right, - self.right_table, - self.left_join_keys, - self.right_join_keys, - self.null_safe, - self.condition, - self.pk_indices, - self.output_indices, - self.table_output_indices, - self.executor_id, - self.watermark_epoch, - self.metrics, - self.chunk_size, - ), - )) + Ok(Box::new(TemporalJoinExecutor::< + K, + S, + { JoinType::$join_type }, + >::new( + self.ctx, + self.left, + self.right, + self.right_table, + self.left_join_keys, + self.right_join_keys, + self.null_safe, + self.condition, + self.pk_indices, + self.output_indices, + self.table_output_indices, + self.executor_id, + self.watermark_epoch, + self.metrics, + self.chunk_size, + self.join_key_data_types, + ))) }; } match self.join_type_proto { @@ -214,4 +229,8 @@ impl TemporalJoinExecutorDispatcherArgs { _ => unreachable!(), } } + + fn data_types(&self) -> &[DataType] { + &self.join_key_data_types + } } diff --git a/src/tests/sqlsmith/src/sql_gen/query.rs b/src/tests/sqlsmith/src/sql_gen/query.rs index 3571911149a6..0879603620ae 100644 --- a/src/tests/sqlsmith/src/sql_gen/query.rs +++ b/src/tests/sqlsmith/src/sql_gen/query.rs @@ -164,7 +164,8 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { fn gen_limit(&mut self, has_order_by: bool) -> Option { if (!self.is_mview || has_order_by) && self.flip_coin() { - Some(self.rng.gen_range(0..=100).to_string()) + let start = if self.is_mview { 1 } else { 0 }; + Some(self.rng.gen_range(start..=100).to_string()) } else { None } @@ -258,17 +259,60 @@ impl<'a, R: Rng> SqlGenerator<'a, R> { /// GROUP BY will constrain the generated columns. fn gen_group_by(&mut self) -> Vec { + // 90% generate simple group by. + // 10% generate grouping sets. + match self.rng.gen_range(0..=9) { + 0 => self.gen_grouping_sets(), + 1..=9 => { + let group_by_cols = self.gen_random_bound_columns(); + self.bound_columns = group_by_cols.clone(); + group_by_cols + .into_iter() + .map(|c| Expr::Identifier(Ident::new_unchecked(c.name))) + .collect_vec() + } + _ => unreachable!(), + } + } + + /// GROUPING SETS will constrain the generated columns. + fn gen_grouping_sets(&mut self) -> Vec { + let grouping_num = self.rng.gen_range(0..=5); + let mut grouping_sets = vec![]; + let mut new_bound_columns = vec![]; + for _i in 0..grouping_num { + let group_by_cols = self.gen_random_bound_columns(); + grouping_sets.push( + group_by_cols + .iter() + .map(|c| Expr::Identifier(Ident::new_unchecked(c.name.clone()))) + .collect_vec(), + ); + new_bound_columns.extend(group_by_cols); + } + if grouping_sets.is_empty() { + self.bound_columns = vec![]; + vec![] + } else { + let grouping_sets = Expr::GroupingSets(grouping_sets); + self.bound_columns = new_bound_columns + .into_iter() + .sorted_by(|a, b| Ord::cmp(&a.name, &b.name)) + .dedup_by(|a, b| a.name == b.name) + .collect(); + + // Currently, grouping sets only support one set. + vec![grouping_sets] + } + } + + fn gen_random_bound_columns(&mut self) -> Vec { let mut available = self.bound_columns.clone(); if !available.is_empty() { available.shuffle(self.rng); let upper_bound = (available.len() + 1) / 2; let n = self.rng.gen_range(1..=upper_bound); - let group_by_cols = available.drain(..n).collect_vec(); - self.bound_columns = group_by_cols.clone(); - group_by_cols - .into_iter() - .map(|c| Expr::Identifier(Ident::new_unchecked(c.name))) - .collect_vec() + available.drain(..n).collect_vec() } else { vec![] }