From 1905e4c3ff27bb594d5e9327812d134e1c171b2e Mon Sep 17 00:00:00 2001 From: xxchan Date: Thu, 30 May 2024 12:03:35 +0800 Subject: [PATCH 01/20] chore: add precommit hook (#17004) Signed-off-by: xxchan --- .pre-commit-config.yaml | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000..ab8ba3d9d7eb9 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,26 @@ +# Usage: install pre-commit, and then run `pre-commit install` to install git hooks +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: end-of-file-fixer + - id: trailing-whitespace +- repo: local + hooks: + - id: rustfmt + name: rustfmt + entry: rustfmt --edition 2021 + language: system + types: [rust] + - id: typos + name: typos + entry: typos -w + language: system + - id: cargo sort + name: cargo sort + entry: cargo sort -g -w + language: system + files: 'Cargo.toml' + pass_filenames: false From cb0b9632b5074da0581ebd62f6d66ee0c12de5dc Mon Sep 17 00:00:00 2001 From: xxchan Date: Thu, 30 May 2024 12:05:06 +0800 Subject: [PATCH 02/20] feat: support schema registry in risedev (#17001) Signed-off-by: xxchan --- ci/docker-compose.yml | 4 +- ci/scripts/e2e-kafka-sink-test.sh | 8 +-- ci/scripts/e2e-source-test.sh | 8 +-- e2e_test/schema_registry/alter_sr.slt | 8 +-- e2e_test/schema_registry/pb.slt | 4 +- e2e_test/sink/kafka/avro.slt | 12 ++-- e2e_test/sink/kafka/protobuf.slt | 8 +-- e2e_test/source/basic/nosim_kafka.slt | 31 +++++---- e2e_test/source/basic/schema_registry.slt | 10 +-- risedev.yml | 26 +++++++- scripts/source/prepare_ci_kafka.sh | 6 +- src/risedevtool/src/bin/risedev-compose.rs | 7 +- src/risedevtool/src/bin/risedev-dev.rs | 16 ++++- src/risedevtool/src/config.rs | 3 + src/risedevtool/src/risedev_env.rs | 9 +++ src/risedevtool/src/service_config.rs | 37 ++++++++++- src/risedevtool/src/task.rs | 2 + src/risedevtool/src/task/docker_service.rs | 4 +- src/risedevtool/src/task/kafka_service.rs | 18 +++-- .../src/task/schema_registry_service.rs | 65 +++++++++++++++++++ 20 files changed, 223 insertions(+), 63 deletions(-) create mode 100644 src/risedevtool/src/task/schema_registry_service.rs diff --git a/ci/docker-compose.yml b/ci/docker-compose.yml index 60d2d8946717c..15274be94be9b 100644 --- a/ci/docker-compose.yml +++ b/ci/docker-compose.yml @@ -61,7 +61,8 @@ services: - "29092:29092" - "9092:9092" - "9644:9644" - - "8081:8081" + # Don't use Redpanda's schema registry, use the separated service instead + # - "8081:8081" environment: {} container_name: message_queue healthcheck: @@ -89,6 +90,7 @@ services: - mysql - db - message_queue + - schemaregistry - elasticsearch - clickhouse-server - redis-server diff --git a/ci/scripts/e2e-kafka-sink-test.sh b/ci/scripts/e2e-kafka-sink-test.sh index 206ce4ba1d75d..7cab1ae1f76f7 100755 --- a/ci/scripts/e2e-kafka-sink-test.sh +++ b/ci/scripts/e2e-kafka-sink-test.sh @@ -154,16 +154,16 @@ cp src/connector/src/test_data/proto_recursive/recursive.pb ./proto-recursive rpk topic create test-rw-sink-append-only-protobuf rpk topic create test-rw-sink-append-only-protobuf-csr-a rpk topic create test-rw-sink-append-only-protobuf-csr-hi -python3 e2e_test/sink/kafka/register_schema.py 'http://message_queue:8081' 'test-rw-sink-append-only-protobuf-csr-a-value' src/connector/src/test_data/test-index-array.proto -python3 e2e_test/sink/kafka/register_schema.py 'http://message_queue:8081' 'test-rw-sink-append-only-protobuf-csr-hi-value' src/connector/src/test_data/test-index-array.proto +python3 e2e_test/sink/kafka/register_schema.py 'http://schemaregistry:8082' 'test-rw-sink-append-only-protobuf-csr-a-value' src/connector/src/test_data/test-index-array.proto +python3 e2e_test/sink/kafka/register_schema.py 'http://schemaregistry:8082' 'test-rw-sink-append-only-protobuf-csr-hi-value' src/connector/src/test_data/test-index-array.proto sqllogictest -p 4566 -d dev 'e2e_test/sink/kafka/protobuf.slt' rpk topic delete test-rw-sink-append-only-protobuf rpk topic delete test-rw-sink-append-only-protobuf-csr-a rpk topic delete test-rw-sink-append-only-protobuf-csr-hi echo "testing avro" -python3 e2e_test/sink/kafka/register_schema.py 'http://message_queue:8081' 'test-rw-sink-upsert-avro-value' src/connector/src/test_data/all-types.avsc -python3 e2e_test/sink/kafka/register_schema.py 'http://message_queue:8081' 'test-rw-sink-upsert-avro-key' src/connector/src/test_data/all-types.avsc 'string_field,int32_field' +python3 e2e_test/sink/kafka/register_schema.py 'http://schemaregistry:8082' 'test-rw-sink-upsert-avro-value' src/connector/src/test_data/all-types.avsc +python3 e2e_test/sink/kafka/register_schema.py 'http://schemaregistry:8082' 'test-rw-sink-upsert-avro-key' src/connector/src/test_data/all-types.avsc 'string_field,int32_field' rpk topic create test-rw-sink-upsert-avro sqllogictest -p 4566 -d dev 'e2e_test/sink/kafka/avro.slt' rpk topic delete test-rw-sink-upsert-avro diff --git a/ci/scripts/e2e-source-test.sh b/ci/scripts/e2e-source-test.sh index 5127731256c6b..35b7965f12bb3 100755 --- a/ci/scripts/e2e-source-test.sh +++ b/ci/scripts/e2e-source-test.sh @@ -137,11 +137,11 @@ export RISINGWAVE_CI=true RUST_LOG="info,risingwave_stream=info,risingwave_batch=info,risingwave_storage=info" \ risedev ci-start ci-1cn-1fe python3 -m pip install --break-system-packages requests protobuf confluent-kafka -python3 e2e_test/schema_registry/pb.py "message_queue:29092" "http://message_queue:8081" "sr_pb_test" 20 user +python3 e2e_test/schema_registry/pb.py "message_queue:29092" "http://schemaregistry:8082" "sr_pb_test" 20 user echo "make sure google/protobuf/source_context.proto is NOT in schema registry" -curl --silent 'http://message_queue:8081/subjects'; echo -# curl --silent --head -X GET 'http://message_queue:8081/subjects/google%2Fprotobuf%2Fsource_context.proto/versions' | grep 404 -curl --silent 'http://message_queue:8081/subjects' | grep -v 'google/protobuf/source_context.proto' +curl --silent 'http://schemaregistry:8082/subjects'; echo +# curl --silent --head -X GET 'http://schemaregistry:8082/subjects/google%2Fprotobuf%2Fsource_context.proto/versions' | grep 404 +curl --silent 'http://schemaregistry:8082/subjects' | grep -v 'google/protobuf/source_context.proto' risedev slt './e2e_test/schema_registry/pb.slt' risedev slt './e2e_test/schema_registry/alter_sr.slt' diff --git a/e2e_test/schema_registry/alter_sr.slt b/e2e_test/schema_registry/alter_sr.slt index 8daf41d87b633..d703c0401a35e 100644 --- a/e2e_test/schema_registry/alter_sr.slt +++ b/e2e_test/schema_registry/alter_sr.slt @@ -9,7 +9,7 @@ CREATE SOURCE src_user WITH ( scan.startup.mode = 'earliest' ) FORMAT PLAIN ENCODE PROTOBUF( - schema.registry = 'http://message_queue:8081', + schema.registry = 'http://schemaregistry:8082', message = 'test.User' ); @@ -24,7 +24,7 @@ CREATE TABLE t_user WITH ( scan.startup.mode = 'earliest' ) FORMAT PLAIN ENCODE PROTOBUF( - schema.registry = 'http://message_queue:8081', + schema.registry = 'http://schemaregistry:8082', message = 'test.User' ); @@ -36,7 +36,7 @@ SELECT age FROM t_user; # Push more events with extended fields system ok -python3 e2e_test/schema_registry/pb.py "message_queue:29092" "http://message_queue:8081" "sr_pb_test" 5 user_with_more_fields +python3 e2e_test/schema_registry/pb.py "message_queue:29092" "http://schemaregistry:8082" "sr_pb_test" 5 user_with_more_fields sleep 5s @@ -58,7 +58,7 @@ SELECT COUNT(*), MAX(age), MIN(age), SUM(age) FROM mv_user_more; # Push more events with extended fields system ok -python3 e2e_test/schema_registry/pb.py "message_queue:29092" "http://message_queue:8081" "sr_pb_test" 5 user_with_more_fields +python3 e2e_test/schema_registry/pb.py "message_queue:29092" "http://schemaregistry:8082" "sr_pb_test" 5 user_with_more_fields sleep 5s diff --git a/e2e_test/schema_registry/pb.slt b/e2e_test/schema_registry/pb.slt index d9c0edca1b21c..7b60b4fa8d7a4 100644 --- a/e2e_test/schema_registry/pb.slt +++ b/e2e_test/schema_registry/pb.slt @@ -9,7 +9,7 @@ create table sr_pb_test with ( properties.bootstrap.server = 'message_queue:29092', scan.startup.mode = 'earliest') FORMAT plain ENCODE protobuf( - schema.registry = 'http://message_queue:8081', + schema.registry = 'http://schemaregistry:8082', message = 'test.User' ); @@ -21,7 +21,7 @@ create table sr_pb_test_bk with ( properties.bootstrap.server = 'message_queue:29092', scan.startup.mode = 'earliest') FORMAT plain ENCODE protobuf( - schema.registry = 'http://message_queue:8081,http://message_queue:8081', + schema.registry = 'http://schemaregistry:8082,http://schemaregistry:8082', message = 'test.User' ); diff --git a/e2e_test/sink/kafka/avro.slt b/e2e_test/sink/kafka/avro.slt index d9fa53bc589ac..1cf27b811d9be 100644 --- a/e2e_test/sink/kafka/avro.slt +++ b/e2e_test/sink/kafka/avro.slt @@ -6,7 +6,7 @@ with ( topic = 'test-rw-sink-upsert-avro', properties.bootstrap.server = 'message_queue:29092') format upsert encode avro ( - schema.registry = 'http://message_queue:8081'); + schema.registry = 'http://schemaregistry:8082'); statement ok create table into_kafka ( @@ -40,7 +40,7 @@ create sink sink0 from into_kafka with ( properties.bootstrap.server = 'message_queue:29092', primary_key = 'int32_field,string_field') format upsert encode avro ( - schema.registry = 'http://message_queue:8081'); + schema.registry = 'http://schemaregistry:8082'); sleep 2s @@ -72,7 +72,7 @@ create sink sink_err from into_kafka with ( properties.bootstrap.server = 'message_queue:29092', primary_key = 'int32_field,string_field') format upsert encode avro ( - schema.registry = 'http://message_queue:8081'); + schema.registry = 'http://schemaregistry:8082'); statement error field not in avro create sink sink_err as select 1 as extra_column, * from into_kafka with ( @@ -81,7 +81,7 @@ create sink sink_err as select 1 as extra_column, * from into_kafka with ( properties.bootstrap.server = 'message_queue:29092', primary_key = 'int32_field,string_field') format upsert encode avro ( - schema.registry = 'http://message_queue:8081'); + schema.registry = 'http://schemaregistry:8082'); statement error unrecognized create sink sink_err from into_kafka with ( @@ -90,7 +90,7 @@ create sink sink_err from into_kafka with ( properties.bootstrap.server = 'message_queue:29092', primary_key = 'int32_field,string_field') format upsert encode avro ( - schema.registry = 'http://message_queue:8081', + schema.registry = 'http://schemaregistry:8082', schema.registry.name.strategy = 'typo'); statement error empty field key.message @@ -100,7 +100,7 @@ create sink sink_err from into_kafka with ( properties.bootstrap.server = 'message_queue:29092', primary_key = 'int32_field,string_field') format upsert encode avro ( - schema.registry = 'http://message_queue:8081', + schema.registry = 'http://schemaregistry:8082', schema.registry.name.strategy = 'record_name_strategy'); statement ok diff --git a/e2e_test/sink/kafka/protobuf.slt b/e2e_test/sink/kafka/protobuf.slt index 0c74cc8a0b369..c3f6f0d3ad8e2 100644 --- a/e2e_test/sink/kafka/protobuf.slt +++ b/e2e_test/sink/kafka/protobuf.slt @@ -13,7 +13,7 @@ create table from_kafka_csr_trivial with ( topic = 'test-rw-sink-append-only-protobuf-csr-a', properties.bootstrap.server = 'message_queue:29092') format plain encode protobuf ( - schema.registry = 'http://message_queue:8081', + schema.registry = 'http://schemaregistry:8082', message = 'test.package.MessageA'); statement ok @@ -22,7 +22,7 @@ create table from_kafka_csr_nested with ( topic = 'test-rw-sink-append-only-protobuf-csr-hi', properties.bootstrap.server = 'message_queue:29092') format plain encode protobuf ( - schema.registry = 'http://message_queue:8081', + schema.registry = 'http://schemaregistry:8082', message = 'test.package.MessageH.MessageI'); statement ok @@ -68,7 +68,7 @@ create sink sink_csr_trivial as select string_field as field_a from into_kafka w properties.bootstrap.server = 'message_queue:29092') format plain encode protobuf ( force_append_only = true, - schema.registry = 'http://message_queue:8081', + schema.registry = 'http://schemaregistry:8082', message = 'test.package.MessageA'); statement ok @@ -78,7 +78,7 @@ create sink sink_csr_nested as select sint32_field as field_i from into_kafka wi properties.bootstrap.server = 'message_queue:29092') format plain encode protobuf ( force_append_only = true, - schema.registry = 'http://message_queue:8081', + schema.registry = 'http://schemaregistry:8082', message = 'test.package.MessageH.MessageI'); sleep 2s diff --git a/e2e_test/source/basic/nosim_kafka.slt b/e2e_test/source/basic/nosim_kafka.slt index 12626b6926fdf..f143471e0f269 100644 --- a/e2e_test/source/basic/nosim_kafka.slt +++ b/e2e_test/source/basic/nosim_kafka.slt @@ -1,3 +1,6 @@ +control substitution on + +# FIXME: does this really work?? # Start with nosim to avoid running in deterministic test @@ -7,18 +10,18 @@ CREATE TABLE upsert_avro_json_default_key ( primary key (rw_key) ) INCLUDE KEY AS rw_key WITH ( connector = 'kafka', - properties.bootstrap.server = 'message_queue:29092', + properties.bootstrap.server = '${RISEDEV_KAFKA_BOOTSTRAP_SERVERS}', topic = 'upsert_avro_json') -FORMAT UPSERT ENCODE AVRO (schema.registry = 'http://message_queue:8081'); +FORMAT UPSERT ENCODE AVRO (schema.registry = '${RISEDEV_SCHEMA_REGISTRY_URL}'); statement ok CREATE TABLE upsert_student_avro_json ( primary key (rw_key) ) INCLUDE KEY AS rw_key WITH ( connector = 'kafka', - properties.bootstrap.server = 'message_queue:29092', + properties.bootstrap.server = '${RISEDEV_KAFKA_BOOTSTRAP_SERVERS}', topic = 'upsert_student_avro_json') -FORMAT UPSERT ENCODE AVRO (schema.registry = 'http://message_queue:8081'); +FORMAT UPSERT ENCODE AVRO (schema.registry = '${RISEDEV_SCHEMA_REGISTRY_URL}'); # TODO: Uncomment this when we add test data kafka key with format `"ID":id` @@ -28,35 +31,35 @@ FORMAT UPSERT ENCODE AVRO (schema.registry = 'http://message_queue:8081'); # ) # WITH ( # connector = 'kafka', -# properties.bootstrap.server = 'message_queue:29092', +# properties.bootstrap.server = '${RISEDEV_KAFKA_BOOTSTRAP_SERVERS}', # topic = 'upsert_avro_json') -# FORMAT UPSERT ENCODE AVRO (schema.registry = 'http://message_queue:8081'); +# FORMAT UPSERT ENCODE AVRO (schema.registry = '${RISEDEV_SCHEMA_REGISTRY_URL}'); statement ok CREATE TABLE debezium_non_compact (PRIMARY KEY(order_id)) with ( connector = 'kafka', kafka.topic = 'debezium_non_compact_avro_json', - kafka.brokers = 'message_queue:29092', + kafka.brokers = '${RISEDEV_KAFKA_BOOTSTRAP_SERVERS}', kafka.scan.startup.mode = 'earliest' -) FORMAT DEBEZIUM ENCODE AVRO (schema.registry = 'http://message_queue:8081'); +) FORMAT DEBEZIUM ENCODE AVRO (schema.registry = '${RISEDEV_SCHEMA_REGISTRY_URL}'); statement ok CREATE TABLE debezium_compact (PRIMARY KEY(order_id)) with ( connector = 'kafka', kafka.topic = 'debezium_compact_avro_json', - kafka.brokers = 'message_queue:29092', + kafka.brokers = '${RISEDEV_KAFKA_BOOTSTRAP_SERVERS}', kafka.scan.startup.mode = 'earliest' -) FORMAT DEBEZIUM ENCODE AVRO (schema.registry = 'http://message_queue:8081'); +) FORMAT DEBEZIUM ENCODE AVRO (schema.registry = '${RISEDEV_SCHEMA_REGISTRY_URL}'); statement ok CREATE TABLE kafka_json_schema_plain with ( connector = 'kafka', kafka.topic = 'kafka_json_schema', - kafka.brokers = 'message_queue:29092', + kafka.brokers = '${RISEDEV_KAFKA_BOOTSTRAP_SERVERS}', kafka.scan.startup.mode = 'earliest' -) FORMAT PLAIN ENCODE JSON (schema.registry = 'http://schemaregistry:8082'); +) FORMAT PLAIN ENCODE JSON (schema.registry = '${RISEDEV_SCHEMA_REGISTRY_URL}'); statement ok CREATE TABLE kafka_json_schema_upsert (PRIMARY KEY(rw_key)) @@ -64,9 +67,9 @@ INCLUDE KEY AS rw_key with ( connector = 'kafka', kafka.topic = 'kafka_upsert_json_schema', - kafka.brokers = 'message_queue:29092', + kafka.brokers = '${RISEDEV_KAFKA_BOOTSTRAP_SERVERS}', kafka.scan.startup.mode = 'earliest' -) FORMAT UPSERT ENCODE JSON (schema.registry = 'http://schemaregistry:8082'); +) FORMAT UPSERT ENCODE JSON (schema.registry = '${RISEDEV_SCHEMA_REGISTRY_URL}'); statement ok flush; diff --git a/e2e_test/source/basic/schema_registry.slt b/e2e_test/source/basic/schema_registry.slt index 76f867b2b1d0e..4673e441e80c6 100644 --- a/e2e_test/source/basic/schema_registry.slt +++ b/e2e_test/source/basic/schema_registry.slt @@ -5,7 +5,7 @@ create source s1 () with ( topic = 'upsert_avro_json-record', properties.bootstrap.server = 'message_queue:29092' ) format plain encode avro ( - schema.registry = 'http://message_queue:8081', + schema.registry = 'http://schemaregistry:8082', schema.registry.name.strategy = 'no sense', message = 'CPLM.OBJ_ATTRIBUTE_VALUE', ); @@ -17,7 +17,7 @@ create source s1 () with ( topic = 'upsert_avro_json-record', properties.bootstrap.server = 'message_queue:29092' ) format plain encode avro ( - schema.registry = 'http://message_queue:8081', + schema.registry = 'http://schemaregistry:8082', schema.registry.name.strategy = 'record_name_strategy', message = 'CPLM.OBJ_ATTRIBUTE_VALUE', key.message = 'string' @@ -29,7 +29,7 @@ create source s1 () with ( topic = 'upsert_avro_json-record', properties.bootstrap.server = 'message_queue:29092' ) format plain encode avro ( - schema.registry = 'http://message_queue:8081', + schema.registry = 'http://schemaregistry:8082', schema.registry.name.strategy = 'record_name_strategy', message = 'CPLM.OBJ_ATTRIBUTE_VALUE', ); @@ -41,7 +41,7 @@ create table t1 () with ( topic = 'upsert_avro_json-topic-record', properties.bootstrap.server = 'message_queue:29092' ) format upsert encode avro ( - schema.registry = 'http://message_queue:8081', + schema.registry = 'http://schemaregistry:8082', schema.registry.name.strategy = 'topic_record_name_strategy', message = 'CPLM.OBJ_ATTRIBUTE_VALUE' ); @@ -54,7 +54,7 @@ with ( topic = 'upsert_avro_json-topic-record', properties.bootstrap.server = 'message_queue:29092' ) format upsert encode avro ( - schema.registry = 'http://message_queue:8081', + schema.registry = 'http://schemaregistry:8082', schema.registry.name.strategy = 'topic_record_name_strategy', message = 'CPLM.OBJ_ATTRIBUTE_VALUE', key.message = 'string' diff --git a/risedev.yml b/risedev.yml index df69da7cb2457..f730a69d5f444 100644 --- a/risedev.yml +++ b/risedev.yml @@ -861,6 +861,10 @@ profile: user-managed: true address: message_queue port: 29092 + - use: schema-registry + user-managed: true + address: schemaregistry + port: 8082 ci-inline-source-test: config-path: src/config/ci-recovery.toml @@ -1433,9 +1437,8 @@ template: # Listen port of KRaft controller controller-port: 29093 - - # Listen address - listen-address: ${address} + # Listen port for other services in docker (schema-registry) + docker-port: 29094 # The docker image. Can be overridden to use a different version. image: "confluentinc/cp-kafka:7.6.1" @@ -1448,6 +1451,23 @@ template: user-managed: false + schema-registry: + # Id to be picked-up by services + id: schema-registry-${port} + + # Advertise address + address: "127.0.0.1" + + # Listen port of Schema Registry + port: 8081 + + # The docker image. Can be overridden to use a different version. + image: "confluentinc/cp-schema-registry:7.6.1" + + user-managed: false + + provide-kafka: "kafka*" + # Google pubsub emulator service pubsub: id: pubsub-${port} diff --git a/scripts/source/prepare_ci_kafka.sh b/scripts/source/prepare_ci_kafka.sh index e50229a73759f..9f3e2f473ca9b 100755 --- a/scripts/source/prepare_ci_kafka.sh +++ b/scripts/source/prepare_ci_kafka.sh @@ -56,7 +56,7 @@ for filename in $kafka_data_files; do if [[ "$topic" = *bin ]]; then kcat -P -b message_queue:29092 -t "$topic" "$filename" elif [[ "$topic" = *avro_json ]]; then - python3 source/schema_registry_producer.py "message_queue:29092" "http://message_queue:8081" "$filename" "topic" "avro" + python3 source/schema_registry_producer.py "message_queue:29092" "http://schemaregistry:8082" "$filename" "topic" "avro" elif [[ "$topic" = *json_schema ]]; then python3 source/schema_registry_producer.py "message_queue:29092" "http://schemaregistry:8082" "$filename" "topic" "json" else @@ -72,9 +72,9 @@ for i in {0..100}; do echo "key$i:{\"a\": $i}" | kcat -P -b message_queue:29092 # write schema with name strategy ## topic: upsert_avro_json-record, key subject: string, value subject: CPLM.OBJ_ATTRIBUTE_VALUE -(python3 source/schema_registry_producer.py "message_queue:29092" "http://message_queue:8081" source/test_data/upsert_avro_json.1 "record" "avro") & +(python3 source/schema_registry_producer.py "message_queue:29092" "http://schemaregistry:8082" source/test_data/upsert_avro_json.1 "record" "avro") & ## topic: upsert_avro_json-topic-record, ## key subject: upsert_avro_json-topic-record-string ## value subject: upsert_avro_json-topic-record-CPLM.OBJ_ATTRIBUTE_VALUE -(python3 source/schema_registry_producer.py "message_queue:29092" "http://message_queue:8081" source/test_data/upsert_avro_json.1 "topic-record" "avro") & +(python3 source/schema_registry_producer.py "message_queue:29092" "http://schemaregistry:8082" source/test_data/upsert_avro_json.1 "topic-record" "avro") & wait diff --git a/src/risedevtool/src/bin/risedev-compose.rs b/src/risedevtool/src/bin/risedev-compose.rs index ec805a840fa71..5ff56916deca6 100644 --- a/src/risedevtool/src/bin/risedev-compose.rs +++ b/src/risedevtool/src/bin/risedev-compose.rs @@ -219,9 +219,10 @@ fn main() -> Result<()> { volumes.insert(c.id.clone(), ComposeVolume::default()); (c.address.clone(), c.compose(&compose_config)?) } - ServiceConfig::Redis(_) | ServiceConfig::MySql(_) | ServiceConfig::Postgres(_) => { - return Err(anyhow!("not supported")) - } + ServiceConfig::Redis(_) + | ServiceConfig::MySql(_) + | ServiceConfig::Postgres(_) + | ServiceConfig::SchemaRegistry(_) => return Err(anyhow!("not supported")), }; compose.container_name = service.id().to_string(); if opts.deploy { diff --git a/src/risedevtool/src/bin/risedev-dev.rs b/src/risedevtool/src/bin/risedev-dev.rs index 5a7ab843ddae2..8dbe155bcd086 100644 --- a/src/risedevtool/src/bin/risedev-dev.rs +++ b/src/risedevtool/src/bin/risedev-dev.rs @@ -27,8 +27,8 @@ use risedev::{ generate_risedev_env, preflight_check, CompactorService, ComputeNodeService, ConfigExpander, ConfigureTmuxTask, DummyService, EnsureStopService, ExecuteContext, FrontendService, GrafanaService, KafkaService, MetaNodeService, MinioService, MySqlService, PostgresService, - PrometheusService, PubsubService, RedisService, ServiceConfig, SqliteConfig, Task, - TempoService, RISEDEV_NAME, + PrometheusService, PubsubService, RedisService, SchemaRegistryService, ServiceConfig, + SqliteConfig, Task, TempoService, RISEDEV_NAME, }; use tempfile::tempdir; use thiserror_ext::AsReport; @@ -279,6 +279,18 @@ fn task_main( ctx.pb .set_message(format!("kafka {}:{}", c.address, c.port)); } + ServiceConfig::SchemaRegistry(c) => { + let mut ctx = + ExecuteContext::new(&mut logger, manager.new_progress(), status_dir.clone()); + let mut service = SchemaRegistryService::new(c.clone()); + service.execute(&mut ctx)?; + let mut task = + risedev::TcpReadyCheckTask::new(c.address.clone(), c.port, c.user_managed)?; + task.execute(&mut ctx)?; + ctx.pb + .set_message(format!("schema registry http://{}:{}", c.address, c.port)); + } + ServiceConfig::Pubsub(c) => { let mut ctx = ExecuteContext::new(&mut logger, manager.new_progress(), status_dir.clone()); diff --git a/src/risedevtool/src/config.rs b/src/risedevtool/src/config.rs index 839ebc22486ee..bf768f8e68cd1 100644 --- a/src/risedevtool/src/config.rs +++ b/src/risedevtool/src/config.rs @@ -175,6 +175,9 @@ impl ConfigExpander { "redpanda" => ServiceConfig::RedPanda(serde_yaml::from_str(&out_str)?), "mysql" => ServiceConfig::MySql(serde_yaml::from_str(&out_str)?), "postgres" => ServiceConfig::Postgres(serde_yaml::from_str(&out_str)?), + "schema-registry" => { + ServiceConfig::SchemaRegistry(serde_yaml::from_str(&out_str)?) + } other => return Err(anyhow!("unsupported use type: {}", other)), }; Ok(result) diff --git a/src/risedevtool/src/risedev_env.rs b/src/risedevtool/src/risedev_env.rs index a45864f097854..2b6cc367b2e71 100644 --- a/src/risedevtool/src/risedev_env.rs +++ b/src/risedevtool/src/risedev_env.rs @@ -77,6 +77,15 @@ pub fn generate_risedev_env(services: &Vec) -> String { writeln!(env, r#"RISEDEV_KAFKA_WITH_OPTIONS_COMMON="connector='kafka',properties.bootstrap.server='{brokers}'""#).unwrap(); writeln!(env, r#"RPK_BROKERS="{brokers}""#).unwrap(); } + ServiceConfig::SchemaRegistry(c) => { + let address = &c.address; + let port = &c.port; + writeln!( + env, + r#"RISEDEV_SCHEMA_REGISTRY_URL="http://{address}:{port}""#, + ) + .unwrap(); + } ServiceConfig::MySql(c) => { let host = &c.address; let port = &c.port; diff --git a/src/risedevtool/src/service_config.rs b/src/risedevtool/src/service_config.rs index 88c1594fb1153..71461b0f58bcc 100644 --- a/src/risedevtool/src/service_config.rs +++ b/src/risedevtool/src/service_config.rs @@ -271,12 +271,16 @@ pub struct KafkaConfig { phantom_use: Option, pub id: String, + /// Advertise address pub address: String, #[serde(with = "string")] pub port: u16, + /// Port for other services in docker. They need to connect to `host.docker.internal`, while the host + /// need to connect to `localhost`. + pub docker_port: u16, + #[serde(with = "string")] pub controller_port: u16, - pub listen_address: String, pub image: String, pub persist_data: bool, @@ -284,6 +288,28 @@ pub struct KafkaConfig { pub user_managed: bool, } + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +#[serde(deny_unknown_fields)] +pub struct SchemaRegistryConfig { + #[serde(rename = "use")] + phantom_use: Option, + + pub id: String, + + pub address: String, + #[serde(with = "string")] + pub port: u16, + + pub provide_kafka: Option>, + + pub image: String, + /// Redpanda supports schema registry natively. You can configure a `user_managed` schema registry + /// to use with redpanda. + pub user_managed: bool, +} + #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "kebab-case")] #[serde(deny_unknown_fields)] @@ -380,6 +406,7 @@ pub enum ServiceConfig { Opendal(OpendalConfig), AwsS3(AwsS3Config), Kafka(KafkaConfig), + SchemaRegistry(SchemaRegistryConfig), Pubsub(PubsubConfig), Redis(RedisConfig), RedPanda(RedPandaConfig), @@ -407,10 +434,12 @@ impl ServiceConfig { Self::RedPanda(c) => &c.id, Self::Opendal(c) => &c.id, Self::MySql(c) => &c.id, - ServiceConfig::Postgres(c) => &c.id, + Self::Postgres(c) => &c.id, + Self::SchemaRegistry(c) => &c.id, } } + /// Used to check whether the port is occupied before running the service. pub fn port(&self) -> Option { match self { Self::ComputeNode(c) => Some(c.port), @@ -430,7 +459,8 @@ impl ServiceConfig { Self::RedPanda(_c) => None, Self::Opendal(_) => None, Self::MySql(c) => Some(c.port), - ServiceConfig::Postgres(c) => Some(c.port), + Self::Postgres(c) => Some(c.port), + Self::SchemaRegistry(c) => Some(c.port), } } @@ -454,6 +484,7 @@ impl ServiceConfig { Self::Opendal(_c) => false, Self::MySql(c) => c.user_managed, Self::Postgres(c) => c.user_managed, + Self::SchemaRegistry(c) => c.user_managed, } } } diff --git a/src/risedevtool/src/task.rs b/src/risedevtool/src/task.rs index e34cddd908b7f..21b6f20eec5ee 100644 --- a/src/risedevtool/src/task.rs +++ b/src/risedevtool/src/task.rs @@ -29,6 +29,7 @@ mod postgres_service; mod prometheus_service; mod pubsub_service; mod redis_service; +mod schema_registry_service; mod task_configure_minio; mod task_etcd_ready_check; mod task_kafka_ready_check; @@ -68,6 +69,7 @@ pub use self::postgres_service::*; pub use self::prometheus_service::*; pub use self::pubsub_service::*; pub use self::redis_service::*; +pub use self::schema_registry_service::SchemaRegistryService; pub use self::task_configure_minio::*; pub use self::task_etcd_ready_check::*; pub use self::task_kafka_ready_check::*; diff --git a/src/risedevtool/src/task/docker_service.rs b/src/risedevtool/src/task/docker_service.rs index 58ff2b59648c0..b87ee8a6a8aef 100644 --- a/src/risedevtool/src/task/docker_service.rs +++ b/src/risedevtool/src/task/docker_service.rs @@ -100,7 +100,9 @@ where cmd.arg("run") .arg("--rm") .arg("--name") - .arg(format!("risedev-{}", self.id())); + .arg(format!("risedev-{}", self.id())) + .arg("--add-host") + .arg("host.docker.internal:host-gateway"); for (k, v) in self.config.envs() { cmd.arg("-e").arg(format!("{k}={v}")); diff --git a/src/risedevtool/src/task/kafka_service.rs b/src/risedevtool/src/task/kafka_service.rs index 52bdd227a72a4..7c415b6d9749a 100644 --- a/src/risedevtool/src/task/kafka_service.rs +++ b/src/risedevtool/src/task/kafka_service.rs @@ -37,15 +37,18 @@ impl DockerServiceConfig for KafkaConfig { ), ( "KAFKA_LISTENERS".to_owned(), - "PLAINTEXT://:9092,CONTROLLER://:9093".to_owned(), + "HOST://:9092,CONTROLLER://:9093,DOCKER://:9094".to_owned(), ), ( "KAFKA_ADVERTISED_LISTENERS".to_owned(), - format!("PLAINTEXT://{}:{}", self.address, self.port), + format!( + "HOST://{}:{},DOCKER://host.docker.internal:{}", + self.address, self.port, self.docker_port + ), ), ( "KAFKA_LISTENER_SECURITY_PROTOCOL_MAP".to_owned(), - "PLAINTEXT:PLAINTEXT,CONTROLLER:PLAINTEXT".to_owned(), + "HOST:PLAINTEXT,CONTROLLER:PLAINTEXT,DOCKER:PLAINTEXT".to_owned(), ), ( "KAFKA_CONTROLLER_QUORUM_VOTERS".to_owned(), @@ -55,12 +58,19 @@ impl DockerServiceConfig for KafkaConfig { "KAFKA_CONTROLLER_LISTENER_NAMES".to_owned(), "CONTROLLER".to_owned(), ), + ( + "KAFKA_INTER_BROKER_LISTENER_NAME".to_owned(), + "HOST".to_owned(), + ), ("CLUSTER_ID".to_owned(), "RiseDevRiseDevRiseDev1".to_owned()), ] } fn ports(&self) -> Vec<(String, String)> { - vec![(self.port.to_string(), "9092".to_owned())] + vec![ + (self.port.to_string(), "9092".to_owned()), + (self.docker_port.to_string(), "9094".to_owned()), + ] } fn data_path(&self) -> Option { diff --git a/src/risedevtool/src/task/schema_registry_service.rs b/src/risedevtool/src/task/schema_registry_service.rs new file mode 100644 index 0000000000000..5c5eba4fa8f35 --- /dev/null +++ b/src/risedevtool/src/task/schema_registry_service.rs @@ -0,0 +1,65 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use super::docker_service::{DockerService, DockerServiceConfig}; +use crate::SchemaRegistryConfig; + +impl DockerServiceConfig for SchemaRegistryConfig { + fn id(&self) -> String { + self.id.clone() + } + + fn is_user_managed(&self) -> bool { + self.user_managed + } + + fn image(&self) -> String { + self.image.clone() + } + + fn envs(&self) -> Vec<(String, String)> { + // https://docs.confluent.io/platform/current/installation/docker/config-reference.html#sr-long-configuration + // https://docs.confluent.io/platform/current/schema-registry/installation/config.html + let kafka = self + .provide_kafka + .as_ref() + .expect("Kafka is required for Schema Registry"); + if kafka.len() != 1 { + panic!("More than one Kafka is not supported yet"); + } + let kafka = &kafka[0]; + vec![ + ("SCHEMA_REGISTRY_HOST_NAME".to_owned(), self.address.clone()), + ( + "SCHEMA_REGISTRY_LISTENERS".to_owned(), + format!("http://{}:{}", self.address, self.port), + ), + ( + "SCHEMA_REGISTRY_KAFKASTORE_BOOTSTRAP_SERVERS".to_owned(), + format!("host.docker.internal:{}", kafka.docker_port), + ), + ] + } + + fn ports(&self) -> Vec<(String, String)> { + vec![(self.port.to_string(), "8081".to_owned())] + } + + fn data_path(&self) -> Option { + None + } +} + +/// Docker-backed Schema Registry service. +pub type SchemaRegistryService = DockerService; From 0d4d2530a5b8cd30fda39b8bd84adf0ce6e26816 Mon Sep 17 00:00:00 2001 From: StrikeW Date: Thu, 30 May 2024 13:19:31 +0800 Subject: [PATCH 03/20] chore: upgrade to openjdk17 in dockerfile (#17013) --- ci/Dockerfile | 2 +- ci/scripts/release.sh | 2 +- docker/Dockerfile | 2 +- docker/Dockerfile.hdfs | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ci/Dockerfile b/ci/Dockerfile index 616af35fd118e..a1b6857a45c20 100644 --- a/ci/Dockerfile +++ b/ci/Dockerfile @@ -12,7 +12,7 @@ ENV LANG en_US.utf8 RUN sed -i 's|http://archive.ubuntu.com/ubuntu|http://us-east-2.ec2.archive.ubuntu.com/ubuntu/|g' /etc/apt/sources.list RUN apt-get update -yy && \ DEBIAN_FRONTEND=noninteractive apt-get -y install sudo make build-essential cmake protobuf-compiler curl parallel python3 python3-pip python3-venv software-properties-common \ - openssl libssl-dev libsasl2-dev libcurl4-openssl-dev pkg-config bash openjdk-11-jdk wget unzip git tmux lld postgresql-client kcat netcat-openbsd mysql-client \ + openssl libssl-dev libsasl2-dev libcurl4-openssl-dev pkg-config bash openjdk-17-jdk wget unzip git tmux lld postgresql-client kcat netcat-openbsd mysql-client \ maven zstd libzstd-dev locales \ python3.12 python3.12-dev \ && rm -rf /var/lib/{apt,dpkg,cache,log}/ diff --git a/ci/scripts/release.sh b/ci/scripts/release.sh index ee6479362f2ed..94fd38e2c9c75 100755 --- a/ci/scripts/release.sh +++ b/ci/scripts/release.sh @@ -23,7 +23,7 @@ echo "--- Install dependencies" dnf install -y perl-core wget python3 python3-devel cyrus-sasl-devel rsync openssl-devel echo "--- Install java and maven" -dnf install -y java-11-openjdk java-11-openjdk-devel +dnf install -y java-17-openjdk java-17-openjdk-devel pip3 install toml-cli wget https://rw-ci-deps-dist.s3.amazonaws.com/apache-maven-3.9.3-bin.tar.gz && tar -zxvf apache-maven-3.9.3-bin.tar.gz export PATH="${REPO_ROOT}/apache-maven-3.9.3/bin:$PATH" diff --git a/docker/Dockerfile b/docker/Dockerfile index 167815a988131..b4d2cf73ee85f 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -3,7 +3,7 @@ FROM ubuntu:24.04 AS base ENV LANG en_US.utf8 RUN apt-get update \ - && apt-get -y install ca-certificates build-essential libsasl2-dev openjdk-11-jdk software-properties-common python3.12 python3.12-dev openssl pkg-config + && apt-get -y install ca-certificates build-essential libsasl2-dev openjdk-17-jdk software-properties-common python3.12 python3.12-dev openssl pkg-config FROM base AS rust-base diff --git a/docker/Dockerfile.hdfs b/docker/Dockerfile.hdfs index 2e49564ccb570..b6eba07c421c0 100644 --- a/docker/Dockerfile.hdfs +++ b/docker/Dockerfile.hdfs @@ -3,7 +3,7 @@ FROM ubuntu:24.04 AS base ENV LANG en_US.utf8 RUN apt-get update \ - && apt-get -y install ca-certificates build-essential libsasl2-dev openjdk-11-jdk software-properties-common python3.12 python3.12-dev openssl pkg-config + && apt-get -y install ca-certificates build-essential libsasl2-dev openjdk-17-jdk software-properties-common python3.12 python3.12-dev openssl pkg-config FROM base AS dashboard-builder @@ -113,7 +113,7 @@ RUN cd /risingwave/java && mvn -B package -Dmaven.test.skip=true -Dno-build-rust tar -zxvf /risingwave/java/connector-node/assembly/target/risingwave-connector-1.0.0.tar.gz -C /risingwave/bin/connector-node FROM ubuntu:24.04 as image-base -RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get -y install ca-certificates openjdk-11-jdk wget libsasl2-dev && rm -rf /var/lib/{apt,dpkg,cache,log}/ +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get -y install ca-certificates openjdk-17-jdk wget libsasl2-dev && rm -rf /var/lib/{apt,dpkg,cache,log}/ FROM image-base as risingwave LABEL org.opencontainers.image.source https://github.com/risingwavelabs/risingwave From 4e64389601dfb14c566f2d4dc52fc2a753266c49 Mon Sep 17 00:00:00 2001 From: William Wen <44139337+wenym1@users.noreply.github.com> Date: Thu, 30 May 2024 14:05:37 +0800 Subject: [PATCH 04/20] fix(log-store): rebuild log store iter when exists for a timeout (#17009) --- .../log_store_impl/kv_log_store/reader.rs | 271 ++++++++++++++++-- .../log_store_impl/kv_log_store/serde.rs | 2 +- 2 files changed, 252 insertions(+), 21 deletions(-) diff --git a/src/stream/src/common/log_store_impl/kv_log_store/reader.rs b/src/stream/src/common/log_store_impl/kv_log_store/reader.rs index 21ee99ec91d08..e2e767c8d2038 100644 --- a/src/stream/src/common/log_store_impl/kv_log_store/reader.rs +++ b/src/stream/src/common/log_store_impl/kv_log_store/reader.rs @@ -13,12 +13,14 @@ // limitations under the License. use std::future::Future; +use std::ops::Bound; use std::ops::Bound::{Excluded, Included, Unbounded}; use std::pin::Pin; -use std::time::Duration; +use std::time::{Duration, Instant}; use anyhow::anyhow; use await_tree::InstrumentAwait; +use bytes::Bytes; use foyer::CacheContext; use futures::future::{try_join_all, BoxFuture}; use futures::{FutureExt, TryFutureExt}; @@ -31,11 +33,14 @@ use risingwave_common::util::epoch::EpochExt; use risingwave_connector::sink::log_store::{ ChunkId, LogReader, LogStoreReadItem, LogStoreResult, TruncateOffset, }; -use risingwave_hummock_sdk::key::prefixed_range_with_vnode; +use risingwave_hummock_sdk::key::{prefixed_range_with_vnode, FullKey, TableKey, TableKeyRange}; use risingwave_hummock_sdk::HummockEpoch; +use risingwave_storage::error::StorageResult; use risingwave_storage::hummock::CachePolicy; -use risingwave_storage::store::{PrefetchOptions, ReadOptions}; -use risingwave_storage::StateStore; +use risingwave_storage::store::{ + PrefetchOptions, ReadOptions, StateStoreIterItemRef, StateStoreRead, +}; +use risingwave_storage::{StateStore, StateStoreIter}; use tokio::sync::watch; use tokio::time::sleep; use tokio_stream::StreamExt; @@ -113,7 +118,7 @@ pub struct KvLogStoreReader { first_write_epoch: Option, /// `Some` means consuming historical log data - state_store_stream: Option>>>, + state_store_stream: Option>>>>, /// Store the future that attempts to read a flushed stream chunk. /// This is for cancellation safety. Since it is possible that the future of `next_item` @@ -180,12 +185,141 @@ impl KvLogStoreReader { } } +struct AutoRebuildStateStoreReadIter { + state_store: S, + iter: S::Iter, + // call to get whether to rebuild the iter. Once return true, the closure should reset itself. + should_rebuild: F, + end_bound: Bound>, + epoch: HummockEpoch, + options: ReadOptions, +} + +impl bool> AutoRebuildStateStoreReadIter { + async fn new( + state_store: S, + should_rebuild: F, + range: TableKeyRange, + epoch: HummockEpoch, + options: ReadOptions, + ) -> StorageResult { + let (start_bound, end_bound) = range; + let iter = state_store + .iter((start_bound, end_bound.clone()), epoch, options.clone()) + .await?; + Ok(Self { + state_store, + iter, + should_rebuild, + end_bound, + epoch, + options, + }) + } +} + +type TimeoutAutoRebuildIter = + AutoRebuildStateStoreReadIter bool + Send>; + +async fn iter_with_timeout_rebuild( + state_store: S, + range: TableKeyRange, + epoch: HummockEpoch, + options: ReadOptions, + timeout: Duration, +) -> StorageResult> { + const CHECK_TIMEOUT_PERIOD: usize = 100; + // use a struct here to avoid accidental copy instead of move on primitive usize + struct Count(usize); + let mut check_count = Count(0); + let mut total_count = Count(0); + let mut curr_iter_item_count = Count(0); + let mut start_time = Instant::now(); + let initial_start_time = start_time; + AutoRebuildStateStoreReadIter::new( + state_store, + move || { + check_count.0 += 1; + curr_iter_item_count.0 += 1; + total_count.0 += 1; + if check_count.0 == CHECK_TIMEOUT_PERIOD { + check_count.0 = 0; + if start_time.elapsed() > timeout { + let prev_iter_item_count = curr_iter_item_count.0; + curr_iter_item_count.0 = 0; + start_time = Instant::now(); + info!( + table_id = options.table_id.table_id, + iter_exist_time_secs = initial_start_time.elapsed().as_secs(), + prev_iter_item_count, + total_iter_item_count = total_count.0, + "kv log store iter is rebuilt" + ); + true + } else { + false + } + } else { + false + } + }, + range, + epoch, + options, + ) + .await +} + +impl bool + Send> StateStoreIter + for AutoRebuildStateStoreReadIter +{ + async fn try_next(&mut self) -> StorageResult>> { + let should_rebuild = (self.should_rebuild)(); + if should_rebuild { + let Some((key, _value)) = self.iter.try_next().await? else { + return Ok(None); + }; + let key: FullKey<&[u8]> = key; + let range_start = Bytes::copy_from_slice(key.user_key.table_key.as_ref()); + let new_iter = self + .state_store + .iter( + ( + Included(TableKey(range_start.clone())), + self.end_bound.clone(), + ), + self.epoch, + self.options.clone(), + ) + .await?; + self.iter = new_iter; + let item: Option> = self.iter.try_next().await?; + if let Some((key, value)) = item { + assert_eq!( + key.user_key.table_key.0, + range_start.as_ref(), + "the first key should be the previous key" + ); + Ok(Some((key, value))) + } else { + unreachable!( + "the first key should be the previous key {:?}, but get None", + range_start + ) + } + } else { + self.iter.try_next().await + } + } +} + impl KvLogStoreReader { fn read_persisted_log_store( &self, last_persisted_epoch: Option, - ) -> impl Future>>>> + Send - { + ) -> impl Future< + Output = LogStoreResult>>>>, + > + Send { let range_start = if let Some(last_persisted_epoch) = last_persisted_epoch { // start from the next epoch of last_persisted_epoch Included( @@ -210,19 +344,21 @@ impl KvLogStoreReader { ); let state_store = self.state_store.clone(); async move { - state_store - .iter( - key_range, - HummockEpoch::MAX, - ReadOptions { - // This stream lives too long, the connection of prefetch object may break. So use a short connection prefetch. - prefetch_options: PrefetchOptions::prefetch_for_small_range_scan(), - cache_policy: CachePolicy::Fill(CacheContext::LruPriorityLow), - table_id, - ..Default::default() - }, - ) - .await + // rebuild the iter every 10 minutes to avoid pinning hummock version for too long + iter_with_timeout_rebuild( + state_store, + key_range, + HummockEpoch::MAX, + ReadOptions { + // This stream lives too long, the connection of prefetch object may break. So use a short connection prefetch. + prefetch_options: PrefetchOptions::prefetch_for_small_range_scan(), + cache_policy: CachePolicy::Fill(CacheContext::LruPriorityLow), + table_id, + ..Default::default() + }, + Duration::from_secs(10 * 60), + ) + .await } })); @@ -500,3 +636,98 @@ impl LogReader for KvLogStoreReader { Ok((true, Some((**self.serde.vnodes()).clone()))) } } + +#[cfg(test)] +mod tests { + use std::ops::Bound::Unbounded; + + use bytes::Bytes; + use itertools::Itertools; + use risingwave_common::util::epoch::test_epoch; + use risingwave_hummock_sdk::key::TableKey; + use risingwave_storage::hummock::iterator::test_utils::{ + iterator_test_table_key_of, iterator_test_value_of, + }; + use risingwave_storage::memory::MemoryStateStore; + use risingwave_storage::storage_value::StorageValue; + use risingwave_storage::store::{ReadOptions, StateStoreRead, StateStoreWrite, WriteOptions}; + use risingwave_storage::StateStoreIter; + + use crate::common::log_store_impl::kv_log_store::reader::AutoRebuildStateStoreReadIter; + use crate::common::log_store_impl::kv_log_store::test_utils::TEST_TABLE_ID; + + #[tokio::test] + async fn test_auto_rebuild_iter() { + let state_store = MemoryStateStore::new(); + let key_count = 100; + let pairs = (0..key_count) + .map(|i| { + let key = iterator_test_table_key_of(i); + let value = iterator_test_value_of(i); + (TableKey(Bytes::from(key)), StorageValue::new_put(value)) + }) + .collect_vec(); + let epoch = test_epoch(1); + state_store + .ingest_batch( + pairs.clone(), + vec![], + WriteOptions { + epoch, + table_id: TEST_TABLE_ID, + }, + ) + .unwrap(); + + async fn validate( + mut kv_iter: impl Iterator, StorageValue)>, + mut iter: impl StateStoreIter, + ) { + while let Some((key, value)) = iter.try_next().await.unwrap() { + let (k, v) = kv_iter.next().unwrap(); + assert_eq!(key.user_key.table_key, k.to_ref()); + assert_eq!(v.user_value.as_deref(), Some(value)); + } + assert!(kv_iter.next().is_none()); + } + + let read_options = ReadOptions { + table_id: TEST_TABLE_ID, + ..Default::default() + }; + + let kv_iter = pairs.clone().into_iter(); + let iter = state_store + .iter((Unbounded, Unbounded), epoch, read_options.clone()) + .await + .unwrap(); + validate(kv_iter, iter).await; + + let kv_iter = pairs.clone().into_iter(); + let mut count = 0; + let count_mut_ref = &mut count; + let rebuild_period = 8; + let mut rebuild_count = 0; + let rebuild_count_mut_ref = &mut rebuild_count; + let iter = AutoRebuildStateStoreReadIter::new( + state_store, + move || { + *count_mut_ref += 1; + if *count_mut_ref % rebuild_period == 0 { + *rebuild_count_mut_ref += 1; + true + } else { + false + } + }, + (Unbounded, Unbounded), + epoch, + read_options, + ) + .await + .unwrap(); + validate(kv_iter, iter).await; + assert_eq!(count, key_count + 1); // with an extra call on the last None + assert_eq!(rebuild_count, key_count / rebuild_period); + } +} diff --git a/src/stream/src/common/log_store_impl/kv_log_store/serde.rs b/src/stream/src/common/log_store_impl/kv_log_store/serde.rs index 9eb7faf237ead..9871139bafddc 100644 --- a/src/stream/src/common/log_store_impl/kv_log_store/serde.rs +++ b/src/stream/src/common/log_store_impl/kv_log_store/serde.rs @@ -544,7 +544,7 @@ impl LogStoreRowOpStream { } } -pub(crate) type LogStoreItemMergeStream = +pub(crate) type LogStoreItemMergeStream = impl Stream>; pub(crate) fn merge_log_store_item_stream( iters: Vec, From 8dfae832334ae1dc0585dcae4b3071e9ee6c9b1d Mon Sep 17 00:00:00 2001 From: Li0k Date: Thu, 30 May 2024 14:13:48 +0800 Subject: [PATCH 05/20] fix(storage): Remove ambiguous configuration max_sub_compaction (#16960) --- proto/hummock.proto | 2 ++ src/common/src/config.rs | 8 ----- src/config/docs.md | 1 - src/config/example.toml | 1 - src/meta/src/hummock/manager/compaction.rs | 1 + .../src/hummock/compactor/compaction_utils.rs | 32 ++++++++++++------- src/storage/src/opts.rs | 3 -- 7 files changed, 23 insertions(+), 25 deletions(-) diff --git a/proto/hummock.proto b/proto/hummock.proto index 8d68ec168ef21..7caf27e155deb 100644 --- a/proto/hummock.proto +++ b/proto/hummock.proto @@ -382,6 +382,8 @@ message CompactTask { map table_watermarks = 24; // The table schemas that are at least as new as the one used to create `input_ssts`. map table_schemas = 25; + // Max sub compaction task numbers + uint32 max_sub_compaction = 26; } message LevelHandler { diff --git a/src/common/src/config.rs b/src/common/src/config.rs index c8da0f6dce5e9..26e8bcaf1f56b 100644 --- a/src/common/src/config.rs +++ b/src/common/src/config.rs @@ -721,10 +721,6 @@ pub struct StorageConfig { #[serde(default = "default::storage::min_sst_size_for_streaming_upload")] pub min_sst_size_for_streaming_upload: u64, - /// Max sub compaction task numbers - #[serde(default = "default::storage::max_sub_compaction")] - pub max_sub_compaction: u32, - #[serde(default = "default::storage::max_concurrent_compaction_task_number")] pub max_concurrent_compaction_task_number: u64, @@ -1461,10 +1457,6 @@ pub mod default { 32 * 1024 * 1024 } - pub fn max_sub_compaction() -> u32 { - 4 - } - pub fn max_concurrent_compaction_task_number() -> u64 { 16 } diff --git a/src/config/docs.md b/src/config/docs.md index 018c9dd41087c..0a024ba992db0 100644 --- a/src/config/docs.md +++ b/src/config/docs.md @@ -121,7 +121,6 @@ This page is automatically generated by `./risedev generate-example-config` | max_prefetch_block_number | max prefetch block number | 16 | | max_preload_io_retry_times | | 3 | | max_preload_wait_time_mill | | 0 | -| max_sub_compaction | Max sub compaction task numbers | 4 | | max_version_pinning_duration_sec | | 10800 | | mem_table_spill_threshold | The spill threshold for mem table. | 4194304 | | meta_cache_capacity_mb | DEPRECATED: This config will be deprecated in the future version, use `storage.cache.meta_cache_capacity_mb` instead. | | diff --git a/src/config/example.toml b/src/config/example.toml index 00b1ef759e5f9..93546c7bdd238 100644 --- a/src/config/example.toml +++ b/src/config/example.toml @@ -136,7 +136,6 @@ compactor_memory_available_proportion = 0.8 sstable_id_remote_fetch_number = 10 min_sstable_size_mb = 32 min_sst_size_for_streaming_upload = 33554432 -max_sub_compaction = 4 max_concurrent_compaction_task_number = 16 max_preload_wait_time_mill = 0 max_version_pinning_duration_sec = 10800 diff --git a/src/meta/src/hummock/manager/compaction.rs b/src/meta/src/hummock/manager/compaction.rs index addb416893b08..16ca79a30962d 100644 --- a/src/meta/src/hummock/manager/compaction.rs +++ b/src/meta/src/hummock/manager/compaction.rs @@ -774,6 +774,7 @@ impl HummockManager { target_sub_level_id: compact_task.input.target_sub_level_id, task_type: compact_task.compaction_task_type as i32, split_weight_by_vnode: vnode_partition_count, + max_sub_compaction: group_config.compaction_config.max_sub_compaction, ..Default::default() }; diff --git a/src/storage/src/hummock/compactor/compaction_utils.rs b/src/storage/src/hummock/compactor/compaction_utils.rs index d0e5fe93c62ee..63b59366195f0 100644 --- a/src/storage/src/hummock/compactor/compaction_utils.rs +++ b/src/storage/src/hummock/compactor/compaction_utils.rs @@ -28,8 +28,7 @@ use risingwave_hummock_sdk::table_stats::TableStatsMap; use risingwave_hummock_sdk::{can_concat, EpochWithGap, KeyComparator}; use risingwave_pb::hummock::compact_task::TaskType; use risingwave_pb::hummock::{ - compact_task, BloomFilterType, CompactTask, KeyRange as KeyRange_vec, LevelType, SstableInfo, - TableSchema, + compact_task, BloomFilterType, CompactTask, LevelType, PbKeyRange, SstableInfo, TableSchema, }; use tokio::time::Instant; @@ -178,7 +177,8 @@ fn generate_splits_fast( sstable_infos: &Vec, compaction_size: u64, context: &CompactorContext, -) -> Vec { + max_sub_compaction: u32, +) -> Vec { let worker_num = context.compaction_executor.worker_num(); let parallel_compact_size = (context.storage_opts.parallel_compact_size_mb as u64) << 20; @@ -186,7 +186,7 @@ fn generate_splits_fast( worker_num, parallel_compact_size, compaction_size, - context.storage_opts.max_sub_compaction, + max_sub_compaction, ); let mut indexes = vec![]; for sst in sstable_infos { @@ -213,13 +213,13 @@ fn generate_splits_fast( } let mut splits = vec![]; - splits.push(KeyRange_vec::new(vec![], vec![])); + splits.push(PbKeyRange::new(vec![], vec![])); let parallel_key_count = indexes.len() / parallelism; let mut last_split_key_count = 0; for key in indexes { if last_split_key_count >= parallel_key_count { splits.last_mut().unwrap().right.clone_from(&key); - splits.push(KeyRange_vec::new(key.clone(), vec![])); + splits.push(PbKeyRange::new(key.clone(), vec![])); last_split_key_count = 0; } last_split_key_count += 1; @@ -232,7 +232,8 @@ pub async fn generate_splits( sstable_infos: &Vec, compaction_size: u64, context: &CompactorContext, -) -> HummockResult> { + max_sub_compaction: u32, +) -> HummockResult> { const MAX_FILE_COUNT: usize = 32; let parallel_compact_size = (context.storage_opts.parallel_compact_size_mb as u64) << 20; if compaction_size > parallel_compact_size { @@ -241,6 +242,7 @@ pub async fn generate_splits( sstable_infos, compaction_size, context, + max_sub_compaction, )); } let mut indexes = vec![]; @@ -269,13 +271,13 @@ pub async fn generate_splits( // sort by key, as for every data block has the same size; indexes.sort_by(|a, b| KeyComparator::compare_encoded_full_key(a.1.as_ref(), b.1.as_ref())); let mut splits = vec![]; - splits.push(KeyRange_vec::new(vec![], vec![])); + splits.push(PbKeyRange::new(vec![], vec![])); let parallelism = calculate_task_parallelism_impl( context.compaction_executor.worker_num(), parallel_compact_size, compaction_size, - context.storage_opts.max_sub_compaction, + max_sub_compaction, ); let sub_compaction_data_size = @@ -291,7 +293,7 @@ pub async fn generate_splits( && remaining_size > parallel_compact_size { splits.last_mut().unwrap().right.clone_from(&key); - splits.push(KeyRange_vec::new(key.clone(), vec![])); + splits.push(PbKeyRange::new(key.clone(), vec![])); last_buffer_size = data_size; } else { last_buffer_size += data_size; @@ -577,7 +579,13 @@ pub async fn generate_splits_for_task( .sum::(); if !optimize_by_copy_block { - let splits = generate_splits(&sstable_infos, compaction_size, context).await?; + let splits = generate_splits( + &sstable_infos, + compaction_size, + context, + compact_task.get_max_sub_compaction(), + ) + .await?; if !splits.is_empty() { compact_task.splits = splits; } @@ -659,7 +667,7 @@ pub fn calculate_task_parallelism(compact_task: &CompactTask, context: &Compacto context.compaction_executor.worker_num(), parallel_compact_size, compaction_size, - context.storage_opts.max_sub_compaction, + compact_task.get_max_sub_compaction(), ) } diff --git a/src/storage/src/opts.rs b/src/storage/src/opts.rs index aa4fd4cbb9630..5a7bca2c30b42 100644 --- a/src/storage/src/opts.rs +++ b/src/storage/src/opts.rs @@ -74,8 +74,6 @@ pub struct StorageOpts { pub sstable_id_remote_fetch_number: u32, /// Whether to enable streaming upload for sstable. pub min_sst_size_for_streaming_upload: u64, - /// Max sub compaction task numbers - pub max_sub_compaction: u32, pub max_concurrent_compaction_task_number: u64, pub max_version_pinning_duration_sec: u64, pub compactor_iter_max_io_retry_times: usize, @@ -176,7 +174,6 @@ impl From<(&RwConfig, &SystemParamsReader, &StorageMemoryConfig)> for StorageOpt compactor_memory_limit_mb: s.compactor_memory_limit_mb, sstable_id_remote_fetch_number: c.storage.sstable_id_remote_fetch_number, min_sst_size_for_streaming_upload: c.storage.min_sst_size_for_streaming_upload, - max_sub_compaction: c.storage.max_sub_compaction, max_concurrent_compaction_task_number: c.storage.max_concurrent_compaction_task_number, max_version_pinning_duration_sec: c.storage.max_version_pinning_duration_sec, data_file_cache_dir: c.storage.data_file_cache.dir.clone(), From e154a374fe2a67d914eb039749f1446f9ef5fb18 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Thu, 30 May 2024 15:01:05 +0800 Subject: [PATCH 06/20] refactor(connector): do not expose internal implementation for parser benchmarks (#16996) Signed-off-by: Bugen Zhao --- src/connector/Cargo.toml | 8 +- src/connector/benches/debezium_json_parser.rs | 86 +++++++ src/connector/benches/json_common/mod.rs | 57 +++++ src/connector/benches/json_parser.rs | 226 ------------------ ...ser.rs => json_parser_case_insensitive.rs} | 56 ++--- src/connector/benches/json_vs_plain_parser.rs | 93 +++++++ src/connector/benches/nexmark_integration.rs | 22 +- src/connector/src/parser/mod.rs | 7 +- src/connector/src/source/base.rs | 13 + 9 files changed, 302 insertions(+), 266 deletions(-) create mode 100644 src/connector/benches/debezium_json_parser.rs create mode 100644 src/connector/benches/json_common/mod.rs delete mode 100644 src/connector/benches/json_parser.rs rename src/connector/benches/{parser.rs => json_parser_case_insensitive.rs} (73%) create mode 100644 src/connector/benches/json_vs_plain_parser.rs diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index 64e3d159daa71..e73fb35e63267 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -179,7 +179,7 @@ prost-build = "0.12" protobuf-src = "1" [[bench]] -name = "parser" +name = "debezium_json_parser" harness = false [[bench]] @@ -187,7 +187,11 @@ name = "nexmark_integration" harness = false [[bench]] -name = "json_parser" +name = "json_parser_case_insensitive" +harness = false + +[[bench]] +name = "json_vs_plain_parser" harness = false [lints] diff --git a/src/connector/benches/debezium_json_parser.rs b/src/connector/benches/debezium_json_parser.rs new file mode 100644 index 0000000000000..e448fa17ad1db --- /dev/null +++ b/src/connector/benches/debezium_json_parser.rs @@ -0,0 +1,86 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Benchmark for Debezium JSON records with `DebeziumParser`. + +mod json_common; + +use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; +use futures::executor::block_on; +use json_common::*; +use paste::paste; +use rand::Rng; +use risingwave_connector::parser::{DebeziumParser, SourceStreamChunkBuilder}; + +fn generate_debezium_json_row(rng: &mut impl Rng, change_event: &str) -> String { + let source = r#"{"version":"1.7.1.Final","connector":"mysql","name":"dbserver1","ts_ms":1639547113601,"snapshot":"true","db":"inventory","sequence":null,"table":"products","server_id":0,"gtid":null,"file":"mysql-bin.000003","pos":156,"row":0,"thread":null,"query":null}"#; + let (before, after) = match change_event { + "c" => ("null".to_string(), generate_json_row(rng)), + "r" => ("null".to_string(), generate_json_row(rng)), + "u" => (generate_json_row(rng), generate_json_row(rng)), + "d" => (generate_json_row(rng), "null".to_string()), + _ => unreachable!(), + }; + format!("{{\"before\": {before}, \"after\": {after}, \"source\": {source}, \"op\": \"{change_event}\", \"ts_ms\":1639551564960, \"transaction\":null}}") +} + +macro_rules! create_debezium_bench_helpers { + ($op:ident, $op_sym:expr, $bench_function:expr) => { + paste! { + fn [](c: &mut Criterion) { + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + // Generate records + let mut rng = rand::thread_rng(); + let mut records = Vec::with_capacity(NUM_RECORDS); + for _ in 0..NUM_RECORDS { + let json_row = generate_debezium_json_row(&mut rng, $op_sym); + records.push(Some(json_row.into_bytes())); + } + + c.bench_function($bench_function, |b| { + b.to_async(&rt).iter_batched( + || (block_on(DebeziumParser::new_for_test(get_descs())).unwrap(), records.clone()) , + | (mut parser, records) | async move { + let mut builder = + SourceStreamChunkBuilder::with_capacity(get_descs(), NUM_RECORDS); + for record in records { + let writer = builder.row_writer(); + parser.parse_inner(None, record, writer).await.unwrap(); + } + }, + BatchSize::SmallInput, + ) + }); + } + } + }; +} + +create_debezium_bench_helpers!(create, "c", "bench_debezium_json_parser_create"); +create_debezium_bench_helpers!(read, "r", "bench_debezium_json_parser_read"); +create_debezium_bench_helpers!(update, "u", "bench_debezium_json_parser_update"); +create_debezium_bench_helpers!(delete, "d", "bench_debezium_json_parser_delete"); + +criterion_group!( + benches, + bench_debezium_json_parser_create, + bench_debezium_json_parser_read, + bench_debezium_json_parser_update, + bench_debezium_json_parser_delete +); +criterion_main!(benches); diff --git a/src/connector/benches/json_common/mod.rs b/src/connector/benches/json_common/mod.rs new file mode 100644 index 0000000000000..cb67c4cb3d547 --- /dev/null +++ b/src/connector/benches/json_common/mod.rs @@ -0,0 +1,57 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Common utilities shared by JSON parser benchmarks. + +use rand::distributions::Alphanumeric; +use rand::prelude::*; +use risingwave_common::catalog::ColumnId; +use risingwave_common::types::{DataType, Date, Timestamp}; +use risingwave_connector::source::SourceColumnDesc; + +pub const NUM_RECORDS: usize = 1 << 18; // ~ 250,000 + +pub fn generate_json_row(rng: &mut impl Rng) -> String { + format!("{{\"i32\":{},\"bool\":{},\"i16\":{},\"i64\":{},\"f32\":{},\"f64\":{},\"varchar\":\"{}\",\"date\":\"{}\",\"timestamp\":\"{}\"}}", + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.gen::(), + rng.sample_iter(&Alphanumeric) + .take(7) + .map(char::from) + .collect::(), + Date::from_num_days_from_ce_uncheck((rng.gen::() % (1 << 20)) as i32).0, + { + let datetime = Timestamp::from_timestamp_uncheck((rng.gen::() % (1u32 << 28)) as i64, 0).0; + format!("{:?} {:?}", datetime.date(), datetime.time()) + } + ) +} + +pub fn get_descs() -> Vec { + vec![ + SourceColumnDesc::simple("i32", DataType::Int32, ColumnId::from(0)), + SourceColumnDesc::simple("bool", DataType::Boolean, ColumnId::from(2)), + SourceColumnDesc::simple("i16", DataType::Int16, ColumnId::from(3)), + SourceColumnDesc::simple("i64", DataType::Int64, ColumnId::from(4)), + SourceColumnDesc::simple("f32", DataType::Float32, ColumnId::from(5)), + SourceColumnDesc::simple("f64", DataType::Float64, ColumnId::from(6)), + SourceColumnDesc::simple("varchar", DataType::Varchar, ColumnId::from(7)), + SourceColumnDesc::simple("date", DataType::Date, ColumnId::from(8)), + SourceColumnDesc::simple("timestamp", DataType::Timestamp, ColumnId::from(9)), + ] +} diff --git a/src/connector/benches/json_parser.rs b/src/connector/benches/json_parser.rs deleted file mode 100644 index 5a12dec735cab..0000000000000 --- a/src/connector/benches/json_parser.rs +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; -use futures::executor::block_on; -use paste::paste; -use rand::distributions::Alphanumeric; -use rand::prelude::*; -use risingwave_common::catalog::ColumnId; -use risingwave_common::types::{DataType, Date, Timestamp}; -use risingwave_connector::parser::plain_parser::PlainParser; -use risingwave_connector::parser::{ - DebeziumParser, JsonParser, SourceStreamChunkBuilder, SpecificParserConfig, -}; -use risingwave_connector::source::{SourceColumnDesc, SourceContext}; - -macro_rules! create_debezium_bench_helpers { - ($op:ident, $op_sym:expr, $bench_function:expr) => { - paste! { - fn [](c: &mut Criterion) { - let rt = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap(); - - // Generate records - let mut rng = rand::thread_rng(); - let mut records = Vec::with_capacity(NUM_RECORDS); - for _ in 0..NUM_RECORDS { - let json_row = generate_debezium_json_row(&mut rng, $op_sym); - records.push(Some(json_row.into_bytes())); - } - - c.bench_function($bench_function, |b| { - b.to_async(&rt).iter_batched( - || (block_on(DebeziumParser::new_for_test(get_descs())).unwrap(), records.clone()) , - | (mut parser, records) | async move { - let mut builder = - SourceStreamChunkBuilder::with_capacity(get_descs(), NUM_RECORDS); - for record in records { - let writer = builder.row_writer(); - parser.parse_inner(None, record, writer).await.unwrap(); - } - }, - BatchSize::SmallInput, - ) - }); - } - } - }; -} - -create_debezium_bench_helpers!(create, "c", "bench_debezium_json_parser_create"); -create_debezium_bench_helpers!(read, "r", "bench_debezium_json_parser_read"); -create_debezium_bench_helpers!(update, "u", "bench_debezium_json_parser_update"); -create_debezium_bench_helpers!(delete, "d", "bench_debezium_json_parser_delete"); - -const NUM_RECORDS: usize = 1 << 18; // ~ 250,000 - -fn generate_json_row(rng: &mut impl Rng) -> String { - format!("{{\"i32\":{},\"bool\":{},\"i16\":{},\"i64\":{},\"f32\":{},\"f64\":{},\"varchar\":\"{}\",\"date\":\"{}\",\"timestamp\":\"{}\"}}", - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.gen::(), - rng.sample_iter(&Alphanumeric) - .take(7) - .map(char::from) - .collect::(), - Date::from_num_days_from_ce_uncheck((rng.gen::() % (1 << 20)) as i32).0, - { - let datetime = Timestamp::from_timestamp_uncheck((rng.gen::() % (1u32 << 28)) as i64, 0).0; - format!("{:?} {:?}", datetime.date(), datetime.time()) - } - ) -} - -fn generate_json_rows() -> Vec> { - let mut rng = rand::thread_rng(); - let mut records = Vec::with_capacity(NUM_RECORDS); - for _ in 0..NUM_RECORDS { - records.push(generate_json_row(&mut rng).into_bytes()); - } - records -} - -fn generate_debezium_json_row(rng: &mut impl Rng, change_event: &str) -> String { - let source = r#"{"version":"1.7.1.Final","connector":"mysql","name":"dbserver1","ts_ms":1639547113601,"snapshot":"true","db":"inventory","sequence":null,"table":"products","server_id":0,"gtid":null,"file":"mysql-bin.000003","pos":156,"row":0,"thread":null,"query":null}"#; - let (before, after) = match change_event { - "c" => ("null".to_string(), generate_json_row(rng)), - "r" => ("null".to_string(), generate_json_row(rng)), - "u" => (generate_json_row(rng), generate_json_row(rng)), - "d" => (generate_json_row(rng), "null".to_string()), - _ => unreachable!(), - }; - format!("{{\"before\": {before}, \"after\": {after}, \"source\": {source}, \"op\": \"{change_event}\", \"ts_ms\":1639551564960, \"transaction\":null}}") -} - -fn get_descs() -> Vec { - vec![ - SourceColumnDesc::simple("i32", DataType::Int32, ColumnId::from(0)), - SourceColumnDesc::simple("bool", DataType::Boolean, ColumnId::from(2)), - SourceColumnDesc::simple("i16", DataType::Int16, ColumnId::from(3)), - SourceColumnDesc::simple("i64", DataType::Int64, ColumnId::from(4)), - SourceColumnDesc::simple("f32", DataType::Float32, ColumnId::from(5)), - SourceColumnDesc::simple("f64", DataType::Float64, ColumnId::from(6)), - SourceColumnDesc::simple("varchar", DataType::Varchar, ColumnId::from(7)), - SourceColumnDesc::simple("date", DataType::Date, ColumnId::from(8)), - SourceColumnDesc::simple("timestamp", DataType::Timestamp, ColumnId::from(9)), - ] -} - -fn bench_json_parser(c: &mut Criterion) { - let descs = get_descs(); - let rt = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap(); - let records = generate_json_rows(); - let ctx = Arc::new(SourceContext::dummy()); - c.bench_function("json_parser", |b| { - b.to_async(&rt).iter_batched( - || records.clone(), - |records| async { - let mut parser = block_on(PlainParser::new( - SpecificParserConfig::DEFAULT_PLAIN_JSON, - descs.clone(), - ctx.clone(), - )) - .unwrap(); - let mut builder = - SourceStreamChunkBuilder::with_capacity(descs.clone(), NUM_RECORDS); - for record in records { - let writer = builder.row_writer(); - parser - .parse_inner(None, Some(record), writer) - .await - .unwrap(); - } - }, - BatchSize::SmallInput, - ) - }); -} - -fn bench_plain_parser_and_json_parser(c: &mut Criterion) { - let rt = tokio::runtime::Runtime::new().unwrap(); - let records = generate_json_rows(); - - let mut group = c.benchmark_group("plain parser and json parser comparison"); - - group.bench_function("plain_parser", |b| { - b.to_async(&rt).iter_batched( - || { - let parser = block_on(PlainParser::new( - SpecificParserConfig::DEFAULT_PLAIN_JSON, - get_descs(), - SourceContext::dummy().into(), - )) - .unwrap(); - (parser, records.clone()) - }, - |(mut parser, records)| async move { - let mut builder = SourceStreamChunkBuilder::with_capacity(get_descs(), NUM_RECORDS); - for record in records { - let writer = builder.row_writer(); - parser - .parse_inner(None, Some(record), writer) - .await - .unwrap(); - } - }, - BatchSize::SmallInput, - ) - }); - - group.bench_function("json_parser", |b| { - b.to_async(&rt).iter_batched( - || { - let parser = JsonParser::new( - SpecificParserConfig::DEFAULT_PLAIN_JSON, - get_descs(), - SourceContext::dummy().into(), - ) - .unwrap(); - (parser, records.clone()) - }, - |(parser, records)| async move { - let mut builder = SourceStreamChunkBuilder::with_capacity(get_descs(), NUM_RECORDS); - for record in records { - let writer = builder.row_writer(); - parser.parse_inner(record, writer).await.unwrap(); - } - }, - BatchSize::SmallInput, - ) - }); - - group.finish(); -} - -criterion_group!( - benches, - bench_json_parser, - bench_plain_parser_and_json_parser, - bench_debezium_json_parser_create, - bench_debezium_json_parser_read, - bench_debezium_json_parser_update, - bench_debezium_json_parser_delete -); -criterion_main!(benches); diff --git a/src/connector/benches/parser.rs b/src/connector/benches/json_parser_case_insensitive.rs similarity index 73% rename from src/connector/benches/parser.rs rename to src/connector/benches/json_parser_case_insensitive.rs index 21ce72dd1b2b1..17fd439e6ccc1 100644 --- a/src/connector/benches/parser.rs +++ b/src/connector/benches/json_parser_case_insensitive.rs @@ -12,24 +12,29 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Benchmarking JSON parsers for scenarios with exact key matches and case-insensitive key matches. + use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use futures::StreamExt; use maplit::hashmap; use rand::Rng; use risingwave_common::types::DataType; use risingwave_connector::parser::{ - EncodingProperties, JsonParser, JsonProperties, ProtocolProperties, SourceStreamChunkBuilder, - SpecificParserConfig, + ByteStreamSourceParserImpl, CommonParserConfig, ParserConfig, SpecificParserConfig, }; -use risingwave_connector::source::{SourceColumnDesc, SourceContext}; +use risingwave_connector::source::{SourceColumnDesc, SourceMessage}; use serde_json::json; use tokio::runtime::Runtime; -fn gen_input(mode: &str, chunk_size: usize, chunk_num: usize) -> Vec>> { +type Input = Vec>; +type Parser = ByteStreamSourceParserImpl; + +fn gen_input(mode: &str, chunk_size: usize, chunk_num: usize) -> Input { let mut input = Vec::with_capacity(chunk_num); for _ in 0..chunk_num { let mut input_inner = Vec::with_capacity(chunk_size); for _ in 0..chunk_size { - input_inner.push(match mode { + let payload = match mode { "match" => r#"{"alpha": 1, "bravo": 2, "charlie": 3, "delta": 4}"# .as_bytes() .to_vec(), @@ -55,6 +60,10 @@ fn gen_input(mode: &str, chunk_size: usize, chunk_num: usize) -> Vec serde_json::to_string(&value).unwrap().as_bytes().to_vec() } _ => unreachable!(), + }; + input_inner.push(SourceMessage { + payload: Some(payload), + ..SourceMessage::dummy() }); } input.push(input_inner); @@ -62,40 +71,27 @@ fn gen_input(mode: &str, chunk_size: usize, chunk_num: usize) -> Vec input } -fn create_parser( - chunk_size: usize, - chunk_num: usize, - mode: &str, -) -> (JsonParser, Vec, Vec>>) { +fn create_parser(chunk_size: usize, chunk_num: usize, mode: &str) -> (Parser, Input) { let desc = vec![ SourceColumnDesc::simple("alpha", DataType::Int16, 0.into()), SourceColumnDesc::simple("bravo", DataType::Int32, 1.into()), SourceColumnDesc::simple("charlie", DataType::Int64, 2.into()), SourceColumnDesc::simple("delta", DataType::Int64, 3.into()), ]; - let props = SpecificParserConfig { - key_encoding_config: None, - encoding_config: EncodingProperties::Json(JsonProperties { - use_schema_registry: false, - timestamptz_handling: None, - }), - protocol_config: ProtocolProperties::Plain, + let config = ParserConfig { + common: CommonParserConfig { rw_columns: desc }, + specific: SpecificParserConfig::DEFAULT_PLAIN_JSON, }; - let parser = JsonParser::new(props, desc.clone(), SourceContext::dummy().into()).unwrap(); + let parser = ByteStreamSourceParserImpl::create_for_test(config).unwrap(); let input = gen_input(mode, chunk_size, chunk_num); - (parser, desc, input) + (parser, input) } -async fn parse(parser: JsonParser, column_desc: Vec, input: Vec>>) { - for input_inner in input { - let mut builder = - SourceStreamChunkBuilder::with_capacity(column_desc.clone(), input_inner.len()); - for payload in input_inner { - let row_writer = builder.row_writer(); - parser.parse_inner(payload, row_writer).await.unwrap(); - } - builder.finish(); - } +async fn parse(parser: Parser, input: Input) { + parser + .into_stream(futures::stream::iter(input.into_iter().map(Ok)).boxed()) + .count() // consume the stream + .await; } fn do_bench(c: &mut Criterion, mode: &str) { @@ -110,7 +106,7 @@ fn do_bench(c: &mut Criterion, mode: &str) { let chunk_num = TOTAL_SIZE / chunk_size; b.to_async(&rt).iter_batched( || create_parser(chunk_size, chunk_num, mode), - |(parser, column_desc, input)| parse(parser, column_desc, input), + |(parser, input)| parse(parser, input), BatchSize::SmallInput, ); }, diff --git a/src/connector/benches/json_vs_plain_parser.rs b/src/connector/benches/json_vs_plain_parser.rs new file mode 100644 index 0000000000000..5e904c88786e6 --- /dev/null +++ b/src/connector/benches/json_vs_plain_parser.rs @@ -0,0 +1,93 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Benchmark for comparing the performance of parsing JSON records directly +//! through the `JsonParser` versus indirectly through the `PlainParser`. + +mod json_common; + +use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; +use futures::executor::block_on; +use json_common::*; +use risingwave_connector::parser::plain_parser::PlainParser; +use risingwave_connector::parser::{JsonParser, SourceStreamChunkBuilder, SpecificParserConfig}; +use risingwave_connector::source::SourceContext; + +fn generate_json_rows() -> Vec> { + let mut rng = rand::thread_rng(); + let mut records = Vec::with_capacity(NUM_RECORDS); + for _ in 0..NUM_RECORDS { + records.push(generate_json_row(&mut rng).into_bytes()); + } + records +} + +fn bench_plain_parser_and_json_parser(c: &mut Criterion) { + let rt = tokio::runtime::Runtime::new().unwrap(); + let records = generate_json_rows(); + + let mut group = c.benchmark_group("plain parser and json parser comparison"); + + group.bench_function("plain_parser", |b| { + b.to_async(&rt).iter_batched( + || { + let parser = block_on(PlainParser::new( + SpecificParserConfig::DEFAULT_PLAIN_JSON, + get_descs(), + SourceContext::dummy().into(), + )) + .unwrap(); + (parser, records.clone()) + }, + |(mut parser, records)| async move { + let mut builder = SourceStreamChunkBuilder::with_capacity(get_descs(), NUM_RECORDS); + for record in records { + let writer = builder.row_writer(); + parser + .parse_inner(None, Some(record), writer) + .await + .unwrap(); + } + }, + BatchSize::SmallInput, + ) + }); + + group.bench_function("json_parser", |b| { + b.to_async(&rt).iter_batched( + || { + let parser = JsonParser::new( + SpecificParserConfig::DEFAULT_PLAIN_JSON, + get_descs(), + SourceContext::dummy().into(), + ) + .unwrap(); + (parser, records.clone()) + }, + |(parser, records)| async move { + let mut builder = SourceStreamChunkBuilder::with_capacity(get_descs(), NUM_RECORDS); + for record in records { + let writer = builder.row_writer(); + parser.parse_inner(record, writer).await.unwrap(); + } + }, + BatchSize::SmallInput, + ) + }); + + group.finish(); +} + +criterion_group!(benches, bench_plain_parser_and_json_parser,); +criterion_main!(benches); diff --git a/src/connector/benches/nexmark_integration.rs b/src/connector/benches/nexmark_integration.rs index 1c05147eeafbb..28596e26eec19 100644 --- a/src/connector/benches/nexmark_integration.rs +++ b/src/connector/benches/nexmark_integration.rs @@ -12,6 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Integration benchmark for parsing Nexmark events. +//! +//! To cover the code path in real-world scenarios, the parser is created through +//! `ByteStreamSourceParserImpl::create` based on the given configuration, rather +//! than depending on a specific internal implementation. + #![feature(lazy_cell)] use std::sync::LazyLock; @@ -23,11 +29,10 @@ use risingwave_common::array::StreamChunk; use risingwave_common::catalog::ColumnId; use risingwave_common::types::DataType; use risingwave_connector::parser::{ - ByteStreamSourceParser, JsonParser, SourceParserIntoStreamExt, SpecificParserConfig, + ByteStreamSourceParserImpl, CommonParserConfig, ParserConfig, SpecificParserConfig, }; use risingwave_connector::source::{ - BoxChunkSourceStream, BoxSourceStream, SourceColumnDesc, SourceContext, SourceMessage, - SourceMeta, + BoxChunkSourceStream, BoxSourceStream, SourceColumnDesc, SourceMessage, SourceMeta, }; use tracing::Level; use tracing_subscriber::prelude::*; @@ -71,8 +76,8 @@ fn make_data_stream() -> BoxSourceStream { .boxed() } -fn make_parser() -> impl ByteStreamSourceParser { - let columns = [ +fn make_parser() -> ByteStreamSourceParserImpl { + let rw_columns = [ ("auction", DataType::Int64), ("bidder", DataType::Int64), ("price", DataType::Int64), @@ -86,9 +91,12 @@ fn make_parser() -> impl ByteStreamSourceParser { .map(|(i, (n, t))| SourceColumnDesc::simple(n, t, ColumnId::new(i as _))) .collect_vec(); - let props = SpecificParserConfig::DEFAULT_PLAIN_JSON; + let config = ParserConfig { + common: CommonParserConfig { rw_columns }, + specific: SpecificParserConfig::DEFAULT_PLAIN_JSON, + }; - JsonParser::new(props, columns, SourceContext::dummy().into()).unwrap() + ByteStreamSourceParserImpl::create_for_test(config).unwrap() } fn make_stream_iter() -> impl Iterator { diff --git a/src/connector/src/parser/mod.rs b/src/connector/src/parser/mod.rs index a249807c6a6bb..2c0643af67109 100644 --- a/src/connector/src/parser/mod.rs +++ b/src/connector/src/parser/mod.rs @@ -866,7 +866,7 @@ impl AccessBuilderImpl { /// The entrypoint of parsing. It parses [`SourceMessage`] stream (byte stream) into [`StreamChunk`] stream. /// Used by [`crate::source::into_chunk_stream`]. #[derive(Debug)] -pub(crate) enum ByteStreamSourceParserImpl { +pub enum ByteStreamSourceParserImpl { Csv(CsvParser), Json(JsonParser), Debezium(DebeziumParser), @@ -937,6 +937,11 @@ impl ByteStreamSourceParserImpl { _ => unreachable!(), } } + + /// Create a parser for testing purposes. + pub fn create_for_test(parser_config: ParserConfig) -> ConnectorResult { + futures::executor::block_on(Self::create(parser_config, SourceContext::dummy().into())) + } } #[derive(Debug, Clone, Default)] diff --git a/src/connector/src/source/base.rs b/src/connector/src/source/base.rs index 9c77382a0143d..b670568dc6e42 100644 --- a/src/connector/src/source/base.rs +++ b/src/connector/src/source/base.rs @@ -543,6 +543,19 @@ pub struct SourceMessage { pub meta: SourceMeta, } +impl SourceMessage { + /// Create a dummy `SourceMessage` with all fields unset for testing purposes. + pub fn dummy() -> Self { + Self { + key: None, + payload: None, + offset: "".to_string(), + split_id: "".into(), + meta: SourceMeta::Empty, + } + } +} + #[derive(Debug, Clone)] pub enum SourceMeta { Kafka(KafkaMeta), From 4bef0868a4272877e5f8c188a5036e66a1e38c05 Mon Sep 17 00:00:00 2001 From: Xinhao Xu <84456268+xxhZs@users.noreply.github.com> Date: Thu, 30 May 2024 16:09:04 +0800 Subject: [PATCH 07/20] feat(frontend): support fetch n from subscription cursor (#16764) --- e2e_test/subscription/main.py | 113 ++++++++++++++---- src/frontend/src/handler/declare_cursor.rs | 3 +- .../optimizer/plan_node/generic/log_scan.rs | 13 ++ src/frontend/src/session/cursor_manager.rs | 74 +++++++----- src/utils/pgwire/src/pg_field_descriptor.rs | 2 +- 5 files changed, 153 insertions(+), 52 deletions(-) diff --git a/e2e_test/subscription/main.py b/e2e_test/subscription/main.py index c7fcc56a35ac5..3ffaefd02cee6 100644 --- a/e2e_test/subscription/main.py +++ b/e2e_test/subscription/main.py @@ -33,8 +33,7 @@ def execute_insert(sql,conn): conn.commit() cur.close() -def check_rows_data(expect_vec,rows,status): - row = rows[0] +def check_rows_data(expect_vec,row,status): value_len = len(row) for index, value in enumerate(row): if index == value_len - 1: @@ -56,7 +55,7 @@ def test_cursor_snapshot(): execute_insert("declare cur subscription cursor for sub",conn) row = execute_query("fetch next from cur",conn) - check_rows_data([1,2],row,1) + check_rows_data([1,2],row[0],1) row = execute_query("fetch next from cur",conn) assert row == [] execute_insert("close cur",conn) @@ -75,7 +74,7 @@ def test_cursor_snapshot_log_store(): execute_insert("declare cur subscription cursor for sub",conn) row = execute_query("fetch next from cur",conn) - check_rows_data([1,2],row,1) + check_rows_data([1,2],row[0],1) row = execute_query("fetch next from cur",conn) assert row == [] execute_insert("insert into t1 values(4,4)",conn) @@ -83,9 +82,9 @@ def test_cursor_snapshot_log_store(): execute_insert("insert into t1 values(5,5)",conn) execute_insert("flush",conn) row = execute_query("fetch next from cur",conn) - check_rows_data([4,4],row,1) + check_rows_data([4,4],row[0],1) row = execute_query("fetch next from cur",conn) - check_rows_data([5,5],row,1) + check_rows_data([5,5],row[0],1) row = execute_query("fetch next from cur",conn) assert row == [] execute_insert("close cur",conn) @@ -109,11 +108,11 @@ def test_cursor_since_begin(): execute_insert("insert into t1 values(6,6)",conn) execute_insert("flush",conn) row = execute_query("fetch next from cur",conn) - check_rows_data([4,4],row,1) + check_rows_data([4,4],row[0],1) row = execute_query("fetch next from cur",conn) - check_rows_data([5,5],row,1) + check_rows_data([5,5],row[0],1) row = execute_query("fetch next from cur",conn) - check_rows_data([6,6],row,1) + check_rows_data([6,6],row[0],1) row = execute_query("fetch next from cur",conn) assert row == [] execute_insert("close cur",conn) @@ -138,7 +137,7 @@ def test_cursor_since_now(): execute_insert("insert into t1 values(6,6)",conn) execute_insert("flush",conn) row = execute_query("fetch next from cur",conn) - check_rows_data([6,6],row,1) + check_rows_data([6,6],row[0],1) row = execute_query("fetch next from cur",conn) assert row == [] execute_insert("close cur",conn) @@ -164,27 +163,27 @@ def test_cursor_since_rw_timestamp(): row = execute_query("fetch next from cur",conn) valuelen = len(row[0]) rw_timestamp_1 = row[0][valuelen - 1] - check_rows_data([4,4],row,1) + check_rows_data([4,4],row[0],1) row = execute_query("fetch next from cur",conn) valuelen = len(row[0]) rw_timestamp_2 = row[0][valuelen - 1] - 1 - check_rows_data([5,5],row,1) + check_rows_data([5,5],row[0],1) row = execute_query("fetch next from cur",conn) valuelen = len(row[0]) rw_timestamp_3 = row[0][valuelen - 1] + 1 - check_rows_data([6,6],row,1) + check_rows_data([6,6],row[0],1) row = execute_query("fetch next from cur",conn) assert row == [] execute_insert("close cur",conn) execute_insert(f"declare cur subscription cursor for sub since {rw_timestamp_1}",conn) row = execute_query("fetch next from cur",conn) - check_rows_data([4,4],row,1) + check_rows_data([4,4],row[0],1) execute_insert("close cur",conn) execute_insert(f"declare cur subscription cursor for sub since {rw_timestamp_2}",conn) row = execute_query("fetch next from cur",conn) - check_rows_data([5,5],row,1) + check_rows_data([5,5],row[0],1) execute_insert("close cur",conn) execute_insert(f"declare cur subscription cursor for sub since {rw_timestamp_3}",conn) @@ -206,7 +205,7 @@ def test_cursor_op(): execute_insert("declare cur subscription cursor for sub",conn) row = execute_query("fetch next from cur",conn) - check_rows_data([1,2],row,1) + check_rows_data([1,2],row[0],1) row = execute_query("fetch next from cur",conn) assert row == [] @@ -215,24 +214,96 @@ def test_cursor_op(): execute_insert("update t1 set v2 = 10 where v1 = 4",conn) execute_insert("flush",conn) row = execute_query("fetch next from cur",conn) - check_rows_data([4,4],row,1) + check_rows_data([4,4],row[0],1) row = execute_query("fetch next from cur",conn) - check_rows_data([4,4],row,4) + check_rows_data([4,4],row[0],4) row = execute_query("fetch next from cur",conn) - check_rows_data([4,10],row,3) + check_rows_data([4,10],row[0],3) row = execute_query("fetch next from cur",conn) assert row == [] execute_insert("delete from t1 where v1 = 4",conn) execute_insert("flush",conn) row = execute_query("fetch next from cur",conn) - check_rows_data([4,10],row,2) + check_rows_data([4,10],row[0],2) row = execute_query("fetch next from cur",conn) assert row == [] execute_insert("close cur",conn) drop_table_subscription() +def test_cursor_with_table_alter(): + print(f"test_cursor_with_table_alter") + create_table_subscription() + conn = psycopg2.connect( + host="localhost", + port="4566", + user="root", + database="dev" + ) + + execute_insert("declare cur subscription cursor for sub",conn) + execute_insert("alter table t1 add v3 int",conn) + execute_insert("insert into t1 values(4,4,4)",conn) + execute_insert("flush",conn) + row = execute_query("fetch next from cur",conn) + check_rows_data([1,2],row[0],1) + row = execute_query("fetch next from cur",conn) + check_rows_data([4,4,4],row[0],1) + execute_insert("insert into t1 values(5,5,5)",conn) + execute_insert("flush",conn) + row = execute_query("fetch next from cur",conn) + check_rows_data([5,5,5],row[0],1) + execute_insert("alter table t1 drop column v2",conn) + execute_insert("insert into t1 values(6,6)",conn) + execute_insert("flush",conn) + row = execute_query("fetch next from cur",conn) + check_rows_data([6,6],row[0],1) + drop_table_subscription() + +def test_cursor_fetch_n(): + print(f"test_cursor_with_table_alter") + create_table_subscription() + conn = psycopg2.connect( + host="localhost", + port="4566", + user="root", + database="dev" + ) + + execute_insert("declare cur subscription cursor for sub",conn) + execute_insert("insert into t1 values(4,4)",conn) + execute_insert("flush",conn) + execute_insert("insert into t1 values(5,5)",conn) + execute_insert("flush",conn) + execute_insert("insert into t1 values(6,6)",conn) + execute_insert("flush",conn) + execute_insert("insert into t1 values(7,7)",conn) + execute_insert("flush",conn) + execute_insert("insert into t1 values(8,8)",conn) + execute_insert("flush",conn) + execute_insert("insert into t1 values(9,9)",conn) + execute_insert("flush",conn) + execute_insert("insert into t1 values(10,10)",conn) + execute_insert("flush",conn) + execute_insert("update t1 set v2 = 100 where v1 = 10",conn) + execute_insert("flush",conn) + row = execute_query("fetch 6 from cur",conn) + assert len(row) == 6 + check_rows_data([1,2],row[0],1) + check_rows_data([4,4],row[1],1) + check_rows_data([5,5],row[2],1) + check_rows_data([6,6],row[3],1) + check_rows_data([7,7],row[4],1) + check_rows_data([8,8],row[5],1) + row = execute_query("fetch 6 from cur",conn) + assert len(row) == 4 + check_rows_data([9,9],row[0],1) + check_rows_data([10,10],row[1],1) + check_rows_data([10,10],row[2],4) + check_rows_data([10,100],row[3],3) + drop_table_subscription() + if __name__ == "__main__": test_cursor_snapshot() test_cursor_op() @@ -240,3 +311,5 @@ def test_cursor_op(): test_cursor_since_rw_timestamp() test_cursor_since_now() test_cursor_since_begin() + test_cursor_with_table_alter() + test_cursor_fetch_n() diff --git a/src/frontend/src/handler/declare_cursor.rs b/src/frontend/src/handler/declare_cursor.rs index 6bd4e300ec0fa..25e146fa714ce 100644 --- a/src/frontend/src/handler/declare_cursor.rs +++ b/src/frontend/src/handler/declare_cursor.rs @@ -58,7 +58,6 @@ async fn handle_declare_subscription_cursor( let cursor_from_subscription_name = sub_name.0.last().unwrap().real_value().clone(); let subscription = session.get_subscription_by_name(schema_name, &cursor_from_subscription_name)?; - let table = session.get_table_by_id(&subscription.dependent_table_id)?; // Start the first query of cursor, which includes querying the table and querying the subscription's logstore let start_rw_timestamp = match rw_timestamp { Some(risingwave_sqlparser::ast::Since::TimestampMsNum(start_rw_timestamp)) => { @@ -81,8 +80,8 @@ async fn handle_declare_subscription_cursor( .add_subscription_cursor( cursor_name.clone(), start_rw_timestamp, + subscription.dependent_table_id, subscription, - table, &handle_args, ) .await?; diff --git a/src/frontend/src/optimizer/plan_node/generic/log_scan.rs b/src/frontend/src/optimizer/plan_node/generic/log_scan.rs index cd5ddebdc0724..498d4a44b0fcc 100644 --- a/src/frontend/src/optimizer/plan_node/generic/log_scan.rs +++ b/src/frontend/src/optimizer/plan_node/generic/log_scan.rs @@ -141,6 +141,19 @@ impl LogScan { Schema { fields } } + pub(crate) fn schema_without_table_name(&self) -> Schema { + let mut fields: Vec<_> = self + .output_col_idx + .iter() + .map(|tb_idx| { + let col = &self.table_desc.columns[*tb_idx]; + Field::from(col) + }) + .collect(); + fields.push(Field::with_name(OP_TYPE, OP_NAME)); + Schema { fields } + } + pub(crate) fn ctx(&self) -> OptimizerContextRef { self.ctx.clone() } diff --git a/src/frontend/src/session/cursor_manager.rs b/src/frontend/src/session/cursor_manager.rs index 13eaec03b1663..46eca3beb9966 100644 --- a/src/frontend/src/session/cursor_manager.rs +++ b/src/frontend/src/session/cursor_manager.rs @@ -30,6 +30,7 @@ use risingwave_sqlparser::ast::{Ident, ObjectName, Statement}; use super::SessionImpl; use crate::catalog::subscription_catalog::SubscriptionCatalog; +use crate::catalog::TableId; use crate::error::{ErrorCode, Result}; use crate::handler::declare_cursor::create_stream_for_cursor_stmt; use crate::handler::query::{create_stream, gen_batch_plan_fragmenter, BatchQueryPlanResult}; @@ -136,7 +137,7 @@ enum State { pub struct SubscriptionCursor { cursor_name: String, subscription: Arc, - table: Arc, + dependent_table_id: TableId, cursor_need_drop_time: Instant, state: State, } @@ -146,7 +147,7 @@ impl SubscriptionCursor { cursor_name: String, start_timestamp: Option, subscription: Arc, - table: Arc, + dependent_table_id: TableId, handle_args: &HandlerArgs, ) -> Result { let state = if let Some(start_timestamp) = start_timestamp { @@ -160,7 +161,7 @@ impl SubscriptionCursor { // // TODO: is this the right behavior? Should we delay the query stream initiation till the first fetch? let (row_stream, pg_descs) = - Self::initiate_query(None, &table, handle_args.clone()).await?; + Self::initiate_query(None, &dependent_table_id, handle_args.clone()).await?; let pinned_epoch = handle_args .session .get_pinned_snapshot() @@ -191,15 +192,16 @@ impl SubscriptionCursor { Ok(Self { cursor_name, subscription, - table, + dependent_table_id, cursor_need_drop_time, state, }) } - pub async fn next_row( + async fn next_row( &mut self, - handle_args: HandlerArgs, + handle_args: &HandlerArgs, + expected_pg_descs: &Vec, ) -> Result<(Option, Vec)> { loop { match &mut self.state { @@ -212,7 +214,7 @@ impl SubscriptionCursor { // Initiate a new batch query to continue fetching match Self::get_next_rw_timestamp( *seek_timestamp, - self.table.id.table_id, + self.dependent_table_id.table_id, *expected_timestamp, handle_args.clone(), ) @@ -221,7 +223,7 @@ impl SubscriptionCursor { Ok((Some(rw_timestamp), expected_timestamp)) => { let (mut row_stream, pg_descs) = Self::initiate_query( Some(rw_timestamp), - &self.table, + &self.dependent_table_id, handle_args.clone(), ) .await?; @@ -235,10 +237,15 @@ impl SubscriptionCursor { from_snapshot, rw_timestamp, row_stream, - pg_descs, + pg_descs: pg_descs.clone(), remaining_rows, expected_timestamp, }; + if (!expected_pg_descs.is_empty()) && expected_pg_descs.ne(&pg_descs) { + // If the user alters the table upstream of the sub, there will be different descs here. + // So we should output data for different descs in two separate batches + return Ok((None, vec![])); + } } Ok((None, _)) => return Ok((None, vec![])), Err(e) => { @@ -313,20 +320,25 @@ impl SubscriptionCursor { ) .into()); } - // `FETCH NEXT` is equivalent to `FETCH 1`. - if count != 1 { - Err(crate::error::ErrorCode::InternalError( - "FETCH count with subscription is not supported".to_string(), - ) - .into()) - } else { - let (row, pg_descs) = self.next_row(handle_args).await?; - if let Some(row) = row { - Ok((vec![row], pg_descs)) - } else { - Ok((vec![], pg_descs)) + + let mut ans = Vec::with_capacity(std::cmp::min(100, count) as usize); + let mut cur = 0; + let mut pg_descs_ans = vec![]; + while cur < count { + let (row, descs_ans) = self.next_row(&handle_args, &pg_descs_ans).await?; + match row { + Some(row) => { + pg_descs_ans = descs_ans; + cur += 1; + ans.push(row); + } + None => { + break; + } } } + + Ok((ans, pg_descs_ans)) } async fn get_next_rw_timestamp( @@ -358,16 +370,17 @@ impl SubscriptionCursor { async fn initiate_query( rw_timestamp: Option, - table_catalog: &TableCatalog, + dependent_table_id: &TableId, handle_args: HandlerArgs, ) -> Result<(PgResponseStream, Vec)> { + let session = handle_args.clone().session; + let table_catalog = session.get_table_by_id(dependent_table_id)?; let (row_stream, pg_descs) = if let Some(rw_timestamp) = rw_timestamp { - let context = OptimizerContext::from_handler_args(handle_args.clone()); - let session = handle_args.session; + let context = OptimizerContext::from_handler_args(handle_args); let plan_fragmenter_result = gen_batch_plan_fragmenter( &session, Self::create_batch_plan_for_cursor( - table_catalog, + &table_catalog, &session, context.into(), rw_timestamp, @@ -458,7 +471,11 @@ impl SubscriptionCursor { new_epoch, ); let batch_log_seq_scan = BatchLogSeqScan::new(core); - let out_fields = FixedBitSet::from_iter(0..batch_log_seq_scan.core().schema().len()); + let schema = batch_log_seq_scan + .core() + .schema_without_table_name() + .clone(); + let out_fields = FixedBitSet::from_iter(0..schema.len()); let out_names = batch_log_seq_scan.core().column_names(); // Here we just need a plan_root to call the method, only out_fields and out_names will be used let plan_root = PlanRoot::new_with_batch_plan( @@ -468,7 +485,6 @@ impl SubscriptionCursor { out_fields, out_names, ); - let schema = batch_log_seq_scan.core().schema().clone(); let (batch_log_seq_scan, query_mode) = match session.config().query_mode() { QueryMode::Auto => (plan_root.gen_batch_local_plan()?, QueryMode::Local), QueryMode::Local => (plan_root.gen_batch_local_plan()?, QueryMode::Local), @@ -497,15 +513,15 @@ impl CursorManager { &self, cursor_name: String, start_timestamp: Option, + dependent_table_id: TableId, subscription: Arc, - table: Arc, handle_args: &HandlerArgs, ) -> Result<()> { let cursor = SubscriptionCursor::new( cursor_name.clone(), start_timestamp, subscription, - table, + dependent_table_id, handle_args, ) .await?; diff --git a/src/utils/pgwire/src/pg_field_descriptor.rs b/src/utils/pgwire/src/pg_field_descriptor.rs index 0b33c5743c107..82d75c78f7956 100644 --- a/src/utils/pgwire/src/pg_field_descriptor.rs +++ b/src/utils/pgwire/src/pg_field_descriptor.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct PgFieldDescriptor { name: String, table_oid: i32, From 0c8b0360463293bb8de6bdad682ec48996feb3dc Mon Sep 17 00:00:00 2001 From: xxchan Date: Thu, 30 May 2024 17:41:27 +0800 Subject: [PATCH 08/20] refactor: minor refactor on avro (#17024) Signed-off-by: xxchan --- .typos.toml | 1 + scripts/source/schema_registry_producer.py | 7 +- src/connector/src/parser/avro/parser.rs | 10 +-- .../src/parser/avro/schema_resolver.rs | 7 +- .../src/parser/debezium/avro_parser.rs | 10 +-- src/connector/src/parser/json_parser.rs | 4 +- src/connector/src/parser/unified/avro.rs | 69 ++++++++++--------- 7 files changed, 59 insertions(+), 49 deletions(-) diff --git a/.typos.toml b/.typos.toml index c062e9de44d2d..7dcf4af6257d4 100644 --- a/.typos.toml +++ b/.typos.toml @@ -6,6 +6,7 @@ inout = "inout" # This is a SQL keyword! numer = "numer" # numerator nd = "nd" # N-dimentional / 2nd steam = "stream" # You played with Steam games too much. +ser = "ser" # Serialization # Some weird short variable names ot = "ot" bui = "bui" # BackwardUserIterator diff --git a/scripts/source/schema_registry_producer.py b/scripts/source/schema_registry_producer.py index 79a3d4db1b40f..a88861b65bd26 100644 --- a/scripts/source/schema_registry_producer.py +++ b/scripts/source/schema_registry_producer.py @@ -39,8 +39,11 @@ def load_avro_json(encoded, schema): if __name__ == '__main__': - if len(sys.argv) < 5: - print("datagen.py ") + if len(sys.argv) <= 5: + print( + "usage: schema_registry_producer.py " + ) + exit(1) broker_list = sys.argv[1] schema_registry_url = sys.argv[2] file = sys.argv[3] diff --git a/src/connector/src/parser/avro/parser.rs b/src/connector/src/parser/avro/parser.rs index 7f498a055ac7e..b37417c41ee40 100644 --- a/src/connector/src/parser/avro/parser.rs +++ b/src/connector/src/parser/avro/parser.rs @@ -21,7 +21,7 @@ use apache_avro::{from_avro_datum, Reader, Schema}; use risingwave_common::{bail, try_match_expand}; use risingwave_pb::plan_common::ColumnDesc; -use super::schema_resolver::ConfluentSchemaResolver; +use super::schema_resolver::ConfluentSchemaCache; use super::util::avro_schema_to_column_descs; use crate::error::ConnectorResult; use crate::parser::unified::avro::{AvroAccess, AvroParseOptions}; @@ -36,7 +36,7 @@ use crate::schema::schema_registry::{ #[derive(Debug)] pub struct AvroAccessBuilder { schema: Arc, - pub schema_resolver: Option>, + pub schema_resolver: Option>, value: Option, } @@ -45,7 +45,7 @@ impl AccessBuilder for AvroAccessBuilder { self.value = self.parse_avro_value(&payload, Some(&*self.schema)).await?; Ok(AccessImpl::Avro(AvroAccess::new( self.value.as_ref().unwrap(), - AvroParseOptions::default().with_schema(&self.schema), + AvroParseOptions::create(&self.schema), ))) } } @@ -100,7 +100,7 @@ impl AvroAccessBuilder { pub struct AvroParserConfig { pub schema: Arc, pub key_schema: Option>, - pub schema_resolver: Option>, + pub schema_resolver: Option>, pub map_handling: Option, } @@ -122,7 +122,7 @@ impl AvroParserConfig { let url = handle_sr_list(schema_location.as_str())?; if use_schema_registry { let client = Client::new(url, &client_config)?; - let resolver = ConfluentSchemaResolver::new(client); + let resolver = ConfluentSchemaCache::new(client); let subject_key = if enable_upsert { Some(get_subject_by_strategy( diff --git a/src/connector/src/parser/avro/schema_resolver.rs b/src/connector/src/parser/avro/schema_resolver.rs index cdc52de7accee..72410e51ab162 100644 --- a/src/connector/src/parser/avro/schema_resolver.rs +++ b/src/connector/src/parser/avro/schema_resolver.rs @@ -21,13 +21,14 @@ use moka::future::Cache; use crate::error::ConnectorResult; use crate::schema::schema_registry::{Client, ConfluentSchema}; +/// TODO: support protobuf #[derive(Debug)] -pub struct ConfluentSchemaResolver { +pub struct ConfluentSchemaCache { writer_schemas: Cache>, confluent_client: Client, } -impl ConfluentSchemaResolver { +impl ConfluentSchemaCache { async fn parse_and_cache_schema( &self, raw_schema: ConfluentSchema, @@ -43,7 +44,7 @@ impl ConfluentSchemaResolver { /// Create a new `ConfluentSchemaResolver` pub fn new(client: Client) -> Self { - ConfluentSchemaResolver { + ConfluentSchemaCache { writer_schemas: Cache::new(u64::MAX), confluent_client: client, } diff --git a/src/connector/src/parser/debezium/avro_parser.rs b/src/connector/src/parser/debezium/avro_parser.rs index 8d73a789b2669..50762171106fc 100644 --- a/src/connector/src/parser/debezium/avro_parser.rs +++ b/src/connector/src/parser/debezium/avro_parser.rs @@ -22,7 +22,7 @@ use risingwave_pb::catalog::PbSchemaRegistryNameStrategy; use risingwave_pb::plan_common::ColumnDesc; use crate::error::ConnectorResult; -use crate::parser::avro::schema_resolver::ConfluentSchemaResolver; +use crate::parser::avro::schema_resolver::ConfluentSchemaCache; use crate::parser::avro::util::avro_schema_to_column_descs; use crate::parser::unified::avro::{ avro_extract_field_schema, avro_schema_skip_union, AvroAccess, AvroParseOptions, @@ -41,7 +41,7 @@ const PAYLOAD: &str = "payload"; #[derive(Debug)] pub struct DebeziumAvroAccessBuilder { schema: Schema, - schema_resolver: Arc, + schema_resolver: Arc, key_schema: Option>, value: Option, encoding_type: EncodingType, @@ -59,7 +59,7 @@ impl AccessBuilder for DebeziumAvroAccessBuilder { }; Ok(AccessImpl::Avro(AvroAccess::new( self.value.as_mut().unwrap(), - AvroParseOptions::default().with_schema(match self.encoding_type { + AvroParseOptions::create(match self.encoding_type { EncodingType::Key => self.key_schema.as_mut().unwrap(), EncodingType::Value => &self.schema, }), @@ -96,7 +96,7 @@ impl DebeziumAvroAccessBuilder { pub struct DebeziumAvroParserConfig { pub key_schema: Arc, pub outer_schema: Arc, - pub schema_resolver: Arc, + pub schema_resolver: Arc, } impl DebeziumAvroParserConfig { @@ -107,7 +107,7 @@ impl DebeziumAvroParserConfig { let kafka_topic = &avro_config.topic; let url = handle_sr_list(schema_location)?; let client = Client::new(url, client_config)?; - let resolver = ConfluentSchemaResolver::new(client); + let resolver = ConfluentSchemaCache::new(client); let name_strategy = &PbSchemaRegistryNameStrategy::Unspecified; let key_subject = get_subject_by_strategy(name_strategy, kafka_topic, None, true)?; diff --git a/src/connector/src/parser/json_parser.rs b/src/connector/src/parser/json_parser.rs index 3621fbc2724b3..f9f5b1c848c46 100644 --- a/src/connector/src/parser/json_parser.rs +++ b/src/connector/src/parser/json_parser.rs @@ -21,7 +21,7 @@ use jst::{convert_avro, Context}; use risingwave_common::{bail, try_match_expand}; use risingwave_pb::plan_common::ColumnDesc; -use super::avro::schema_resolver::ConfluentSchemaResolver; +use super::avro::schema_resolver::ConfluentSchemaCache; use super::unified::Access; use super::util::{bytes_from_url, get_kafka_topic}; use super::{EncodingProperties, JsonProperties, SchemaRegistryAuth, SpecificParserConfig}; @@ -161,7 +161,7 @@ pub async fn schema_to_columns( let json_schema = if let Some(schema_registry_auth) = schema_registry_auth { let client = Client::new(url, &schema_registry_auth)?; let topic = get_kafka_topic(props)?; - let resolver = ConfluentSchemaResolver::new(client); + let resolver = ConfluentSchemaCache::new(client); let content = resolver .get_raw_schema_by_subject_name(&format!("{}-value", topic)) .await? diff --git a/src/connector/src/parser/unified/avro.rs b/src/connector/src/parser/unified/avro.rs index bbab918f5be1d..2c94eb47ccfd1 100644 --- a/src/connector/src/parser/unified/avro.rs +++ b/src/connector/src/parser/unified/avro.rs @@ -34,27 +34,23 @@ use crate::parser::avro::util::avro_to_jsonb; #[derive(Clone)] /// Options for parsing an `AvroValue` into Datum, with an optional avro schema. pub struct AvroParseOptions<'a> { + /// Currently, this schema is only used for decimal pub schema: Option<&'a Schema>, /// Strict Mode /// If strict mode is disabled, an int64 can be parsed from an `AvroInt` (int32) value. pub relax_numeric: bool, } -impl<'a> Default for AvroParseOptions<'a> { - fn default() -> Self { +impl<'a> AvroParseOptions<'a> { + pub fn create(schema: &'a Schema) -> Self { Self { - schema: None, + schema: Some(schema), relax_numeric: true, } } } impl<'a> AvroParseOptions<'a> { - pub fn with_schema(mut self, schema: &'a Schema) -> Self { - self.schema = Some(schema); - self - } - fn extract_inner_schema(&self, key: Option<&'a str>) -> Option<&'a Schema> { self.schema .map(|schema| avro_extract_field_schema(schema, key)) @@ -71,15 +67,23 @@ impl<'a> AvroParseOptions<'a> { } /// Parse an avro value into expected type. - /// 3 kinds of type info are used to parsing things. - /// - `type_expected`. The type that we expect the value is. - /// - value type. The type info together with the value argument. - /// - schema. The `AvroSchema` provided in option. - /// If both `type_expected` and schema are provided, it will check both strictly. - /// If only `type_expected` is provided, it will try to match the value type and the - /// `type_expected`, converting the value if possible. If only value is provided (without - /// schema and `type_expected`), the `DateType` will be inferred. - pub fn parse<'b>(&self, value: &'b Value, type_expected: Option<&'b DataType>) -> AccessResult + /// + /// 3 kinds of type info are used to parsing: + /// - `type_expected`. The type that we expect the value is. + /// - value type. The type info together with the value argument. + /// - schema. The `AvroSchema` provided in option. + /// + /// Cases: (FIXME: Is this precise?) + /// - If both `type_expected` and schema are provided, it will check both strictly. + /// - If only `type_expected` is provided, it will try to match the value type and the + /// `type_expected`, converting the value if possible. + /// - If only value is provided (without schema and `type_expected`), + /// the `DataType` will be inferred. + pub fn convert_to_datum<'b>( + &self, + value: &'b Value, + type_expected: Option<&'b DataType>, + ) -> AccessResult where 'b: 'a, { @@ -97,7 +101,7 @@ impl<'a> AvroParseOptions<'a> { schema, relax_numeric: self.relax_numeric, } - .parse(v, type_expected); + .convert_to_datum(v, type_expected); } // ---- Boolean ----- (Some(DataType::Boolean) | None, Value::Boolean(b)) => (*b).into(), @@ -224,7 +228,7 @@ impl<'a> AvroParseOptions<'a> { schema, relax_numeric: self.relax_numeric, } - .parse(value, Some(field_type))?) + .convert_to_datum(value, Some(field_type))?) } else { Ok(None) } @@ -241,7 +245,7 @@ impl<'a> AvroParseOptions<'a> { schema, relax_numeric: self.relax_numeric, } - .parse(field_value, None) + .convert_to_datum(field_value, None) }) .collect::, AccessError>>()?; ScalarImpl::Struct(StructValue::new(rw_values)) @@ -255,7 +259,7 @@ impl<'a> AvroParseOptions<'a> { schema, relax_numeric: self.relax_numeric, } - .parse(v, Some(item_type))?; + .convert_to_datum(v, Some(item_type))?; builder.append(value); } builder.finish() @@ -325,7 +329,7 @@ where Err(create_error())?; } - options.parse(value, type_expected) + options.convert_to_datum(value, type_expected) } } @@ -484,12 +488,9 @@ mod tests { value_schema: &Schema, shape: &DataType, ) -> crate::error::ConnectorResult { - AvroParseOptions { - schema: Some(value_schema), - relax_numeric: true, - } - .parse(&value, Some(shape)) - .map_err(Into::into) + AvroParseOptions::create(value_schema) + .convert_to_datum(&value, Some(shape)) + .map_err(Into::into) } #[test] @@ -529,8 +530,10 @@ mod tests { .unwrap(); let bytes = vec![0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f, 0x3f]; let value = Value::Decimal(AvroDecimal::from(bytes)); - let options = AvroParseOptions::default().with_schema(&schema); - let resp = options.parse(&value, Some(&DataType::Decimal)).unwrap(); + let options = AvroParseOptions::create(&schema); + let resp = options + .convert_to_datum(&value, Some(&DataType::Decimal)) + .unwrap(); assert_eq!( resp, Some(ScalarImpl::Decimal(Decimal::Normalized( @@ -566,8 +569,10 @@ mod tests { ("value".to_string(), Value::Bytes(vec![0x01, 0x02, 0x03])), ]); - let options = AvroParseOptions::default().with_schema(&schema); - let resp = options.parse(&value, Some(&DataType::Decimal)).unwrap(); + let options = AvroParseOptions::create(&schema); + let resp = options + .convert_to_datum(&value, Some(&DataType::Decimal)) + .unwrap(); assert_eq!(resp, Some(ScalarImpl::Decimal(Decimal::from(66051)))); } } From 9edfd72a86a5e165eb8309a9c83b964c4eb3de80 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Fri, 31 May 2024 10:46:34 +0800 Subject: [PATCH 09/20] refactor(connector): remove `JsonParser` from production code (#17016) Signed-off-by: Bugen Zhao --- src/connector/benches/json_vs_plain_parser.rs | 82 ++++- src/connector/src/parser/json_parser.rs | 279 ++++-------------- src/connector/src/parser/mod.rs | 52 +++- src/connector/src/source/base.rs | 12 +- 4 files changed, 201 insertions(+), 224 deletions(-) diff --git a/src/connector/benches/json_vs_plain_parser.rs b/src/connector/benches/json_vs_plain_parser.rs index 5e904c88786e6..a176e3b2b0203 100644 --- a/src/connector/benches/json_vs_plain_parser.rs +++ b/src/connector/benches/json_vs_plain_parser.rs @@ -20,10 +20,90 @@ mod json_common; use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; use futures::executor::block_on; use json_common::*; +use old_json_parser::JsonParser; use risingwave_connector::parser::plain_parser::PlainParser; -use risingwave_connector::parser::{JsonParser, SourceStreamChunkBuilder, SpecificParserConfig}; +use risingwave_connector::parser::{SourceStreamChunkBuilder, SpecificParserConfig}; use risingwave_connector::source::SourceContext; +// The original implementation used to parse JSON prior to #13707. +mod old_json_parser { + use anyhow::Context as _; + use itertools::{Either, Itertools as _}; + use risingwave_common::{bail, try_match_expand}; + use risingwave_connector::error::ConnectorResult; + use risingwave_connector::parser::{ + Access as _, EncodingProperties, JsonAccess, SourceStreamChunkRowWriter, + }; + use risingwave_connector::source::{SourceColumnDesc, SourceContextRef}; + + use super::*; + + /// Parser for JSON format + #[derive(Debug)] + pub struct JsonParser { + _rw_columns: Vec, + _source_ctx: SourceContextRef, + // If schema registry is used, the starting index of payload is 5. + payload_start_idx: usize, + } + + impl JsonParser { + pub fn new( + props: SpecificParserConfig, + rw_columns: Vec, + source_ctx: SourceContextRef, + ) -> ConnectorResult { + let json_config = try_match_expand!(props.encoding_config, EncodingProperties::Json)?; + let payload_start_idx = if json_config.use_schema_registry { + 5 + } else { + 0 + }; + Ok(Self { + _rw_columns: rw_columns, + _source_ctx: source_ctx, + payload_start_idx, + }) + } + + #[allow(clippy::unused_async)] + pub async fn parse_inner( + &self, + mut payload: Vec, + mut writer: SourceStreamChunkRowWriter<'_>, + ) -> ConnectorResult<()> { + let value = simd_json::to_borrowed_value(&mut payload[self.payload_start_idx..]) + .context("failed to parse json payload")?; + let values = if let simd_json::BorrowedValue::Array(arr) = value { + Either::Left(arr.into_iter()) + } else { + Either::Right(std::iter::once(value)) + }; + + let mut errors = Vec::new(); + for value in values { + let accessor = JsonAccess::new(value); + match writer + .insert(|column| accessor.access(&[&column.name], Some(&column.data_type))) + { + Ok(_) => {} + Err(err) => errors.push(err), + } + } + + if errors.is_empty() { + Ok(()) + } else { + bail!( + "failed to parse {} row(s) in a single json message: {}", + errors.len(), + errors.iter().format(", ") + ); + } + } + } +} + fn generate_json_rows() -> Vec> { let mut rng = rand::thread_rng(); let mut records = Vec::with_capacity(NUM_RECORDS); diff --git a/src/connector/src/parser/json_parser.rs b/src/connector/src/parser/json_parser.rs index f9f5b1c848c46..701fa78322967 100644 --- a/src/connector/src/parser/json_parser.rs +++ b/src/connector/src/parser/json_parser.rs @@ -12,29 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Note on this file: +// +// There's no struct named `JsonParser` anymore since #13707. `ENCODE JSON` will be +// dispatched to `PlainParser` or `UpsertParser` with `JsonAccessBuilder` instead. +// +// This file now only contains utilities and tests for JSON parsing. Also, to avoid +// rely on the internal implementation and allow that to be changed, the tests use +// `ByteStreamSourceParserImpl` to create a parser instance. + use std::collections::HashMap; use anyhow::Context as _; use apache_avro::Schema; -use itertools::{Either, Itertools}; use jst::{convert_avro, Context}; -use risingwave_common::{bail, try_match_expand}; use risingwave_pb::plan_common::ColumnDesc; use super::avro::schema_resolver::ConfluentSchemaCache; -use super::unified::Access; use super::util::{bytes_from_url, get_kafka_topic}; -use super::{EncodingProperties, JsonProperties, SchemaRegistryAuth, SpecificParserConfig}; +use super::{JsonProperties, SchemaRegistryAuth}; use crate::error::ConnectorResult; -use crate::only_parse_payload; use crate::parser::avro::util::avro_schema_to_column_descs; use crate::parser::unified::json::{JsonAccess, JsonParseOptions}; use crate::parser::unified::AccessImpl; -use crate::parser::{ - AccessBuilder, ByteStreamSourceParser, ParserFormat, SourceStreamChunkRowWriter, -}; +use crate::parser::AccessBuilder; use crate::schema::schema_registry::{handle_sr_list, Client}; -use crate::source::{SourceColumnDesc, SourceContext, SourceContextRef}; #[derive(Debug)] pub struct JsonAccessBuilder { @@ -78,80 +80,6 @@ impl JsonAccessBuilder { } } -/// Parser for JSON format -#[derive(Debug)] -pub struct JsonParser { - rw_columns: Vec, - source_ctx: SourceContextRef, - // If schema registry is used, the starting index of payload is 5. - payload_start_idx: usize, -} - -impl JsonParser { - pub fn new( - props: SpecificParserConfig, - rw_columns: Vec, - source_ctx: SourceContextRef, - ) -> ConnectorResult { - let json_config = try_match_expand!(props.encoding_config, EncodingProperties::Json)?; - let payload_start_idx = if json_config.use_schema_registry { - 5 - } else { - 0 - }; - Ok(Self { - rw_columns, - source_ctx, - payload_start_idx, - }) - } - - #[cfg(test)] - pub fn new_for_test(rw_columns: Vec) -> ConnectorResult { - Ok(Self { - rw_columns, - source_ctx: SourceContext::dummy().into(), - payload_start_idx: 0, - }) - } - - #[allow(clippy::unused_async)] - pub async fn parse_inner( - &self, - mut payload: Vec, - mut writer: SourceStreamChunkRowWriter<'_>, - ) -> ConnectorResult<()> { - let value = simd_json::to_borrowed_value(&mut payload[self.payload_start_idx..]) - .context("failed to parse json payload")?; - let values = if let simd_json::BorrowedValue::Array(arr) = value { - Either::Left(arr.into_iter()) - } else { - Either::Right(std::iter::once(value)) - }; - - let mut errors = Vec::new(); - for value in values { - let accessor = JsonAccess::new(value); - match writer.insert(|column| accessor.access(&[&column.name], Some(&column.data_type))) - { - Ok(_) => {} - Err(err) => errors.push(err), - } - } - - if errors.is_empty() { - Ok(()) - } else { - // TODO(error-handling): multiple errors - bail!( - "failed to parse {} row(s) in a single json message: {}", - errors.len(), - errors.iter().format(", ") - ); - } - } -} - pub async fn schema_to_columns( schema_location: &str, schema_registry_auth: Option, @@ -179,29 +107,6 @@ pub async fn schema_to_columns( avro_schema_to_column_descs(&schema, None) } -impl ByteStreamSourceParser for JsonParser { - fn columns(&self) -> &[SourceColumnDesc] { - &self.rw_columns - } - - fn source_ctx(&self) -> &SourceContext { - &self.source_ctx - } - - fn parser_format(&self) -> ParserFormat { - ParserFormat::Json - } - - async fn parse_one<'a>( - &'a mut self, - _key: Option>, - payload: Option>, - writer: SourceStreamChunkRowWriter<'a>, - ) -> ConnectorResult<()> { - only_parse_payload!(self, payload, writer) - } -} - #[cfg(test)] mod tests { use std::vec; @@ -215,13 +120,31 @@ mod tests { use risingwave_pb::plan_common::additional_column::ColumnType as AdditionalColumnType; use risingwave_pb::plan_common::{AdditionalColumn, AdditionalColumnKey}; - use super::JsonParser; - use crate::parser::upsert_parser::UpsertParser; + use crate::parser::test_utils::ByteStreamSourceParserImplTestExt as _; use crate::parser::{ - EncodingProperties, JsonProperties, ProtocolProperties, SourceColumnDesc, - SourceStreamChunkBuilder, SpecificParserConfig, + ByteStreamSourceParserImpl, CommonParserConfig, ParserConfig, ProtocolProperties, + SourceColumnDesc, SpecificParserConfig, }; - use crate::source::{SourceColumnType, SourceContext}; + use crate::source::SourceColumnType; + + fn make_parser(rw_columns: Vec) -> ByteStreamSourceParserImpl { + ByteStreamSourceParserImpl::create_for_test(ParserConfig { + common: CommonParserConfig { rw_columns }, + specific: SpecificParserConfig::DEFAULT_PLAIN_JSON, + }) + .unwrap() + } + + fn make_upsert_parser(rw_columns: Vec) -> ByteStreamSourceParserImpl { + ByteStreamSourceParserImpl::create_for_test(ParserConfig { + common: CommonParserConfig { rw_columns }, + specific: SpecificParserConfig { + protocol_config: ProtocolProperties::Upsert, + ..SpecificParserConfig::DEFAULT_PLAIN_JSON + }, + }) + .unwrap() + } fn get_payload() -> Vec> { vec![ @@ -251,21 +174,8 @@ mod tests { SourceColumnDesc::simple("interval", DataType::Interval, 11.into()), ]; - let parser = JsonParser::new( - SpecificParserConfig::DEFAULT_PLAIN_JSON, - descs.clone(), - SourceContext::dummy().into(), - ) - .unwrap(); - - let mut builder = SourceStreamChunkBuilder::with_capacity(descs, 2); - - for payload in get_payload() { - let writer = builder.row_writer(); - parser.parse_inner(payload, writer).await.unwrap(); - } - - let chunk = builder.finish(); + let parser = make_parser(descs); + let chunk = parser.parse(get_payload()).await; let mut rows = chunk.rows(); @@ -361,38 +271,20 @@ mod tests { SourceColumnDesc::simple("v2", DataType::Int16, 1.into()), SourceColumnDesc::simple("v3", DataType::Varchar, 2.into()), ]; - let parser = JsonParser::new( - SpecificParserConfig::DEFAULT_PLAIN_JSON, - descs.clone(), - SourceContext::dummy().into(), - ) - .unwrap(); - let mut builder = SourceStreamChunkBuilder::with_capacity(descs, 3); - - // Parse a correct record. - { - let writer = builder.row_writer(); - let payload = br#"{"v1": 1, "v2": 2, "v3": "3"}"#.to_vec(); - parser.parse_inner(payload, writer).await.unwrap(); - } - // Parse an incorrect record. - { - let writer = builder.row_writer(); + let parser = make_parser(descs); + let payloads = vec![ + // Parse a correct record. + br#"{"v1": 1, "v2": 2, "v3": "3"}"#.to_vec(), + // Parse an incorrect record. // `v2` overflowed. - let payload = br#"{"v1": 1, "v2": 65536, "v3": "3"}"#.to_vec(); // ignored the error, and fill None at v2. - parser.parse_inner(payload, writer).await.unwrap(); - } - - // Parse a correct record. - { - let writer = builder.row_writer(); - let payload = br#"{"v1": 1, "v2": 2, "v3": "3"}"#.to_vec(); - parser.parse_inner(payload, writer).await.unwrap(); - } + br#"{"v1": 1, "v2": 65536, "v3": "3"}"#.to_vec(), + // Parse a correct record. + br#"{"v1": 1, "v2": 2, "v3": "3"}"#.to_vec(), + ]; + let chunk = parser.parse(payloads).await; - let chunk = builder.finish(); assert!(chunk.valid()); assert_eq!(chunk.cardinality(), 3); @@ -432,12 +324,7 @@ mod tests { .map(SourceColumnDesc::from) .collect_vec(); - let parser = JsonParser::new( - SpecificParserConfig::DEFAULT_PLAIN_JSON, - descs.clone(), - SourceContext::dummy().into(), - ) - .unwrap(); + let parser = make_parser(descs); let payload = br#" { "data": { @@ -456,12 +343,8 @@ mod tests { "VarcharCastToI64": "1598197865760800768" } "#.to_vec(); - let mut builder = SourceStreamChunkBuilder::with_capacity(descs, 1); - { - let writer = builder.row_writer(); - parser.parse_inner(payload, writer).await.unwrap(); - } - let chunk = builder.finish(); + let chunk = parser.parse(vec![payload]).await; + let (op, row) = chunk.rows().next().unwrap(); assert_eq!(op, Op::Insert); let row = row.into_owned_row().into_inner(); @@ -504,24 +387,15 @@ mod tests { .map(SourceColumnDesc::from) .collect_vec(); - let parser = JsonParser::new( - SpecificParserConfig::DEFAULT_PLAIN_JSON, - descs.clone(), - SourceContext::dummy().into(), - ) - .unwrap(); + let parser = make_parser(descs); let payload = br#" { "struct": "{\"varchar\": \"varchar\", \"boolean\": true}" } "# .to_vec(); - let mut builder = SourceStreamChunkBuilder::with_capacity(descs, 1); - { - let writer = builder.row_writer(); - parser.parse_inner(payload, writer).await.unwrap(); - } - let chunk = builder.finish(); + let chunk = parser.parse(vec![payload]).await; + let (op, row) = chunk.rows().next().unwrap(); assert_eq!(op, Op::Insert); let row = row.into_owned_row().into_inner(); @@ -550,12 +424,7 @@ mod tests { .map(SourceColumnDesc::from) .collect_vec(); - let parser = JsonParser::new( - SpecificParserConfig::DEFAULT_PLAIN_JSON, - descs.clone(), - SourceContext::dummy().into(), - ) - .unwrap(); + let parser = make_parser(descs); let payload = br#" { "struct": { @@ -564,12 +433,8 @@ mod tests { } "# .to_vec(); - let mut builder = SourceStreamChunkBuilder::with_capacity(descs, 1); - { - let writer = builder.row_writer(); - parser.parse_inner(payload, writer).await.unwrap(); - } - let chunk = builder.finish(); + let chunk = parser.parse(vec![payload]).await; + let (op, row) = chunk.rows().next().unwrap(); assert_eq!(op, Op::Insert); let row = row.into_owned_row().into_inner(); @@ -591,7 +456,10 @@ mod tests { (r#"{"a":2}"#, r#"{"a":2,"b":2}"#), (r#"{"a":2}"#, r#""#), ] - .to_vec(); + .into_iter() + .map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec())) + .collect_vec(); + let key_column_desc = SourceColumnDesc { name: "rw_key".into(), data_type: DataType::Bytea, @@ -609,34 +477,9 @@ mod tests { SourceColumnDesc::simple("b", DataType::Int32, 1.into()), key_column_desc, ]; - let props = SpecificParserConfig { - key_encoding_config: None, - encoding_config: EncodingProperties::Json(JsonProperties { - use_schema_registry: false, - timestamptz_handling: None, - }), - protocol_config: ProtocolProperties::Upsert, - }; - let mut parser = UpsertParser::new(props, descs.clone(), SourceContext::dummy().into()) - .await - .unwrap(); - let mut builder = SourceStreamChunkBuilder::with_capacity(descs, 4); - for item in items { - parser - .parse_inner( - Some(item.0.as_bytes().to_vec()), - if !item.1.is_empty() { - Some(item.1.as_bytes().to_vec()) - } else { - None - }, - builder.row_writer(), - ) - .await - .unwrap(); - } - let chunk = builder.finish(); + let parser = make_upsert_parser(descs); + let chunk = parser.parse_upsert(items).await; // expected chunk // +---+---+---+------------------+ diff --git a/src/connector/src/parser/mod.rs b/src/connector/src/parser/mod.rs index 2c0643af67109..be697d990a39a 100644 --- a/src/connector/src/parser/mod.rs +++ b/src/connector/src/parser/mod.rs @@ -45,7 +45,8 @@ pub use self::mysql::mysql_row_to_owned_row; use self::plain_parser::PlainParser; pub use self::postgres::postgres_row_to_owned_row; use self::simd_json_parser::DebeziumJsonAccessBuilder; -pub use self::unified::json::TimestamptzHandling; +pub use self::unified::json::{JsonAccess, TimestamptzHandling}; +pub use self::unified::Access; use self::unified::AccessImpl; use self::upsert_parser::UpsertParser; use self::util::get_kafka_topic; @@ -868,7 +869,6 @@ impl AccessBuilderImpl { #[derive(Debug)] pub enum ByteStreamSourceParserImpl { Csv(CsvParser), - Json(JsonParser), Debezium(DebeziumParser), Plain(PlainParser), Upsert(UpsertParser), @@ -883,7 +883,6 @@ impl ByteStreamSourceParserImpl { #[auto_enum(futures03::Stream)] let stream = match self { Self::Csv(parser) => parser.into_stream(msg_stream), - Self::Json(parser) => parser.into_stream(msg_stream), Self::Debezium(parser) => parser.into_stream(msg_stream), Self::DebeziumMongoJson(parser) => parser.into_stream(msg_stream), Self::Maxwell(parser) => parser.into_stream(msg_stream), @@ -944,6 +943,53 @@ impl ByteStreamSourceParserImpl { } } +/// Test utilities for [`ByteStreamSourceParserImpl`]. +#[cfg(test)] +pub mod test_utils { + use futures::StreamExt as _; + use itertools::Itertools as _; + + use super::*; + + #[easy_ext::ext(ByteStreamSourceParserImplTestExt)] + pub(crate) impl ByteStreamSourceParserImpl { + /// Parse the given payloads into a [`StreamChunk`]. + async fn parse(self, payloads: Vec>) -> StreamChunk { + let source_messages = payloads + .into_iter() + .map(|p| SourceMessage { + payload: (!p.is_empty()).then_some(p), + ..SourceMessage::dummy() + }) + .collect_vec(); + + self.into_stream(futures::stream::once(async { Ok(source_messages) }).boxed()) + .next() + .await + .unwrap() + .unwrap() + } + + /// Parse the given key-value pairs into a [`StreamChunk`]. + async fn parse_upsert(self, kvs: Vec<(Vec, Vec)>) -> StreamChunk { + let source_messages = kvs + .into_iter() + .map(|(k, v)| SourceMessage { + key: (!k.is_empty()).then_some(k), + payload: (!v.is_empty()).then_some(v), + ..SourceMessage::dummy() + }) + .collect_vec(); + + self.into_stream(futures::stream::once(async { Ok(source_messages) }).boxed()) + .next() + .await + .unwrap() + .unwrap() + } + } +} + #[derive(Debug, Clone, Default)] pub struct ParserConfig { pub common: CommonParserConfig, diff --git a/src/connector/src/source/base.rs b/src/connector/src/source/base.rs index b670568dc6e42..a4996eabbf82e 100644 --- a/src/connector/src/source/base.rs +++ b/src/connector/src/source/base.rs @@ -316,8 +316,16 @@ pub fn extract_source_struct(info: &PbStreamSourceInfo) -> Result /// Stream of [`SourceMessage`]. pub type BoxSourceStream = BoxStream<'static, crate::error::ConnectorResult>>; -pub trait ChunkSourceStream = - Stream> + Send + 'static; +// Manually expand the trait alias to improve IDE experience. +pub trait ChunkSourceStream: + Stream> + Send + 'static +{ +} +impl ChunkSourceStream for T where + T: Stream> + Send + 'static +{ +} + pub type BoxChunkSourceStream = BoxStream<'static, crate::error::ConnectorResult>; pub type BoxTryStream = BoxStream<'static, crate::error::ConnectorResult>; From 1c1f34992b8d5fa992214422963184dc9679ed21 Mon Sep 17 00:00:00 2001 From: Dylan Date: Fri, 31 May 2024 11:20:08 +0800 Subject: [PATCH 10/20] feat(batch): add spill at least memory for hash agg (#17021) --- src/batch/src/executor/hash_agg.rs | 70 +++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 20 deletions(-) diff --git a/src/batch/src/executor/hash_agg.rs b/src/batch/src/executor/hash_agg.rs index cb4adcecdc8c7..00f7366655e0d 100644 --- a/src/batch/src/executor/hash_agg.rs +++ b/src/batch/src/executor/hash_agg.rs @@ -188,6 +188,8 @@ pub struct HashAggExecutor { chunk_size: usize, mem_context: MemoryContext, enable_spill: bool, + /// The upper bound of memory usage for this executor. + memory_upper_bound: Option, shutdown_rx: ShutdownToken, _phantom: PhantomData, } @@ -205,7 +207,7 @@ impl HashAggExecutor { enable_spill: bool, shutdown_rx: ShutdownToken, ) -> Self { - Self::new_with_init_agg_state( + Self::new_inner( aggs, group_key_columns, group_key_types, @@ -216,12 +218,13 @@ impl HashAggExecutor { chunk_size, mem_context, enable_spill, + None, shutdown_rx, ) } #[allow(clippy::too_many_arguments)] - fn new_with_init_agg_state( + fn new_inner( aggs: Arc>, group_key_columns: Vec, group_key_types: Vec, @@ -232,6 +235,7 @@ impl HashAggExecutor { chunk_size: usize, mem_context: MemoryContext, enable_spill: bool, + memory_upper_bound: Option, shutdown_rx: ShutdownToken, ) -> Self { HashAggExecutor { @@ -245,6 +249,7 @@ impl HashAggExecutor { chunk_size, mem_context, enable_spill, + memory_upper_bound, shutdown_rx, _phantom: PhantomData, } @@ -461,6 +466,22 @@ impl AggSpillManager { Ok(Self::read_stream(r)) } + async fn estimate_partition_size(&self, partition: usize) -> Result { + let agg_state_partition_file_name = format!("agg-state-p{}", partition); + let agg_state_size = self + .op + .stat(&agg_state_partition_file_name) + .await? + .content_length(); + let input_partition_file_name = format!("input-chunks-p{}", partition); + let input_size = self + .op + .stat(&input_partition_file_name) + .await? + .content_length(); + Ok(agg_state_size + input_size) + } + async fn clear_partition(&mut self, partition: usize) -> Result<()> { let agg_state_partition_file_name = format!("agg-state-p{}", partition); self.op.delete(&agg_state_partition_file_name).await?; @@ -470,11 +491,18 @@ impl AggSpillManager { } } +const SPILL_AT_LEAST_MEMORY: u64 = 1024 * 1024; + impl HashAggExecutor { #[try_stream(boxed, ok = DataChunk, error = BatchError)] async fn do_execute(self: Box) { let child_schema = self.child.schema().clone(); let mut need_to_spill = false; + // If the memory upper bound is less than 1MB, we don't need to check memory usage. + let check_memory = match self.memory_upper_bound { + Some(upper_bound) => upper_bound > SPILL_AT_LEAST_MEMORY, + None => true, + }; // hash map for each agg groups let mut groups = AggHashMap::::with_hasher_in( @@ -508,7 +536,7 @@ impl HashAggExecutor { groups.try_insert(key, agg_states).unwrap(); } - if !self.mem_context.add(memory_usage_diff) { + if !self.mem_context.add(memory_usage_diff) && check_memory { warn!("not enough memory to load one partition agg state after spill which is not a normal case, so keep going"); } } @@ -553,7 +581,7 @@ impl HashAggExecutor { } } // update memory usage - if !self.mem_context.add(memory_usage_diff) { + if !self.mem_context.add(memory_usage_diff) && check_memory { if self.enable_spill { need_to_spill = true; break; @@ -624,26 +652,28 @@ impl HashAggExecutor { // Process each partition one by one. for i in 0..agg_spill_manager.partition_num { + let partition_size = agg_spill_manager.estimate_partition_size(i).await?; + let agg_state_stream = agg_spill_manager.read_agg_state_partition(i).await?; let input_stream = agg_spill_manager.read_input_partition(i).await?; - let sub_hash_agg_executor: HashAggExecutor = - HashAggExecutor::new_with_init_agg_state( - self.aggs.clone(), - self.group_key_columns.clone(), - self.group_key_types.clone(), + let sub_hash_agg_executor: HashAggExecutor = HashAggExecutor::new_inner( + self.aggs.clone(), + self.group_key_columns.clone(), + self.group_key_types.clone(), + self.schema.clone(), + Box::new(WrapStreamExecutor::new(child_schema.clone(), input_stream)), + Some(Box::new(WrapStreamExecutor::new( self.schema.clone(), - Box::new(WrapStreamExecutor::new(child_schema.clone(), input_stream)), - Some(Box::new(WrapStreamExecutor::new( - self.schema.clone(), - agg_state_stream, - ))), - format!("{}-sub{}", self.identity.clone(), i), - self.chunk_size, - self.mem_context.clone(), - self.enable_spill, - self.shutdown_rx.clone(), - ); + agg_state_stream, + ))), + format!("{}-sub{}", self.identity.clone(), i), + self.chunk_size, + self.mem_context.clone(), + self.enable_spill, + Some(partition_size), + self.shutdown_rx.clone(), + ); debug!( "create sub_hash_agg {} for hash_agg {} to spill", From 669358e735ae9da6fdd2a7080ff1baed572e5f8f Mon Sep 17 00:00:00 2001 From: xxchan Date: Fri, 31 May 2024 12:02:14 +0800 Subject: [PATCH 11/20] refactor(source): explain writer_schema and union handling for avro (#17031) Signed-off-by: xxchan --- src/connector/src/parser/avro/parser.rs | 39 +++++++++---------- .../src/parser/avro/schema_resolver.rs | 30 +++++++------- src/connector/src/parser/avro/util.rs | 13 +++++-- .../src/parser/debezium/avro_parser.rs | 6 +-- src/connector/src/parser/json_parser.rs | 11 ++---- 5 files changed, 50 insertions(+), 49 deletions(-) diff --git a/src/connector/src/parser/avro/parser.rs b/src/connector/src/parser/avro/parser.rs index b37417c41ee40..59bc10b084539 100644 --- a/src/connector/src/parser/avro/parser.rs +++ b/src/connector/src/parser/avro/parser.rs @@ -36,13 +36,14 @@ use crate::schema::schema_registry::{ #[derive(Debug)] pub struct AvroAccessBuilder { schema: Arc, - pub schema_resolver: Option>, + /// Refer to [`AvroParserConfig::writer_schema_cache`]. + pub writer_schema_cache: Option>, value: Option, } impl AccessBuilder for AvroAccessBuilder { async fn generate_accessor(&mut self, payload: Vec) -> ConnectorResult> { - self.value = self.parse_avro_value(&payload, Some(&*self.schema)).await?; + self.value = self.parse_avro_value(&payload).await?; Ok(AccessImpl::Avro(AvroAccess::new( self.value.as_ref().unwrap(), AvroParseOptions::create(&self.schema), @@ -55,7 +56,7 @@ impl AvroAccessBuilder { let AvroParserConfig { schema, key_schema, - schema_resolver, + writer_schema_cache, .. } = config; Ok(Self { @@ -63,35 +64,29 @@ impl AvroAccessBuilder { EncodingType::Key => key_schema.context("Avro with empty key schema")?, EncodingType::Value => schema, }, - schema_resolver, + writer_schema_cache, value: None, }) } - async fn parse_avro_value( - &self, - payload: &[u8], - reader_schema: Option<&Schema>, - ) -> ConnectorResult> { + async fn parse_avro_value(&self, payload: &[u8]) -> ConnectorResult> { // parse payload to avro value // if use confluent schema, get writer schema from confluent schema registry - if let Some(resolver) = &self.schema_resolver { + if let Some(resolver) = &self.writer_schema_cache { let (schema_id, mut raw_payload) = extract_schema_id(payload)?; - let writer_schema = resolver.get(schema_id).await?; + let writer_schema = resolver.get_by_id(schema_id).await?; Ok(Some(from_avro_datum( writer_schema.as_ref(), &mut raw_payload, - reader_schema, + Some(self.schema.as_ref()), )?)) - } else if let Some(schema) = reader_schema { - let mut reader = Reader::with_schema(schema, payload)?; + } else { + let mut reader = Reader::with_schema(self.schema.as_ref(), payload)?; match reader.next() { Some(Ok(v)) => Ok(Some(v)), Some(Err(e)) => Err(e)?, None => bail!("avro parse unexpected eof"), } - } else { - unreachable!("both schema_resolver and reader_schema not exist"); } } } @@ -100,7 +95,9 @@ impl AvroAccessBuilder { pub struct AvroParserConfig { pub schema: Arc, pub key_schema: Option>, - pub schema_resolver: Option>, + /// Writer schema is the schema used to write the data. When parsing Avro data, the exactly same schema + /// must be used to decode the message, and then convert it with the reader schema. + pub writer_schema_cache: Option>, pub map_handling: Option, } @@ -146,13 +143,13 @@ impl AvroParserConfig { tracing::debug!("infer key subject {subject_key:?}, value subject {subject_value}"); Ok(Self { - schema: resolver.get_by_subject_name(&subject_value).await?, + schema: resolver.get_by_subject(&subject_value).await?, key_schema: if let Some(subject_key) = subject_key { - Some(resolver.get_by_subject_name(&subject_key).await?) + Some(resolver.get_by_subject(&subject_key).await?) } else { None }, - schema_resolver: Some(Arc::new(resolver)), + writer_schema_cache: Some(Arc::new(resolver)), map_handling, }) } else { @@ -166,7 +163,7 @@ impl AvroParserConfig { Ok(Self { schema: Arc::new(schema), key_schema: None, - schema_resolver: None, + writer_schema_cache: None, map_handling, }) } diff --git a/src/connector/src/parser/avro/schema_resolver.rs b/src/connector/src/parser/avro/schema_resolver.rs index 72410e51ab162..058f9bcbf7ea3 100644 --- a/src/connector/src/parser/avro/schema_resolver.rs +++ b/src/connector/src/parser/avro/schema_resolver.rs @@ -21,7 +21,13 @@ use moka::future::Cache; use crate::error::ConnectorResult; use crate::schema::schema_registry::{Client, ConfluentSchema}; -/// TODO: support protobuf +/// Fetch schemas from confluent schema registry and cache them. +/// +/// Background: This is mainly used for Avro **writer schema** (during schema evolution): When decoding an Avro message, +/// we must get the message's schema id, and use the *exactly same schema* to decode the message, and then +/// convert it with the reader schema. (This is also why Avro has to be used with a schema registry instead of a static schema file.) +/// +/// TODO: support protobuf (not sure if it's needed) #[derive(Debug)] pub struct ConfluentSchemaCache { writer_schemas: Cache>, @@ -50,23 +56,17 @@ impl ConfluentSchemaCache { } } - pub async fn get_by_subject_name(&self, subject_name: &str) -> ConnectorResult> { - let raw_schema = self.get_raw_schema_by_subject_name(subject_name).await?; - self.parse_and_cache_schema(raw_schema).await - } - - pub async fn get_raw_schema_by_subject_name( - &self, - subject_name: &str, - ) -> ConnectorResult { - self.confluent_client + /// Gets the latest schema by subject name, which is used as *reader schema*. + pub async fn get_by_subject(&self, subject_name: &str) -> ConnectorResult> { + let raw_schema = self + .confluent_client .get_schema_by_subject(subject_name) - .await - .map_err(Into::into) + .await?; + self.parse_and_cache_schema(raw_schema).await } - // get the writer schema by id - pub async fn get(&self, schema_id: i32) -> ConnectorResult> { + /// Gets the a specific schema by id, which is used as *writer schema*. + pub async fn get_by_id(&self, schema_id: i32) -> ConnectorResult> { // TODO: use `get_with` if let Some(schema) = self.writer_schemas.get(&schema_id).await { Ok(schema) diff --git a/src/connector/src/parser/avro/util.rs b/src/connector/src/parser/avro/util.rs index 4f36b15e5ce76..ab3a200b513ea 100644 --- a/src/connector/src/parser/avro/util.rs +++ b/src/connector/src/parser/avro/util.rs @@ -147,11 +147,18 @@ fn avro_type_mapping( DataType::List(Box::new(item_type)) } Schema::Union(union_schema) => { - let nested_schema = union_schema - .variants() + // We only support using union to represent nullable fields, not general unions. + let variants = union_schema.variants(); + if variants.len() != 2 || !variants.contains(&Schema::Null) { + bail!( + "unsupported Avro type, only unions like [null, T] is supported: {:?}", + schema + ); + } + let nested_schema = variants .iter() .find_or_first(|s| !matches!(s, Schema::Null)) - .ok_or_else(|| anyhow::format_err!("unsupported Avro type: {:?}", union_schema))?; + .unwrap(); avro_type_mapping(nested_schema, map_handling)? } diff --git a/src/connector/src/parser/debezium/avro_parser.rs b/src/connector/src/parser/debezium/avro_parser.rs index 50762171106fc..6f4041ab5d39c 100644 --- a/src/connector/src/parser/debezium/avro_parser.rs +++ b/src/connector/src/parser/debezium/avro_parser.rs @@ -51,7 +51,7 @@ pub struct DebeziumAvroAccessBuilder { impl AccessBuilder for DebeziumAvroAccessBuilder { async fn generate_accessor(&mut self, payload: Vec) -> ConnectorResult> { let (schema_id, mut raw_payload) = extract_schema_id(&payload)?; - let schema = self.schema_resolver.get(schema_id).await?; + let schema = self.schema_resolver.get_by_id(schema_id).await?; self.value = Some(from_avro_datum(schema.as_ref(), &mut raw_payload, None)?); self.key_schema = match self.encoding_type { EncodingType::Key => Some(schema), @@ -112,8 +112,8 @@ impl DebeziumAvroParserConfig { let name_strategy = &PbSchemaRegistryNameStrategy::Unspecified; let key_subject = get_subject_by_strategy(name_strategy, kafka_topic, None, true)?; let val_subject = get_subject_by_strategy(name_strategy, kafka_topic, None, false)?; - let key_schema = resolver.get_by_subject_name(&key_subject).await?; - let outer_schema = resolver.get_by_subject_name(&val_subject).await?; + let key_schema = resolver.get_by_subject(&key_subject).await?; + let outer_schema = resolver.get_by_subject(&val_subject).await?; Ok(Self { key_schema, diff --git a/src/connector/src/parser/json_parser.rs b/src/connector/src/parser/json_parser.rs index 701fa78322967..5c511af9efb40 100644 --- a/src/connector/src/parser/json_parser.rs +++ b/src/connector/src/parser/json_parser.rs @@ -28,7 +28,6 @@ use apache_avro::Schema; use jst::{convert_avro, Context}; use risingwave_pb::plan_common::ColumnDesc; -use super::avro::schema_resolver::ConfluentSchemaCache; use super::util::{bytes_from_url, get_kafka_topic}; use super::{JsonProperties, SchemaRegistryAuth}; use crate::error::ConnectorResult; @@ -89,12 +88,10 @@ pub async fn schema_to_columns( let json_schema = if let Some(schema_registry_auth) = schema_registry_auth { let client = Client::new(url, &schema_registry_auth)?; let topic = get_kafka_topic(props)?; - let resolver = ConfluentSchemaCache::new(client); - let content = resolver - .get_raw_schema_by_subject_name(&format!("{}-value", topic)) - .await? - .content; - serde_json::from_str(&content)? + let schema = client + .get_schema_by_subject(&format!("{}-value", topic)) + .await?; + serde_json::from_str(&schema.content)? } else { let url = url.first().unwrap(); let bytes = bytes_from_url(url, None).await?; From cab8403d8327e753feb7524b6df610810e915849 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Fri, 31 May 2024 14:32:38 +0800 Subject: [PATCH 12/20] refactor(parser): eliminate the gap between parser v1 and v2 (#17019) Signed-off-by: Runji Wang Signed-off-by: TennyZhuang Co-authored-by: TennyZhuang --- .typos.toml | 3 +- e2e_test/error_ui/extended/main.slt | 4 +- e2e_test/error_ui/simple/main.slt | 4 +- e2e_test/source/basic/ddl.slt | 2 +- .../src/expr/function_impl/cast_regclass.rs | 22 +- src/sqlparser/src/ast/legacy_source.rs | 33 +- src/sqlparser/src/ast/mod.rs | 17 +- src/sqlparser/src/ast/statement.rs | 126 +- src/sqlparser/src/bin/sqlparser.rs | 27 +- src/sqlparser/src/parser.rs | 1356 +++++++---------- src/sqlparser/src/parser_v2/data_type.rs | 2 +- src/sqlparser/src/parser_v2/impl_.rs | 72 +- src/sqlparser/src/parser_v2/mod.rs | 89 +- src/sqlparser/src/parser_v2/number.rs | 32 +- src/sqlparser/src/test_utils.rs | 4 +- src/sqlparser/src/tokenizer.rs | 85 +- src/sqlparser/tests/sqlparser_postgres.rs | 8 +- src/sqlparser/tests/testdata/array.yaml | 32 +- src/sqlparser/tests/testdata/create.yaml | 38 +- src/sqlparser/tests/testdata/insert.yaml | 4 +- src/sqlparser/tests/testdata/select.yaml | 43 +- src/sqlparser/tests/testdata/set.yaml | 6 +- src/sqlparser/tests/testdata/struct.yaml | 2 +- src/sqlparser/tests/testdata/subquery.yaml | 8 +- 24 files changed, 885 insertions(+), 1134 deletions(-) diff --git a/.typos.toml b/.typos.toml index 7dcf4af6257d4..567904f5c319b 100644 --- a/.typos.toml +++ b/.typos.toml @@ -9,7 +9,7 @@ steam = "stream" # You played with Steam games too much. ser = "ser" # Serialization # Some weird short variable names ot = "ot" -bui = "bui" # BackwardUserIterator +bui = "bui" # BackwardUserIterator mosquitto = "mosquitto" # This is a MQTT broker. abd = "abd" iy = "iy" @@ -22,6 +22,7 @@ extend-exclude = [ "e2e_test", "**/*.svg", "scripts", + "src/sqlparser/tests/testdata/", "src/frontend/planner_test/tests/testdata", "src/tests/sqlsmith/tests/freeze", "Cargo.lock", diff --git a/e2e_test/error_ui/extended/main.slt b/e2e_test/error_ui/extended/main.slt index eb6669ce89d71..99de2de06908d 100644 --- a/e2e_test/error_ui/extended/main.slt +++ b/e2e_test/error_ui/extended/main.slt @@ -4,9 +4,9 @@ selet 1; db error: ERROR: Failed to prepare the statement Caused by: - sql parser error: expected an SQL statement, found: selet at line 1, column 1 + sql parser error: expected statement, found: selet LINE 1: selet 1; - ^ + ^ query error diff --git a/e2e_test/error_ui/simple/main.slt b/e2e_test/error_ui/simple/main.slt index c569560af631a..4445bedee968b 100644 --- a/e2e_test/error_ui/simple/main.slt +++ b/e2e_test/error_ui/simple/main.slt @@ -4,9 +4,9 @@ selet 1; db error: ERROR: Failed to run the query Caused by: - sql parser error: expected an SQL statement, found: selet at line 1, column 1 + sql parser error: expected statement, found: selet LINE 1: selet 1; - ^ + ^ statement error diff --git a/e2e_test/source/basic/ddl.slt b/e2e_test/source/basic/ddl.slt index 33b79dfda9b67..465e0f19344e9 100644 --- a/e2e_test/source/basic/ddl.slt +++ b/e2e_test/source/basic/ddl.slt @@ -4,7 +4,7 @@ create source s; db error: ERROR: Failed to run the query Caused by: - sql parser error: expected description of the format, found: ; at line 1, column 16 + sql parser error: expected description of the format, found: ; LINE 1: create source s; ^ diff --git a/src/frontend/src/expr/function_impl/cast_regclass.rs b/src/frontend/src/expr/function_impl/cast_regclass.rs index b1ec47a2d3508..c350d3984ab97 100644 --- a/src/frontend/src/expr/function_impl/cast_regclass.rs +++ b/src/frontend/src/expr/function_impl/cast_regclass.rs @@ -15,7 +15,6 @@ use risingwave_common::session_config::SearchPath; use risingwave_expr::{capture_context, function, ExprError}; use risingwave_sqlparser::parser::{Parser, ParserError}; -use risingwave_sqlparser::tokenizer::{Token, Tokenizer}; use thiserror::Error; use thiserror_ext::AsReport; @@ -63,7 +62,11 @@ fn resolve_regclass_inner( db_name: &str, class_name: &str, ) -> Result { - let obj = parse_object_name(class_name)?; + // We use the full parser here because this function needs to accept every legal way + // of identifying an object in PG SQL as a valid value for the varchar + // literal. For example: 'foo', 'public.foo', '"my table"', and + // '"my schema".foo' must all work as values passed pg_table_size. + let obj = Parser::parse_object_name_str(class_name)?; if obj.0.len() == 1 { let class_name = obj.0[0].real_value(); @@ -81,21 +84,6 @@ fn resolve_regclass_inner( } } -fn parse_object_name(name: &str) -> Result { - // We use the full parser here because this function needs to accept every legal way - // of identifying an object in PG SQL as a valid value for the varchar - // literal. For example: 'foo', 'public.foo', '"my table"', and - // '"my schema".foo' must all work as values passed pg_table_size. - let mut tokenizer = Tokenizer::new(name); - let tokens = tokenizer - .tokenize_with_location() - .map_err(ParserError::from)?; - let mut parser = Parser::new(tokens); - let object = parser.parse_object_name()?; - parser.expect_token(&Token::EOF)?; - Ok(object) -} - #[function("cast_regclass(varchar) -> int4")] fn cast_regclass(class_name: &str) -> Result { let oid = resolve_regclass_impl_captured(class_name)?; diff --git a/src/sqlparser/src/ast/legacy_source.rs b/src/sqlparser/src/ast/legacy_source.rs index e5f957231f5ad..6a079688c2d4e 100644 --- a/src/sqlparser/src/ast/legacy_source.rs +++ b/src/sqlparser/src/ast/legacy_source.rs @@ -20,14 +20,15 @@ use std::fmt; use itertools::Itertools as _; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use winnow::PResult; use crate::ast::{ AstString, AstVec, ConnectorSchema, Encode, Format, Ident, ObjectName, ParseTo, SqlOption, Value, }; use crate::keywords::Keyword; -use crate::parser::{Parser, ParserError}; -use crate::{impl_fmt_display, impl_parse_to}; +use crate::parser::{Parser, StrError}; +use crate::{impl_fmt_display, impl_parse_to, parser_err}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -64,12 +65,10 @@ impl From for CompatibleSourceSchema { } } -pub fn parse_source_schema(p: &mut Parser) -> Result { +pub fn parse_source_schema(p: &mut Parser<'_>) -> PResult { if let Some(schema_v2) = p.parse_schema()? { if schema_v2.key_encode.is_some() { - return Err(ParserError::ParserError( - "key encode clause is not supported in source schema".to_string(), - )); + parser_err!("key encode clause is not supported in source schema"); } Ok(CompatibleSourceSchema::V2(schema_v2)) } else if p.peek_nth_any_of_keywords(0, &[Keyword::ROW]) @@ -109,16 +108,15 @@ pub fn parse_source_schema(p: &mut Parser) -> Result SourceSchema::Bytes, _ => { - return Err(ParserError::ParserError( + parser_err!( "expected JSON | UPSERT_JSON | PROTOBUF | DEBEZIUM_JSON | DEBEZIUM_AVRO \ | AVRO | UPSERT_AVRO | MAXWELL | CANAL_JSON | BYTES | NATIVE after ROW FORMAT" - .to_string(), - )); + ); } }; Ok(CompatibleSourceSchema::RowFormat(schema)) } else { - p.expected("description of the format", p.peek_token()) + p.expected("description of the format") } } @@ -286,7 +284,7 @@ pub struct ProtobufSchema { } impl ParseTo for ProtobufSchema { - fn parse_to(p: &mut Parser) -> Result { + fn parse_to(p: &mut Parser<'_>) -> PResult { impl_parse_to!([Keyword::MESSAGE], p); impl_parse_to!(message_name: AstString, p); impl_parse_to!([Keyword::ROW, Keyword::SCHEMA, Keyword::LOCATION], p); @@ -324,7 +322,7 @@ pub struct AvroSchema { } impl ParseTo for AvroSchema { - fn parse_to(p: &mut Parser) -> Result { + fn parse_to(p: &mut Parser<'_>) -> PResult { impl_parse_to!([Keyword::ROW, Keyword::SCHEMA, Keyword::LOCATION], p); impl_parse_to!(use_schema_registry => [Keyword::CONFLUENT, Keyword::SCHEMA, Keyword::REGISTRY], p); impl_parse_to!(row_schema_location: AstString, p); @@ -371,7 +369,7 @@ impl fmt::Display for DebeziumAvroSchema { } impl ParseTo for DebeziumAvroSchema { - fn parse_to(p: &mut Parser) -> Result { + fn parse_to(p: &mut Parser<'_>) -> PResult { impl_parse_to!( [ Keyword::ROW, @@ -397,19 +395,18 @@ pub struct CsvInfo { pub has_header: bool, } -pub fn get_delimiter(chars: &str) -> Result { +pub fn get_delimiter(chars: &str) -> Result { match chars { "," => Ok(b','), // comma "\t" => Ok(b'\t'), // tab - other => Err(ParserError::ParserError(format!( - "The delimiter should be one of ',', E'\\t', but got {:?}", - other + other => Err(StrError(format!( + "The delimiter should be one of ',', E'\\t', but got {other:?}", ))), } } impl ParseTo for CsvInfo { - fn parse_to(p: &mut Parser) -> Result { + fn parse_to(p: &mut Parser<'_>) -> PResult { impl_parse_to!(without_header => [Keyword::WITHOUT, Keyword::HEADER], p); impl_parse_to!([Keyword::DELIMITED, Keyword::BY], p); impl_parse_to!(delimiter: AstString, p); diff --git a/src/sqlparser/src/ast/mod.rs b/src/sqlparser/src/ast/mod.rs index 4b96565a0d683..0ce72be34bf82 100644 --- a/src/sqlparser/src/ast/mod.rs +++ b/src/sqlparser/src/ast/mod.rs @@ -33,6 +33,7 @@ use std::sync::Arc; use itertools::Itertools; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use winnow::PResult; pub use self::data_type::{DataType, StructField}; pub use self::ddl::{ @@ -59,7 +60,7 @@ pub use crate::ast::ddl::{ AlterViewOperation, }; use crate::keywords::Keyword; -use crate::parser::{IncludeOption, IncludeOptionItem, Parser, ParserError}; +use crate::parser::{IncludeOption, IncludeOptionItem, Parser, ParserError, StrError}; pub type RedactSqlOptionKeywordsRef = Arc>; @@ -191,7 +192,7 @@ impl From<&str> for Ident { } impl ParseTo for Ident { - fn parse_to(parser: &mut Parser) -> Result { + fn parse_to(parser: &mut Parser<'_>) -> PResult { parser.parse_identifier() } } @@ -235,7 +236,7 @@ impl fmt::Display for ObjectName { } impl ParseTo for ObjectName { - fn parse_to(p: &mut Parser) -> Result { + fn parse_to(p: &mut Parser<'_>) -> PResult { p.parse_object_name() } } @@ -2560,7 +2561,7 @@ impl fmt::Display for ObjectType { } impl ParseTo for ObjectType { - fn parse_to(parser: &mut Parser) -> Result { + fn parse_to(parser: &mut Parser<'_>) -> PResult { let object_type = if parser.parse_keyword(Keyword::TABLE) { ObjectType::Table } else if parser.parse_keyword(Keyword::VIEW) { @@ -2588,7 +2589,6 @@ impl ParseTo for ObjectType { } else { return parser.expected( "TABLE, VIEW, INDEX, MATERIALIZED VIEW, SOURCE, SINK, SUBSCRIPTION, SCHEMA, DATABASE, USER, SECRET or CONNECTION after DROP", - parser.peek_token(), ); }; Ok(object_type) @@ -3007,7 +3007,7 @@ impl CreateFunctionWithOptions { /// TODO(kwannoel): Generate from the struct definition instead. impl TryFrom> for CreateFunctionWithOptions { - type Error = ParserError; + type Error = StrError; fn try_from(with_options: Vec) -> Result { let mut always_retry_on_network_error = None; @@ -3015,10 +3015,7 @@ impl TryFrom> for CreateFunctionWithOptions { if option.name.to_string().to_lowercase() == "always_retry_on_network_error" { always_retry_on_network_error = Some(option.value == Value::Boolean(true)); } else { - return Err(ParserError::ParserError(format!( - "Unsupported option: {}", - option.name - ))); + return Err(StrError(format!("Unsupported option: {}", option.name))); } } Ok(Self { diff --git a/src/sqlparser/src/ast/statement.rs b/src/sqlparser/src/ast/statement.rs index 2e5b281d1938f..0c92209471450 100644 --- a/src/sqlparser/src/ast/statement.rs +++ b/src/sqlparser/src/ast/statement.rs @@ -19,6 +19,7 @@ use std::fmt::Write; use itertools::Itertools; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +use winnow::PResult; use super::ddl::SourceWatermark; use super::legacy_source::{parse_source_schema, CompatibleSourceSchema}; @@ -27,12 +28,14 @@ use crate::ast::{ display_comma_separated, display_separated, ColumnDef, ObjectName, SqlOption, TableConstraint, }; use crate::keywords::Keyword; -use crate::parser::{IncludeOption, IsOptional, Parser, ParserError, UPSTREAM_SOURCE_KEY}; +use crate::parser::{IncludeOption, IsOptional, Parser, UPSTREAM_SOURCE_KEY}; +use crate::parser_err; +use crate::parser_v2::literal_u32; use crate::tokenizer::Token; /// Consumes token from the parser into an AST node. pub trait ParseTo: Sized { - fn parse_to(parser: &mut Parser) -> Result; + fn parse_to(parser: &mut Parser<'_>) -> PResult; } #[macro_export] @@ -132,7 +135,7 @@ impl fmt::Display for Format { } impl Format { - pub fn from_keyword(s: &str) -> Result { + pub fn from_keyword(s: &str) -> PResult { Ok(match s { "DEBEZIUM" => Format::Debezium, "DEBEZIUM_MONGO" => Format::DebeziumMongo, @@ -142,12 +145,9 @@ impl Format { "UPSERT" => Format::Upsert, "NATIVE" => Format::Native, // used internally for schema change "NONE" => Format::None, // used by iceberg - _ => { - return Err(ParserError::ParserError( - "expected CANAL | PROTOBUF | DEBEZIUM | MAXWELL | PLAIN | NATIVE | NONE after FORMAT" - .to_string(), - )); - } + _ => parser_err!( + "expected CANAL | PROTOBUF | DEBEZIUM | MAXWELL | PLAIN | NATIVE | NONE after FORMAT" + ), }) } } @@ -188,7 +188,7 @@ impl fmt::Display for Encode { } impl Encode { - pub fn from_keyword(s: &str) -> Result { + pub fn from_keyword(s: &str) -> PResult { Ok(match s { "AVRO" => Encode::Avro, "TEXT" => Encode::Text, @@ -199,10 +199,9 @@ impl Encode { "TEMPLATE" => Encode::Template, "NATIVE" => Encode::Native, // used internally for schema change "NONE" => Encode::None, // used by iceberg - _ => return Err(ParserError::ParserError( + _ => parser_err!( "expected AVRO | BYTES | CSV | PROTOBUF | JSON | NATIVE | TEMPLATE | NONE after Encode" - .to_string(), - )), + ), }) } } @@ -217,7 +216,7 @@ pub struct ConnectorSchema { pub key_encode: Option, } -impl Parser { +impl Parser<'_> { /// Peek the next tokens to see if it is `FORMAT` or `ROW FORMAT` (for compatibility). fn peek_source_schema_format(&mut self) -> bool { (self.peek_nth_any_of_keywords(0, &[Keyword::ROW]) @@ -230,7 +229,7 @@ impl Parser { &mut self, connector: &str, cdc_source_job: bool, - ) -> Result { + ) -> PResult { // row format for cdc source must be debezium json // row format for nexmark source must be native // default row format for datagen source is native @@ -247,10 +246,10 @@ impl Parser { if self.peek_source_schema_format() { let schema = parse_source_schema(self)?.into_v2(); if schema != expected { - return Err(ParserError::ParserError(format!( + parser_err!( "Row format for CDC connectors should be \ either omitted or set to `{expected}`", - ))); + ); } } Ok(expected.into()) @@ -259,10 +258,10 @@ impl Parser { if self.peek_source_schema_format() { let schema = parse_source_schema(self)?.into_v2(); if schema != expected { - return Err(ParserError::ParserError(format!( + parser_err!( "Row format for nexmark connectors should be \ either omitted or set to `{expected}`", - ))); + ); } } Ok(expected.into()) @@ -277,10 +276,10 @@ impl Parser { if self.peek_source_schema_format() { let schema = parse_source_schema(self)?.into_v2(); if schema != expected { - return Err(ParserError::ParserError(format!( + parser_err!( "Row format for iceberg connectors should be \ either omitted or set to `{expected}`", - ))); + ); } } Ok(expected.into()) @@ -290,7 +289,7 @@ impl Parser { } /// Parse `FORMAT ... ENCODE ... (...)`. - pub fn parse_schema(&mut self) -> Result, ParserError> { + pub fn parse_schema(&mut self) -> PResult> { if !self.parse_keyword(Keyword::FORMAT) { return Ok(None); } @@ -389,7 +388,7 @@ impl fmt::Display for ConnectorSchema { } impl ParseTo for CreateSourceStatement { - fn parse_to(p: &mut Parser) -> Result { + fn parse_to(p: &mut Parser<'_>) -> PResult { impl_parse_to!(if_not_exists => [Keyword::IF, Keyword::NOT, Keyword::EXISTS], p); impl_parse_to!(source_name: ObjectName, p); @@ -405,9 +404,7 @@ impl ParseTo for CreateSourceStatement { let connector: String = option.map(|opt| opt.value.to_string()).unwrap_or_default(); let cdc_source_job = connector.contains("-cdc"); if cdc_source_job && (!columns.is_empty() || !constraints.is_empty()) { - return Err(ParserError::ParserError( - "CDC source cannot define columns and constraints".to_string(), - )); + parser_err!("CDC source cannot define columns and constraints"); } // row format for nexmark source must be native @@ -524,7 +521,7 @@ pub struct CreateSinkStatement { } impl ParseTo for CreateSinkStatement { - fn parse_to(p: &mut Parser) -> Result { + fn parse_to(p: &mut Parser<'_>) -> PResult { impl_parse_to!(if_not_exists => [Keyword::IF, Keyword::NOT, Keyword::EXISTS], p); impl_parse_to!(sink_name: ObjectName, p); @@ -544,7 +541,7 @@ impl ParseTo for CreateSinkStatement { let query = Box::new(p.parse_query()?); CreateSink::AsQuery(query) } else { - p.expected("FROM or AS after CREATE SINK sink_name", p.peek_token())? + p.expected("FROM or AS after CREATE SINK sink_name")? }; let emit_mode: Option = p.parse_emit_mode()?; @@ -552,14 +549,12 @@ impl ParseTo for CreateSinkStatement { // This check cannot be put into the `WithProperties::parse_to`, since other // statements may not need the with properties. if !p.peek_nth_any_of_keywords(0, &[Keyword::WITH]) && into_table_name.is_none() { - p.expected("WITH", p.peek_token())? + p.expected("WITH")? } impl_parse_to!(with_properties: WithProperties, p); if with_properties.0.is_empty() && into_table_name.is_none() { - return Err(ParserError::ParserError( - "sink properties not provided".to_string(), - )); + parser_err!("sink properties not provided"); } let sink_schema = p.parse_schema()?; @@ -616,7 +611,7 @@ pub struct CreateSubscriptionStatement { } impl ParseTo for CreateSubscriptionStatement { - fn parse_to(p: &mut Parser) -> Result { + fn parse_to(p: &mut Parser<'_>) -> PResult { impl_parse_to!(if_not_exists => [Keyword::IF, Keyword::NOT, Keyword::EXISTS], p); impl_parse_to!(subscription_name: ObjectName, p); @@ -624,10 +619,7 @@ impl ParseTo for CreateSubscriptionStatement { impl_parse_to!(from_name: ObjectName, p); from_name } else { - p.expected( - "FROM after CREATE SUBSCRIPTION subscription_name", - p.peek_token(), - )? + p.expected("FROM after CREATE SUBSCRIPTION subscription_name")? }; // let emit_mode = p.parse_emit_mode()?; @@ -635,14 +627,12 @@ impl ParseTo for CreateSubscriptionStatement { // This check cannot be put into the `WithProperties::parse_to`, since other // statements may not need the with properties. if !p.peek_nth_any_of_keywords(0, &[Keyword::WITH]) { - p.expected("WITH", p.peek_token())? + p.expected("WITH")? } impl_parse_to!(with_properties: WithProperties, p); if with_properties.0.is_empty() { - return Err(ParserError::ParserError( - "subscription properties not provided".to_string(), - )); + parser_err!("subscription properties not provided"); } Ok(Self { @@ -703,7 +693,7 @@ pub struct DeclareCursorStatement { } impl ParseTo for DeclareCursorStatement { - fn parse_to(p: &mut Parser) -> Result { + fn parse_to(p: &mut Parser<'_>) -> PResult { impl_parse_to!(cursor_name: ObjectName, p); let declare_cursor = if !p.parse_keyword(Keyword::SUBSCRIPTION) { @@ -746,14 +736,11 @@ pub struct FetchCursorStatement { } impl ParseTo for FetchCursorStatement { - fn parse_to(p: &mut Parser) -> Result { + fn parse_to(p: &mut Parser<'_>) -> PResult { let count = if p.parse_keyword(Keyword::NEXT) { 1 } else { - let count_str = p.parse_number_value()?; - count_str.parse::().map_err(|e| { - ParserError::ParserError(format!("Could not parse '{}' as i32: {}", count_str, e)) - })? + literal_u32(p)? }; p.expect_keyword(Keyword::FROM)?; impl_parse_to!(cursor_name: ObjectName, p); @@ -786,7 +773,7 @@ pub struct CloseCursorStatement { } impl ParseTo for CloseCursorStatement { - fn parse_to(p: &mut Parser) -> Result { + fn parse_to(p: &mut Parser<'_>) -> PResult { let cursor_name = if p.parse_keyword(Keyword::ALL) { None } else { @@ -823,14 +810,12 @@ pub struct CreateConnectionStatement { } impl ParseTo for CreateConnectionStatement { - fn parse_to(p: &mut Parser) -> Result { + fn parse_to(p: &mut Parser<'_>) -> PResult { impl_parse_to!(if_not_exists => [Keyword::IF, Keyword::NOT, Keyword::EXISTS], p); impl_parse_to!(connection_name: ObjectName, p); impl_parse_to!(with_properties: WithProperties, p); if with_properties.0.is_empty() { - return Err(ParserError::ParserError( - "connection properties not provided".to_string(), - )); + parser_err!("connection properties not provided"); } Ok(Self { @@ -861,7 +846,7 @@ pub struct CreateSecretStatement { } impl ParseTo for CreateSecretStatement { - fn parse_to(parser: &mut Parser) -> Result { + fn parse_to(parser: &mut Parser<'_>) -> PResult { impl_parse_to!(if_not_exists => [Keyword::IF, Keyword::NOT, Keyword::EXISTS], parser); impl_parse_to!(secret_name: ObjectName, parser); impl_parse_to!(with_properties: WithProperties, parser); @@ -907,7 +892,7 @@ impl fmt::Display for AstVec { pub struct WithProperties(pub Vec); impl ParseTo for WithProperties { - fn parse_to(parser: &mut Parser) -> Result { + fn parse_to(parser: &mut Parser<'_>) -> PResult { Ok(Self( parser.parse_options_with_preceding_keyword(Keyword::WITH)?, )) @@ -950,7 +935,7 @@ pub struct RowSchemaLocation { } impl ParseTo for RowSchemaLocation { - fn parse_to(p: &mut Parser) -> Result { + fn parse_to(p: &mut Parser<'_>) -> PResult { impl_parse_to!([Keyword::ROW, Keyword::SCHEMA, Keyword::LOCATION], p); impl_parse_to!(value: AstString, p); Ok(Self { value }) @@ -973,7 +958,7 @@ impl fmt::Display for RowSchemaLocation { pub struct AstString(pub String); impl ParseTo for AstString { - fn parse_to(parser: &mut Parser) -> Result { + fn parse_to(parser: &mut Parser<'_>) -> PResult { Ok(Self(parser.parse_literal_string()?)) } } @@ -996,7 +981,7 @@ pub enum AstOption { } impl ParseTo for AstOption { - fn parse_to(parser: &mut Parser) -> Result { + fn parse_to(parser: &mut Parser<'_>) -> PResult { match T::parse_to(parser) { Ok(t) => Ok(AstOption::Some(t)), Err(_) => Ok(AstOption::None), @@ -1116,17 +1101,14 @@ impl UserOptionsBuilder { } impl ParseTo for UserOptions { - fn parse_to(parser: &mut Parser) -> Result { + fn parse_to(parser: &mut Parser<'_>) -> PResult { let mut builder = UserOptionsBuilder::default(); let add_option = |item: &mut Option, user_option| { let old_value = item.replace(user_option); if old_value.is_some() { - Err(ParserError::ParserError( - "conflicting or redundant options".to_string(), - )) - } else { - Ok(()) + parser_err!("conflicting or redundant options"); } + Ok(()) }; let _ = parser.parse_keyword(Keyword::WITH); loop { @@ -1136,6 +1118,7 @@ impl ParseTo for UserOptions { } if let Token::Word(ref w) = token.token { + let checkpoint = *parser; parser.next_token(); let (item_mut_ref, user_option) = match w.keyword { Keyword::SUPERUSER => (&mut builder.super_user, UserOption::SuperUser), @@ -1168,10 +1151,10 @@ impl ParseTo for UserOptions { (&mut builder.password, UserOption::OAuth(options)) } _ => { - parser.expected( + parser.expected_at( + checkpoint, "SUPERUSER | NOSUPERUSER | CREATEDB | NOCREATEDB | LOGIN \ | NOLOGIN | CREATEUSER | NOCREATEUSER | [ENCRYPTED] PASSWORD | NULL | OAUTH", - token, )?; unreachable!() } @@ -1181,7 +1164,6 @@ impl ParseTo for UserOptions { parser.expected( "SUPERUSER | NOSUPERUSER | CREATEDB | NOCREATEDB | LOGIN | NOLOGIN \ | CREATEUSER | NOCREATEUSER | [ENCRYPTED] PASSWORD | NULL | OAUTH", - token, )? } } @@ -1200,7 +1182,7 @@ impl fmt::Display for UserOptions { } impl ParseTo for CreateUserStatement { - fn parse_to(p: &mut Parser) -> Result { + fn parse_to(p: &mut Parser<'_>) -> PResult { impl_parse_to!(user_name: ObjectName, p); impl_parse_to!(with_options: UserOptions, p); @@ -1243,7 +1225,7 @@ impl fmt::Display for AlterUserStatement { } impl ParseTo for AlterUserStatement { - fn parse_to(p: &mut Parser) -> Result { + fn parse_to(p: &mut Parser<'_>) -> PResult { impl_parse_to!(user_name: ObjectName, p); impl_parse_to!(mode: AlterUserMode, p); @@ -1252,7 +1234,7 @@ impl ParseTo for AlterUserStatement { } impl ParseTo for AlterUserMode { - fn parse_to(p: &mut Parser) -> Result { + fn parse_to(p: &mut Parser<'_>) -> PResult { if p.parse_keyword(Keyword::RENAME) { p.expect_keyword(Keyword::TO)?; impl_parse_to!(new_name: ObjectName, p); @@ -1285,7 +1267,7 @@ pub struct DropStatement { // drop_mode: AstOption, // }); impl ParseTo for DropStatement { - fn parse_to(p: &mut Parser) -> Result { + fn parse_to(p: &mut Parser<'_>) -> PResult { impl_parse_to!(object_type: ObjectType, p); impl_parse_to!(if_exists => [Keyword::IF, Keyword::EXISTS], p); let object_name = p.parse_object_name()?; @@ -1318,13 +1300,13 @@ pub enum DropMode { } impl ParseTo for DropMode { - fn parse_to(parser: &mut Parser) -> Result { + fn parse_to(parser: &mut Parser<'_>) -> PResult { let drop_mode = if parser.parse_keyword(Keyword::CASCADE) { DropMode::Cascade } else if parser.parse_keyword(Keyword::RESTRICT) { DropMode::Restrict } else { - return parser.expected("CASCADE | RESTRICT", parser.peek_token()); + return parser.expected("CASCADE | RESTRICT"); }; Ok(drop_mode) } diff --git a/src/sqlparser/src/bin/sqlparser.rs b/src/sqlparser/src/bin/sqlparser.rs index be2ec51bc78fc..57a984a8c1cd1 100644 --- a/src/sqlparser/src/bin/sqlparser.rs +++ b/src/sqlparser/src/bin/sqlparser.rs @@ -1,3 +1,19 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![feature(register_tool)] +#![register_tool(rw)] +#![allow(rw::format_error)] // test code + use std::io; use risingwave_sqlparser::parser::Parser; @@ -12,5 +28,14 @@ fn main() { let mut buffer = String::new(); io::stdin().read_line(&mut buffer).unwrap(); let result = Parser::parse_sql(&buffer); - println!("{:#?}", result); + match result { + Ok(statements) => { + for statement in statements { + println!("{:#?}", statement); + } + } + Err(e) => { + eprintln!("{}", e); + } + } } diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index 71a2099e0e6fd..b1a8e2ba22b90 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -24,10 +24,13 @@ use core::fmt; use itertools::Itertools; use tracing::{debug, instrument}; +use winnow::combinator::{alt, cut_err, dispatch, fail, opt, peek, preceded, repeat, separated}; +use winnow::{PResult, Parser as _}; use crate::ast::*; use crate::keywords::{self, Keyword}; use crate::parser_v2; +use crate::parser_v2::{keyword, literal_i64, literal_uint, ParserExt as _}; use crate::tokenizer::*; pub(crate) const UPSTREAM_SOURCE_KEY: &str = "connector"; @@ -45,14 +48,33 @@ impl ParserError { } } } + +#[derive(Debug, thiserror::Error)] +#[error("{0}")] +pub struct StrError(pub String); + // Use `Parser::expected` instead, if possible #[macro_export] macro_rules! parser_err { - ($MSG:expr) => { - Err(ParserError::ParserError($MSG.to_string())) + ($($arg:tt)*) => { + return Err(winnow::error::ErrMode::Backtrack(>::from_external_error( + &Parser::default(), + winnow::error::ErrorKind::Fail, + $crate::parser::StrError(format!($($arg)*)), + ))) }; } +impl From for winnow::error::ErrMode { + fn from(e: StrError) -> Self { + winnow::error::ErrMode::Backtrack(>::from_external_error( + &Parser::default(), + winnow::error::ErrorKind::Fail, + e, + )) + } +} + // Returns a successful result if the optional expression is some macro_rules! return_ok_if_some { ($e:expr) => {{ @@ -169,70 +191,20 @@ pub enum Precedence { DoubleColon, // 50 in upstream } -pub struct Parser { - tokens: Vec, - /// The index of the first unprocessed token in `self.tokens` - index: usize, -} - -impl Parser { - /// Parse the specified tokens - pub fn new(tokens: Vec) -> Self { - Parser { tokens, index: 0 } - } - - /// Adaptor for [`parser_v2`]. - /// - /// You can call a v2 parser from original parser by using this method. - pub(crate) fn parse_v2<'a, O>( - &'a mut self, - mut parse_next: impl winnow::Parser< - winnow::Located>, - O, - winnow::error::ContextError, - >, - ) -> Result { - use winnow::stream::Location; - - let mut token_stream = winnow::Located::new(parser_v2::TokenStreamWrapper { - tokens: &self.tokens[self.index..], - }); - let output = parse_next.parse_next(&mut token_stream).map_err(|e| { - let msg = if let Some(e) = e.into_inner() - && let Some(cause) = e.cause() - { - format!(": {}", cause) - } else { - "".to_string() - }; - ParserError::ParserError(format!( - "Unexpected {}{}", - if self.index + token_stream.location() >= self.tokens.len() { - &"EOF" as &dyn std::fmt::Display - } else { - &self.tokens[self.index + token_stream.location()] as &dyn std::fmt::Display - }, - msg - )) - }); - let offset = token_stream.location(); - self.index += offset; - output - } +#[derive(Clone, Copy, Default)] +pub struct Parser<'a>(pub(crate) &'a [TokenWithLocation]); +impl Parser<'_> { /// Parse a SQL statement and produce an Abstract Syntax Tree (AST) #[instrument(level = "debug")] pub fn parse_sql(sql: &str) -> Result, ParserError> { let mut tokenizer = Tokenizer::new(sql); let tokens = tokenizer.tokenize_with_location()?; - let mut parser = Parser::new(tokens); - let ast = parser.parse_statements().map_err(|e| { + let parser = Parser(&tokens); + let stmts = Parser::parse_statements.parse(parser).map_err(|e| { // append SQL context to the error message, e.g.: // LINE 1: SELECT 1::int(2); - // ^ - // XXX: the cursor location is not accurate - // it may be offset one token forward because the error token has been consumed - let loc = match parser.tokens.get(parser.index) { + let loc = match tokens.get(e.offset()) { Some(token) => token.location.clone(), None => { // get location of EOF @@ -247,17 +219,27 @@ impl Parser { let cursor = " ".repeat(prefix.len() + loc.column as usize - 1); ParserError::ParserError(format!( "{}\n{}{}\n{}^", - e.inner_msg(), + e.inner().to_string().replace('\n', ": "), prefix, sql_line, cursor )) })?; - Ok(ast) + Ok(stmts) + } + + /// Parse object name from a string. + pub fn parse_object_name_str(s: &str) -> Result { + let mut tokenizer = Tokenizer::new(s); + let tokens = tokenizer.tokenize_with_location()?; + let parser = Parser(&tokens); + Parser::parse_object_name + .parse(parser) + .map_err(|e| ParserError::ParserError(e.inner().to_string())) } - /// Parse a list of semicolon-separated SQL statements. - pub fn parse_statements(&mut self) -> Result, ParserError> { + /// Parse a list of semicolon-separated statements. + fn parse_statements(&mut self) -> PResult> { let mut stmts = Vec::new(); let mut expecting_statement_delimiter = false; loop { @@ -270,7 +252,7 @@ impl Parser { break; } if expecting_statement_delimiter { - return self.expected("end of statement", self.peek_token()); + return self.expected("end of statement"); } let statement = self.parse_statement()?; @@ -283,14 +265,15 @@ impl Parser { /// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.), /// stopping before the statement separator, if any. - pub fn parse_statement(&mut self) -> Result { + pub fn parse_statement(&mut self) -> PResult { + let checkpoint = *self; let token = self.next_token(); match token.token { Token::Word(w) => match w.keyword { Keyword::EXPLAIN => Ok(self.parse_explain()?), Keyword::ANALYZE => Ok(self.parse_analyze()?), Keyword::SELECT | Keyword::WITH | Keyword::VALUES => { - self.prev_token(); + *self = checkpoint; Ok(Statement::Query(Box::new(self.parse_query()?))) } Keyword::DECLARE => Ok(self.parse_declare()?), @@ -337,28 +320,23 @@ impl Parser { Keyword::FLUSH => Ok(Statement::Flush), Keyword::WAIT => Ok(Statement::Wait), Keyword::RECOVER => Ok(Statement::Recover), - _ => self.expected( - "an SQL statement", - Token::Word(w).with_location(token.location), - ), + _ => self.expected_at(checkpoint, "statement"), }, Token::LParen => { - self.prev_token(); + *self = checkpoint; Ok(Statement::Query(Box::new(self.parse_query()?))) } - unexpected => { - self.expected("an SQL statement", unexpected.with_location(token.location)) - } + _ => self.expected_at(checkpoint, "statement"), } } - pub fn parse_truncate(&mut self) -> Result { + pub fn parse_truncate(&mut self) -> PResult { let _ = self.parse_keyword(Keyword::TABLE); let table_name = self.parse_object_name()?; Ok(Statement::Truncate { table_name }) } - pub fn parse_analyze(&mut self) -> Result { + pub fn parse_analyze(&mut self) -> PResult { let table_name = self.parse_object_name()?; Ok(Statement::Analyze { table_name }) @@ -372,14 +350,14 @@ impl Parser { /// contain parentheses. /// - Selecting all columns from a table. In this case, it is a /// [`WildcardOrExpr::QualifiedWildcard`] or a [`WildcardOrExpr::Wildcard`]. - pub fn parse_wildcard_or_expr(&mut self) -> Result { - let index = self.index; + pub fn parse_wildcard_or_expr(&mut self) -> PResult { + let checkpoint = *self; match self.next_token().token { Token::Word(w) if self.peek_token() == Token::Period => { // Since there's no parenthesis, `w` must be a column or a table // So what follows must be dot-delimited identifiers, e.g. `a.b.c.*` - let wildcard_expr = self.parse_simple_wildcard_expr(index)?; + let wildcard_expr = self.parse_simple_wildcard_expr(checkpoint)?; return self.word_concat_wildcard_expr(w.to_ident()?, wildcard_expr); } Token::Mul => { @@ -396,14 +374,14 @@ impl Parser { } // Now that we have an expr, what follows must be // dot-delimited identifiers, e.g. `b.c.*` in `(a).b.c.*` - let wildcard_expr = self.parse_simple_wildcard_expr(index)?; + let wildcard_expr = self.parse_simple_wildcard_expr(checkpoint)?; return self.expr_concat_wildcard_expr(expr, wildcard_expr); } } _ => (), }; - self.index = index; + *self = checkpoint; self.parse_expr().map(WildcardOrExpr::Expr) } @@ -412,7 +390,7 @@ impl Parser { &mut self, ident: Ident, simple_wildcard_expr: WildcardOrExpr, - ) -> Result { + ) -> PResult { let mut idents = vec![ident]; let mut except_cols = vec![]; match simple_wildcard_expr { @@ -445,7 +423,7 @@ impl Parser { &mut self, expr: Expr, simple_wildcard_expr: WildcardOrExpr, - ) -> Result { + ) -> PResult { if let WildcardOrExpr::Expr(e) = simple_wildcard_expr { return Ok(WildcardOrExpr::Expr(e)); } @@ -475,19 +453,13 @@ impl Parser { match simple_wildcard_expr { WildcardOrExpr::QualifiedWildcard(ids, except) => { if except.is_some() { - return self.expected( - "Expr quantified wildcard does not support except", - self.peek_token(), - ); + return self.expected("Expr quantified wildcard does not support except"); } idents.extend(ids.0); } WildcardOrExpr::Wildcard(except) => { if except.is_some() { - return self.expected( - "Expr quantified wildcard does not support except", - self.peek_token(), - ); + return self.expected("Expr quantified wildcard does not support except"); } } WildcardOrExpr::ExprQualifiedWildcard(_, _) => unreachable!(), @@ -499,12 +471,10 @@ impl Parser { /// Tries to parses a wildcard expression without any parentheses. /// /// If wildcard is not found, go back to `index` and parse an expression. - pub fn parse_simple_wildcard_expr( - &mut self, - index: usize, - ) -> Result { + pub fn parse_simple_wildcard_expr(&mut self, checkpoint: Self) -> PResult { let mut id_parts = vec![]; while self.consume_token(&Token::Period) { + let ckpt = *self; let token = self.next_token(); match token.token { Token::Word(w) => id_parts.push(w.to_ident()?), @@ -518,43 +488,38 @@ impl Parser { )) }; } - unexpected => { - return self.expected( - "an identifier or a '*' after '.'", - unexpected.with_location(token.location), - ); + _ => { + *self = ckpt; + return self.expected("an identifier or a '*' after '.'"); } } } - self.index = index; + *self = checkpoint; self.parse_expr().map(WildcardOrExpr::Expr) } - pub fn parse_except(&mut self) -> Result>, ParserError> { + pub fn parse_except(&mut self) -> PResult>> { if !self.parse_keyword(Keyword::EXCEPT) { return Ok(None); } if !self.consume_token(&Token::LParen) { - return self.expected("EXCEPT should be followed by (", self.peek_token()); + return self.expected("EXCEPT should be followed by ("); } let exprs = self.parse_comma_separated(Parser::parse_expr)?; if self.consume_token(&Token::RParen) { Ok(Some(exprs)) } else { - self.expected( - "( should be followed by ) after column names", - self.peek_token(), - ) + self.expected("( should be followed by ) after column names") } } /// Parse a new expression - pub fn parse_expr(&mut self) -> Result { + pub fn parse_expr(&mut self) -> PResult { self.parse_subexpr(Precedence::Zero) } /// Parse tokens until the precedence changes - pub fn parse_subexpr(&mut self, precedence: Precedence) -> Result { + pub fn parse_subexpr(&mut self, precedence: Precedence) -> PResult { debug!("parsing expr, current token: {:?}", self.peek_token().token); let mut expr = self.parse_prefix()?; debug!("prefix: {:?}", expr); @@ -572,7 +537,7 @@ impl Parser { } /// Parse an expression prefix - pub fn parse_prefix(&mut self) -> Result { + pub fn parse_prefix(&mut self) -> PResult { // PostgreSQL allows any string literal to be preceded by a type name, indicating that the // string literal represents a literal of that type. Some examples: // @@ -607,11 +572,12 @@ impl Parser { } })); + let checkpoint = *self; let token = self.next_token(); let expr = match token.token.clone() { Token::Word(w) => match w.keyword { Keyword::TRUE | Keyword::FALSE | Keyword::NULL => { - self.prev_token(); + *self = checkpoint; Ok(Expr::Value(self.parse_value()?)) } Keyword::CASE => self.parse_case_expr(), @@ -657,7 +623,7 @@ impl Parser { // TODO: support `all/any/some(subquery)`. if let Expr::Subquery(_) = &sub { - parser_err!("ANY/SOME/ALL(Subquery) is not implemented")?; + parser_err!("ANY/SOME/ALL(Subquery) is not implemented"); } Ok(match keyword { @@ -668,7 +634,7 @@ impl Parser { }) } k if keywords::RESERVED_FOR_COLUMN_OR_TABLE_NAME.contains(&k) => { - parser_err!(format!("syntax error at or near {token}")) + parser_err!("syntax error at or near {token}") } // Here `w` is a word, check if it's a part of a multi-part // identifier, a function call, or a simple identifier: @@ -676,20 +642,18 @@ impl Parser { Token::LParen | Token::Period => { let mut id_parts: Vec = vec![w.to_ident()?]; while self.consume_token(&Token::Period) { + let ckpt = *self; let token = self.next_token(); match token.token { Token::Word(w) => id_parts.push(w.to_ident()?), - unexpected => { - return self.expected( - "an identifier or a '*' after '.'", - unexpected.with_location(token.location), - ); + _ => { + *self = ckpt; + return self.expected("an identifier or a '*' after '.'"); } } } - if self.consume_token(&Token::LParen) { - self.prev_token(); + if self.peek_token().token == Token::LParen { self.parse_function(ObjectName(id_parts)) } else { Ok(Expr::CompoundIdentifier(id_parts)) @@ -743,7 +707,7 @@ impl Parser { | Token::NationalStringLiteral(_) | Token::HexStringLiteral(_) | Token::CstyleEscapesString(_) => { - self.prev_token(); + *self = checkpoint; Ok(Expr::Value(self.parse_value()?)) } Token::Parameter(number) => self.parse_param(number), @@ -757,18 +721,17 @@ impl Parser { }) } Token::LParen => { - let expr = - if self.parse_keyword(Keyword::SELECT) || self.parse_keyword(Keyword::WITH) { - self.prev_token(); - Expr::Subquery(Box::new(self.parse_query()?)) + let expr = if matches!(self.peek_token().token, Token::Word(w) if w.keyword == Keyword::SELECT || w.keyword == Keyword::WITH) + { + Expr::Subquery(Box::new(self.parse_query()?)) + } else { + let mut exprs = self.parse_comma_separated(Parser::parse_expr)?; + if exprs.len() == 1 { + Expr::Nested(Box::new(exprs.pop().unwrap())) } else { - let mut exprs = self.parse_comma_separated(Parser::parse_expr)?; - if exprs.len() == 1 { - Expr::Nested(Box::new(exprs.pop().unwrap())) - } else { - Expr::Row(exprs) - } - }; + Expr::Row(exprs) + } + }; self.expect_token(&Token::RParen)?; if self.peek_token() == Token::Period && matches!(expr, Expr::Nested(_)) { self.parse_struct_selection(expr) @@ -776,7 +739,7 @@ impl Parser { Ok(expr) } } - unexpected => self.expected("an expression:", unexpected.with_location(token.location)), + _ => self.expected_at(checkpoint, "an expression"), }?; if self.parse_keyword(Keyword::COLLATE) { @@ -789,16 +752,15 @@ impl Parser { } } - fn parse_param(&mut self, param: String) -> Result { - Ok(Expr::Parameter { - index: param.parse().map_err(|_| { - ParserError::ParserError(format!("Parameter symbol has a invalid index {}.", param)) - })?, - }) + fn parse_param(&mut self, param: String) -> PResult { + let Ok(index) = param.parse() else { + parser_err!("Parameter symbol has a invalid index {}.", param); + }; + Ok(Expr::Parameter { index }) } /// Parses a field selection expression. See also [`Expr::FieldIdentifier`]. - pub fn parse_struct_selection(&mut self, expr: Expr) -> Result { + pub fn parse_struct_selection(&mut self, expr: Expr) -> PResult { let mut nested_expr = expr; // Unwrap parentheses while let Expr::Nested(inner) = nested_expr { @@ -809,35 +771,21 @@ impl Parser { } /// Parses consecutive field identifiers after a period. i.e., `.foo.bar.baz` - pub fn parse_fields(&mut self) -> Result, ParserError> { - let mut idents = vec![]; - while self.consume_token(&Token::Period) { - let token = self.next_token(); - match token.token { - Token::Word(w) => { - idents.push(w.to_ident()?); - } - unexpected => { - return self.expected( - "an identifier after '.'", - unexpected.with_location(token.location), - ); - } - } - } - Ok(idents) + pub fn parse_fields(&mut self) -> PResult> { + repeat(.., preceded(Token::Period, cut_err(Self::parse_identifier))).parse_next(self) } - pub fn parse_qualified_operator(&mut self) -> Result { + pub fn parse_qualified_operator(&mut self) -> PResult { self.expect_token(&Token::LParen)?; + let checkpoint = *self; let schema = match self.parse_identifier_non_reserved() { Ok(ident) => { self.expect_token(&Token::Period)?; Some(ident) } Err(_) => { - self.prev_token(); + *self = checkpoint; None } }; @@ -853,11 +801,12 @@ impl Parser { // // To support custom operators and be fully compatible with PostgreSQL later, the // tokenizer should also be updated. + let checkpoint = *self; let token = self.next_token(); let name = token.token.to_string(); if !name.trim_matches(OP_CHARS).is_empty() { - self.prev_token(); - return self.expected(&format!("one of {}", OP_CHARS.iter().join(" ")), token); + return self + .expected_at(checkpoint, &format!("one of {}", OP_CHARS.iter().join(" "))); } name }; @@ -866,7 +815,7 @@ impl Parser { Ok(QualifiedOperator { schema, name }) } - pub fn parse_function(&mut self, name: ObjectName) -> Result { + pub fn parse_function(&mut self, name: ObjectName) -> PResult { self.expect_token(&Token::LParen)?; let distinct = self.parse_all_or_distinct()?; let (args, order_by, variadic) = self.parse_optional_args()?; @@ -914,10 +863,7 @@ impl Parser { let within_group = if self.parse_keywords(&[Keyword::WITHIN, Keyword::GROUP]) { self.expect_token(&Token::LParen)?; self.expect_keywords(&[Keyword::ORDER, Keyword::BY])?; - let order_by_parsed = self.parse_comma_separated(Parser::parse_order_by_expr)?; - let order_by = order_by_parsed.iter().exactly_one().map_err(|_| { - ParserError::ParserError("only one arg in order by is expected here".to_string()) - })?; + let order_by = self.parse_order_by_expr()?; self.expect_token(&Token::RParen)?; Some(Box::new(order_by.clone())) } else { @@ -936,26 +882,18 @@ impl Parser { })) } - pub fn parse_window_frame_units(&mut self) -> Result { - let token = self.next_token(); - match token.token { - Token::Word(w) => match w.keyword { - Keyword::ROWS => Ok(WindowFrameUnits::Rows), - Keyword::RANGE => Ok(WindowFrameUnits::Range), - Keyword::GROUPS => Ok(WindowFrameUnits::Groups), - _ => self.expected( - "ROWS, RANGE, GROUPS", - Token::Word(w).with_location(token.location), - )?, - }, - unexpected => self.expected( - "ROWS, RANGE, GROUPS", - unexpected.with_location(token.location), - ), + pub fn parse_window_frame_units(&mut self) -> PResult { + dispatch! { peek(keyword); + Keyword::ROWS => keyword.value(WindowFrameUnits::Rows), + Keyword::RANGE => keyword.value(WindowFrameUnits::Range), + Keyword::GROUPS => keyword.value(WindowFrameUnits::Groups), + _ => fail, } + .expect("ROWS, RANGE, or GROUPS") + .parse_next(self) } - pub fn parse_window_frame(&mut self) -> Result { + pub fn parse_window_frame(&mut self) -> PResult { let units = self.parse_window_frame_units()?; let (start_bound, end_bound) = if self.parse_keyword(Keyword::BETWEEN) { let start_bound = self.parse_window_frame_bound()?; @@ -979,7 +917,7 @@ impl Parser { } /// Parse `CURRENT ROW` or `{ | UNBOUNDED } { PRECEDING | FOLLOWING }` - pub fn parse_window_frame_bound(&mut self) -> Result { + pub fn parse_window_frame_bound(&mut self) -> PResult { if self.parse_keywords(&[Keyword::CURRENT, Keyword::ROW]) { Ok(WindowFrameBound::CurrentRow) } else { @@ -993,12 +931,12 @@ impl Parser { } else if self.parse_keyword(Keyword::FOLLOWING) { Ok(WindowFrameBound::Following(rows)) } else { - self.expected("PRECEDING or FOLLOWING", self.peek_token()) + self.expected("PRECEDING or FOLLOWING") } } } - pub fn parse_window_frame_exclusion(&mut self) -> Result { + pub fn parse_window_frame_exclusion(&mut self) -> PResult { if self.parse_keywords(&[Keyword::CURRENT, Keyword::ROW]) { Ok(WindowFrameExclusion::CurrentRow) } else if self.parse_keyword(Keyword::GROUP) { @@ -1008,13 +946,13 @@ impl Parser { } else if self.parse_keywords(&[Keyword::NO, Keyword::OTHERS]) { Ok(WindowFrameExclusion::NoOthers) } else { - self.expected("CURRENT ROW, GROUP, TIES, or NO OTHERS", self.peek_token()) + self.expected("CURRENT ROW, GROUP, TIES, or NO OTHERS") } } /// parse a group by expr. a group by expr can be one of group sets, roll up, cube, or simple /// expr. - fn parse_group_by_expr(&mut self) -> Result { + fn parse_group_by_expr(&mut self) -> PResult { if self.parse_keywords(&[Keyword::GROUPING, Keyword::SETS]) { self.expect_token(&Token::LParen)?; let result = self.parse_comma_separated(|p| p.parse_tuple(true, true))?; @@ -1038,11 +976,7 @@ impl Parser { /// parse a tuple with `(` and `)`. /// If `lift_singleton` is true, then a singleton tuple is lifted to a tuple of length 1, /// otherwise it will fail. If `allow_empty` is true, then an empty tuple is allowed. - fn parse_tuple( - &mut self, - lift_singleton: bool, - allow_empty: bool, - ) -> Result, ParserError> { + fn parse_tuple(&mut self, lift_singleton: bool, allow_empty: bool) -> PResult> { if lift_singleton { if self.consume_token(&Token::LParen) { let result = if allow_empty && self.consume_token(&Token::RParen) { @@ -1069,7 +1003,7 @@ impl Parser { } } - pub fn parse_case_expr(&mut self) -> Result { + pub fn parse_case_expr(&mut self) -> PResult { let mut operand = None; if !self.parse_keyword(Keyword::WHEN) { operand = Some(Box::new(self.parse_expr()?)); @@ -1100,7 +1034,7 @@ impl Parser { } /// Parse a SQL CAST function e.g. `CAST(expr AS FLOAT)` - pub fn parse_cast_expr(&mut self) -> Result { + pub fn parse_cast_expr(&mut self) -> PResult { self.expect_token(&Token::LParen)?; let expr = self.parse_expr()?; self.expect_keyword(Keyword::AS)?; @@ -1113,7 +1047,7 @@ impl Parser { } /// Parse a SQL TRY_CAST function e.g. `TRY_CAST(expr AS FLOAT)` - pub fn parse_try_cast_expr(&mut self) -> Result { + pub fn parse_try_cast_expr(&mut self) -> PResult { self.expect_token(&Token::LParen)?; let expr = self.parse_expr()?; self.expect_keyword(Keyword::AS)?; @@ -1126,14 +1060,14 @@ impl Parser { } /// Parse a SQL EXISTS expression e.g. `WHERE EXISTS(SELECT ...)`. - pub fn parse_exists_expr(&mut self) -> Result { + pub fn parse_exists_expr(&mut self) -> PResult { self.expect_token(&Token::LParen)?; let exists_node = Expr::Exists(Box::new(self.parse_query()?)); self.expect_token(&Token::RParen)?; Ok(exists_node) } - pub fn parse_extract_expr(&mut self) -> Result { + pub fn parse_extract_expr(&mut self) -> PResult { self.expect_token(&Token::LParen)?; let field = self.parse_date_time_field_in_extract()?; self.expect_keyword(Keyword::FROM)?; @@ -1145,7 +1079,7 @@ impl Parser { }) } - pub fn parse_substring_expr(&mut self) -> Result { + pub fn parse_substring_expr(&mut self) -> PResult { // PARSE SUBSTRING (EXPR [FROM 1] [FOR 3]) self.expect_token(&Token::LParen)?; let expr = self.parse_expr()?; @@ -1168,7 +1102,7 @@ impl Parser { } /// `POSITION( IN )` - pub fn parse_position_expr(&mut self) -> Result { + pub fn parse_position_expr(&mut self) -> PResult { self.expect_token(&Token::LParen)?; // Logically `parse_expr`, but limited to those with precedence higher than `BETWEEN`/`IN`, @@ -1187,7 +1121,7 @@ impl Parser { } /// `OVERLAY( PLACING FROM [ FOR ])` - pub fn parse_overlay_expr(&mut self) -> Result { + pub fn parse_overlay_expr(&mut self) -> PResult { self.expect_token(&Token::LParen)?; let expr = self.parse_expr()?; @@ -1215,7 +1149,7 @@ impl Parser { /// `TRIM ([WHERE] ['text'] FROM 'text')`\ /// `TRIM ([WHERE] [FROM] 'text' [, 'text'])` - pub fn parse_trim_expr(&mut self) -> Result { + pub fn parse_trim_expr(&mut self) -> PResult { self.expect_token(&Token::LParen)?; let mut trim_where = None; if let Token::Word(word) = self.peek_token().token { @@ -1251,26 +1185,19 @@ impl Parser { }) } - pub fn parse_trim_where(&mut self) -> Result { - let token = self.next_token(); - match token.token { - Token::Word(w) => match w.keyword { - Keyword::BOTH => Ok(TrimWhereField::Both), - Keyword::LEADING => Ok(TrimWhereField::Leading), - Keyword::TRAILING => Ok(TrimWhereField::Trailing), - _ => self.expected( - "trim_where field", - Token::Word(w).with_location(token.location), - )?, - }, - unexpected => { - self.expected("trim_where field", unexpected.with_location(token.location)) - } + pub fn parse_trim_where(&mut self) -> PResult { + dispatch! { peek(keyword); + Keyword::BOTH => keyword.value(TrimWhereField::Both), + Keyword::LEADING => keyword.value(TrimWhereField::Leading), + Keyword::TRAILING => keyword.value(TrimWhereField::Trailing), + _ => fail } + .expect("BOTH, LEADING, or TRAILING") + .parse_next(self) } /// Parses an array expression `[ex1, ex2, ..]` - pub fn parse_array_expr(&mut self) -> Result { + pub fn parse_array_expr(&mut self) -> PResult { let mut expected_depth = None; let exprs = self.parse_array_inner(0, &mut expected_depth)?; Ok(Expr::Array(Array { @@ -1284,12 +1211,12 @@ impl Parser { &mut self, depth: usize, expected_depth: &mut Option, - ) -> Result, ParserError> { + ) -> PResult> { self.expect_token(&Token::LBracket)?; if let Some(expected_depth) = *expected_depth && depth > expected_depth { - return self.expected("]", self.peek_token()); + return self.expected("]"); } let exprs = if self.peek_token() == Token::LBracket { self.parse_comma_separated(|parser| { @@ -1302,7 +1229,7 @@ impl Parser { } else { if let Some(expected_depth) = *expected_depth { if depth < expected_depth { - return self.expected("[", self.peek_token()); + return self.expected("["); } } else { *expected_depth = Some(depth); @@ -1317,25 +1244,18 @@ impl Parser { } // This function parses date/time fields for interval qualifiers. - pub fn parse_date_time_field(&mut self) -> Result { - let token = self.next_token(); - match token.token { - Token::Word(w) => match w.keyword { - Keyword::YEAR => Ok(DateTimeField::Year), - Keyword::MONTH => Ok(DateTimeField::Month), - Keyword::DAY => Ok(DateTimeField::Day), - Keyword::HOUR => Ok(DateTimeField::Hour), - Keyword::MINUTE => Ok(DateTimeField::Minute), - Keyword::SECOND => Ok(DateTimeField::Second), - _ => self.expected( - "date/time field", - Token::Word(w).with_location(token.location), - )?, - }, - unexpected => { - self.expected("date/time field", unexpected.with_location(token.location)) - } + pub fn parse_date_time_field(&mut self) -> PResult { + dispatch! { peek(keyword); + Keyword::YEAR => keyword.value(DateTimeField::Year), + Keyword::MONTH => keyword.value(DateTimeField::Month), + Keyword::DAY => keyword.value(DateTimeField::Day), + Keyword::HOUR => keyword.value(DateTimeField::Hour), + Keyword::MINUTE => keyword.value(DateTimeField::Minute), + Keyword::SECOND => keyword.value(DateTimeField::Second), + _ => fail, } + .expect("date/time field") + .parse_next(self) } // This function parses date/time fields for the EXTRACT function-like operator. PostgreSQL @@ -1347,13 +1267,15 @@ impl Parser { // select extract("invaLId" from null::date); // select extract('invaLId' from null::date); // ``` - pub fn parse_date_time_field_in_extract(&mut self) -> Result { + pub fn parse_date_time_field_in_extract(&mut self) -> PResult { + let checkpoint = *self; let token = self.next_token(); match token.token { Token::Word(w) => Ok(w.value.to_uppercase()), Token::SingleQuotedString(s) => Ok(s.to_uppercase()), - unexpected => { - self.expected("date/time field", unexpected.with_location(token.location)) + _ => { + *self = checkpoint; + self.expected("date/time field") } } } @@ -1370,7 +1292,7 @@ impl Parser { /// 6. `INTERVAL '1:1' HOUR (5) TO MINUTE (5)` /// /// Note that we do not currently attempt to parse the quoted value. - pub fn parse_literal_interval(&mut self) -> Result { + pub fn parse_literal_interval(&mut self) -> PResult { // The SQL standard allows an optional sign before the value string, but // it is not clear if any implementations support that syntax, so we // don't currently try to parse it. (The sign can instead be included @@ -1438,7 +1360,8 @@ impl Parser { } /// Parse an operator following an expression - pub fn parse_infix(&mut self, expr: Expr, precedence: Precedence) -> Result { + pub fn parse_infix(&mut self, expr: Expr, precedence: Precedence) -> PResult { + let checkpoint = *self; let tok = self.next_token(); debug!("parsing infix {:?}", tok.token); let regular_binary_operator = match &tok.token { @@ -1508,7 +1431,7 @@ impl Parser { // // TODO: support `all/any/some(subquery)`. // if let Expr::Subquery(_) = &right { - // parser_err!("ANY/SOME/ALL(Subquery) is not implemented")?; + // parser_err!("ANY/SOME/ALL(Subquery) is not implemented"); // } // let right = match keyword { @@ -1564,27 +1487,20 @@ impl Parser { } else { self.expected( "[NOT] { TRUE | FALSE | UNKNOWN | NULL | DISTINCT FROM | JSON } after IS", - self.peek_token(), ) } } } Keyword::AT => { - if self.parse_keywords(&[Keyword::TIME, Keyword::ZONE]) { - let token = self.next_token(); - match token.token { - Token::SingleQuotedString(time_zone) => Ok(Expr::AtTimeZone { - timestamp: Box::new(expr), - time_zone, - }), - unexpected => self.expected( - "Expected Token::SingleQuotedString after AT TIME ZONE", - unexpected.with_location(token.location), - ), - } - } else { - self.expected("Expected Token::Word after AT", tok) - } + let time_zone = preceded( + (Keyword::TIME, Keyword::ZONE), + cut_err(Self::parse_literal_string), + ) + .parse_next(self)?; + Ok(Expr::AtTimeZone { + timestamp: Box::new(expr), + time_zone, + }) } keyword @ (Keyword::ALL | Keyword::ANY | Keyword::SOME) => { self.expect_token(&Token::LParen)?; @@ -1595,7 +1511,7 @@ impl Parser { // TODO: support `all/any/some(subquery)`. if let Expr::Subquery(_) = &sub { - parser_err!("ANY/SOME/ALL(Subquery) is not implemented")?; + parser_err!("ANY/SOME/ALL(Subquery) is not implemented"); } Ok(match keyword { @@ -1611,7 +1527,7 @@ impl Parser { | Keyword::LIKE | Keyword::ILIKE | Keyword::SIMILAR => { - self.prev_token(); + *self = checkpoint; let negated = self.parse_keyword(Keyword::NOT); if self.parse_keyword(Keyword::IN) { self.parse_in(expr, negated) @@ -1639,11 +1555,11 @@ impl Parser { escape_char: self.parse_escape()?, }) } else { - self.expected("IN, BETWEEN or SIMILAR TO after NOT", self.peek_token()) + self.expected("IN, BETWEEN or SIMILAR TO after NOT") } } // Can only happen if `get_next_precedence` got out of sync with this function - _ => parser_err!(format!("No infix parser for token {:?}", tok)), + _ => parser_err!("No infix parser for token {:?}", tok), } } else if Token::DoubleColon == tok { self.parse_pg_cast(expr) @@ -1657,20 +1573,18 @@ impl Parser { self.parse_array_index(expr) } else { // Can only happen if `get_next_precedence` got out of sync with this function - parser_err!(format!("No infix parser for token {:?}", tok)) + parser_err!("No infix parser for token {:?}", tok) } } /// parse the ESCAPE CHAR portion of LIKE, ILIKE, and SIMILAR TO - pub fn parse_escape(&mut self) -> Result, ParserError> { + pub fn parse_escape(&mut self) -> PResult> { if self.parse_keyword(Keyword::ESCAPE) { let s = self.parse_literal_string()?; let mut chs = s.chars(); if let Some(ch) = chs.next() { if chs.next().is_some() { - parser_err!(format!( - "Escape string must be empty or one character, found {s:?}" - )) + parser_err!("Escape string must be empty or one character, found {s:?}") } else { Ok(Some(EscapeChar::escape(ch))) } @@ -1684,7 +1598,7 @@ impl Parser { /// We parse both `array[1,9][1]`, `array[1,9][1:2]`, `array[1,9][:2]`, `array[1,9][1:]` and /// `array[1,9][:]` in this function. - pub fn parse_array_index(&mut self, expr: Expr) -> Result { + pub fn parse_array_index(&mut self, expr: Expr) -> PResult { let new_expr = match self.peek_token().token { Token::Colon => { // [:] or [:N] @@ -1749,7 +1663,7 @@ impl Parser { } /// Parses the optional constraints following the `IS [NOT] JSON` predicate - pub fn parse_is_json(&mut self, expr: Expr, negated: bool) -> Result { + pub fn parse_is_json(&mut self, expr: Expr, negated: bool) -> PResult { let item_type = match self.peek_token().token { Token::Word(w) => match w.keyword { Keyword::VALUE => Some(JsonPredicateType::Value), @@ -1781,10 +1695,10 @@ impl Parser { } /// Parses the parens following the `[ NOT ] IN` operator - pub fn parse_in(&mut self, expr: Expr, negated: bool) -> Result { + pub fn parse_in(&mut self, expr: Expr, negated: bool) -> PResult { self.expect_token(&Token::LParen)?; - let in_op = if self.parse_keyword(Keyword::SELECT) || self.parse_keyword(Keyword::WITH) { - self.prev_token(); + let in_op = if matches!(self.peek_token().token, Token::Word(w) if w.keyword == Keyword::SELECT || w.keyword == Keyword::WITH) + { Expr::InSubquery { expr: Box::new(expr), subquery: Box::new(self.parse_query()?), @@ -1802,7 +1716,7 @@ impl Parser { } /// Parses `BETWEEN AND `, assuming the `BETWEEN` keyword was already consumed - pub fn parse_between(&mut self, expr: Expr, negated: bool) -> Result { + pub fn parse_between(&mut self, expr: Expr, negated: bool) -> PResult { // Stop parsing subexpressions for and on tokens with // precedence lower than that of `BETWEEN`, such as `AND`, `IS`, etc. let low = self.parse_subexpr(Precedence::Between)?; @@ -1817,7 +1731,7 @@ impl Parser { } /// Parse a postgresql casting style which is in the form of `expr::datatype` - pub fn parse_pg_cast(&mut self, expr: Expr) -> Result { + pub fn parse_pg_cast(&mut self, expr: Expr) -> PResult { Ok(Expr::Cast { expr: Box::new(expr), data_type: self.parse_data_type()?, @@ -1825,7 +1739,7 @@ impl Parser { } /// Get the precedence of the next token - pub fn get_next_precedence(&self) -> Result { + pub fn get_next_precedence(&self) -> PResult { use Precedence as P; let token = self.peek_token(); @@ -1930,17 +1844,15 @@ impl Parser { /// Return nth non-whitespace token that has not yet been processed pub fn peek_nth_token(&self, mut n: usize) -> TokenWithLocation { - let mut index = self.index; + let mut index = 0; loop { + let token = self.0.get(index); index += 1; - let token = self.tokens.get(index - 1); match token.map(|x| &x.token) { Some(Token::Whitespace(_)) => continue, _ => { if n == 0 { - return token - .cloned() - .unwrap_or(TokenWithLocation::wrap(Token::EOF)); + return token.cloned().unwrap_or(TokenWithLocation::eof()); } n -= 1; } @@ -1953,44 +1865,37 @@ impl Parser { /// repeatedly after reaching EOF. pub fn next_token(&mut self) -> TokenWithLocation { loop { - self.index += 1; - let token = self.tokens.get(self.index - 1); - match token.map(|x| &x.token) { - Some(Token::Whitespace(_)) => continue, - _ => { - return token - .cloned() - .unwrap_or(TokenWithLocation::wrap(Token::EOF)); - } + let Some(token) = self.0.first() else { + return TokenWithLocation::eof(); + }; + self.0 = &self.0[1..]; + match token.token { + Token::Whitespace(_) => continue, + _ => return token.clone(), } } } /// Return the first unprocessed token, possibly whitespace. pub fn next_token_no_skip(&mut self) -> Option<&TokenWithLocation> { - self.index += 1; - self.tokens.get(self.index - 1) + if self.0.is_empty() { + None + } else { + let (first, rest) = self.0.split_at(1); + self.0 = rest; + Some(&first[0]) + } } - /// Push back the last one non-whitespace token. Must be called after - /// `next_token()`, otherwise might panic. OK to call after - /// `next_token()` indicates an EOF. - pub fn prev_token(&mut self) { - loop { - assert!(self.index > 0); - self.index -= 1; - if let Some(token) = self.tokens.get(self.index) - && let Token::Whitespace(_) = token.token - { - continue; - } - return; - } + /// Report an expected error at the current position. + pub fn expected(&self, expected: &str) -> PResult { + parser_err!("expected {}, found: {}", expected, self.peek_token().token) } - /// Report unexpected token - pub fn expected(&self, expected: &str, found: TokenWithLocation) -> Result { - parser_err!(format!("expected {}, found: {}", expected, found)) + /// Revert the parser to a previous position and report an expected error. + pub fn expected_at(&mut self, checkpoint: Self, expected: &str) -> PResult { + *self = checkpoint; + self.expected(expected) } /// Look for an expected keyword and consume it if it exists @@ -2008,12 +1913,12 @@ impl Parser { /// Look for an expected sequence of keywords and consume them if they exist #[must_use] pub fn parse_keywords(&mut self, keywords: &[Keyword]) -> bool { - let index = self.index; + let checkpoint = *self; for &keyword in keywords { if !self.parse_keyword(keyword) { // println!("parse_keywords aborting .. did not find {:?}", keyword); // reset index and return immediately - self.index = index; + *self = checkpoint; return false; } } @@ -2045,30 +1950,27 @@ impl Parser { } /// Bail out if the current token is not one of the expected keywords, or consume it if it is - pub fn expect_one_of_keywords(&mut self, keywords: &[Keyword]) -> Result { + pub fn expect_one_of_keywords(&mut self, keywords: &[Keyword]) -> PResult { if let Some(keyword) = self.parse_one_of_keywords(keywords) { Ok(keyword) } else { let keywords: Vec = keywords.iter().map(|x| format!("{:?}", x)).collect(); - self.expected( - &format!("one of {}", keywords.join(" or ")), - self.peek_token(), - ) + self.expected(&format!("one of {}", keywords.join(" or "))) } } /// Bail out if the current token is not an expected keyword, or consume it if it is - pub fn expect_keyword(&mut self, expected: Keyword) -> Result<(), ParserError> { + pub fn expect_keyword(&mut self, expected: Keyword) -> PResult<()> { if self.parse_keyword(expected) { Ok(()) } else { - self.expected(format!("{:?}", &expected).as_str(), self.peek_token()) + self.expected(format!("{:?}", &expected).as_str()) } } /// Bail out if the following tokens are not the expected sequence of /// keywords, or consume them if they are. - pub fn expect_keywords(&mut self, expected: &[Keyword]) -> Result<(), ParserError> { + pub fn expect_keywords(&mut self, expected: &[Keyword]) -> PResult<()> { for &kw in expected { self.expect_keyword(kw)?; } @@ -2087,18 +1989,18 @@ impl Parser { } /// Bail out if the current token is not an expected keyword, or consume it if it is - pub fn expect_token(&mut self, expected: &Token) -> Result<(), ParserError> { + pub fn expect_token(&mut self, expected: &Token) -> PResult<()> { if self.consume_token(expected) { Ok(()) } else { - self.expected(&expected.to_string(), self.peek_token()) + self.expected(&expected.to_string()) } } /// Parse a comma-separated list of 1+ items accepted by `F` - pub fn parse_comma_separated(&mut self, mut f: F) -> Result, ParserError> + pub fn parse_comma_separated(&mut self, mut f: F) -> PResult> where - F: FnMut(&mut Parser) -> Result, + F: FnMut(&mut Self) -> PResult, { let mut values = vec![]; loop { @@ -2115,31 +2017,31 @@ impl Parser { #[must_use] fn maybe_parse(&mut self, mut f: F) -> Option where - F: FnMut(&mut Parser) -> Result, + F: FnMut(&mut Self) -> PResult, { - let index = self.index; + let checkpoint = *self; if let Ok(t) = f(self) { Some(t) } else { - self.index = index; + *self = checkpoint; None } } /// Parse either `ALL` or `DISTINCT`. Returns `true` if `DISTINCT` is parsed and results in a /// `ParserError` if both `ALL` and `DISTINCT` are fround. - pub fn parse_all_or_distinct(&mut self) -> Result { + pub fn parse_all_or_distinct(&mut self) -> PResult { let all = self.parse_keyword(Keyword::ALL); let distinct = self.parse_keyword(Keyword::DISTINCT); if all && distinct { - parser_err!("Cannot specify both ALL and DISTINCT".to_string()) + parser_err!("Cannot specify both ALL and DISTINCT") } else { Ok(distinct) } } /// Parse either `ALL` or `DISTINCT` or `DISTINCT ON ()`. - pub fn parse_all_or_distinct_on(&mut self) -> Result { + pub fn parse_all_or_distinct_on(&mut self) -> PResult { if self.parse_keywords(&[Keyword::DISTINCT, Keyword::ON]) { self.expect_token(&Token::LParen)?; let exprs = self.parse_comma_separated(Parser::parse_expr)?; @@ -2153,7 +2055,7 @@ impl Parser { } /// Parse a SQL CREATE statement - pub fn parse_create(&mut self) -> Result { + pub fn parse_create(&mut self) -> PResult { let or_replace = self.parse_keywords(&[Keyword::OR, Keyword::REPLACE]); let temporary = self .parse_one_of_keywords(&[Keyword::TEMP, Keyword::TEMPORARY]) @@ -2181,7 +2083,6 @@ impl Parser { } else if or_replace { self.expected( "[EXTERNAL] TABLE or [MATERIALIZED] VIEW or [MATERIALIZED] SOURCE or SINK or FUNCTION after CREATE OR REPLACE", - self.peek_token(), ) } else if self.parse_keyword(Keyword::INDEX) { self.parse_create_index(false) @@ -2196,11 +2097,11 @@ impl Parser { } else if self.parse_keyword(Keyword::SECRET) { self.parse_create_secret() } else { - self.expected("an object type after CREATE", self.peek_token()) + self.expected("an object type after CREATE") } } - pub fn parse_create_schema(&mut self) -> Result { + pub fn parse_create_schema(&mut self) -> PResult { let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let (schema_name, user_specified) = if self.parse_keyword(Keyword::AUTHORIZATION) { let user_specified = self.parse_object_name()?; @@ -2221,7 +2122,7 @@ impl Parser { }) } - pub fn parse_create_database(&mut self) -> Result { + pub fn parse_create_database(&mut self) -> PResult { let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let db_name = self.parse_object_name()?; Ok(Statement::CreateDatabase { @@ -2234,7 +2135,7 @@ impl Parser { &mut self, materialized: bool, or_replace: bool, - ) -> Result { + ) -> PResult { let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); // Many dialects support `OR ALTER` right after `CREATE`, but we don't (yet). // ANSI SQL and Postgres support RECURSIVE here, but we don't support it either. @@ -2269,7 +2170,7 @@ impl Parser { // [WITH (properties)]? // ROW FORMAT // [ROW SCHEMA LOCATION ]? - pub fn parse_create_source(&mut self, _or_replace: bool) -> Result { + pub fn parse_create_source(&mut self, _or_replace: bool) -> PResult { Ok(Statement::CreateSource { stmt: CreateSourceStatement::parse_to(self)?, }) @@ -2282,7 +2183,7 @@ impl Parser { // FROM // // [WITH (properties)]? - pub fn parse_create_sink(&mut self, _or_replace: bool) -> Result { + pub fn parse_create_sink(&mut self, _or_replace: bool) -> PResult { Ok(Statement::CreateSink { stmt: CreateSinkStatement::parse_to(self)?, }) @@ -2295,10 +2196,7 @@ impl Parser { // FROM // // [WITH (properties)]? - pub fn parse_create_subscription( - &mut self, - _or_replace: bool, - ) -> Result { + pub fn parse_create_subscription(&mut self, _or_replace: bool) -> PResult { Ok(Statement::CreateSubscription { stmt: CreateSubscriptionStatement::parse_to(self)?, }) @@ -2309,7 +2207,7 @@ impl Parser { // [IF NOT EXISTS]? // // [WITH (properties)]? - pub fn parse_create_connection(&mut self) -> Result { + pub fn parse_create_connection(&mut self) -> PResult { Ok(Statement::CreateConnection { stmt: CreateConnectionStatement::parse_to(self)?, }) @@ -2319,11 +2217,10 @@ impl Parser { &mut self, or_replace: bool, temporary: bool, - ) -> Result { + ) -> PResult { let name = self.parse_object_name()?; self.expect_token(&Token::LParen)?; - let args = if self.consume_token(&Token::RParen) { - self.prev_token(); + let args = if self.peek_token().token == Token::RParen { None } else { Some(self.parse_comma_separated(Parser::parse_function_arg)?) @@ -2342,7 +2239,7 @@ impl Parser { // allow a trailing comma, even though it's not in standard break; } else if !comma { - return self.expected("',' or ')'", self.peek_token()); + return self.expected("',' or ')'"); } } Some(CreateFunctionReturns::Table(values)) @@ -2367,7 +2264,7 @@ impl Parser { }) } - fn parse_create_aggregate(&mut self, or_replace: bool) -> Result { + fn parse_create_aggregate(&mut self, or_replace: bool) -> PResult { let name = self.parse_object_name()?; self.expect_token(&Token::LParen)?; let args = self.parse_comma_separated(Parser::parse_function_arg)?; @@ -2389,32 +2286,32 @@ impl Parser { }) } - pub fn parse_declare(&mut self) -> Result { + pub fn parse_declare(&mut self) -> PResult { Ok(Statement::DeclareCursor { stmt: DeclareCursorStatement::parse_to(self)?, }) } - pub fn parse_fetch_cursor(&mut self) -> Result { + pub fn parse_fetch_cursor(&mut self) -> PResult { Ok(Statement::FetchCursor { stmt: FetchCursorStatement::parse_to(self)?, }) } - pub fn parse_close_cursor(&mut self) -> Result { + pub fn parse_close_cursor(&mut self) -> PResult { Ok(Statement::CloseCursor { stmt: CloseCursorStatement::parse_to(self)?, }) } - fn parse_table_column_def(&mut self) -> Result { + fn parse_table_column_def(&mut self) -> PResult { Ok(TableColumnDef { name: self.parse_identifier_non_reserved()?, data_type: self.parse_data_type()?, }) } - fn parse_function_arg(&mut self) -> Result { + fn parse_function_arg(&mut self) -> PResult { let mode = if self.parse_keyword(Keyword::IN) { Some(ArgMode::In) } else if self.parse_keyword(Keyword::OUT) { @@ -2450,14 +2347,12 @@ impl Parser { }) } - fn parse_create_function_body(&mut self) -> Result { + fn parse_create_function_body(&mut self) -> PResult { let mut body = CreateFunctionBody::default(); loop { - fn ensure_not_set(field: &Option, name: &str) -> Result<(), ParserError> { + fn ensure_not_set(field: &Option, name: &str) -> PResult<()> { if field.is_some() { - return Err(ParserError::ParserError(format!( - "{name} specified more than once", - ))); + parser_err!("{name} specified more than once"); } Ok(()) } @@ -2500,7 +2395,7 @@ impl Parser { } } - fn parse_create_function_using(&mut self) -> Result { + fn parse_create_function_using(&mut self) -> PResult { let keyword = self.expect_one_of_keywords(&[Keyword::LINK, Keyword::BASE64])?; match keyword { @@ -2520,7 +2415,7 @@ impl Parser { &mut self, is_async: bool, is_generator: bool, - ) -> Result { + ) -> PResult { let is_generator = if is_generator { true } else { @@ -2542,29 +2437,28 @@ impl Parser { // | CREATEUSER | NOCREATEUSER // | LOGIN | NOLOGIN // | [ ENCRYPTED ] PASSWORD 'password' | PASSWORD NULL | OAUTH - fn parse_create_user(&mut self) -> Result { + fn parse_create_user(&mut self) -> PResult { Ok(Statement::CreateUser(CreateUserStatement::parse_to(self)?)) } - fn parse_create_secret(&mut self) -> Result { + fn parse_create_secret(&mut self) -> PResult { Ok(Statement::CreateSecret { stmt: CreateSecretStatement::parse_to(self)?, }) } - pub fn parse_with_properties(&mut self) -> Result, ParserError> { + pub fn parse_with_properties(&mut self) -> PResult> { Ok(self .parse_options_with_preceding_keyword(Keyword::WITH)? .to_vec()) } - pub fn parse_discard(&mut self) -> Result { - self.expect_keyword(Keyword::ALL) - .map_err(|_| ParserError::ParserError("only DISCARD ALL is supported".to_string()))?; + pub fn parse_discard(&mut self) -> PResult { + self.expect_keyword(Keyword::ALL)?; Ok(Statement::Discard(DiscardType::All)) } - pub fn parse_drop(&mut self) -> Result { + pub fn parse_drop(&mut self) -> PResult { if self.parse_keyword(Keyword::FUNCTION) { return self.parse_drop_function(); } else if self.parse_keyword(Keyword::AGGREGATE) { @@ -2577,7 +2471,7 @@ impl Parser { /// DROP FUNCTION [ IF EXISTS ] name [ ( [ [ argmode ] [ argname ] argtype [, ...] ] ) ] [, ...] /// [ CASCADE | RESTRICT ] /// ``` - fn parse_drop_function(&mut self) -> Result { + fn parse_drop_function(&mut self) -> PResult { let if_exists = self.parse_keywords(&[Keyword::IF, Keyword::EXISTS]); let func_desc = self.parse_comma_separated(Parser::parse_function_desc)?; let option = match self.parse_one_of_keywords(&[Keyword::CASCADE, Keyword::RESTRICT]) { @@ -2596,7 +2490,7 @@ impl Parser { /// DROP AGGREGATE [ IF EXISTS ] name [ ( [ [ argmode ] [ argname ] argtype [, ...] ] ) ] [, ...] /// [ CASCADE | RESTRICT ] /// ``` - fn parse_drop_aggregate(&mut self) -> Result { + fn parse_drop_aggregate(&mut self) -> PResult { let if_exists = self.parse_keywords(&[Keyword::IF, Keyword::EXISTS]); let func_desc = self.parse_comma_separated(Parser::parse_function_desc)?; let option = match self.parse_one_of_keywords(&[Keyword::CASCADE, Keyword::RESTRICT]) { @@ -2611,7 +2505,7 @@ impl Parser { }) } - fn parse_function_desc(&mut self) -> Result { + fn parse_function_desc(&mut self) -> PResult { let name = self.parse_object_name()?; let args = if self.consume_token(&Token::LParen) { @@ -2629,7 +2523,7 @@ impl Parser { Ok(FunctionDesc { name, args }) } - pub fn parse_create_index(&mut self, unique: bool) -> Result { + pub fn parse_create_index(&mut self, unique: bool) -> PResult { let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let index_name = self.parse_object_name()?; self.expect_keyword(Keyword::ON)?; @@ -2660,7 +2554,7 @@ impl Parser { }) } - pub fn parse_with_version_column(&mut self) -> Result, ParserError> { + pub fn parse_with_version_column(&mut self) -> PResult> { if self.parse_keywords(&[Keyword::WITH, Keyword::VERSION, Keyword::COLUMN]) { self.expect_token(&Token::LParen)?; let name = self.parse_identifier_non_reserved()?; @@ -2671,7 +2565,7 @@ impl Parser { } } - pub fn parse_on_conflict(&mut self) -> Result, ParserError> { + pub fn parse_on_conflict(&mut self) -> PResult> { if self.parse_keywords(&[Keyword::ON, Keyword::CONFLICT]) { self.parse_handle_conflict_behavior() } else { @@ -2679,11 +2573,7 @@ impl Parser { } } - pub fn parse_create_table( - &mut self, - or_replace: bool, - temporary: bool, - ) -> Result { + pub fn parse_create_table(&mut self, or_replace: bool, temporary: bool) -> PResult { let if_not_exists = self.parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let table_name = self.parse_object_name()?; // parse optional column list (schema) and watermarks on source. @@ -2718,9 +2608,7 @@ impl Parser { // Parse optional `AS ( query )` let query = if self.parse_keyword(Keyword::AS) { if !source_watermarks.is_empty() { - return Err(ParserError::ParserError( - "Watermarks can't be defined on table created by CREATE TABLE AS".to_string(), - )); + parser_err!("Watermarks can't be defined on table created by CREATE TABLE AS"); } Some(Box::new(self.parse_query()?)) } else { @@ -2759,7 +2647,7 @@ impl Parser { }) } - pub fn parse_include_options(&mut self) -> Result { + pub fn parse_include_options(&mut self) -> PResult { let mut options = vec![]; while self.parse_keyword(Keyword::INCLUDE) { let column_type = self.parse_identifier()?; @@ -2803,7 +2691,7 @@ impl Parser { Ok(options) } - pub fn parse_columns_with_watermark(&mut self) -> Result { + pub fn parse_columns_with_watermark(&mut self) -> PResult { let mut columns = vec![]; let mut constraints = vec![]; let mut watermarks = vec![]; @@ -2817,9 +2705,7 @@ impl Parser { if wildcard_idx.is_none() { wildcard_idx = Some(columns.len()); } else { - return Err(ParserError::ParserError( - "At most 1 wildcard is allowed in source definetion".to_string(), - )); + parser_err!("At most 1 wildcard is allowed in source definetion"); } } else if let Some(constraint) = self.parse_optional_table_constraint()? { constraints.push(constraint); @@ -2827,28 +2713,26 @@ impl Parser { watermarks.push(watermark); if watermarks.len() > 1 { // TODO(yuhao): allow multiple watermark on source. - return Err(ParserError::ParserError( - "Only 1 watermark is allowed to be defined on source.".to_string(), - )); + parser_err!("Only 1 watermark is allowed to be defined on source."); } } else if let Token::Word(_) = self.peek_token().token { columns.push(self.parse_column_def()?); } else { - return self.expected("column name or constraint definition", self.peek_token()); + return self.expected("column name or constraint definition"); } let comma = self.consume_token(&Token::Comma); if self.consume_token(&Token::RParen) { // allow a trailing comma, even though it's not in standard break; } else if !comma { - return self.expected("',' or ')' after column definition", self.peek_token()); + return self.expected("',' or ')' after column definition"); } } Ok((columns, constraints, watermarks, wildcard_idx)) } - fn parse_column_def(&mut self) -> Result { + fn parse_column_def(&mut self) -> PResult { let name = self.parse_identifier_non_reserved()?; let data_type = if let Token::Word(_) = self.peek_token().token { Some(self.parse_data_type()?) @@ -2868,10 +2752,7 @@ impl Parser { if let Some(option) = self.parse_optional_column_option()? { options.push(ColumnOptionDef { name, option }); } else { - return self.expected( - "constraint details after CONSTRAINT ", - self.peek_token(), - ); + return self.expected("constraint details after CONSTRAINT "); } } else if let Some(option) = self.parse_optional_column_option()? { options.push(ColumnOptionDef { name: None, option }); @@ -2887,7 +2768,7 @@ impl Parser { }) } - pub fn parse_optional_column_option(&mut self) -> Result, ParserError> { + pub fn parse_optional_column_option(&mut self) -> PResult> { if self.parse_keywords(&[Keyword::NOT, Keyword::NULL]) { Ok(Some(ColumnOption::NotNull)) } else if self.parse_keyword(Keyword::NULL) { @@ -2934,7 +2815,7 @@ impl Parser { } } - pub fn parse_handle_conflict_behavior(&mut self) -> Result, ParserError> { + pub fn parse_handle_conflict_behavior(&mut self) -> PResult> { if self.parse_keyword(Keyword::OVERWRITE) { Ok(Some(OnConflict::OverWrite)) } else if self.parse_keyword(Keyword::IGNORE) { @@ -2952,7 +2833,7 @@ impl Parser { } } - pub fn parse_referential_action(&mut self) -> Result { + pub fn parse_referential_action(&mut self) -> PResult { if self.parse_keyword(Keyword::RESTRICT) { Ok(ReferentialAction::Restrict) } else if self.parse_keyword(Keyword::CASCADE) { @@ -2964,14 +2845,11 @@ impl Parser { } else if self.parse_keywords(&[Keyword::SET, Keyword::DEFAULT]) { Ok(ReferentialAction::SetDefault) } else { - self.expected( - "one of RESTRICT, CASCADE, SET NULL, NO ACTION or SET DEFAULT", - self.peek_token(), - ) + self.expected("one of RESTRICT, CASCADE, SET NULL, NO ACTION or SET DEFAULT") } } - pub fn parse_optional_watermark(&mut self) -> Result, ParserError> { + pub fn parse_optional_watermark(&mut self) -> PResult> { if self.parse_keyword(Keyword::WATERMARK) { self.expect_keyword(Keyword::FOR)?; let column = self.parse_identifier_non_reserved()?; @@ -2983,14 +2861,13 @@ impl Parser { } } - pub fn parse_optional_table_constraint( - &mut self, - ) -> Result, ParserError> { + pub fn parse_optional_table_constraint(&mut self) -> PResult> { let name = if self.parse_keyword(Keyword::CONSTRAINT) { Some(self.parse_identifier_non_reserved()?) } else { None }; + let checkpoint = *self; let token = self.next_token(); match token.token { Token::Word(w) if w.keyword == Keyword::PRIMARY || w.keyword == Keyword::UNIQUE => { @@ -3039,14 +2916,11 @@ impl Parser { self.expect_token(&Token::RParen)?; Ok(Some(TableConstraint::Check { name, expr })) } - unexpected => { + _ => { + *self = checkpoint; if name.is_some() { - self.expected( - "PRIMARY, UNIQUE, FOREIGN, or CHECK", - unexpected.with_location(token.location), - ) + self.expected("PRIMARY, UNIQUE, FOREIGN, or CHECK") } else { - self.prev_token(); Ok(None) } } @@ -3056,7 +2930,7 @@ impl Parser { pub fn parse_options_with_preceding_keyword( &mut self, keyword: Keyword, - ) -> Result, ParserError> { + ) -> PResult> { if self.parse_keyword(keyword) { self.expect_token(&Token::LParen)?; self.parse_options_inner() @@ -3065,7 +2939,7 @@ impl Parser { } } - pub fn parse_options(&mut self) -> Result, ParserError> { + pub fn parse_options(&mut self) -> PResult> { if self.peek_token() == Token::LParen { self.next_token(); self.parse_options_inner() @@ -3075,7 +2949,7 @@ impl Parser { } // has parsed a LParen - pub fn parse_options_inner(&mut self) -> Result, ParserError> { + pub fn parse_options_inner(&mut self) -> PResult> { let mut values = vec![]; loop { values.push(Parser::parse_sql_option(self)?); @@ -3084,21 +2958,22 @@ impl Parser { // allow a trailing comma, even though it's not in standard break; } else if !comma { - return self.expected("',' or ')' after option definition", self.peek_token()); + return self.expected("',' or ')' after option definition"); } } Ok(values) } - pub fn parse_sql_option(&mut self) -> Result { + pub fn parse_sql_option(&mut self) -> PResult { let name = self.parse_object_name()?; self.expect_token(&Token::Eq)?; let value = self.parse_value()?; Ok(SqlOption { name, value }) } - pub fn parse_since(&mut self) -> Result, ParserError> { + pub fn parse_since(&mut self) -> PResult> { if self.parse_keyword(Keyword::SINCE) { + let checkpoint = *self; let token = self.next_token(); match token.token { Token::Word(w) => { @@ -3113,29 +2988,26 @@ impl Parser { self.expect_token(&Token::RParen)?; Ok(Some(Since::Begin)) } else { - parser_err!(format!( + parser_err!( "Expected proctime(), begin() or now(), found: {}", ident.real_value() - )) + ) } } Token::Number(s) => { - let num = s.parse::().map_err(|e| { - ParserError::ParserError(format!("Could not parse '{}' as u64: {}", s, e)) - }); - Ok(Some(Since::TimestampMsNum(num?))) + let num = s + .parse::() + .map_err(|e| StrError(format!("Could not parse '{}' as u64: {}", s, e)))?; + Ok(Some(Since::TimestampMsNum(num))) } - unexpected => self.expected( - "proctime(), begin() , now(), Number", - unexpected.with_location(token.location), - ), + _ => self.expected_at(checkpoint, "proctime(), begin() , now(), Number"), } } else { Ok(None) } } - pub fn parse_emit_mode(&mut self) -> Result, ParserError> { + pub fn parse_emit_mode(&mut self) -> PResult> { if self.parse_keyword(Keyword::EMIT) { match self.parse_one_of_keywords(&[Keyword::IMMEDIATELY, Keyword::ON]) { Some(Keyword::IMMEDIATELY) => Ok(Some(EmitMode::Immediately)), @@ -3144,17 +3016,14 @@ impl Parser { Ok(Some(EmitMode::OnWindowClose)) } Some(_) => unreachable!(), - None => self.expected( - "IMMEDIATELY or ON WINDOW CLOSE after EMIT", - self.peek_token(), - ), + None => self.expected("IMMEDIATELY or ON WINDOW CLOSE after EMIT"), } } else { Ok(None) } } - pub fn parse_alter(&mut self) -> Result { + pub fn parse_alter(&mut self) -> PResult { if self.parse_keyword(Keyword::DATABASE) { self.parse_alter_database() } else if self.parse_keyword(Keyword::SCHEMA) { @@ -3183,13 +3052,12 @@ impl Parser { self.parse_alter_subscription() } else { self.expected( - "DATABASE, SCHEMA, TABLE, INDEX, MATERIALIZED, VIEW, SINK, SUBSCRIPTION, SOURCE, FUNCTION, USER or SYSTEM after ALTER", - self.peek_token(), + "DATABASE, SCHEMA, TABLE, INDEX, MATERIALIZED, VIEW, SINK, SUBSCRIPTION, SOURCE, FUNCTION, USER or SYSTEM after ALTER" ) } } - pub fn parse_alter_database(&mut self) -> Result { + pub fn parse_alter_database(&mut self) -> PResult { let database_name = self.parse_object_name()?; let operation = if self.parse_keywords(&[Keyword::OWNER, Keyword::TO]) { let owner_name: Ident = self.parse_identifier()?; @@ -3201,10 +3069,10 @@ impl Parser { let database_name = self.parse_object_name()?; AlterDatabaseOperation::RenameDatabase { database_name } } else { - return self.expected("TO after RENAME", self.peek_token()); + return self.expected("TO after RENAME"); } } else { - return self.expected("OWNER TO after ALTER DATABASE", self.peek_token()); + return self.expected("OWNER TO after ALTER DATABASE"); }; Ok(Statement::AlterDatabase { @@ -3213,7 +3081,7 @@ impl Parser { }) } - pub fn parse_alter_schema(&mut self) -> Result { + pub fn parse_alter_schema(&mut self) -> PResult { let schema_name = self.parse_object_name()?; let operation = if self.parse_keywords(&[Keyword::OWNER, Keyword::TO]) { let owner_name: Ident = self.parse_identifier()?; @@ -3221,14 +3089,11 @@ impl Parser { new_owner_name: owner_name, } } else if self.parse_keyword(Keyword::RENAME) { - if self.parse_keyword(Keyword::TO) { - let schema_name = self.parse_object_name()?; - AlterSchemaOperation::RenameSchema { schema_name } - } else { - return self.expected("TO after RENAME", self.peek_token()); - } + self.expect_keyword(Keyword::TO)?; + let schema_name = self.parse_object_name()?; + AlterSchemaOperation::RenameSchema { schema_name } } else { - return self.expected("RENAME OR OWNER TO after ALTER SCHEMA", self.peek_token()); + return self.expected("RENAME OR OWNER TO after ALTER SCHEMA"); }; Ok(Statement::AlterSchema { @@ -3237,11 +3102,11 @@ impl Parser { }) } - pub fn parse_alter_user(&mut self) -> Result { + pub fn parse_alter_user(&mut self) -> PResult { Ok(Statement::AlterUser(AlterUserStatement::parse_to(self)?)) } - pub fn parse_alter_table(&mut self) -> Result { + pub fn parse_alter_table(&mut self) -> PResult { let _ = self.parse_keyword(Keyword::ONLY); let table_name = self.parse_object_name()?; let operation = if self.parse_keyword(Keyword::ADD) { @@ -3288,10 +3153,7 @@ impl Parser { if self.expect_keyword(Keyword::TO).is_err() && self.expect_token(&Token::Eq).is_err() { - return self.expected( - "TO or = after ALTER TABLE SET PARALLELISM", - self.peek_token(), - ); + return self.expected("TO or = after ALTER TABLE SET PARALLELISM"); } let value = self.parse_set_variable()?; @@ -3305,10 +3167,7 @@ impl Parser { } else if let Some(rate_limit) = self.parse_alter_streaming_rate_limit()? { AlterTableOperation::SetStreamingRateLimit { rate_limit } } else { - return self.expected( - "SCHEMA/PARALLELISM/STREAMING_RATE_LIMIT after SET", - self.peek_token(), - ); + return self.expected("SCHEMA/PARALLELISM/STREAMING_RATE_LIMIT after SET"); } } else if self.parse_keyword(Keyword::DROP) { let _ = self.parse_keyword(Keyword::COLUMN); @@ -3345,19 +3204,14 @@ impl Parser { }; AlterColumnOperation::SetDataType { data_type, using } } else { - return self.expected( - "SET/DROP NOT NULL, SET DEFAULT, SET DATA TYPE after ALTER COLUMN", - self.peek_token(), - ); + return self + .expected("SET/DROP NOT NULL, SET DEFAULT, SET DATA TYPE after ALTER COLUMN"); }; AlterTableOperation::AlterColumn { column_name, op } } else if self.parse_keywords(&[Keyword::REFRESH, Keyword::SCHEMA]) { AlterTableOperation::RefreshSchema } else { - return self.expected( - "ADD or RENAME or OWNER TO or SET or DROP after ALTER TABLE", - self.peek_token(), - ); + return self.expected("ADD or RENAME or OWNER TO or SET or DROP after ALTER TABLE"); }; Ok(Statement::AlterTable { name: table_name, @@ -3367,15 +3221,12 @@ impl Parser { /// STREAMING_RATE_LIMIT = default | NUMBER /// STREAMING_RATE_LIMIT TO default | NUMBER - pub fn parse_alter_streaming_rate_limit(&mut self) -> Result, ParserError> { + pub fn parse_alter_streaming_rate_limit(&mut self) -> PResult> { if !self.parse_keyword(Keyword::STREAMING_RATE_LIMIT) { return Ok(None); } if self.expect_keyword(Keyword::TO).is_err() && self.expect_token(&Token::Eq).is_err() { - return self.expected( - "TO or = after ALTER TABLE SET STREAMING_RATE_LIMIT", - self.peek_token(), - ); + return self.expected("TO or = after ALTER TABLE SET STREAMING_RATE_LIMIT"); } let rate_limit = if self.parse_keyword(Keyword::DEFAULT) { -1 @@ -3384,30 +3235,27 @@ impl Parser { if let Ok(n) = s.parse::() { n } else { - return self.expected("number or DEFAULT", self.peek_token()); + return self.expected("number or DEFAULT"); } }; Ok(Some(rate_limit)) } - pub fn parse_alter_index(&mut self) -> Result { + pub fn parse_alter_index(&mut self) -> PResult { let index_name = self.parse_object_name()?; let operation = if self.parse_keyword(Keyword::RENAME) { if self.parse_keyword(Keyword::TO) { let index_name = self.parse_object_name()?; AlterIndexOperation::RenameIndex { index_name } } else { - return self.expected("TO after RENAME", self.peek_token()); + return self.expected("TO after RENAME"); } } else if self.parse_keyword(Keyword::SET) { if self.parse_keyword(Keyword::PARALLELISM) { if self.expect_keyword(Keyword::TO).is_err() && self.expect_token(&Token::Eq).is_err() { - return self.expected( - "TO or = after ALTER TABLE SET PARALLELISM", - self.peek_token(), - ); + return self.expected("TO or = after ALTER TABLE SET PARALLELISM"); } let value = self.parse_set_variable()?; @@ -3419,10 +3267,10 @@ impl Parser { deferred, } } else { - return self.expected("PARALLELISM after SET", self.peek_token()); + return self.expected("PARALLELISM after SET"); } } else { - return self.expected("RENAME after ALTER INDEX", self.peek_token()); + return self.expected("RENAME after ALTER INDEX"); }; Ok(Statement::AlterIndex { @@ -3431,14 +3279,14 @@ impl Parser { }) } - pub fn parse_alter_view(&mut self, materialized: bool) -> Result { + pub fn parse_alter_view(&mut self, materialized: bool) -> PResult { let view_name = self.parse_object_name()?; let operation = if self.parse_keyword(Keyword::RENAME) { if self.parse_keyword(Keyword::TO) { let view_name = self.parse_object_name()?; AlterViewOperation::RenameView { view_name } } else { - return self.expected("TO after RENAME", self.peek_token()); + return self.expected("TO after RENAME"); } } else if self.parse_keywords(&[Keyword::OWNER, Keyword::TO]) { let owner_name: Ident = self.parse_identifier()?; @@ -3455,10 +3303,7 @@ impl Parser { if self.expect_keyword(Keyword::TO).is_err() && self.expect_token(&Token::Eq).is_err() { - return self.expected( - "TO or = after ALTER TABLE SET PARALLELISM", - self.peek_token(), - ); + return self.expected("TO or = after ALTER TABLE SET PARALLELISM"); } let value = self.parse_set_variable()?; @@ -3474,19 +3319,13 @@ impl Parser { { AlterViewOperation::SetStreamingRateLimit { rate_limit } } else { - return self.expected( - "SCHEMA/PARALLELISM/STREAMING_RATE_LIMIT after SET", - self.peek_token(), - ); + return self.expected("SCHEMA/PARALLELISM/STREAMING_RATE_LIMIT after SET"); } } else { - return self.expected( - &format!( - "RENAME or OWNER TO or SET after ALTER {}VIEW", - if materialized { "MATERIALIZED " } else { "" } - ), - self.peek_token(), - ); + return self.expected(&format!( + "RENAME or OWNER TO or SET after ALTER {}VIEW", + if materialized { "MATERIALIZED " } else { "" } + )); }; Ok(Statement::AlterView { @@ -3496,14 +3335,14 @@ impl Parser { }) } - pub fn parse_alter_sink(&mut self) -> Result { + pub fn parse_alter_sink(&mut self) -> PResult { let sink_name = self.parse_object_name()?; let operation = if self.parse_keyword(Keyword::RENAME) { if self.parse_keyword(Keyword::TO) { let sink_name = self.parse_object_name()?; AlterSinkOperation::RenameSink { sink_name } } else { - return self.expected("TO after RENAME", self.peek_token()); + return self.expected("TO after RENAME"); } } else if self.parse_keywords(&[Keyword::OWNER, Keyword::TO]) { let owner_name: Ident = self.parse_identifier()?; @@ -3520,10 +3359,7 @@ impl Parser { if self.expect_keyword(Keyword::TO).is_err() && self.expect_token(&Token::Eq).is_err() { - return self.expected( - "TO or = after ALTER TABLE SET PARALLELISM", - self.peek_token(), - ); + return self.expected("TO or = after ALTER TABLE SET PARALLELISM"); } let value = self.parse_set_variable()?; @@ -3534,13 +3370,10 @@ impl Parser { deferred, } } else { - return self.expected("SCHEMA/PARALLELISM after SET", self.peek_token()); + return self.expected("SCHEMA/PARALLELISM after SET"); } } else { - return self.expected( - "RENAME or OWNER TO or SET after ALTER SINK", - self.peek_token(), - ); + return self.expected("RENAME or OWNER TO or SET after ALTER SINK"); }; Ok(Statement::AlterSink { @@ -3549,14 +3382,14 @@ impl Parser { }) } - pub fn parse_alter_subscription(&mut self) -> Result { + pub fn parse_alter_subscription(&mut self) -> PResult { let subscription_name = self.parse_object_name()?; let operation = if self.parse_keyword(Keyword::RENAME) { if self.parse_keyword(Keyword::TO) { let subscription_name = self.parse_object_name()?; AlterSubscriptionOperation::RenameSubscription { subscription_name } } else { - return self.expected("TO after RENAME", self.peek_token()); + return self.expected("TO after RENAME"); } } else if self.parse_keywords(&[Keyword::OWNER, Keyword::TO]) { let owner_name: Ident = self.parse_identifier()?; @@ -3570,13 +3403,10 @@ impl Parser { new_schema_name: schema_name, } } else { - return self.expected("SCHEMA after SET", self.peek_token()); + return self.expected("SCHEMA after SET"); } } else { - return self.expected( - "RENAME or OWNER TO or SET after ALTER SUBSCRIPTION", - self.peek_token(), - ); + return self.expected("RENAME or OWNER TO or SET after ALTER SUBSCRIPTION"); }; Ok(Statement::AlterSubscription { @@ -3585,14 +3415,14 @@ impl Parser { }) } - pub fn parse_alter_source(&mut self) -> Result { + pub fn parse_alter_source(&mut self) -> PResult { let source_name = self.parse_object_name()?; let operation = if self.parse_keyword(Keyword::RENAME) { if self.parse_keyword(Keyword::TO) { let source_name = self.parse_object_name()?; AlterSourceOperation::RenameSource { source_name } } else { - return self.expected("TO after RENAME", self.peek_token()); + return self.expected("TO after RENAME"); } } else if self.parse_keyword(Keyword::ADD) { let _ = self.parse_keyword(Keyword::COLUMN); @@ -3613,14 +3443,12 @@ impl Parser { } else if let Some(rate_limit) = self.parse_alter_streaming_rate_limit()? { AlterSourceOperation::SetStreamingRateLimit { rate_limit } } else { - return self.expected("SCHEMA after SET", self.peek_token()); + return self.expected("SCHEMA after SET"); } } else if self.peek_nth_any_of_keywords(0, &[Keyword::FORMAT]) { let connector_schema = self.parse_schema()?.unwrap(); if connector_schema.key_encode.is_some() { - return Err(ParserError::ParserError( - "key encode clause is not supported in source schema".to_string(), - )); + parser_err!("key encode clause is not supported in source schema"); } AlterSourceOperation::FormatEncode { connector_schema } } else if self.parse_keywords(&[Keyword::REFRESH, Keyword::SCHEMA]) { @@ -3628,7 +3456,6 @@ impl Parser { } else { return self.expected( "RENAME, ADD COLUMN, OWNER TO, SET or STREAMING_RATE_LIMIT after ALTER SOURCE", - self.peek_token(), ); }; @@ -3638,7 +3465,7 @@ impl Parser { }) } - pub fn parse_alter_function(&mut self) -> Result { + pub fn parse_alter_function(&mut self) -> PResult { let FunctionDesc { name, args } = self.parse_function_desc()?; let operation = if self.parse_keyword(Keyword::SET) { @@ -3648,10 +3475,10 @@ impl Parser { new_schema_name: schema_name, } } else { - return self.expected("SCHEMA after SET", self.peek_token()); + return self.expected("SCHEMA after SET"); } } else { - return self.expected("SET after ALTER FUNCTION", self.peek_token()); + return self.expected("SET after ALTER FUNCTION"); }; Ok(Statement::AlterFunction { @@ -3661,7 +3488,7 @@ impl Parser { }) } - pub fn parse_alter_connection(&mut self) -> Result { + pub fn parse_alter_connection(&mut self) -> PResult { let connection_name = self.parse_object_name()?; let operation = if self.parse_keyword(Keyword::SET) { if self.parse_keyword(Keyword::SCHEMA) { @@ -3670,10 +3497,10 @@ impl Parser { new_schema_name: schema_name, } } else { - return self.expected("SCHEMA after SET", self.peek_token()); + return self.expected("SCHEMA after SET"); } } else { - return self.expected("SET after ALTER CONNECTION", self.peek_token()); + return self.expected("SET after ALTER CONNECTION"); }; Ok(Statement::AlterConnection { @@ -3682,18 +3509,18 @@ impl Parser { }) } - pub fn parse_alter_system(&mut self) -> Result { + pub fn parse_alter_system(&mut self) -> PResult { self.expect_keyword(Keyword::SET)?; let param = self.parse_identifier()?; if self.expect_keyword(Keyword::TO).is_err() && self.expect_token(&Token::Eq).is_err() { - return self.expected("TO or = after ALTER SYSTEM SET", self.peek_token()); + return self.expected("TO or = after ALTER SYSTEM SET"); } let value = self.parse_set_variable()?; Ok(Statement::AlterSystem { param, value }) } /// Parse a copy statement - pub fn parse_copy(&mut self) -> Result { + pub fn parse_copy(&mut self) -> PResult { let table_name = self.parse_object_name()?; let columns = self.parse_parenthesized_column_list(Optional)?; self.expect_keywords(&[Keyword::FROM, Keyword::STDIN])?; @@ -3744,7 +3571,8 @@ impl Parser { } /// Parse a literal value (numbers, strings, date/time, booleans) - pub fn parse_value(&mut self) -> Result { + pub fn parse_value(&mut self) -> PResult { + let checkpoint = *self; let token = self.next_token(); match token.token { Token::Word(w) => match w.keyword { @@ -3754,12 +3582,9 @@ impl Parser { Keyword::NoKeyword if w.quote_style.is_some() => match w.quote_style { Some('"') => Ok(Value::DoubleQuotedString(w.value)), Some('\'') => Ok(Value::SingleQuotedString(w.value)), - _ => self.expected("A value?", Token::Word(w).with_location(token.location))?, + _ => self.expected_at(checkpoint, "A value")?, }, - _ => self.expected( - "a concrete value", - Token::Word(w).with_location(token.location), - ), + _ => self.expected_at(checkpoint, "a concrete value"), }, Token::Number(ref n) => Ok(Value::Number(n.clone())), Token::SingleQuotedString(ref s) => Ok(Value::SingleQuotedString(s.to_string())), @@ -3767,67 +3592,55 @@ impl Parser { Token::CstyleEscapesString(ref s) => Ok(Value::CstyleEscapedString(s.clone())), Token::NationalStringLiteral(ref s) => Ok(Value::NationalStringLiteral(s.to_string())), Token::HexStringLiteral(ref s) => Ok(Value::HexStringLiteral(s.to_string())), - unexpected => self.expected("a value", unexpected.with_location(token.location)), - } - } - - fn parse_set_variable(&mut self) -> Result { - let mut values = vec![]; - loop { - let token = self.peek_token(); - let value = match (self.parse_value(), token.token) { - (Ok(value), _) => SetVariableValueSingle::Literal(value), - (Err(_), Token::Word(w)) => { - if w.keyword == Keyword::DEFAULT { - if !values.is_empty() { - self.expected( - "parameter list value", - Token::Word(w).with_location(token.location), - )? + _ => self.expected_at(checkpoint, "a value"), + } + } + + fn parse_set_variable(&mut self) -> PResult { + alt(( + Keyword::DEFAULT.value(SetVariableValue::Default), + separated( + 1.., + alt(( + Self::parse_value.map(SetVariableValueSingle::Literal), + |parser: &mut Self| { + let checkpoint = *parser; + let ident = parser.parse_identifier()?; + if ident.value == "default" { + *parser = checkpoint; + return parser.expected("parameter list value").map_err(|e| e.cut()); } - return Ok(SetVariableValue::Default); - } else { - SetVariableValueSingle::Ident(w.to_ident()?) - } - } - (Err(_), unexpected) => { - self.expected("parameter value", unexpected.with_location(token.location))? + Ok(SetVariableValueSingle::Ident(ident)) + }, + fail.expect("parameter value"), + )), + Token::Comma, + ) + .map(|list: Vec| { + if list.len() == 1 { + SetVariableValue::Single(list[0].clone()) + } else { + SetVariableValue::List(list) } - }; - values.push(value); - if !self.consume_token(&Token::Comma) { - break; - } - } - if values.len() == 1 { - Ok(SetVariableValue::Single(values[0].clone())) - } else { - Ok(SetVariableValue::List(values)) - } + }), + )) + .parse_next(self) } - pub fn parse_number_value(&mut self) -> Result { + pub fn parse_number_value(&mut self) -> PResult { + let checkpoint = *self; match self.parse_value()? { Value::Number(v) => Ok(v), - _ => { - self.prev_token(); - self.expected("literal number", self.peek_token()) - } + _ => self.expected_at(checkpoint, "literal number"), } } /// Parse an unsigned literal integer/long - pub fn parse_literal_uint(&mut self) -> Result { - let token = self.next_token(); - match token.token { - Token::Number(s) => s.parse::().map_err(|e| { - ParserError::ParserError(format!("Could not parse '{}' as u64: {}", s, e)) - }), - unexpected => self.expected("literal int", unexpected.with_location(token.location)), - } + pub fn parse_literal_uint(&mut self) -> PResult { + literal_uint(self) } - pub fn parse_function_definition(&mut self) -> Result { + pub fn parse_function_definition(&mut self) -> PResult { let peek_token = self.peek_token(); match peek_token.token { Token::DollarQuotedString(value) => { @@ -3841,7 +3654,8 @@ impl Parser { } /// Parse a literal string - pub fn parse_literal_string(&mut self) -> Result { + pub fn parse_literal_string(&mut self) -> PResult { + let checkpoint = *self; let token = self.next_token(); match token.token { Token::Word(Word { @@ -3850,12 +3664,13 @@ impl Parser { .. }) => Ok(value), Token::SingleQuotedString(s) => Ok(s), - unexpected => self.expected("literal string", unexpected.with_location(token.location)), + _ => self.expected_at(checkpoint, "literal string"), } } /// Parse a map key string - pub fn parse_map_key(&mut self) -> Result { + pub fn parse_map_key(&mut self) -> PResult { + let checkpoint = *self; let token = self.next_token(); match token.token { Token::Word(Word { @@ -3870,27 +3685,22 @@ impl Parser { } Token::SingleQuotedString(s) => Ok(Expr::Value(Value::SingleQuotedString(s))), Token::Number(s) => Ok(Expr::Value(Value::Number(s))), - unexpected => self.expected( - "literal string, number or function", - unexpected.with_location(token.location), - ), + _ => self.expected_at(checkpoint, "literal string, number or function"), } } /// Parse a SQL datatype (in the context of a CREATE TABLE statement for example) and convert /// into an array of that datatype if needed - pub fn parse_data_type(&mut self) -> Result { - self.parse_v2(parser_v2::data_type) + pub fn parse_data_type(&mut self) -> PResult { + parser_v2::data_type(self) } /// Parse `AS identifier` (or simply `identifier` if it's not a reserved keyword) /// Some examples with aliases: `SELECT 1 foo`, `SELECT COUNT(*) AS cnt`, /// `SELECT ... FROM t1 foo, t2 bar`, `SELECT ... FROM (...) AS bar` - pub fn parse_optional_alias( - &mut self, - reserved_kwds: &[Keyword], - ) -> Result, ParserError> { + pub fn parse_optional_alias(&mut self, reserved_kwds: &[Keyword]) -> PResult> { let after_as = self.parse_keyword(Keyword::AS); + let checkpoint = *self; let token = self.next_token(); match token.token { // Accept any identifier after `AS` (though many dialects have restrictions on @@ -3901,14 +3711,11 @@ impl Parser { Token::Word(w) if after_as || (!reserved_kwds.contains(&w.keyword)) => { Ok(Some(w.to_ident()?)) } - not_an_ident => { + _ => { + *self = checkpoint; if after_as { - return self.expected( - "an identifier after AS", - not_an_ident.with_location(token.location), - ); + return self.expected("an identifier after AS"); } - self.prev_token(); Ok(None) // no alias found } } @@ -3921,7 +3728,7 @@ impl Parser { pub fn parse_optional_table_alias( &mut self, reserved_kwds: &[Keyword], - ) -> Result, ParserError> { + ) -> PResult> { match self.parse_optional_alias(reserved_kwds)? { Some(name) => { let columns = self.parse_parenthesized_column_list(Optional)?; @@ -3932,67 +3739,38 @@ impl Parser { } /// syntax `FOR SYSTEM_TIME AS OF PROCTIME()` is used for temporal join. - pub fn parse_as_of(&mut self) -> Result, ParserError> { - let after_for = self.parse_keyword(Keyword::FOR); - if after_for { - if self.peek_nth_any_of_keywords(0, &[Keyword::SYSTEM_TIME]) { - self.expect_keywords(&[Keyword::SYSTEM_TIME, Keyword::AS, Keyword::OF])?; - let token = self.next_token(); - match token.token { - Token::Word(w) => { - let ident = w.to_ident()?; - // Backward compatibility for now. - if ident.real_value() == "proctime" || ident.real_value() == "now" { - self.expect_token(&Token::LParen)?; - self.expect_token(&Token::RParen)?; - Ok(Some(AsOf::ProcessTime)) - } else { - parser_err!(format!("Expected proctime, found: {}", ident.real_value())) - } - } - Token::Number(s) => { - let num = s.parse::().map_err(|e| { - ParserError::ParserError(format!( - "Could not parse '{}' as i64: {}", - s, e - )) - }); - Ok(Some(AsOf::TimestampNum(num?))) - } - Token::SingleQuotedString(s) => Ok(Some(AsOf::TimestampString(s))), - unexpected => self.expected( - "Proctime(), Number or SingleQuotedString", - unexpected.with_location(token.location), - ), - } - } else { - self.expect_keywords(&[Keyword::SYSTEM_VERSION, Keyword::AS, Keyword::OF])?; - let token = self.next_token(); - match token.token { - Token::Number(s) => { - let num = s.parse::().map_err(|e| { - ParserError::ParserError(format!( - "Could not parse '{}' as i64: {}", - s, e - )) - }); - Ok(Some(AsOf::VersionNum(num?))) - } - Token::SingleQuotedString(s) => Ok(Some(AsOf::VersionString(s))), - unexpected => self.expected( - "Number or SingleQuotedString", - unexpected.with_location(token.location), - ), - } - } - } else { - Ok(None) - } + pub fn parse_as_of(&mut self) -> PResult { + Keyword::FOR.parse_next(self)?; + alt(( + preceded( + (Keyword::SYSTEM_TIME, Keyword::AS, Keyword::OF), + alt(( + ( + Self::parse_identifier.verify(|ident| { + ident.real_value() == "proctime" || ident.real_value() == "now" + }), + Token::LParen, + Token::RParen, + ) + .value(AsOf::ProcessTime), + literal_i64.map(AsOf::VersionNum), + Self::parse_literal_string.map(AsOf::TimestampString), + )), + ), + preceded( + (Keyword::SYSTEM_VERSION, Keyword::AS, Keyword::OF), + alt(( + literal_i64.map(AsOf::VersionNum), + Self::parse_literal_string.map(AsOf::VersionString), + )), + ), + )) + .parse_next(self) } /// Parse a possibly qualified, possibly quoted identifier, e.g. /// `foo` or `myschema."table" - pub fn parse_object_name(&mut self) -> Result { + pub fn parse_object_name(&mut self) -> PResult { let mut idents = vec![]; loop { idents.push(self.parse_identifier()?); @@ -4004,7 +3782,7 @@ impl Parser { } /// Parse identifiers strictly i.e. don't parse keywords - pub fn parse_identifiers_non_keywords(&mut self) -> Result, ParserError> { + pub fn parse_identifiers_non_keywords(&mut self) -> PResult> { let mut idents = vec![]; loop { match self.peek_token().token { @@ -4026,7 +3804,7 @@ impl Parser { } /// Parse identifiers - pub fn parse_identifiers(&mut self) -> Result, ParserError> { + pub fn parse_identifiers(&mut self) -> PResult> { let mut idents = vec![]; loop { let token = self.next_token(); @@ -4043,33 +3821,32 @@ impl Parser { } /// Parse a simple one-word identifier (possibly quoted, possibly a keyword) - pub fn parse_identifier(&mut self) -> Result { + pub fn parse_identifier(&mut self) -> PResult { + let checkpoint = *self; let token = self.next_token(); match token.token { Token::Word(w) => Ok(w.to_ident()?), - unexpected => self.expected("identifier", unexpected.with_location(token.location)), + _ => self.expected_at(checkpoint, "identifier"), } } /// Parse a simple one-word identifier (possibly quoted, possibly a non-reserved keyword) - pub fn parse_identifier_non_reserved(&mut self) -> Result { + pub fn parse_identifier_non_reserved(&mut self) -> PResult { + let checkpoint = *self; let token = self.next_token(); - match token.token.clone() { + match token.token { Token::Word(w) => { match keywords::RESERVED_FOR_COLUMN_OR_TABLE_NAME.contains(&w.keyword) { - true => parser_err!(format!("syntax error at or near {token}")), + true => parser_err!("syntax error at or near {w}"), false => Ok(w.to_ident()?), } } - unexpected => self.expected("identifier", unexpected.with_location(token.location)), + _ => self.expected_at(checkpoint, "identifier"), } } /// Parse a parenthesized comma-separated list of unqualified, possibly quoted identifiers - pub fn parse_parenthesized_column_list( - &mut self, - optional: IsOptional, - ) -> Result, ParserError> { + pub fn parse_parenthesized_column_list(&mut self, optional: IsOptional) -> PResult> { if self.consume_token(&Token::LParen) { let cols = self.parse_comma_separated(Parser::parse_identifier_non_reserved)?; self.expect_token(&Token::RParen)?; @@ -4077,25 +3854,22 @@ impl Parser { } else if optional == Optional { Ok(vec![]) } else { - self.expected("a list of columns in parentheses", self.peek_token()) + self.expected("a list of columns in parentheses") } } - pub fn parse_returning( - &mut self, - optional: IsOptional, - ) -> Result, ParserError> { + pub fn parse_returning(&mut self, optional: IsOptional) -> PResult> { if self.parse_keyword(Keyword::RETURNING) { let cols = self.parse_comma_separated(Parser::parse_select_item)?; Ok(cols) } else if optional == Optional { Ok(vec![]) } else { - self.expected("a list of columns or * after returning", self.peek_token()) + self.expected("a list of columns or * after returning") } } - pub fn parse_row_expr(&mut self) -> Result { + pub fn parse_row_expr(&mut self) -> PResult { Ok(Expr::Row(self.parse_token_wrapped_exprs( &Token::LParen, &Token::RParen, @@ -4103,11 +3877,7 @@ impl Parser { } /// Parse a comma-separated list (maybe empty) from a wrapped expression - pub fn parse_token_wrapped_exprs( - &mut self, - left: &Token, - right: &Token, - ) -> Result, ParserError> { + pub fn parse_token_wrapped_exprs(&mut self, left: &Token, right: &Token) -> PResult> { if self.consume_token(left) { let exprs = if self.consume_token(right) { vec![] @@ -4118,11 +3888,11 @@ impl Parser { }; Ok(exprs) } else { - self.expected(left.to_string().as_str(), self.peek_token()) + self.expected(left.to_string().as_str()) } } - pub fn parse_optional_precision(&mut self) -> Result, ParserError> { + pub fn parse_optional_precision(&mut self) -> PResult> { if self.consume_token(&Token::LParen) { let n = self.parse_literal_uint()?; self.expect_token(&Token::RParen)?; @@ -4132,9 +3902,7 @@ impl Parser { } } - pub fn parse_optional_precision_scale( - &mut self, - ) -> Result<(Option, Option), ParserError> { + pub fn parse_optional_precision_scale(&mut self) -> PResult<(Option, Option)> { if self.consume_token(&Token::LParen) { let n = self.parse_literal_uint()?; let scale = if self.consume_token(&Token::Comma) { @@ -4149,7 +3917,7 @@ impl Parser { } } - pub fn parse_delete(&mut self) -> Result { + pub fn parse_delete(&mut self) -> PResult { self.expect_keyword(Keyword::FROM)?; let table_name = self.parse_object_name()?; let selection = if self.parse_keyword(Keyword::WHERE) { @@ -4178,7 +3946,7 @@ impl Parser { } } - pub fn parse_explain(&mut self) -> Result { + pub fn parse_explain(&mut self) -> PResult { let mut options = ExplainOptions::default(); let explain_key_words = [ @@ -4190,7 +3958,7 @@ impl Parser { Keyword::DISTSQL, ]; - let parse_explain_option = |parser: &mut Parser| -> Result<(), ParserError> { + let parse_explain_option = |parser: &mut Parser<'_>| -> PResult<()> { let keyword = parser.expect_one_of_keywords(&explain_key_words)?; match keyword { Keyword::VERBOSE => options.verbose = parser.parse_optional_boolean(true), @@ -4239,7 +4007,7 @@ impl Parser { /// preceded with some `WITH` CTE declarations and optionally followed /// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't /// expect the initial keyword to be already consumed - pub fn parse_query(&mut self) -> Result { + pub fn parse_query(&mut self) -> PResult { let with = if self.parse_keyword(Keyword::WITH) { Some(With { recursive: self.parse_keyword(Keyword::RECURSIVE), @@ -4271,13 +4039,11 @@ impl Parser { let fetch = if self.parse_keyword(Keyword::FETCH) { if limit.is_some() { - return parser_err!("Cannot specify both LIMIT and FETCH".to_string()); + parser_err!("Cannot specify both LIMIT and FETCH"); } let fetch = self.parse_fetch()?; if fetch.with_ties && order_by.is_empty() { - return parser_err!( - "WITH TIES cannot be specified without ORDER BY clause".to_string() - ); + parser_err!("WITH TIES cannot be specified without ORDER BY clause"); } Some(fetch) } else { @@ -4295,7 +4061,7 @@ impl Parser { } /// Parse a CTE (`alias [( col1, col2, ... )] AS (subquery)`) - fn parse_cte(&mut self) -> Result { + fn parse_cte(&mut self) -> PResult { let name = self.parse_identifier_non_reserved()?; let mut cte = if self.parse_keyword(Keyword::AS) { @@ -4338,7 +4104,7 @@ impl Parser { /// subquery ::= query_body [ order_by_limit ] /// set_operation ::= query_body { 'UNION' | 'EXCEPT' | 'INTERSECT' } [ 'ALL' ] query_body /// ``` - fn parse_query_body(&mut self, precedence: u8) -> Result { + fn parse_query_body(&mut self, precedence: u8) -> PResult { // We parse the expression using a Pratt parser, as in `parse_expr()`. // Start by parsing a restricted SELECT or a `(subquery)`: let mut expr = if self.parse_keyword(Keyword::SELECT) { @@ -4351,10 +4117,7 @@ impl Parser { } else if self.parse_keyword(Keyword::VALUES) { SetExpr::Values(self.parse_values()?) } else { - return self.expected( - "SELECT, VALUES, or a subquery in the query body", - self.peek_token(), - ); + return self.expected("SELECT, VALUES, or a subquery in the query body"); }; loop { @@ -4394,7 +4157,7 @@ impl Parser { /// Parse a restricted `SELECT` statement (no CTEs / `UNION` / `ORDER BY`), /// assuming the initial `SELECT` was already consumed - pub fn parse_select(&mut self) -> Result { + pub fn parse_select(&mut self) -> PResult