diff --git a/.buildkite/engineer b/.buildkite/engineer deleted file mode 100755 index e820a489d1a0..000000000000 --- a/.buildkite/engineer +++ /dev/null @@ -1,67 +0,0 @@ -#!/usr/bin/env bash - -set -e - -if [[ -z "$2" ]]; then - printf "Error: the name of the pipeline must be provided.\nExample: './engineer pipeline test'" 1>&2 - exit 1 -else - echo "We are in the $2 pipeline." -fi - -# Checks what's the diff with the previous commit -# This is used to detect if the previous commit was empty -GIT_DIFF=$(git diff --name-only HEAD HEAD~1 -- .) - -# Checks what's the diff with the previous commit, -# excluding some paths that do not need a run, -# because they do not affect tests running in Buildkite. -GIT_DIFF_WITH_IGNORED_PATHS=$(git diff --name-only HEAD HEAD~1 -- . ':!.github' ':!query-engine/driver-adapters/js' ':!query-engine/query-engine-wasm' ':!renovate.json' ':!*.md' ':!LICENSE' ':!CODEOWNERS';) - -# $2 is either "test" or "build", depending on the pipeline -# Example: ./.buildkite/engineer pipeline test -# We only want to check for changes and skip in the test pipeline. -if [[ "$2" == "test" ]]; then - # If GIT_DIFF is empty then the previous commit was empty - # We assume it's intended and we continue with the run - # Example use: to get a new engine hash built with identical code - if [ -z "${GIT_DIFF}" ]; then - echo "The previous commit is empty, this run will continue..." - else - # Checking if GIT_DIFF_WITH_IGNORED_PATHS is empty - # If it's empty then it's most likely that there are changes but they are in ignored paths. - # So we do not start Buildkite - if [ -z "${GIT_DIFF_WITH_IGNORED_PATHS}" ]; then - echo "No changes found for the previous commit in paths that are not ignored, this run will now be skipped." - exit 0 - else - # Note that printf works better for displaying line returns in CI - printf "Changes found for the previous commit in paths that are not ignored: \n\n%s\n\nThis run will continue...\n" "${GIT_DIFF_WITH_IGNORED_PATHS}" - fi - fi -fi - -# Check OS -if [[ "$OSTYPE" == "linux-gnu" ]]; then - OS=linux-amzn -elif [[ "$OSTYPE" == "darwin"* ]]; then - OS=darwin -else - echo "Unhandled OS: '$OSTYPE'" - exit 1 -fi - -# Check if the system has engineer installed, if not, use a local copy. -if ! type "engineer" &> /dev/null; then - # Setup Prisma engine build & test tool (engineer). - curl --fail -sSL "https://prisma-engineer.s3-eu-west-1.amazonaws.com/1.67/latest/$OS/engineer.gz" --output engineer.gz - gzip -d engineer.gz - chmod +x engineer - - # Execute passed command and clean up - ./engineer "$@" - rm -rf ./engineer -else - # Already installed on the system - engineer "$@" -fi diff --git a/.github/workflows/build-engines-apple-intel.yml b/.github/workflows/build-engines-apple-intel-template.yml similarity index 58% rename from .github/workflows/build-engines-apple-intel.yml rename to .github/workflows/build-engines-apple-intel-template.yml index 05ebb4f97b33..30d8ef9697b5 100644 --- a/.github/workflows/build-engines-apple-intel.yml +++ b/.github/workflows/build-engines-apple-intel-template.yml @@ -1,15 +1,16 @@ name: Build Engines for Apple Intel + on: - workflow_dispatch: + workflow_call: inputs: commit: - description: "Commit on the given branch to build" + description: 'Commit on the given branch to build' + type: string required: false jobs: build: - # Do not change `name`, prisma-engines Buildkite build job depends on this name ending with the commit - name: "MacOS Intel engines build on branch ${{ github.event.ref }} for commit ${{ github.event.inputs.commit }}" + name: 'MacOS Intel engines build for commit ${{ inputs.commit }}' env: SQLITE_MAX_VARIABLE_NUMBER: 250000 SQLITE_MAX_EXPR_DEPTH: 10000 @@ -20,12 +21,12 @@ jobs: steps: - name: Output link to real commit - run: echo ${{ github.repository }}/commit/${{ github.event.inputs.commit }} + run: echo ${{ github.repository }}/commit/${{ inputs.commit }} - - name: Checkout ${{ github.event.inputs.commit }} + - name: Checkout ${{ inputs.commit }} uses: actions/checkout@v4 with: - ref: ${{ github.event.inputs.commit }} + ref: ${{ inputs.commit }} - uses: actions-rust-lang/setup-rust-toolchain@v1 @@ -35,16 +36,27 @@ jobs: ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-intel-cargo-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} - run: | cargo build --release -p query-engine -p query-engine-node-api -p schema-engine-cli -p prisma-fmt + - name: Rename files + working-directory: ${{ github.workspace }}/target/release/ + run: | + echo "Files in target/release before renaming" + ls -la . + + mv libquery_engine.dylib libquery_engine.dylib.node + + echo "Files in target/release after renaming" + ls -la . + - uses: actions/upload-artifact@v4 with: - name: binaries + name: darwin path: | ${{ github.workspace }}/target/release/schema-engine ${{ github.workspace }}/target/release/prisma-fmt ${{ github.workspace }}/target/release/query-engine - ${{ github.workspace }}/target/release/libquery_engine.dylib + ${{ github.workspace }}/target/release/libquery_engine.dylib.node diff --git a/.github/workflows/build-engines-apple-silicon.yml b/.github/workflows/build-engines-apple-silicon-template.yml similarity index 61% rename from .github/workflows/build-engines-apple-silicon.yml rename to .github/workflows/build-engines-apple-silicon-template.yml index 959c0e935653..de136018a2c2 100644 --- a/.github/workflows/build-engines-apple-silicon.yml +++ b/.github/workflows/build-engines-apple-silicon-template.yml @@ -1,15 +1,16 @@ name: Build Engines for Apple Silicon + on: - workflow_dispatch: + workflow_call: inputs: commit: - description: "Commit on the given branch to build" + description: 'Commit on the given branch to build' + type: string required: false jobs: build: - # Do not change `name`, prisma-engines Buildkite build job depends on this name ending with the commit - name: "MacOS ARM64 (Apple Silicon) engines build on branch ${{ github.event.ref }} for commit ${{ github.event.inputs.commit }}" + name: 'MacOS ARM64 (Apple Silicon) engines build for commit ${{ inputs.commit }}' env: SQLITE_MAX_VARIABLE_NUMBER: 250000 SQLITE_MAX_EXPR_DEPTH: 10000 @@ -17,17 +18,16 @@ jobs: steps: - name: Output link to real commit - run: echo ${{ github.repository }}/commit/${{ github.event.inputs.commit }} + run: echo ${{ github.repository }}/commit/${{ inputs.commit }} - - name: Checkout ${{ github.event.inputs.commit }} + - name: Checkout ${{ inputs.commit }} uses: actions/checkout@v4 with: - ref: ${{ github.event.inputs.commit }} + ref: ${{ inputs.commit }} - - uses: actions-rust-lang/setup-rust-toolchain@v1 + - run: xcodebuild -showsdks - - name: Install aarch64 toolchain - run: rustup target add aarch64-apple-darwin + - uses: actions-rust-lang/setup-rust-toolchain@v1 - uses: actions/cache@v4 with: @@ -37,16 +37,25 @@ jobs: target key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} - - run: xcodebuild -showsdks - - run: | cargo build --target=aarch64-apple-darwin --release -p query-engine -p query-engine-node-api -p schema-engine-cli -p prisma-fmt + - name: Rename files + working-directory: ${{ github.workspace }}/target/aarch64-apple-darwin/release + run: | + echo "Files in target/release before renaming" + ls -la . + + mv libquery_engine.dylib libquery_engine.dylib.node + + echo "Files in target/release after renaming" + ls -la . + - uses: actions/upload-artifact@v4 with: - name: binaries + name: darwin-arm64 path: | ${{ github.workspace }}/target/aarch64-apple-darwin/release/schema-engine ${{ github.workspace }}/target/aarch64-apple-darwin/release/prisma-fmt ${{ github.workspace }}/target/aarch64-apple-darwin/release/query-engine - ${{ github.workspace }}/target/aarch64-apple-darwin/release/libquery_engine.dylib + ${{ github.workspace }}/target/aarch64-apple-darwin/release/libquery_engine.dylib.node diff --git a/.github/workflows/build-engines-linux-template.yml b/.github/workflows/build-engines-linux-template.yml new file mode 100644 index 000000000000..7fa0446b420e --- /dev/null +++ b/.github/workflows/build-engines-linux-template.yml @@ -0,0 +1,202 @@ +name: Build Engines for Linux + +on: + workflow_call: + inputs: + commit: + description: 'Commit on the given branch to build' + type: string + required: false + +jobs: + build: + name: '${{ matrix.target.name }} for commit ${{ inputs.commit }}' + runs-on: ubuntu-22.04 + + strategy: + fail-fast: false + matrix: + # ⚠️ The target names are used to dtermine the directory name when uploaded to the buckets. + # Do not change them. + target: + # Linux Glibc + - name: 'rhel-openssl-1.0.x' + image: 'prismagraphql/build:rhel-libssl1.0.x' + target_string: '' + target_path: '' + features_string: '--features vendored-openssl' + - name: 'rhel-openssl-1.1.x' + image: 'prismagraphql/build:rhel-libssl1.1.x' + target_string: '' + target_path: '' + features_string: '' + - name: 'rhel-openssl-3.0.x' + image: 'prismagraphql/build:rhel-libssl3.0.x' + target_string: '' + target_path: '' + features_string: '' + # Linux Musl + # A better name would be "linux-musl-openssl-1.1.x" + # But we keep the old name for compatibility reasons + - name: 'linux-musl' + image: 'prismagraphql/build:alpine-libssl1.1.x' + target_string: '' + target_path: '' + features_string: '' + - name: 'linux-musl-openssl-3.0.x' + image: 'prismagraphql/build:alpine-libssl3.0.x' + target_string: '' + target_path: '' + features_string: '' + # Linux Static x86_64 + # Note that the name should have "-static-" + # Because we look for "-static-" later in the construct_build_command step + - name: 'linux-static-x64' + image: 'prismagraphql/build:linux-static-x64' + target_string: '--target x86_64-unknown-linux-musl' + target_path: 'x86_64-unknown-linux-musl' + features_string: '--features vendored-openssl' + # Linux Glibc ARM64 + - name: 'linux-arm64-openssl-1.0.x' + image: 'prismagraphql/build:cross-linux-arm-ssl-1.0.x' + target_string: '--target aarch64-unknown-linux-gnu' + target_path: 'aarch64-unknown-linux-gnu' + features_string: '--features vendored-openssl' + - name: 'linux-arm64-openssl-1.1.x' + image: 'prismagraphql/build:cross-linux-arm-ssl-1.1.x' + target_string: '--target aarch64-unknown-linux-gnu' + target_path: 'aarch64-unknown-linux-gnu' + features_string: '' + - name: 'linux-arm64-openssl-3.0.x' + image: 'prismagraphql/build:cross-linux-arm-ssl-3.0.x' + target_string: '--target aarch64-unknown-linux-gnu' + target_path: 'aarch64-unknown-linux-gnu' + features_string: '' + # Linux Musl ARM64 + - name: 'linux-musl-arm64-openssl-1.1.x' + image: 'prismagraphql/build:cross-linux-musl-arm-ssl-1.1.x' + target_string: '--target aarch64-unknown-linux-musl' + target_path: 'aarch64-unknown-linux-musl' + features_string: '' + - name: 'linux-musl-arm64-openssl-3.0.x' + image: 'prismagraphql/build:cross-linux-musl-arm-ssl-3.0.x' + target_string: '--target aarch64-unknown-linux-musl' + target_path: 'aarch64-unknown-linux-musl' + features_string: '' + # Linux Static ARM64 + # Note that the name should have "-static-" + # Because we look for "-static-" later in the construct_build_command step + - name: 'linux-static-arm64' + image: 'prismagraphql/build:linux-static-arm64' + target_string: '--target aarch64-unknown-linux-musl' + target_path: 'aarch64-unknown-linux-musl' + features_string: '--features vendored-openssl' + + steps: + - name: Output link to commit + if: ${{ inputs.commit }} + run: echo https://github.com/prisma/prisma-engines/commit/${{ inputs.commit }} + + - name: Checkout ${{ inputs.commit }} + uses: actions/checkout@v4 + with: + ref: ${{ inputs.commit }} + + - uses: actions/cache@v4 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: ${{ runner.os }}-cargo- + + - name: Construct build command + id: construct_build_command + env: + TARGET_NAME: ${{ matrix.target.name }} + IMAGE: ${{ matrix.target.image }} + TARGET_STRING: ${{ matrix.target.target_string }} + FEATURES_STRING: ${{ matrix.target.features_string }} + run: | + set -eux; + + command=$(bash .github/workflows/utils/constructDockerBuildCommand.sh) + + # store command in GitHub output + echo "COMMAND=$command" >> "$GITHUB_OUTPUT" + + - name: Show Build Command + env: + COMMAND: ${{ steps.construct_build_command.outputs.COMMAND }}" + run: echo "Build command is $COMMAND" + + - name: Execute Build command + run: ${{ steps.construct_build_command.outputs.command }} + + - name: Prepare files for "release" target + if: ${{ matrix.target.target_path == '' }} + env: + TARGET_NAME: ${{ matrix.target.name }} + RELEASE_DIR: ${{ github.workspace }}/target/release + run: | + echo "Files in target/release before renaming" + ls -la $RELEASE_DIR + + echo "Copying files to engines-artifacts" + cp -r $RELEASE_DIR/ engines-artifacts + + echo "Rename libquery_engine.so to libquery_engine.so.node for non-static targets" + if [[ "$TARGET_NAME" == *-static-* ]]; then + echo "Current target is static. Skipping." + else + mv engines-artifacts/libquery_engine.so engines-artifacts/libquery_engine.so.node + fi + + echo "Files in engines-artifacts after renaming" + ls -la engines-artifacts + + - name: Upload artifacts for "release" target + uses: actions/upload-artifact@v4 + if: ${{ matrix.target.target_path == '' }} + with: + name: '${{ matrix.target.name }}' + path: | + ${{ github.workspace }}/engines-artifacts/libquery_engine.so.node + ${{ github.workspace }}/engines-artifacts/schema-engine + ${{ github.workspace }}/engines-artifacts/query-engine + ${{ github.workspace }}/engines-artifacts/prisma-fmt + + - name: Prepare files for "${{ matrix.target.name }}" target + if: ${{ matrix.target.target_path != '' }} + env: + TARGET_NAME: ${{ matrix.target.name }} + RELEASE_DIR: ${{ github.workspace }}/target/${{ matrix.target.target_path }}/release + run: | + echo "Files in target/release before renaming" + ls -la $RELEASE_DIR + + echo "Copying files to engines-artifacts" + cp -r $RELEASE_DIR/ engines-artifacts + + echo "Rename libquery_engine.so to libquery_engine.so.node for non-static targets" + if [[ "$TARGET_NAME" == *-static-* ]]; then + echo "Current target is static. Skipping." + else + mv engines-artifacts/libquery_engine.so engines-artifacts/libquery_engine.so.node + fi + + echo "Files in engines-artifacts after renaming" + ls -la engines-artifacts + + - name: Upload artifacts for "${{ matrix.target.name }}" target + uses: actions/upload-artifact@v4 + if: ${{ matrix.target.target_path != '' }} + with: + name: ${{ matrix.target.name }} + path: | + ${{ github.workspace }}/engines-artifacts/libquery_engine.so.node + ${{ github.workspace }}/engines-artifacts/schema-engine + ${{ github.workspace }}/engines-artifacts/query-engine + ${{ github.workspace }}/engines-artifacts/prisma-fmt diff --git a/.github/workflows/build-engines-react-native-template.yml b/.github/workflows/build-engines-react-native-template.yml new file mode 100644 index 000000000000..09cf6f6858b5 --- /dev/null +++ b/.github/workflows/build-engines-react-native-template.yml @@ -0,0 +1,102 @@ +name: Build Engines for React native + +on: + workflow_call: + inputs: + commit: + description: 'Commit on the given branch to build' + type: string + required: false + uploadArtifacts: + description: If the job should upload artifacts after build finishes + type: boolean + default: true + +jobs: + build-ios: + name: 'iOS build for commit ${{ inputs.commit }}' + runs-on: macos-14 + + steps: + - name: Output link to real commit + run: echo ${{ github.repository }}/commit/${{ inputs.commit }} + + - name: Checkout + uses: actions/checkout@v4 + with: + ref: ${{ inputs.commit }} + + - uses: actions/cache@v4 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: ${{ runner.os }}-cargo- + + - uses: dtolnay/rust-toolchain@stable + with: + targets: x86_64-apple-ios,aarch64-apple-ios,aarch64-apple-ios-sim + + - run: | + cd query-engine/query-engine-c-abi + make ios + + - name: Print files + working-directory: ${{ github.workspace }}/query-engine/query-engine-c-abi/ios/ + run: | + ls -la . + + - uses: actions/upload-artifact@v4 + if: ${{ inputs.uploadArtifacts }} + with: + name: ios + path: | + ${{ github.workspace }}/query-engine/query-engine-c-abi/ios/* + + build-android: + name: 'Android build for commit ${{ inputs.commit }}' + runs-on: ubuntu-latest + + steps: + - name: Output link to real commit + run: echo ${{ github.repository }}/commit/${{ inputs.commit }} + + - name: Checkout + uses: actions/checkout@v4 + with: + ref: ${{ inputs.commit }} + + - uses: actions/cache@v4 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: ${{ runner.os }}-cargo- + + - uses: actions-rust-lang/setup-rust-toolchain@v1 + + - uses: nttld/setup-ndk@v1 + with: + ndk-version: r26d + + - run: | + cd query-engine/query-engine-c-abi + make android + + - name: Print files + working-directory: ${{ github.workspace }}/query-engine/query-engine-c-abi/android/ + run: | + ls -la . + + - uses: actions/upload-artifact@v4 + if: ${{ inputs.uploadArtifacts }} + with: + name: android + path: | + ${{ github.workspace }}/query-engine/query-engine-c-abi/android/* diff --git a/.github/workflows/build-engines-react-native.yml b/.github/workflows/build-engines-react-native.yml deleted file mode 100644 index 235fc633fd3c..000000000000 --- a/.github/workflows/build-engines-react-native.yml +++ /dev/null @@ -1,101 +0,0 @@ -name: Build Engines for React native -on: - workflow_dispatch: - inputs: - commit: - description: "Commit on the given branch to build" - required: false - -jobs: - build-ios: - # Do not change `name`, prisma-engines Buildkite build job depends on this name ending with the commit - name: "iOS build on branch ${{ github.event.ref }} for commit ${{ github.event.inputs.commit }}" - runs-on: macos-14 - - steps: - - name: Output link to real commit - run: echo ${{ github.repository }}/commit/${{ github.event.inputs.commit }} - - - name: Checkout - uses: actions/checkout@v4 - with: - ref: ${{ github.event.inputs.commit }} - - - uses: actions-rust-lang/setup-rust-toolchain@v1 - with: - targets: x86_64-apple-ios,aarch64-apple-ios,aarch64-apple-ios-sim - - - uses: actions/cache@v4 - with: - path: | - ~/.cargo/registry - ~/.cargo/git - target - key: ${{ runner.os }}-ios-cargo-${{ hashFiles('**/Cargo.lock') }} - - - run: | - cd query-engine/query-engine-c-abi - make ios - - uses: actions/upload-artifact@v4 - with: - name: ios - path: | - ${{ github.workspace }}/query-engine/query-engine-c-abi/ios/* - - build-android: - # Do not change `name`, prisma-engines Buildkite build job depends on this name ending with the commit - name: "Android build on branch ${{ github.event.ref }} for commit ${{ github.event.inputs.commit }}" - runs-on: ubuntu-latest - - steps: - - name: Output link to real commit - run: echo ${{ github.repository }}/commit/${{ github.event.inputs.commit }} - - - name: Checkout - uses: actions/checkout@v4 - with: - ref: ${{ github.event.inputs.commit }} - - uses: actions-rust-lang/setup-rust-toolchain@v1 - with: - targets: aarch64-linux-android,armv7-linux-androideabi,x86_64-linux-android,i686-linux-android - - - uses: nttld/setup-ndk@v1 - with: - ndk-version: r26d - - - uses: actions/cache@v4 - with: - path: | - ~/.cargo/registry - ~/.cargo/git - target - key: ${{ runner.os }}-android-cargo-${{ hashFiles('**/Cargo.lock') }} - - - run: | - cd query-engine/query-engine-c-abi - make android - - - uses: actions/upload-artifact@v4 - with: - name: android - path: | - ${{ github.workspace }}/query-engine/query-engine-c-abi/android/* - combine-artifacts: - # Do not change `name`, prisma-engines Buildkite build job depends on this name ending with the commit - name: "Combine iOS and Android artifacts on branch ${{ github.event.ref }} for commit ${{ github.event.inputs.commit }}" - runs-on: ubuntu-latest - needs: - - build-ios - - build-android - steps: - - name: Download artifacts - uses: actions/download-artifact@v4 - - - name: Upload combined artifact - uses: actions/upload-artifact@v4 - with: - name: binaries - path: | - ${{ github.workspace }}/ios - ${{ github.workspace }}/android - diff --git a/.github/workflows/build-engines-windows-template.yml b/.github/workflows/build-engines-windows-template.yml new file mode 100644 index 000000000000..a129393a87c9 --- /dev/null +++ b/.github/workflows/build-engines-windows-template.yml @@ -0,0 +1,62 @@ +name: Build Engines for Windows + +on: + workflow_call: + inputs: + commit: + description: 'Commit on the given branch to build' + type: string + required: false + +jobs: + build: + name: 'Windows engines build for commit ${{ inputs.commit }}' + env: + SQLITE_MAX_VARIABLE_NUMBER: 250000 + SQLITE_MAX_EXPR_DEPTH: 10000 + RUSTFLAGS: '-C target-feature=+crt-static' + runs-on: windows-latest + + steps: + - name: Output link to real commit + run: echo ${{ github.repository }}/commit/${{ inputs.commit }} + + - name: Checkout ${{ inputs.commit }} + uses: actions/checkout@v4 + with: + ref: ${{ inputs.commit }} + + - uses: actions/cache@v4 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: ${{ runner.os }}-cargo- + + - uses: actions-rust-lang/setup-rust-toolchain@v1 + + - run: cargo build --release -p query-engine -p query-engine-node-api -p schema-engine-cli -p prisma-fmt + + - name: Rename files + working-directory: ${{ github.workspace }}/target/release/ + run: | + echo "Files in target/release before renaming" + ls . + + mv query_engine.dll query_engine.dll.node + mv query-engine.exe query-engine.exe + + echo "Files in target/release after renaming" + ls . + + - uses: actions/upload-artifact@v4 + with: + name: windows + path: | + ${{ github.workspace }}/target/release/prisma-fmt.exe + ${{ github.workspace }}/target/release/schema-engine.exe + ${{ github.workspace }}/target/release/query-engine.exe + ${{ github.workspace }}/target/release/query_engine.dll.node diff --git a/.github/workflows/build-engines.yml b/.github/workflows/build-engines.yml new file mode 100644 index 000000000000..db1c1a42ce6d --- /dev/null +++ b/.github/workflows/build-engines.yml @@ -0,0 +1,330 @@ +name: Build Engines +run-name: Build Engines for ${{ github.sha }} + +# Run on `push` only for main, if not it will trigger `push` & `pull_request` on PRs at the same time +on: + push: + branches: + - main + - '*.*.x' + - 'integration/*' + paths-ignore: + - '!.github/workflows/build-engines*' + - '.github/**' + - '.buildkite/**' + - '*.md' + - 'LICENSE' + - 'CODEOWNERS' + - 'renovate.json' + workflow_dispatch: + pull_request: + paths-ignore: + - '!.github/workflows/build-engines*' + - '.github/**' + - '.buildkite/**' + - '*.md' + - 'LICENSE' + - 'CODEOWNERS' + - 'renovate.json' + +jobs: + is-release-necessary: + name: 'Decide if a release of the engines artifacts is necessary' + runs-on: ubuntu-22.04 + outputs: + release: ${{ steps.decision.outputs.release }} + steps: + - uses: actions/checkout@v4 + with: + # using head ref rather than merge branch to get original commit message + ref: ${{ github.event.pull_request.head.sha }} + - name: Get commit message + id: commit-msg + run: | + commit_msg=$(git log --format=%B -n 1) + echo 'commit-msg<> $GITHUB_OUTPUT + echo "$commit_msg" >> $GITHUB_OUTPUT + echo 'EOF' >> $GITHUB_OUTPUT + + - name: Debug Pull Request Event + if: ${{ github.event_name == 'pull_request' }} + run: | + echo "Pull Request: ${{ github.event.pull_request.number }}" + echo "Repository Owner: ${{ github.repository_owner }}" + echo "Pull Request Author: ${{ github.actor }}" + echo "Pull Request Author Association: ${{ github.event.pull_request.author_association }}" + cat <> $GITHUB_OUTPUT + + # + # A patch branch (e.g. "4.6.x") + # + - name: Check if branch is a patch or integration branch + id: check-branch + uses: actions/github-script@v7 + env: + BRANCH: ${{ github.ref }} + with: + script: | + const { BRANCH } = process.env + const parts = BRANCH.split('.') + if (parts.length === 3 && parts[2] === 'x') { + console.log(`Branch is a patch branch: ${BRANCH}`) + core.setOutput('release', true) + } else if (BRANCH.startsWith("integration/")) { + console.log(`Branch is an "integration/" branch: ${BRANCH}`) + core.setOutput('release', true) + } else { + core.setOutput('release', false) + } + + - name: Debug event & outputs + env: + EVENT_NAME: ${{ github.event_name }} + EVENT_PATH: ${{ github.event_path }} + CHECK_COMMIT_MESSAGE: ${{ steps.check-commit-message.outputs.release }} + CHECK_BRANCH: ${{ steps.check-branch.outputs.release }} + run: | + echo "Event Name: $EVENT_NAME" + echo "Event path: $EVENT_PATH" + echo "Check Commit Message outputs: $CHECK_COMMIT_MESSAGE" + echo "Check branch: $CHECK_BRANCH" + + - name: Release is necessary! + # https://github.com/peter-evans/find-comment/tree/v3/?tab=readme-ov-file#outputs + # Tip: Empty strings evaluate to zero in GitHub Actions expressions. e.g. If comment-id is an empty string steps.fc.outputs.comment-id == 0 evaluates to true. + if: | + github.event_name == 'workflow_dispatch' || + github.event_name == 'push' || + steps.check-commit-message.outputs.release == 'true' || + steps.check-branch.outputs.release == 'true' + + id: decision + env: + EVENT_NAME: ${{ github.event_name }} + EVENT_PATH: ${{ github.event_path }} + CHECK_COMMIT_MESSAGE: ${{ steps.check-commit-message.outputs.release }} + CHECK_BRANCH: ${{ steps.check-branch.outputs.release }} + run: | + echo "Event Name: $EVENT_NAME" + echo "Event path: $EVENT_PATH" + echo "Check Commit Message outputs: $CHECK_COMMIT_MESSAGE" + echo "Check branch: $CHECK_BRANCH" + + echo "Release is necessary" + echo "release=true" >> $GITHUB_OUTPUT + + build-linux: + name: Build Engines for Linux + needs: + - is-release-necessary + if: ${{ needs.is-release-necessary.outputs.release == 'true' }} + uses: ./.github/workflows/build-engines-linux-template.yml + with: + commit: ${{ github.sha }} + + build-macos-intel: + name: Build Engines for Apple Intel + needs: + - is-release-necessary + if: ${{ needs.is-release-necessary.outputs.release == 'true' }} + uses: ./.github/workflows/build-engines-apple-intel-template.yml + with: + commit: ${{ github.sha }} + + build-macos-silicon: + name: Build Engines for Apple Silicon + needs: + - is-release-necessary + if: ${{ needs.is-release-necessary.outputs.release == 'true' }} + uses: ./.github/workflows/build-engines-apple-silicon-template.yml + with: + commit: ${{ github.sha }} + + build-react-native: + name: Build Engines for React native + needs: + - is-release-necessary + if: ${{ needs.is-release-necessary.outputs.release == 'true' }} + uses: ./.github/workflows/build-engines-react-native-template.yml + with: + commit: ${{ github.sha }} + + build-windows: + name: Build Engines for Windows + needs: + - is-release-necessary + if: ${{ needs.is-release-necessary.outputs.release == 'true' }} + uses: ./.github/workflows/build-engines-windows-template.yml + with: + commit: ${{ github.sha }} + + release-artifacts: + name: 'Release artifacts from branch ${{ github.head_ref || github.ref_name }} for commit ${{ github.sha }}' + runs-on: ubuntu-22.04 + concurrency: + group: ${{ github.sha }} + needs: + - build-linux + - build-macos-intel + - build-macos-silicon + - build-react-native + - build-windows + env: + BUCKET_NAME: 'prisma-builds' + PRISMA_ENGINES_COMMIT_SHA: ${{ github.sha }} + DESTINATION_TARGET_PATH: 's3://prisma-builds/all_commits/${{ github.sha }}' + + steps: + # Because we need the scripts + - name: Checkout git repository + uses: actions/checkout@v4 + + - uses: actions/download-artifact@v4 + with: + path: engines-artifacts + # For debug purposes + # A previous run ID can be specified, to avoid the build step + # First disable the build step, then specify the run ID + # The github-token is mandatory for this to work + # https://github.com/prisma/prisma-engines-builds/actions/runs/9526334324 + # run-id: 9526334324 + # github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: 'R2: Check if artifacts were already built and uploaded before via `.finished` file' + env: + FILE_PATH: 'all_commits/${{ github.sha }}/.finished' + FILE_PATH_LEGACY: 'all_commits/${{ github.sha }}/rhel-openssl-1.1.x/.finished' + AWS_DEFAULT_REGION: 'auto' + AWS_ACCESS_KEY_ID: ${{ vars.R2_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} + AWS_ENDPOINT_URL_S3: ${{ vars.R2_ENDPOINT }} + working-directory: .github/workflows/utils + run: bash checkFinishedMarker.sh + + - name: 'S3: Check if artifacts were already built and uploaded before via `.finished` file' + env: + FILE_PATH: 'all_commits/${{ github.sha }}/.finished' + FILE_PATH_LEGACY: 'all_commits/${{ github.sha }}/rhel-openssl-1.1.x/.finished' + AWS_DEFAULT_REGION: 'eu-west-1' + AWS_ACCESS_KEY_ID: ${{ vars.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + working-directory: .github/workflows/utils + run: bash checkFinishedMarker.sh + + - name: Display structure of downloaded files + run: ls -Rl engines-artifacts + + # TODO in a next major version of Prisma: remove this, and replace both `Debian` and `Rhel` with a single `LinuxGlibc`/`LinuxGnu` option. + - name: Duplicate engines for debian + working-directory: engines-artifacts + run: | + cp -r rhel-openssl-1.0.x debian-openssl-1.0.x + cp -r rhel-openssl-1.1.x debian-openssl-1.1.x + cp -r rhel-openssl-3.0.x debian-openssl-3.0.x + + - name: Create .zip for react-native + working-directory: engines-artifacts + run: | + mkdir react-native + zip -r react-native/binaries.zip ios android + rm -rf ios android + + - name: 'Create compressed engine files (.gz)' + working-directory: engines-artifacts + run: | + set -eu + + find . -type f -not -name "*.zip" | while read filename; do + gzip -c "$filename" > "$filename.gz" + echo "$filename.gz file created." + done + + ls -Rl . + + - name: 'Create SHA256 checksum files (.sha256).' + working-directory: engines-artifacts + run: | + set -eu + + find . -type f | while read filename; do + sha256sum "$filename" > "$filename.sha256" + echo "$filename.sha256 file created." + done + + ls -Rl . + + # https://github.com/crazy-max/ghaction-import-gpg + - name: Import GPG key + # See https://github.com/crazy-max/ghaction-import-gpg/releases + # v6 -> 01dd5d3ca463c7f10f7f4f7b4f177225ac661ee4 + # For security reasons, we should pin the version of the action + uses: crazy-max/ghaction-import-gpg@01dd5d3ca463c7f10f7f4f7b4f177225ac661ee4 + with: + gpg_private_key: ${{ secrets.GPG_PRIVATE_KEY }} + passphrase: ${{ secrets.GPG_KEY_PASSPHRASE }} + + - name: List keys + run: gpg -K + + # next to each file (excluding .sha256 files) + - name: 'Create a GPG detached signature (.sig)' + working-directory: engines-artifacts + run: | + set -eu + + for file in $(find . -type f ! -name "*.sha256"); do + gpg --detach-sign --armor --batch --output "${file#*/}.sig" "$file" + done + + ls -Rl . + + - name: 'Cloudflare R2: Upload to bucket and verify uploaded files then create `.finished` file' + # https://docs.aws.amazon.com/cli/v1/userguide/cli-configure-envvars.html + env: + AWS_DEFAULT_REGION: 'auto' + AWS_ACCESS_KEY_ID: ${{ vars.R2_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} + AWS_ENDPOINT_URL_S3: ${{ vars.R2_ENDPOINT }} + run: bash .github/workflows/utils/uploadAndVerify.sh engines-artifacts-for-r2 + + - name: 'AWS S3: Upload to bucket and verify uploaded files then create `.finished` file' + env: + AWS_DEFAULT_REGION: 'eu-west-1' + AWS_ACCESS_KEY_ID: ${{ vars.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + run: bash .github/workflows/utils/uploadAndVerify.sh engines-artifacts-for-s3 + + - name: Repository dispatch to prisma/engines-wrapper + uses: peter-evans/repository-dispatch@v3 + with: + repository: prisma/engines-wrapper + event-type: publish-engines + client-payload: '{ "commit": "${{ github.sha }}", "branch": "${{ github.head_ref || github.ref_name }}" }' + token: ${{ secrets.PRISMA_BOT_TOKEN }} + - name: Cleanup local directories + run: rm -rf engines-artifacts engines-artifacts-for-r2 engines-artifacts-for-s3 diff --git a/.github/workflows/codspeed.yml b/.github/workflows/codspeed.yml index 14b0d11108ca..f5958960cd62 100644 --- a/.github/workflows/codspeed.yml +++ b/.github/workflows/codspeed.yml @@ -25,7 +25,7 @@ jobs: - uses: actions-rust-lang/setup-rust-toolchain@v1 - name: Install cargo-codspeed - run: cargo install cargo-codspeed + run: cargo install --locked cargo-codspeed - name: "Build the benchmark targets: schema" run: cargo codspeed build -p schema --features all_connectors diff --git a/.github/workflows/include/rust-wasm-setup/action.yml b/.github/workflows/include/rust-wasm-setup/action.yml index 21b963f30fd3..81c197fa80bb 100644 --- a/.github/workflows/include/rust-wasm-setup/action.yml +++ b/.github/workflows/include/rust-wasm-setup/action.yml @@ -17,7 +17,7 @@ runs: shell: bash run: | cargo binstall -y \ - wasm-bindgen-cli@0.2.92 \ + wasm-bindgen-cli@0.2.93 \ wasm-opt@0.116.0 - name: Install bc diff --git a/.github/workflows/send-main-push-event.yml b/.github/workflows/send-main-push-event.yml deleted file mode 100644 index fa9294cba03f..000000000000 --- a/.github/workflows/send-main-push-event.yml +++ /dev/null @@ -1,20 +0,0 @@ -name: Trigger prisma-engines-builds run -run-name: Trigger prisma-engines-builds run for ${{ github.sha }} - -on: - push: - branches: - - main - -jobs: - send-commit-hash: - runs-on: ubuntu-22.04 - steps: - - run: echo "Sending event for commit $GITHUB_SHA" - - name: Workflow dispatch to prisma/prisma-engines-builds - uses: benc-uk/workflow-dispatch@v1 - with: - workflow: .github/workflows/build-engines.yml - repo: prisma/prisma-engines-builds - token: ${{ secrets.BOT_TOKEN_PRISMA_ENGINES_BUILD }} - inputs: '{ "commit": "${{ github.sha }}" }' diff --git a/.github/workflows/test-compilation.yml b/.github/workflows/test-compilation.yml index b1d80995f264..3db71c67b5e7 100644 --- a/.github/workflows/test-compilation.yml +++ b/.github/workflows/test-compilation.yml @@ -16,25 +16,43 @@ concurrency: jobs: test-crate-compilation: - name: "Check release compilation" + name: "${{ matrix.crate }} on ${{ matrix.os }}" strategy: fail-fast: false - runs-on: ubuntu-latest + matrix: + os: + - ubuntu-latest + - windows-latest + - macos-13 + crate: + - schema-engine-cli + - prisma-fmt + - query-engine + - query-engine-node-api + runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 - uses: actions-rust-lang/setup-rust-toolchain@v1 + - uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} - - run: "cargo clean && cargo build --release -p schema-engine-cli" - name: "Compile Migration Engine" - - - run: "cargo clean && cargo build --release -p prisma-fmt" - name: "Compile prisma-fmt" - - - run: "cargo clean && cargo build --release -p query-engine" - name: "Compile Query Engine Binary" - - - run: "cargo clean && cargo build --release -p query-engine-node-api" - name: "Compile Query Engine Library" + - name: compile ${{ matrix.crate }} + shell: bash + env: + CRATE: ${{ matrix.crate }} + run: cargo build --release -p "$CRATE" - name: "Check that Cargo.lock did not change" run: "git diff --exit-code" + + test-react-native-compilation: + name: React Native + uses: ./.github/workflows/build-engines-react-native-template.yml + with: + commit: ${{ github.sha }} + uploadArtifacts: false diff --git a/.github/workflows/test-query-engine-driver-adapters.yml b/.github/workflows/test-driver-adapters-template.yml similarity index 60% rename from .github/workflows/test-query-engine-driver-adapters.yml rename to .github/workflows/test-driver-adapters-template.yml index 011342f65fcd..eb45af278bc4 100644 --- a/.github/workflows/test-query-engine-driver-adapters.yml +++ b/.github/workflows/test-driver-adapters-template.yml @@ -1,48 +1,18 @@ name: "QE: driver-adapter integration tests" on: - push: - branches: - - main - pull_request: - paths-ignore: - - "!.github/workflows/test-query-engine-driver-adapters.yml" - - ".github/**" - - ".buildkite/**" - - "*.md" - - "LICENSE" - - "CODEOWNERS" - - "renovate.json" - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true + workflow_call: + inputs: + setup_task: + type: string + required: true jobs: rust-query-engine-tests: - name: "${{ matrix.adapter.name }} ${{ matrix.partition }}" + name: "${{ matrix.partition }}" strategy: fail-fast: false matrix: - adapter: - - name: "planetscale (napi)" - setup_task: "dev-planetscale-js" - - name: "pg (napi)" - setup_task: "dev-pg-js" - - name: "neon (napi)" - setup_task: "dev-neon-js" - - name: "libsql (napi)" - setup_task: "dev-libsql-js" - - name: "planetscale (wasm)" - setup_task: "dev-planetscale-wasm" - - name: "pg (wasm)" - setup_task: "dev-pg-wasm" - - name: "neon (wasm)" - setup_task: "dev-neon-wasm" - - name: "libsql (wasm)" - setup_task: "dev-libsql-wasm" - - name: "d1 (wasm)" - setup_task: "dev-d1" node_version: ["18"] partition: ["1/4", "2/4", "3/4", "4/4"] env: @@ -62,8 +32,8 @@ jobs: steps: - uses: actions/checkout@v4 with: + # using head ref rather than merge branch to get original commit message ref: ${{ github.event.pull_request.head.sha }} - - name: "Setup Node.js" uses: actions/setup-node@v4 with: @@ -74,11 +44,6 @@ jobs: with: version: 8 - - name: "Get pnpm store directory" - shell: bash - run: | - echo "STORE_PATH=$(pnpm store path --silent)" >> $GITHUB_ENV - - name: "Login to Docker Hub" uses: docker/login-action@v3 continue-on-error: true @@ -93,7 +58,9 @@ jobs: - name: Extract Branch Name id: extract-branch run: | + echo "Extracting branch name from: $(git show -s --format=%s)" branch="$(git show -s --format=%s | grep -o "DRIVER_ADAPTERS_BRANCH=[^ ]*" | cut -f2 -d=)" + echo "branch=$branch" if [ -n "$branch" ]; then echo "Using $branch branch of driver adapters" echo "DRIVER_ADAPTERS_BRANCH=$branch" >> "$GITHUB_ENV" @@ -102,7 +69,12 @@ jobs: - uses: ./.github/workflows/include/rust-wasm-setup - uses: taiki-e/install-action@nextest - - run: make ${{ matrix.adapter.setup_task }} + - name: Setup + env: + SETUP_TASK: ${{ inputs.setup_task }} + run: make "$SETUP_TASK" - name: "Run tests" - run: cargo nextest run --package query-engine-tests --test-threads=1 --partition hash:${{ matrix.partition }} + env: + PARTITION: ${{ matrix.partition }} + run: cargo nextest run --package query-engine-tests --test-threads=1 --partition hash:"$PARTITION" diff --git a/.github/workflows/test-query-engine-template.yml b/.github/workflows/test-query-engine-template.yml new file mode 100644 index 000000000000..9f20fdac1514 --- /dev/null +++ b/.github/workflows/test-query-engine-template.yml @@ -0,0 +1,82 @@ +name: "QE: integration template" +run-name: ${{ inputs.connector }} +on: + workflow_call: + inputs: + name: + type: string + required: true + connector: + type: string + required: true + version: + type: string + required: true + ubuntu: + type: string + default: 'latest' + single_threaded: + type: boolean + default: false + relation_load_strategy: + type: string + default: '["join", "query"]' + + +jobs: + rust-query-engine-tests: + name: "${{ matrix.engine_protocol }} ${{ matrix.relation_load_strategy }} ${{ matrix.partition }}" + + strategy: + fail-fast: false + matrix: + engine_protocol: [graphql, json] + relation_load_strategy: ${{ fromJson(inputs.relation_load_strategy) }} + partition: ["1/4", "2/4", "3/4", "4/4"] + + env: + LOG_LEVEL: "info" + LOG_QUERIES: "y" + RUST_LOG_FORMAT: "devel" + RUST_BACKTRACE: "1" + CLICOLOR_FORCE: "1" + CLOSED_TX_CLEANUP: "2" + SIMPLE_TEST_MODE: "1" + QUERY_BATCH_SIZE: "10" + TEST_RUNNER: "direct" + TEST_CONNECTOR: ${{ inputs.connector }} + TEST_CONNECTOR_VERSION: ${{ inputs.version }} + PRISMA_ENGINE_PROTOCOL: ${{ matrix.engine_protocol }} + PRISMA_RELATION_LOAD_STRATEGY: ${{ matrix.relation_load_strategy }} + + runs-on: "ubuntu-20.04" + # TODO: Replace with the following once `prisma@5.20.0` is released. + # runs-on: "ubuntu-${{ inputs.ubuntu }}" + steps: + - uses: actions/checkout@v4 + - uses: actions-rust-lang/setup-rust-toolchain@v1 + - uses: taiki-e/install-action@nextest + + - name: Login to Docker Hub + uses: docker/login-action@v3 + continue-on-error: true + env: + DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} + DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }} + if: "${{ env.DOCKERHUB_USERNAME != '' && env.DOCKERHUB_TOKEN != '' }}" + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: "Start ${{ inputs.name }} (${{ matrix.engine_protocol }})" + run: make start-${{ inputs.name }} + + - run: export WORKSPACE_ROOT=$(pwd) && cargo nextest run -p query-engine-tests --partition hash:${{ matrix.partition }} --test-threads=1 + if: ${{ inputs.single_threaded }} + env: + CLICOLOR_FORCE: 1 + + - run: export WORKSPACE_ROOT=$(pwd) && cargo nextest run -p query-engine-tests --partition hash:${{ matrix.partition }} --test-threads=8 + if: ${{ !inputs.single_threaded }} + env: + CLICOLOR_FORCE: 1 diff --git a/.github/workflows/test-query-engine.yml b/.github/workflows/test-query-engine.yml index 6c56575dabad..6ee6629fee80 100644 --- a/.github/workflows/test-query-engine.yml +++ b/.github/workflows/test-query-engine.yml @@ -1,4 +1,4 @@ -name: "QE: integration tests" +name: "QE" on: push: branches: @@ -6,6 +6,7 @@ on: pull_request: paths-ignore: - "!.github/workflows/test-query-engine.yml" + - "!.github/workflows/test-query-engine-template.yml" - ".github/**" - ".buildkite/**" - "*.md" @@ -18,105 +19,156 @@ concurrency: cancel-in-progress: true jobs: - rust-query-engine-tests: - name: "${{ matrix.database.name }} - ${{ matrix.engine_protocol }} ${{ matrix.relation_load_strategy }} ${{ matrix.partition }}" - + postgres: strategy: fail-fast: false matrix: database: - - name: "vitess_8_0" - single_threaded: true - connector: "vitess" - version: "8.0" - # Arbitrary PostgreSQL version - # we opted for the most recent one, there is no need to have a matrix - name: "postgres16" - single_threaded: true - connector: "postgres" version: "16" - - name: "mssql_2022" - single_threaded: false - connector: "sqlserver" - version: "2022" - - name: "sqlite" - single_threaded: false - connector: "sqlite" - version: "3" - - name: "mongodb_4_2" - single_threaded: true - connector: "mongodb" - version: "4.2" + - name: "postgres15" + version: "15" + - name: "postgres14" + version: "14" + - name: "postgres13" + version: "13" + - name: "postgres12" + version: "12" + - name: "postgres11" + version: "11" + - name: "postgres10" + version: "10" + - name: "postgres9" + version: "9" + uses: ./.github/workflows/test-query-engine-template.yml + name: postgres ${{ matrix.database.version }} + with: + name: ${{ matrix.database.name }} + version: ${{ matrix.database.version }} + connector: "postgres" + single_threaded: true + + mysql: + strategy: + fail-fast: false + matrix: + database: + - name: "mysql_5_6" + version: "5.6" + relation_load_strategy: '["query"]' + - name: "mysql_5_7" + version: "5.7" + relation_load_strategy: '["query"]' + - name: "mysql_8" + version: "8" + relation_load_strategy: '["join", "query"]' + - name: "mysql_mariadb" + version: "mariadb" + relation_load_strategy: '["query"]' + + uses: ./.github/workflows/test-query-engine-template.yml + name: mysql ${{ matrix.database.version }} + with: + name: ${{ matrix.database.name }} + version: ${{ matrix.database.version }} + connector: "mysql" + relation_load_strategy: ${{ matrix.database.relation_load_strategy }} + single_threaded: true + + cockroachdb: + strategy: + fail-fast: false + matrix: + database: - name: "cockroach_23_1" - single_threaded: false connector: "cockroachdb" version: "23.1" - name: "cockroach_22_2" - single_threaded: false - connector: "cockroachdb" version: "22.2" - name: "cockroach_22_1_0" - single_threaded: false - connector: "cockroachdb" version: "22.1" - - name: "mysql_8" - single_threaded: false - connector: "mysql" - version: "8" - engine_protocol: [graphql, json] - relation_load_strategy: [join, query] - partition: ["1/4", "2/4", "3/4", "4/4"] - exclude: - - relation_load_strategy: join - database: - [ - { "connector": "mongodb" }, - { "connector": "sqlite" }, - { "connector": "mssql_2022" }, - ] - - env: - LOG_LEVEL: "info" - LOG_QUERIES: "y" - RUST_LOG_FORMAT: "devel" - RUST_BACKTRACE: "1" - CLICOLOR_FORCE: "1" - CLOSED_TX_CLEANUP: "2" - SIMPLE_TEST_MODE: "1" - QUERY_BATCH_SIZE: "10" - TEST_RUNNER: "direct" - TEST_CONNECTOR: ${{ matrix.database.connector }} - TEST_CONNECTOR_VERSION: ${{ matrix.database.version }} - PRISMA_ENGINE_PROTOCOL: ${{ matrix.engine_protocol }} - PRISMA_RELATION_LOAD_STRATEGY: ${{ matrix.relation_load_strategy }} - WORKSPACE_ROOT: ${{ github.workspace }} + uses: ./.github/workflows/test-query-engine-template.yml + name: cockroachdb ${{ matrix.database.version }} + with: + name: ${{ matrix.database.name }} + version: ${{ matrix.database.version }} + connector: "cockroachdb" - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions-rust-lang/setup-rust-toolchain@v1 - - uses: taiki-e/install-action@nextest - - - name: Login to Docker Hub - uses: docker/login-action@v3 - continue-on-error: true - env: - DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} - DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }} - if: "${{ env.DOCKERHUB_USERNAME != '' && env.DOCKERHUB_TOKEN != '' }}" - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - - name: "Start ${{ matrix.database.name }} (${{ matrix.engine_protocol }})" - run: make start-${{ matrix.database.name }} + mongodb: + strategy: + fail-fast: false + matrix: + database: + - name: "mongodb_4_2" + version: "4.2" + - name: "mongodb_4_4" + version: "4.4" + - name: "mongodb_5" + connector: "mongodb" + version: "5" + uses: ./.github/workflows/test-query-engine-template.yml + name: mongodb ${{ matrix.database.version }} + with: + name: ${{ matrix.database.name }} + version: ${{ matrix.database.version }} + single_threaded: true + connector: "mongodb" + relation_load_strategy: '["query"]' - - run: cargo nextest run -p query-engine-tests --partition hash:${{ matrix.partition }} --test-threads=1 - if: ${{ matrix.database.single_threaded }} - env: - CLICOLOR_FORCE: 1 + mssql: + strategy: + fail-fast: false + matrix: + database: + - name: "mssql_2022" + version: "2022" + - name: "mssql_2019" + version: "2019" + - name: "mssql_2017" + version: "2017" + ubuntu: "20.04" + uses: ./.github/workflows/test-query-engine-template.yml + name: mssql ${{ matrix.database.version }} + with: + name: ${{ matrix.database.name }} + version: ${{ matrix.database.version }} + ubuntu: ${{ matrix.database.ubuntu }} + connector: "sqlserver" + relation_load_strategy: '["query"]' - - run: cargo nextest run -p query-engine-tests --partition hash:${{ matrix.partition }} --test-threads=8 - if: ${{ !matrix.database.single_threaded }} - env: - CLICOLOR_FORCE: 1 + sqlite: + uses: ./.github/workflows/test-query-engine-template.yml + name: sqlite + with: + name: "sqlite" + version: 3 + connector: "sqlite" + relation_load_strategy: '["query"]' + + driver_adapters: + strategy: + fail-fast: false + matrix: + adapter: + - name: "planetscale (napi)" + setup_task: "dev-planetscale-js" + - name: "pg (napi)" + setup_task: "dev-pg-js" + - name: "neon (napi)" + setup_task: "dev-neon-js" + - name: "libsql (napi)" + setup_task: "dev-libsql-js" + - name: "planetscale (wasm)" + setup_task: "dev-planetscale-wasm" + - name: "pg (wasm)" + setup_task: "dev-pg-wasm" + - name: "neon (wasm)" + setup_task: "dev-neon-wasm" + - name: "libsql (wasm)" + setup_task: "dev-libsql-wasm" + - name: "d1 (wasm)" + setup_task: "dev-d1" + name: ${{ matrix.adapter.name }} + uses: ./.github/workflows/test-driver-adapters-template.yml + with: + setup_task: ${{ matrix.adapter.setup_task }} diff --git a/.github/workflows/test-schema-engine.yml b/.github/workflows/test-schema-engine.yml index a661fd0554c8..88c9ba43f745 100644 --- a/.github/workflows/test-schema-engine.yml +++ b/.github/workflows/test-schema-engine.yml @@ -62,6 +62,7 @@ jobs: database: - name: mssql_2017 url: "sqlserver://localhost:1434;database=master;user=SA;password=;trustServerCertificate=true;socket_timeout=60;isolationLevel=READ UNCOMMITTED" + ubuntu: "20.04" - name: mssql_2019 url: "sqlserver://localhost:1433;database=master;user=SA;password=;trustServerCertificate=true;socket_timeout=60;isolationLevel=READ UNCOMMITTED" - name: mssql_2022 @@ -104,7 +105,9 @@ jobs: is_vitess: true single_threaded: true - runs-on: ubuntu-latest + runs-on: "ubuntu-20.04" + # TODO: Replace with the following once `prisma@5.20.0` is released. + # runs-on: "ubuntu-${{ matrix.database.ubuntu || 'latest' }}" steps: - uses: actions/checkout@v4 - uses: actions-rust-lang/setup-rust-toolchain@v1 diff --git a/.github/workflows/utils/checkFinishedMarker.sh b/.github/workflows/utils/checkFinishedMarker.sh new file mode 100644 index 000000000000..2eb57f1600b9 --- /dev/null +++ b/.github/workflows/utils/checkFinishedMarker.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +set -eux + +# We check if the .finished file marker exists in the S3 bucket +# i.e. 'all_commits/[COMMIT]/.finished' +object_exists=$(aws s3api head-object --bucket "$BUCKET_NAME" --key "$FILE_PATH" || true) + +if [ -z "$object_exists" ]; then +echo ".finished file marker was NOT found at $FILE_PATH. Continuing..." +else +echo "::error::.finished file marker was found at $FILE_PATH - This means that artifacts were already uploaded in a previous run. Aborting to avoid overwriting the artifacts.", +exit 1 +fi; + + +# When we were using our Buildkite pipeline +# Before this GitHub Actions pipeline +# We were uploading the artifacts for each build separately +# And the .finished file marker was in the same directory as the build target +# i.e. 'all_commits/[COMMIT]/rhel-openssl-1.1.x/.finished' +object_exists_in_legacy_path=$(aws s3api head-object --bucket "$BUCKET_NAME" --key "$FILE_PATH_LEGACY" || true) + +if [ -z "$object_exists_in_legacy_path" ]; then +echo "(legacy) .finished file marker was NOT found at $FILE_PATH. Continuing..." +else +echo "::error::(legacy) .finished file marker was found at $FILE_PATH - This means that artifacts were already uploaded in a previous run. Aborting to avoid overwriting the artifacts.", +exit 1 +fi; diff --git a/.github/workflows/utils/constructDockerBuildCommand.sh b/.github/workflows/utils/constructDockerBuildCommand.sh new file mode 100644 index 000000000000..6ec9e3072665 --- /dev/null +++ b/.github/workflows/utils/constructDockerBuildCommand.sh @@ -0,0 +1,37 @@ +#!/bin/bash +set -eux; + +DOCKER_WORKSPACE="/root/build" + +# Full command, Docker + Bash. +# In Bash, we use `git config` to avoid "fatal: detected dubious ownership in repository at /root/build" panic messages +# that can occur when Prisma Engines' `build.rs` scripts run `git rev-parse HEAD` to extract the current commit hash. +# See: https://www.kenmuse.com/blog/avoiding-dubious-ownership-in-dev-containers/. +command="docker run \ +-e SQLITE_MAX_VARIABLE_NUMBER=250000 \ +-e SQLITE_MAX_EXPR_DEPTH=10000 \ +-e LIBZ_SYS_STATIC=1 \ +-w ${DOCKER_WORKSPACE} \ +-v \"$(pwd)\":${DOCKER_WORKSPACE} \ +-v \"$HOME\"/.cargo/bin:/root/cargo/bin \ +-v \"$HOME\"/.cargo/registry/index:/root/cargo/registry/index \ +-v \"$HOME\"/.cargo/registry/cache:/root/cargo/registry/cache \ +-v \"$HOME\"/.cargo/git/db:/root/cargo/git/db \ +$IMAGE \ +bash -c \ + \" \ + git config --global --add safe.directory ${DOCKER_WORKSPACE} \ + && cargo clean \ + && cargo build --release -p query-engine --manifest-path query-engine/query-engine/Cargo.toml $TARGET_STRING $FEATURES_STRING \ + && cargo build --release -p query-engine-node-api --manifest-path query-engine/query-engine-node-api/Cargo.toml $TARGET_STRING $FEATURES_STRING \ + && cargo build --release -p schema-engine-cli --manifest-path schema-engine/cli/Cargo.toml $TARGET_STRING $FEATURES_STRING \ + && cargo build --release -p prisma-fmt --manifest-path prisma-fmt/Cargo.toml $TARGET_STRING $FEATURES_STRING \ + \" \ +" +# remove query-engine-node-api for "static" targets +if [[ "$TARGET_NAME" == *-static-* ]]; then + substring_to_replace="&& cargo build --release -p query-engine-node-api --manifest-path query-engine/query-engine-node-api/Cargo.toml $TARGET_STRING $FEATURES_STRING" + command="${command/$substring_to_replace/}" +fi + +echo "$command" \ No newline at end of file diff --git a/.github/workflows/utils/expectedFiles.txt b/.github/workflows/utils/expectedFiles.txt new file mode 100644 index 000000000000..2da123d8da2f --- /dev/null +++ b/.github/workflows/utils/expectedFiles.txt @@ -0,0 +1,373 @@ +. +./darwin +./darwin-arm64 +./darwin-arm64/libquery_engine.dylib.node.gz +./darwin-arm64/libquery_engine.dylib.node.gz.sha256 +./darwin-arm64/libquery_engine.dylib.node.gz.sig +./darwin-arm64/libquery_engine.dylib.node.sha256 +./darwin-arm64/libquery_engine.dylib.node.sig +./darwin-arm64/prisma-fmt.gz +./darwin-arm64/prisma-fmt.gz.sha256 +./darwin-arm64/prisma-fmt.gz.sig +./darwin-arm64/prisma-fmt.sha256 +./darwin-arm64/prisma-fmt.sig +./darwin-arm64/query-engine.gz +./darwin-arm64/query-engine.gz.sha256 +./darwin-arm64/query-engine.gz.sig +./darwin-arm64/query-engine.sha256 +./darwin-arm64/query-engine.sig +./darwin-arm64/schema-engine.gz +./darwin-arm64/schema-engine.gz.sha256 +./darwin-arm64/schema-engine.gz.sig +./darwin-arm64/schema-engine.sha256 +./darwin-arm64/schema-engine.sig +./darwin/libquery_engine.dylib.node.gz +./darwin/libquery_engine.dylib.node.gz.sha256 +./darwin/libquery_engine.dylib.node.gz.sig +./darwin/libquery_engine.dylib.node.sha256 +./darwin/libquery_engine.dylib.node.sig +./darwin/prisma-fmt.gz +./darwin/prisma-fmt.gz.sha256 +./darwin/prisma-fmt.gz.sig +./darwin/prisma-fmt.sha256 +./darwin/prisma-fmt.sig +./darwin/query-engine.gz +./darwin/query-engine.gz.sha256 +./darwin/query-engine.gz.sig +./darwin/query-engine.sha256 +./darwin/query-engine.sig +./darwin/schema-engine.gz +./darwin/schema-engine.gz.sha256 +./darwin/schema-engine.gz.sig +./darwin/schema-engine.sha256 +./darwin/schema-engine.sig +./debian-openssl-1.0.x +./debian-openssl-1.0.x/libquery_engine.so.node.gz +./debian-openssl-1.0.x/libquery_engine.so.node.gz.sha256 +./debian-openssl-1.0.x/libquery_engine.so.node.gz.sig +./debian-openssl-1.0.x/libquery_engine.so.node.sha256 +./debian-openssl-1.0.x/libquery_engine.so.node.sig +./debian-openssl-1.0.x/prisma-fmt.gz +./debian-openssl-1.0.x/prisma-fmt.gz.sha256 +./debian-openssl-1.0.x/prisma-fmt.gz.sig +./debian-openssl-1.0.x/prisma-fmt.sha256 +./debian-openssl-1.0.x/prisma-fmt.sig +./debian-openssl-1.0.x/query-engine.gz +./debian-openssl-1.0.x/query-engine.gz.sha256 +./debian-openssl-1.0.x/query-engine.gz.sig +./debian-openssl-1.0.x/query-engine.sha256 +./debian-openssl-1.0.x/query-engine.sig +./debian-openssl-1.0.x/schema-engine.gz +./debian-openssl-1.0.x/schema-engine.gz.sha256 +./debian-openssl-1.0.x/schema-engine.gz.sig +./debian-openssl-1.0.x/schema-engine.sha256 +./debian-openssl-1.0.x/schema-engine.sig +./debian-openssl-1.1.x +./debian-openssl-1.1.x/libquery_engine.so.node.gz +./debian-openssl-1.1.x/libquery_engine.so.node.gz.sha256 +./debian-openssl-1.1.x/libquery_engine.so.node.gz.sig +./debian-openssl-1.1.x/libquery_engine.so.node.sha256 +./debian-openssl-1.1.x/libquery_engine.so.node.sig +./debian-openssl-1.1.x/prisma-fmt.gz +./debian-openssl-1.1.x/prisma-fmt.gz.sha256 +./debian-openssl-1.1.x/prisma-fmt.gz.sig +./debian-openssl-1.1.x/prisma-fmt.sha256 +./debian-openssl-1.1.x/prisma-fmt.sig +./debian-openssl-1.1.x/query-engine.gz +./debian-openssl-1.1.x/query-engine.gz.sha256 +./debian-openssl-1.1.x/query-engine.gz.sig +./debian-openssl-1.1.x/query-engine.sha256 +./debian-openssl-1.1.x/query-engine.sig +./debian-openssl-1.1.x/schema-engine.gz +./debian-openssl-1.1.x/schema-engine.gz.sha256 +./debian-openssl-1.1.x/schema-engine.gz.sig +./debian-openssl-1.1.x/schema-engine.sha256 +./debian-openssl-1.1.x/schema-engine.sig +./debian-openssl-3.0.x +./debian-openssl-3.0.x/libquery_engine.so.node.gz +./debian-openssl-3.0.x/libquery_engine.so.node.gz.sha256 +./debian-openssl-3.0.x/libquery_engine.so.node.gz.sig +./debian-openssl-3.0.x/libquery_engine.so.node.sha256 +./debian-openssl-3.0.x/libquery_engine.so.node.sig +./debian-openssl-3.0.x/prisma-fmt.gz +./debian-openssl-3.0.x/prisma-fmt.gz.sha256 +./debian-openssl-3.0.x/prisma-fmt.gz.sig +./debian-openssl-3.0.x/prisma-fmt.sha256 +./debian-openssl-3.0.x/prisma-fmt.sig +./debian-openssl-3.0.x/query-engine.gz +./debian-openssl-3.0.x/query-engine.gz.sha256 +./debian-openssl-3.0.x/query-engine.gz.sig +./debian-openssl-3.0.x/query-engine.sha256 +./debian-openssl-3.0.x/query-engine.sig +./debian-openssl-3.0.x/schema-engine.gz +./debian-openssl-3.0.x/schema-engine.gz.sha256 +./debian-openssl-3.0.x/schema-engine.gz.sig +./debian-openssl-3.0.x/schema-engine.sha256 +./debian-openssl-3.0.x/schema-engine.sig +./linux-arm64-openssl-1.0.x +./linux-arm64-openssl-1.0.x/libquery_engine.so.node.gz +./linux-arm64-openssl-1.0.x/libquery_engine.so.node.gz.sha256 +./linux-arm64-openssl-1.0.x/libquery_engine.so.node.gz.sig +./linux-arm64-openssl-1.0.x/libquery_engine.so.node.sha256 +./linux-arm64-openssl-1.0.x/libquery_engine.so.node.sig +./linux-arm64-openssl-1.0.x/prisma-fmt.gz +./linux-arm64-openssl-1.0.x/prisma-fmt.gz.sha256 +./linux-arm64-openssl-1.0.x/prisma-fmt.gz.sig +./linux-arm64-openssl-1.0.x/prisma-fmt.sha256 +./linux-arm64-openssl-1.0.x/prisma-fmt.sig +./linux-arm64-openssl-1.0.x/query-engine.gz +./linux-arm64-openssl-1.0.x/query-engine.gz.sha256 +./linux-arm64-openssl-1.0.x/query-engine.gz.sig +./linux-arm64-openssl-1.0.x/query-engine.sha256 +./linux-arm64-openssl-1.0.x/query-engine.sig +./linux-arm64-openssl-1.0.x/schema-engine.gz +./linux-arm64-openssl-1.0.x/schema-engine.gz.sha256 +./linux-arm64-openssl-1.0.x/schema-engine.gz.sig +./linux-arm64-openssl-1.0.x/schema-engine.sha256 +./linux-arm64-openssl-1.0.x/schema-engine.sig +./linux-arm64-openssl-1.1.x +./linux-arm64-openssl-1.1.x/libquery_engine.so.node.gz +./linux-arm64-openssl-1.1.x/libquery_engine.so.node.gz.sha256 +./linux-arm64-openssl-1.1.x/libquery_engine.so.node.gz.sig +./linux-arm64-openssl-1.1.x/libquery_engine.so.node.sha256 +./linux-arm64-openssl-1.1.x/libquery_engine.so.node.sig +./linux-arm64-openssl-1.1.x/prisma-fmt.gz +./linux-arm64-openssl-1.1.x/prisma-fmt.gz.sha256 +./linux-arm64-openssl-1.1.x/prisma-fmt.gz.sig +./linux-arm64-openssl-1.1.x/prisma-fmt.sha256 +./linux-arm64-openssl-1.1.x/prisma-fmt.sig +./linux-arm64-openssl-1.1.x/query-engine.gz +./linux-arm64-openssl-1.1.x/query-engine.gz.sha256 +./linux-arm64-openssl-1.1.x/query-engine.gz.sig +./linux-arm64-openssl-1.1.x/query-engine.sha256 +./linux-arm64-openssl-1.1.x/query-engine.sig +./linux-arm64-openssl-1.1.x/schema-engine.gz +./linux-arm64-openssl-1.1.x/schema-engine.gz.sha256 +./linux-arm64-openssl-1.1.x/schema-engine.gz.sig +./linux-arm64-openssl-1.1.x/schema-engine.sha256 +./linux-arm64-openssl-1.1.x/schema-engine.sig +./linux-arm64-openssl-3.0.x +./linux-arm64-openssl-3.0.x/libquery_engine.so.node.gz +./linux-arm64-openssl-3.0.x/libquery_engine.so.node.gz.sha256 +./linux-arm64-openssl-3.0.x/libquery_engine.so.node.gz.sig +./linux-arm64-openssl-3.0.x/libquery_engine.so.node.sha256 +./linux-arm64-openssl-3.0.x/libquery_engine.so.node.sig +./linux-arm64-openssl-3.0.x/prisma-fmt.gz +./linux-arm64-openssl-3.0.x/prisma-fmt.gz.sha256 +./linux-arm64-openssl-3.0.x/prisma-fmt.gz.sig +./linux-arm64-openssl-3.0.x/prisma-fmt.sha256 +./linux-arm64-openssl-3.0.x/prisma-fmt.sig +./linux-arm64-openssl-3.0.x/query-engine.gz +./linux-arm64-openssl-3.0.x/query-engine.gz.sha256 +./linux-arm64-openssl-3.0.x/query-engine.gz.sig +./linux-arm64-openssl-3.0.x/query-engine.sha256 +./linux-arm64-openssl-3.0.x/query-engine.sig +./linux-arm64-openssl-3.0.x/schema-engine.gz +./linux-arm64-openssl-3.0.x/schema-engine.gz.sha256 +./linux-arm64-openssl-3.0.x/schema-engine.gz.sig +./linux-arm64-openssl-3.0.x/schema-engine.sha256 +./linux-arm64-openssl-3.0.x/schema-engine.sig +./linux-musl +./linux-musl-arm64-openssl-1.1.x +./linux-musl-arm64-openssl-1.1.x/libquery_engine.so.node.gz +./linux-musl-arm64-openssl-1.1.x/libquery_engine.so.node.gz.sha256 +./linux-musl-arm64-openssl-1.1.x/libquery_engine.so.node.gz.sig +./linux-musl-arm64-openssl-1.1.x/libquery_engine.so.node.sha256 +./linux-musl-arm64-openssl-1.1.x/libquery_engine.so.node.sig +./linux-musl-arm64-openssl-1.1.x/prisma-fmt.gz +./linux-musl-arm64-openssl-1.1.x/prisma-fmt.gz.sha256 +./linux-musl-arm64-openssl-1.1.x/prisma-fmt.gz.sig +./linux-musl-arm64-openssl-1.1.x/prisma-fmt.sha256 +./linux-musl-arm64-openssl-1.1.x/prisma-fmt.sig +./linux-musl-arm64-openssl-1.1.x/query-engine.gz +./linux-musl-arm64-openssl-1.1.x/query-engine.gz.sha256 +./linux-musl-arm64-openssl-1.1.x/query-engine.gz.sig +./linux-musl-arm64-openssl-1.1.x/query-engine.sha256 +./linux-musl-arm64-openssl-1.1.x/query-engine.sig +./linux-musl-arm64-openssl-1.1.x/schema-engine.gz +./linux-musl-arm64-openssl-1.1.x/schema-engine.gz.sha256 +./linux-musl-arm64-openssl-1.1.x/schema-engine.gz.sig +./linux-musl-arm64-openssl-1.1.x/schema-engine.sha256 +./linux-musl-arm64-openssl-1.1.x/schema-engine.sig +./linux-musl-arm64-openssl-3.0.x +./linux-musl-arm64-openssl-3.0.x/libquery_engine.so.node.gz +./linux-musl-arm64-openssl-3.0.x/libquery_engine.so.node.gz.sha256 +./linux-musl-arm64-openssl-3.0.x/libquery_engine.so.node.gz.sig +./linux-musl-arm64-openssl-3.0.x/libquery_engine.so.node.sha256 +./linux-musl-arm64-openssl-3.0.x/libquery_engine.so.node.sig +./linux-musl-arm64-openssl-3.0.x/prisma-fmt.gz +./linux-musl-arm64-openssl-3.0.x/prisma-fmt.gz.sha256 +./linux-musl-arm64-openssl-3.0.x/prisma-fmt.gz.sig +./linux-musl-arm64-openssl-3.0.x/prisma-fmt.sha256 +./linux-musl-arm64-openssl-3.0.x/prisma-fmt.sig +./linux-musl-arm64-openssl-3.0.x/query-engine.gz +./linux-musl-arm64-openssl-3.0.x/query-engine.gz.sha256 +./linux-musl-arm64-openssl-3.0.x/query-engine.gz.sig +./linux-musl-arm64-openssl-3.0.x/query-engine.sha256 +./linux-musl-arm64-openssl-3.0.x/query-engine.sig +./linux-musl-arm64-openssl-3.0.x/schema-engine.gz +./linux-musl-arm64-openssl-3.0.x/schema-engine.gz.sha256 +./linux-musl-arm64-openssl-3.0.x/schema-engine.gz.sig +./linux-musl-arm64-openssl-3.0.x/schema-engine.sha256 +./linux-musl-arm64-openssl-3.0.x/schema-engine.sig +./linux-musl-openssl-3.0.x +./linux-musl-openssl-3.0.x/libquery_engine.so.node.gz +./linux-musl-openssl-3.0.x/libquery_engine.so.node.gz.sha256 +./linux-musl-openssl-3.0.x/libquery_engine.so.node.gz.sig +./linux-musl-openssl-3.0.x/libquery_engine.so.node.sha256 +./linux-musl-openssl-3.0.x/libquery_engine.so.node.sig +./linux-musl-openssl-3.0.x/prisma-fmt.gz +./linux-musl-openssl-3.0.x/prisma-fmt.gz.sha256 +./linux-musl-openssl-3.0.x/prisma-fmt.gz.sig +./linux-musl-openssl-3.0.x/prisma-fmt.sha256 +./linux-musl-openssl-3.0.x/prisma-fmt.sig +./linux-musl-openssl-3.0.x/query-engine.gz +./linux-musl-openssl-3.0.x/query-engine.gz.sha256 +./linux-musl-openssl-3.0.x/query-engine.gz.sig +./linux-musl-openssl-3.0.x/query-engine.sha256 +./linux-musl-openssl-3.0.x/query-engine.sig +./linux-musl-openssl-3.0.x/schema-engine.gz +./linux-musl-openssl-3.0.x/schema-engine.gz.sha256 +./linux-musl-openssl-3.0.x/schema-engine.gz.sig +./linux-musl-openssl-3.0.x/schema-engine.sha256 +./linux-musl-openssl-3.0.x/schema-engine.sig +./linux-musl/libquery_engine.so.node.gz +./linux-musl/libquery_engine.so.node.gz.sha256 +./linux-musl/libquery_engine.so.node.gz.sig +./linux-musl/libquery_engine.so.node.sha256 +./linux-musl/libquery_engine.so.node.sig +./linux-musl/prisma-fmt.gz +./linux-musl/prisma-fmt.gz.sha256 +./linux-musl/prisma-fmt.gz.sig +./linux-musl/prisma-fmt.sha256 +./linux-musl/prisma-fmt.sig +./linux-musl/query-engine.gz +./linux-musl/query-engine.gz.sha256 +./linux-musl/query-engine.gz.sig +./linux-musl/query-engine.sha256 +./linux-musl/query-engine.sig +./linux-musl/schema-engine.gz +./linux-musl/schema-engine.gz.sha256 +./linux-musl/schema-engine.gz.sig +./linux-musl/schema-engine.sha256 +./linux-musl/schema-engine.sig +./linux-static-arm64 +./linux-static-arm64/prisma-fmt.gz +./linux-static-arm64/prisma-fmt.gz.sha256 +./linux-static-arm64/prisma-fmt.gz.sig +./linux-static-arm64/prisma-fmt.sha256 +./linux-static-arm64/prisma-fmt.sig +./linux-static-arm64/query-engine.gz +./linux-static-arm64/query-engine.gz.sha256 +./linux-static-arm64/query-engine.gz.sig +./linux-static-arm64/query-engine.sha256 +./linux-static-arm64/query-engine.sig +./linux-static-arm64/schema-engine.gz +./linux-static-arm64/schema-engine.gz.sha256 +./linux-static-arm64/schema-engine.gz.sig +./linux-static-arm64/schema-engine.sha256 +./linux-static-arm64/schema-engine.sig +./linux-static-x64 +./linux-static-x64/prisma-fmt.gz +./linux-static-x64/prisma-fmt.gz.sha256 +./linux-static-x64/prisma-fmt.gz.sig +./linux-static-x64/prisma-fmt.sha256 +./linux-static-x64/prisma-fmt.sig +./linux-static-x64/query-engine.gz +./linux-static-x64/query-engine.gz.sha256 +./linux-static-x64/query-engine.gz.sig +./linux-static-x64/query-engine.sha256 +./linux-static-x64/query-engine.sig +./linux-static-x64/schema-engine.gz +./linux-static-x64/schema-engine.gz.sha256 +./linux-static-x64/schema-engine.gz.sig +./linux-static-x64/schema-engine.sha256 +./linux-static-x64/schema-engine.sig +./react-native +./react-native/binaries.zip +./react-native/binaries.zip.sha256 +./react-native/binaries.zip.sig +./rhel-openssl-1.0.x +./rhel-openssl-1.0.x/libquery_engine.so.node.gz +./rhel-openssl-1.0.x/libquery_engine.so.node.gz.sha256 +./rhel-openssl-1.0.x/libquery_engine.so.node.gz.sig +./rhel-openssl-1.0.x/libquery_engine.so.node.sha256 +./rhel-openssl-1.0.x/libquery_engine.so.node.sig +./rhel-openssl-1.0.x/prisma-fmt.gz +./rhel-openssl-1.0.x/prisma-fmt.gz.sha256 +./rhel-openssl-1.0.x/prisma-fmt.gz.sig +./rhel-openssl-1.0.x/prisma-fmt.sha256 +./rhel-openssl-1.0.x/prisma-fmt.sig +./rhel-openssl-1.0.x/query-engine.gz +./rhel-openssl-1.0.x/query-engine.gz.sha256 +./rhel-openssl-1.0.x/query-engine.gz.sig +./rhel-openssl-1.0.x/query-engine.sha256 +./rhel-openssl-1.0.x/query-engine.sig +./rhel-openssl-1.0.x/schema-engine.gz +./rhel-openssl-1.0.x/schema-engine.gz.sha256 +./rhel-openssl-1.0.x/schema-engine.gz.sig +./rhel-openssl-1.0.x/schema-engine.sha256 +./rhel-openssl-1.0.x/schema-engine.sig +./rhel-openssl-1.1.x +./rhel-openssl-1.1.x/libquery_engine.so.node.gz +./rhel-openssl-1.1.x/libquery_engine.so.node.gz.sha256 +./rhel-openssl-1.1.x/libquery_engine.so.node.gz.sig +./rhel-openssl-1.1.x/libquery_engine.so.node.sha256 +./rhel-openssl-1.1.x/libquery_engine.so.node.sig +./rhel-openssl-1.1.x/prisma-fmt.gz +./rhel-openssl-1.1.x/prisma-fmt.gz.sha256 +./rhel-openssl-1.1.x/prisma-fmt.gz.sig +./rhel-openssl-1.1.x/prisma-fmt.sha256 +./rhel-openssl-1.1.x/prisma-fmt.sig +./rhel-openssl-1.1.x/query-engine.gz +./rhel-openssl-1.1.x/query-engine.gz.sha256 +./rhel-openssl-1.1.x/query-engine.gz.sig +./rhel-openssl-1.1.x/query-engine.sha256 +./rhel-openssl-1.1.x/query-engine.sig +./rhel-openssl-1.1.x/schema-engine.gz +./rhel-openssl-1.1.x/schema-engine.gz.sha256 +./rhel-openssl-1.1.x/schema-engine.gz.sig +./rhel-openssl-1.1.x/schema-engine.sha256 +./rhel-openssl-1.1.x/schema-engine.sig +./rhel-openssl-3.0.x +./rhel-openssl-3.0.x/libquery_engine.so.node.gz +./rhel-openssl-3.0.x/libquery_engine.so.node.gz.sha256 +./rhel-openssl-3.0.x/libquery_engine.so.node.gz.sig +./rhel-openssl-3.0.x/libquery_engine.so.node.sha256 +./rhel-openssl-3.0.x/libquery_engine.so.node.sig +./rhel-openssl-3.0.x/prisma-fmt.gz +./rhel-openssl-3.0.x/prisma-fmt.gz.sha256 +./rhel-openssl-3.0.x/prisma-fmt.gz.sig +./rhel-openssl-3.0.x/prisma-fmt.sha256 +./rhel-openssl-3.0.x/prisma-fmt.sig +./rhel-openssl-3.0.x/query-engine.gz +./rhel-openssl-3.0.x/query-engine.gz.sha256 +./rhel-openssl-3.0.x/query-engine.gz.sig +./rhel-openssl-3.0.x/query-engine.sha256 +./rhel-openssl-3.0.x/query-engine.sig +./rhel-openssl-3.0.x/schema-engine.gz +./rhel-openssl-3.0.x/schema-engine.gz.sha256 +./rhel-openssl-3.0.x/schema-engine.gz.sig +./rhel-openssl-3.0.x/schema-engine.sha256 +./rhel-openssl-3.0.x/schema-engine.sig +./windows +./windows/prisma-fmt.exe.gz +./windows/prisma-fmt.exe.gz.sha256 +./windows/prisma-fmt.exe.gz.sig +./windows/prisma-fmt.exe.sha256 +./windows/prisma-fmt.exe.sig +./windows/query-engine.exe.gz +./windows/query-engine.exe.gz.sha256 +./windows/query-engine.exe.gz.sig +./windows/query-engine.exe.sha256 +./windows/query-engine.exe.sig +./windows/query_engine.dll.node.gz +./windows/query_engine.dll.node.gz.sha256 +./windows/query_engine.dll.node.gz.sig +./windows/query_engine.dll.node.sha256 +./windows/query_engine.dll.node.sig +./windows/schema-engine.exe.gz +./windows/schema-engine.exe.gz.sha256 +./windows/schema-engine.exe.gz.sig +./windows/schema-engine.exe.sha256 +./windows/schema-engine.exe.sig diff --git a/.github/workflows/utils/uploadAndVerify.sh b/.github/workflows/utils/uploadAndVerify.sh new file mode 100644 index 000000000000..606a4c2532fa --- /dev/null +++ b/.github/workflows/utils/uploadAndVerify.sh @@ -0,0 +1,128 @@ +#!/bin/bash + +set -eux; + +# engines-artifacts-for-r2 +# engines-artifacts-for-s3 +LOCAL_DIR_PATH=$1 + +if [ -z "$LOCAL_DIR_PATH" ]; then + echo "::error::LOCAL_DIR_PATH is not set." + exit 1 +fi + +echo "Uploading files..." +cd engines-artifacts +aws s3 sync . "$DESTINATION_TARGET_PATH" --no-progress \ + --exclude "*" \ + --include "*.gz" \ + --include "*.zip" \ + --include "*.sha256" \ + --include "*.sig" +cd ".." + +echo "Downloading files..." +mkdir "$LOCAL_DIR_PATH" +cd "$LOCAL_DIR_PATH" +aws s3 sync "$DESTINATION_TARGET_PATH" . --no-progress + +echo "Verifing downloaded files..." +ls -R . + +FILECOUNT_FOR_SHA256=$(find . -type f -name "*.sha256" | wc -l) +if [ "$FILECOUNT_FOR_SHA256" -eq 0 ]; then + echo "::error::No .sha256 files found." + exit 1 +fi + +FILECOUNT_FOR_GZ=$(find . -type f -name "*.gz" | wc -l) +if [ "$FILECOUNT_FOR_GZ" -eq 0 ]; then + echo "::error::No .gz files found." + exit 1 +fi + +FILECOUNT_FOR_SIG=$(find . -type f -name "*.sig" | wc -l) +if [ "$FILECOUNT_FOR_SIG" -eq 0 ]; then + echo "::error::No .sig files found." + exit 1 +fi + +# Manual check +# +# Set PROD env vars +# mkdir engines-artifacts-from-prod +# Download the artifacts from the S3 bucket +# aws s3 sync s3://prisma-builds/all_commits/6f3b8db04fa234ab2812fdd27456e9d9590eedb1 engines-artifacts-from-prod/ +# Print the files and save the output to a file +# cd engines-artifacts-from-prod +# find . | sort > ../expectedFiles.txt +# +# cd .. +# +# Set DEV env vars +# mkdir engines-artifacts-from-dev +# Download the artifacts from the S3 bucket +# aws s3 sync s3://prisma-builds-github-actions/all_commits/6f3b8db04fa234ab2812fdd27456e9d9590eedb1 engines-artifacts-from-dev/ +# Print the files and save the output to a file +# cd engines-artifacts-from-dev +# find . | sort > ../currentFiles.txt + +# Automated check +# expectedFiles.txt is in the same directory as this script + +echo "Create list of files" +find . | sort > ../currentFiles.txt +cd .. +echo "Comparing expectedFiles.txt vs currentFiles.txt" +diff -c .github/workflows/utils/expectedFiles.txt currentFiles.txt +cd "$LOCAL_DIR_PATH" + +# Unpack all .gz files first +find . -type f | while read -r filename; do + echo "Unpacking $filename file." + gzip -d "$filename" --keep -q +done + +# Verify .sha256 files +find . -type f -name "*.sha256" | while read -r filename; do + echo "Validating sha256 sum." + sha256sum -c "$filename" +done + +# Verify .sig files +find . -type f -name "*.sig" | while read -r filename; do + # Remove .sig from the file name + fileToVerify=$(echo "$filename" | rev | cut -c5- | rev) + + echo "Validating signature $filename for $fileToVerify" + gpg --verify "$filename" "$fileToVerify" +done + +echo "Validating OpenSSL linking." +if [[ "$(uname)" == 'Darwin' ]]; then + echo "::error::Mac OS does not have ldd command." + exit 1 +fi + +FILES_TO_VALIDATE_WITH_LDD=$(find . -type f | grep -E "./(rhel|debian)-openssl-(3.0|1.1).*(query-engine|schema-engine|libquery_engine.so.node)$") +echo "FILES_TO_VALIDATE_WITH_LDD: $FILES_TO_VALIDATE_WITH_LDD" + +for filename in $FILES_TO_VALIDATE_WITH_LDD +do + echo "Validating libssl linking for $filename." + GREP_OUTPUT=$(ldd "$filename" | grep "libssl") + OUTPUT=$(echo "$GREP_OUTPUT" | cut -f2 | cut -d'.' -f1) + + if [[ "$OUTPUT" == "libssl" ]]; then + echo "Linux build linked correctly to libssl." + else + echo "GREP_OUTPUT: $GREP_OUTPUT" + echo "Linux build linked incorrectly to libssl." + exit 1 + fi +done + +echo "Upload .finished marker file" +touch .finished +aws s3 cp .finished "$DESTINATION_TARGET_PATH/.finished" +rm .finished diff --git a/.github/workflows/wasm-benchmarks.yml b/.github/workflows/wasm-benchmarks.yml index 160abfaf5b50..cfbe4ca7b531 100644 --- a/.github/workflows/wasm-benchmarks.yml +++ b/.github/workflows/wasm-benchmarks.yml @@ -23,6 +23,7 @@ jobs: - name: Checkout PR branch uses: actions/checkout@v4 with: + # using head ref rather than merge branch to get original commit message ref: ${{ github.event.pull_request.head.sha }} - uses: ./.github/workflows/include/rust-wasm-setup @@ -48,7 +49,9 @@ jobs: - name: Extract Branch Name run: | + echo "Extracting branch name from: $(git show -s --format=%s)" branch="$(git show -s --format=%s | grep -o "DRIVER_ADAPTERS_BRANCH=[^ ]*" | cut -f2 -d=)" + echo "branch=$branch" if [ -n "$branch" ]; then echo "Using $branch branch of driver adapters" echo "DRIVER_ADAPTERS_BRANCH=$branch" >> "$GITHUB_ENV" diff --git a/Cargo.lock b/Cargo.lock index fbe4a8ffb542..b915dbf08a69 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -30,9 +30,9 @@ dependencies = [ [[package]] name = "ahash" -version = "0.8.7" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", "getrandom 0.2.11", @@ -41,15 +41,6 @@ dependencies = [ "zerocopy", ] -[[package]] -name = "aho-corasick" -version = "0.7.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" -dependencies = [ - "memchr", -] - [[package]] name = "aho-corasick" version = "1.0.3" @@ -150,7 +141,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.58", ] [[package]] @@ -161,7 +152,35 @@ checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.58", +] + +[[package]] +name = "async-tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90e661b6cb0a6eb34d02c520b052daa3aa9ac0cc02495c9d066bbce13ead132b" +dependencies = [ + "futures-io", + "futures-util", + "log", + "native-tls", + "pin-project-lite", + "tokio", + "tokio-native-tls", + "tungstenite", +] + +[[package]] +name = "async_io_stream" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d7b9decdf35d8908a7e3ef02f64c5e9b1695e230154c0e8de3969142d9b94c" +dependencies = [ + "futures", + "pharos", + "rustc_version", + "tokio", ] [[package]] @@ -178,12 +197,12 @@ dependencies = [ ] [[package]] -name = "atomic-shim" -version = "0.2.0" +name = "atoi" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67cd4b51d303cf3501c301e8125df442128d3c6d7c69f71b27833d253de47e77" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" dependencies = [ - "crossbeam-utils", + "num-traits", ] [[package]] @@ -257,6 +276,12 @@ version = "0.21.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bigdecimal" version = "0.3.1" @@ -325,7 +350,7 @@ dependencies = [ "enumflags2", "indoc 2.0.3", "insta", - "query-engine-metrics", + "prisma-metrics", "query-engine-tests", "query-tests-setup", "regex", @@ -400,16 +425,16 @@ dependencies = [ [[package]] name = "bson" -version = "2.8.1" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88c18b51216e1f74b9d769cead6ace2f82b965b807e3d73330aabe9faec31c84" +checksum = "d8a88e82b9106923b5c4d6edfca9e7db958d4e98a478ec115022e81b9b38e2c8" dependencies = [ - "ahash 0.8.7", + "ahash 0.8.11", "base64 0.13.1", "bitvec", "chrono", "hex", - "indexmap 1.9.3", + "indexmap 2.2.2", "js-sys", "once_cell", "rand 0.8.5", @@ -429,6 +454,10 @@ dependencies = [ "memchr", ] +[[package]] +name = "build-utils" +version = "0.1.0" + [[package]] name = "bumpalo" version = "3.13.0" @@ -465,9 +494,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "bytes" -version = "1.4.0" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" [[package]] name = "cast" @@ -496,11 +525,11 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.83" +version = "1.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +checksum = "b16803a61b81d9eabb7eae2588776c4c1e584b738ede45fdbb4c972cec1e9945" dependencies = [ - "libc", + "shlex", ] [[package]] @@ -520,15 +549,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "cfg_aliases" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" - -[[package]] -name = "cfg_aliases" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77e53693616d3075149f4ead59bdeecd204ac6b8192d8969757601b74bddf00f" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" @@ -676,6 +699,25 @@ dependencies = [ "unreachable", ] +[[package]] +name = "concat-idents" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f76990911f2267d837d9d0ad060aa63aaad170af40904b29461734c339030d4d" +dependencies = [ + "quote", + "syn 2.0.58", +] + +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "connection-string" version = "0.2.0" @@ -863,9 +905,12 @@ dependencies = [ name = "crosstarget-utils" version = "0.1.0" dependencies = [ + "derive_more", + "enumflags2", "futures", "js-sys", "pin-project", + "regex", "tokio", "wasm-bindgen", "wasm-bindgen-futures", @@ -888,7 +933,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f34ba9a9bcb8645379e9de8cb3ecfcf4d1c85ba66d90deb3259206fa5aa193b" dependencies = [ "quote", - "syn 2.0.48", + "syn 2.0.58", ] [[package]] @@ -943,6 +988,16 @@ dependencies = [ "darling_macro 0.13.4", ] +[[package]] +name = "darling" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +dependencies = [ + "darling_core 0.20.10", + "darling_macro 0.20.10", +] + [[package]] name = "darling_core" version = "0.10.2" @@ -971,6 +1026,20 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "darling_core" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.11.1", + "syn 2.0.58", +] + [[package]] name = "darling_macro" version = "0.10.2" @@ -993,6 +1062,17 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "darling_macro" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +dependencies = [ + "darling_core 0.20.10", + "quote", + "syn 2.0.58", +] + [[package]] name = "dashmap" version = "5.5.0" @@ -1000,10 +1080,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6943ae99c34386c84a470c499d3414f66502a41340aa895406e0d2e4a207b91d" dependencies = [ "cfg-if", - "hashbrown 0.14.3", + "hashbrown 0.14.5", "lock_api", "once_cell", - "parking_lot_core 0.9.8", + "parking_lot_core", ] [[package]] @@ -1027,9 +1107,13 @@ dependencies = [ [[package]] name = "deranged" -version = "0.3.7" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7684a49fb1af197853ef7b2ee694bc1f5b4179556f1e5710e1760c5db6f5e929" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", + "serde", +] [[package]] name = "derivative" @@ -1051,7 +1135,7 @@ dependencies = [ "convert_case 0.4.0", "proc-macro2", "quote", - "rustc_version 0.4.0", + "rustc_version", "syn 1.0.109", ] @@ -1124,11 +1208,11 @@ dependencies = [ "expect-test", "futures", "js-sys", - "metrics 0.18.1", "napi", "napi-derive", "once_cell", "pin-project", + "prisma-metrics", "quaint", "serde", "serde-wasm-bindgen", @@ -1156,70 +1240,6 @@ version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" -[[package]] -name = "encoding" -version = "0.2.33" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b0d943856b990d12d3b55b359144ff341533e516d94098b1d3fc1ac666d36ec" -dependencies = [ - "encoding-index-japanese", - "encoding-index-korean", - "encoding-index-simpchinese", - "encoding-index-singlebyte", - "encoding-index-tradchinese", -] - -[[package]] -name = "encoding-index-japanese" -version = "1.20141219.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04e8b2ff42e9a05335dbf8b5c6f7567e5591d0d916ccef4e0b1710d32a0d0c91" -dependencies = [ - "encoding_index_tests", -] - -[[package]] -name = "encoding-index-korean" -version = "1.20141219.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dc33fb8e6bcba213fe2f14275f0963fd16f0a02c878e3095ecfdf5bee529d81" -dependencies = [ - "encoding_index_tests", -] - -[[package]] -name = "encoding-index-simpchinese" -version = "1.20141219.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d87a7194909b9118fc707194baa434a4e3b0fb6a5a757c73c3adb07aa25031f7" -dependencies = [ - "encoding_index_tests", -] - -[[package]] -name = "encoding-index-singlebyte" -version = "1.20141219.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3351d5acffb224af9ca265f435b859c7c01537c0849754d3db3fdf2bfe2ae84a" -dependencies = [ - "encoding_index_tests", -] - -[[package]] -name = "encoding-index-tradchinese" -version = "1.20141219.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd0e20d5688ce3cab59eb3ef3a2083a5c77bf496cb798dc6fcdb75f323890c18" -dependencies = [ - "encoding_index_tests", -] - -[[package]] -name = "encoding_index_tests" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a246d82be1c9d791c5dfde9a2bd045fc3cbba3fa2b11ad558f27d01712f00569" - [[package]] name = "encoding_rs" version = "0.8.32" @@ -1237,14 +1257,14 @@ checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" [[package]] name = "enum-as-inner" -version = "0.4.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21cdad81446a7f7dc43f6a77409efeb9733d2fa65553efef6018ef257c959b73" +checksum = "5ffccbb6966c05b32ef8fbac435df276c4ae4d3dc55a8cd0eb9745e6c12f546a" dependencies = [ "heck 0.4.1", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.58", ] [[package]] @@ -1265,7 +1285,7 @@ checksum = "5e9a1f9f7d83e59740248a6e14ecf93929ade55027844dfcea78beafccc15745" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.58", ] [[package]] @@ -1284,6 +1304,17 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "event-listener" +version = "5.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + [[package]] name = "expect-test" version = "1.4.1" @@ -1322,6 +1353,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + [[package]] name = "fallible-streaming-iterator" version = "0.1.9" @@ -1357,6 +1394,17 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "flume" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" +dependencies = [ + "futures-core", + "futures-sink", + "spin 0.9.8", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1412,7 +1460,7 @@ checksum = "b0fa992f1656e1707946bbba340ad244f0814009ef8c0118eb7b658395f19a2e" dependencies = [ "frunk_proc_macro_helpers", "quote", - "syn 2.0.48", + "syn 2.0.58", ] [[package]] @@ -1424,7 +1472,7 @@ dependencies = [ "frunk_core", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.58", ] [[package]] @@ -1436,7 +1484,7 @@ dependencies = [ "frunk_core", "frunk_proc_macro_helpers", "quote", - "syn 2.0.48", + "syn 2.0.58", ] [[package]] @@ -1493,6 +1541,17 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-intrusive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" +dependencies = [ + "futures-core", + "lock_api", + "parking_lot", +] + [[package]] name = "futures-io" version = "0.3.28" @@ -1507,7 +1566,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.58", ] [[package]] @@ -1626,7 +1685,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.9", "indexmap 2.2.2", "slab", "tokio", @@ -1640,15 +1699,6 @@ version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" -[[package]] -name = "hashbrown" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" -dependencies = [ - "ahash 0.7.8", -] - [[package]] name = "hashbrown" version = "0.12.3" @@ -1660,21 +1710,21 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.3" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ - "ahash 0.8.7", + "ahash 0.8.11", "allocator-api2", ] [[package]] name = "hashlink" -version = "0.8.3" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "312f66718a2d7789ffef4f4b7b213138ed9f1eb3aa1d0d82fc99f88fb3ffd26f" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" dependencies = [ - "hashbrown 0.14.3", + "hashbrown 0.14.5", ] [[package]] @@ -1713,6 +1763,51 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hickory-proto" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07698b8420e2f0d6447a436ba999ec85d8fbf2a398bbd737b82cac4a2e96e512" +dependencies = [ + "async-trait", + "cfg-if", + "data-encoding", + "enum-as-inner", + "futures-channel", + "futures-io", + "futures-util", + "idna 0.4.0", + "ipnet", + "once_cell", + "rand 0.8.5", + "thiserror", + "tinyvec", + "tokio", + "tracing", + "url", +] + +[[package]] +name = "hickory-resolver" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28757f23aa75c98f254cf0405e6d8c25b831b32921b050a66692427679b1f243" +dependencies = [ + "cfg-if", + "futures-util", + "hickory-proto", + "ipconfig", + "lru-cache", + "once_cell", + "parking_lot", + "rand 0.8.5", + "resolv-conf", + "smallvec", + "thiserror", + "tokio", + "tracing", +] + [[package]] name = "hmac" version = "0.12.1" @@ -1753,6 +1848,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.5" @@ -1760,7 +1866,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" dependencies = [ "bytes", - "http", + "http 0.2.9", "pin-project-lite", ] @@ -1787,7 +1893,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "httparse", "httpdate", @@ -1856,11 +1962,10 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] name = "idna" -version = "0.2.3" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "418a0a6fab821475f634efe3ccc45c013f742efe03d853e8d3355d5cb850ecf8" +checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" dependencies = [ - "matches", "unicode-bidi", "unicode-normalization", ] @@ -1883,6 +1988,7 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", "hashbrown 0.12.3", + "serde", ] [[package]] @@ -1892,7 +1998,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "824b2ae422412366ba479e8111fd301f7b5faece8149317bb81925979a53f520" dependencies = [ "equivalent", - "hashbrown 0.14.3", + "hashbrown 0.14.5", "serde", ] @@ -1941,22 +2047,13 @@ dependencies = [ "yaml-rust", ] -[[package]] -name = "instant" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" -dependencies = [ - "cfg-if", -] - [[package]] name = "ipconfig" version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b58db92f96b720de98181bbbe63c831e87005ab460c1bf306eb2622b4707997f" dependencies = [ - "socket2 0.5.3", + "socket2 0.5.7", "widestring", "windows-sys 0.48.0", "winreg 0.50.0", @@ -2159,9 +2256,9 @@ dependencies = [ [[package]] name = "libsqlite3-sys" -version = "0.26.0" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afc22eff61b133b115c6e8c74e818c628d6d5e7a502afea6f64dee076dd94326" +checksum = "0c10584274047cb335c23d3e61bcef8e323adae7c5c8c760540f73610177fc3f" dependencies = [ "cc", "pkg-config", @@ -2247,15 +2344,6 @@ dependencies = [ "url", ] -[[package]] -name = "mach" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b823e83b2affd8f40a9ee8c29dbc56404c1e34cd2710921f2801e2cf29527afa" -dependencies = [ - "libc", -] - [[package]] name = "match_cfg" version = "0.1.0" @@ -2271,12 +2359,6 @@ dependencies = [ "regex-automata 0.1.10", ] -[[package]] -name = "matches" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5" - [[package]] name = "md-5" version = "0.10.5" @@ -2309,91 +2391,47 @@ dependencies = [ [[package]] name = "metrics" -version = "0.18.1" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e52eb6380b6d2a10eb3434aec0885374490f5b82c8aaf5cd487a183c98be834" +checksum = "884adb57038347dfbaf2d5065887b6cf4312330dc8e94bc30a1a839bd79d3261" dependencies = [ - "ahash 0.7.8", - "metrics-macros", -] - -[[package]] -name = "metrics" -version = "0.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "142c53885123b68d94108295a09d4afe1a1388ed95b54d5dacd9a454753030f2" -dependencies = [ - "ahash 0.7.8", - "metrics-macros", + "ahash 0.8.11", + "portable-atomic", ] [[package]] name = "metrics-exporter-prometheus" -version = "0.10.0" +version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "953cbbb6f9ba4b9304f4df79b98cdc9d14071ed93065a9fca11c00c5d9181b66" +checksum = "b4f0c8427b39666bf970460908b213ec09b3b350f20c0c2eabcbba51704a08e6" dependencies = [ - "hyper", - "indexmap 1.9.3", - "ipnet", - "metrics 0.19.0", - "metrics-util 0.13.0", - "parking_lot 0.11.2", + "base64 0.22.1", + "indexmap 2.2.2", + "metrics", + "metrics-util", "quanta", "thiserror", - "tokio", - "tracing", -] - -[[package]] -name = "metrics-macros" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49e30813093f757be5cf21e50389a24dc7dbb22c49f23b7e8f51d69b508a5ffa" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", ] [[package]] name = "metrics-util" -version = "0.12.1" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65a9e83b833e1d2e07010a386b197c13aa199bbd0fca5cf69bfa147972db890a" +checksum = "4259040465c955f9f2f1a4a8a16dc46726169bca0f88e8fb2dbeced487c3e828" dependencies = [ - "aho-corasick 0.7.20", - "atomic-shim", + "aho-corasick", "crossbeam-epoch", "crossbeam-utils", - "hashbrown 0.11.2", - "indexmap 1.9.3", - "metrics 0.18.1", + "hashbrown 0.14.5", + "indexmap 2.2.2", + "metrics", "num_cpus", "ordered-float", - "parking_lot 0.11.2", "quanta", "radix_trie", "sketches-ddsketch", ] -[[package]] -name = "metrics-util" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd1f4b69bef1e2b392b2d4a12902f2af90bb438ba4a66aa222d1023fa6561b50" -dependencies = [ - "atomic-shim", - "crossbeam-epoch", - "crossbeam-utils", - "hashbrown 0.11.2", - "metrics 0.19.0", - "num_cpus", - "parking_lot 0.11.2", - "quanta", - "sketches-ddsketch", -] - [[package]] name = "mime" version = "0.3.17" @@ -2429,9 +2467,9 @@ dependencies = [ [[package]] name = "mobc" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90eb49dc5d193287ff80e72a86f34cfb27aae562299d22fea215e06ea1059dd3" +checksum = "316a7d198b51958a0ab57248bf5f42d8409551203cb3c821d5925819a8d5415f" dependencies = [ "async-trait", "futures-channel", @@ -2439,7 +2477,7 @@ dependencies = [ "futures-timer", "futures-util", "log", - "metrics 0.18.1", + "metrics", "thiserror", "tokio", "tracing", @@ -2448,9 +2486,8 @@ dependencies = [ [[package]] name = "mongodb" -version = "2.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c30763a5c6c52079602be44fa360ca3bfacee55fca73f4734aecd23706a7f2" +version = "3.0.0" +source = "git+https://github.com/prisma/mongo-rust-driver.git?branch=RUST-1994/happy-eyeballs#31e0356391a7871bec000ae35fe0edc35582449e" dependencies = [ "async-trait", "base64 0.13.1", @@ -2464,9 +2501,12 @@ dependencies = [ "futures-io", "futures-util", "hex", + "hickory-proto", + "hickory-resolver", "hmac", - "lazy_static", "md-5", + "mongodb-internal-macros", + "once_cell", "pbkdf2", "percent-encoding", "rand 0.8.5", @@ -2478,16 +2518,14 @@ dependencies = [ "serde_with", "sha-1", "sha2 0.10.7", - "socket2 0.4.9", + "socket2 0.5.7", "stringprep", - "strsim 0.10.0", + "strsim 0.11.1", "take_mut", "thiserror", "tokio", "tokio-rustls 0.24.1", "tokio-util 0.7.8", - "trust-dns-proto", - "trust-dns-resolver", "typed-builder", "uuid", "webpki-roots", @@ -2503,6 +2541,16 @@ dependencies = [ "thiserror", ] +[[package]] +name = "mongodb-internal-macros" +version = "3.0.0" +source = "git+https://github.com/prisma/mongo-rust-driver.git?branch=RUST-1994/happy-eyeballs#31e0356391a7871bec000ae35fe0edc35582449e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.58", +] + [[package]] name = "mongodb-query-connector" version = "0.1.0" @@ -2520,15 +2568,16 @@ dependencies = [ "mongodb", "mongodb-client", "pretty_assertions", + "prisma-metrics", "prisma-value", "psl", "query-connector", - "query-engine-metrics", "query-structure", "rand 0.8.5", "regex", "serde", "serde_json", + "telemetry", "thiserror", "tokio", "tracing", @@ -2541,6 +2590,7 @@ dependencies = [ name = "mongodb-schema-connector" version = "0.1.0" dependencies = [ + "bson", "convert_case 0.6.0", "datamodel-renderer", "dissimilar", @@ -2569,6 +2619,7 @@ dependencies = [ name = "mongodb-schema-describer" version = "0.1.0" dependencies = [ + "bson", "futures", "mongodb", "serde", @@ -2680,9 +2731,9 @@ dependencies = [ [[package]] name = "napi" -version = "2.15.1" +version = "2.16.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43792514b0c95c5beec42996da0c1b39265b02b75c97baa82d163d3ef55cbfa7" +checksum = "214f07a80874bb96a8433b3cdfc84980d56c7b02e1a0d7ba4ba0db5cef785e2b" dependencies = [ "bitflags 2.4.0", "ctor", @@ -2702,38 +2753,38 @@ checksum = "ebd4419172727423cf30351406c54f6cc1b354a2cfb4f1dba3e6cd07f6d5522b" [[package]] name = "napi-derive" -version = "2.15.0" +version = "2.16.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7622f0dbe0968af2dacdd64870eee6dee94f93c989c841f1ad8f300cf1abd514" +checksum = "17435f7a00bfdab20b0c27d9c56f58f6499e418252253081bfff448099da31d1" dependencies = [ "cfg-if", "convert_case 0.6.0", "napi-derive-backend", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.58", ] [[package]] name = "napi-derive-backend" -version = "1.0.59" +version = "1.0.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ec514d65fce18a959be55e7f683ac89c6cb850fb59b09e25ab777fd5a4a8d9e" +checksum = "967c485e00f0bf3b1bdbe510a38a4606919cf1d34d9a37ad41f25a81aa077abe" dependencies = [ "convert_case 0.6.0", "once_cell", "proc-macro2", "quote", "regex", - "semver 1.0.18", - "syn 2.0.48", + "semver", + "syn 2.0.58", ] [[package]] name = "napi-sys" -version = "2.3.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2503fa6af34dc83fb74888df8b22afe933b58d37daf7d80424b1c60c68196b8b" +checksum = "427802e8ec3a734331fec1035594a210ce1ff4dc5bc1950530920ab717964ea3" dependencies = [ "libloading 0.8.1", ] @@ -2819,6 +2870,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-integer" version = "0.1.45" @@ -2921,7 +2978,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.58", ] [[package]] @@ -2983,7 +3040,7 @@ dependencies = [ "async-trait", "futures", "futures-util", - "http", + "http 0.2.9", "opentelemetry", "prost", "thiserror", @@ -3008,9 +3065,9 @@ dependencies = [ [[package]] name = "ordered-float" -version = "2.10.0" +version = "4.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7940cf2ca942593318d07fcf2596cdca60a85c9e7fab408a5e21a4f9dcd40d87" +checksum = "44d501f1a72f71d3c063a6bbc8f7271fa73aa09fe5d6283b6571e2ed176a2537" dependencies = [ "num-traits", ] @@ -3034,15 +3091,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" [[package]] -name = "parking_lot" -version = "0.11.2" +name = "parking" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" -dependencies = [ - "instant", - "lock_api", - "parking_lot_core 0.8.6", -] +checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" [[package]] name = "parking_lot" @@ -3051,21 +3103,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ "lock_api", - "parking_lot_core 0.9.8", -] - -[[package]] -name = "parking_lot_core" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" -dependencies = [ - "cfg-if", - "instant", - "libc", - "redox_syscall 0.2.16", - "smallvec", - "winapi", + "parking_lot_core", ] [[package]] @@ -3172,7 +3210,7 @@ dependencies = [ "pest_meta", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.58", ] [[package]] @@ -3206,6 +3244,16 @@ dependencies = [ "indexmap 1.9.3", ] +[[package]] +name = "pharos" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9567389417feee6ce15dd6527a8a1ecac205ef62c2932bcf3d9f6fc5b78b414" +dependencies = [ + "futures", + "rustc_version", +] + [[package]] name = "phf" version = "0.11.2" @@ -3241,7 +3289,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.58", ] [[package]] @@ -3290,10 +3338,16 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "portable-atomic" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" + [[package]] name = "postgres-native-tls" version = "0.5.0" -source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#a1a2dc6d9584deaf70a14293c428e7b6ca614d98" +source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#c62b9928d402685e152161907e8480603c29ef65" dependencies = [ "native-tls", "tokio", @@ -3303,13 +3357,13 @@ dependencies = [ [[package]] name = "postgres-protocol" -version = "0.6.4" -source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#a1a2dc6d9584deaf70a14293c428e7b6ca614d98" +version = "0.6.7" +source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#c62b9928d402685e152161907e8480603c29ef65" dependencies = [ - "base64 0.13.1", + "base64 0.22.1", "byteorder", "bytes", - "fallible-iterator", + "fallible-iterator 0.2.0", "hmac", "md-5", "memchr", @@ -3320,19 +3374,25 @@ dependencies = [ [[package]] name = "postgres-types" -version = "0.2.4" -source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#a1a2dc6d9584deaf70a14293c428e7b6ca614d98" +version = "0.2.8" +source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#c62b9928d402685e152161907e8480603c29ef65" dependencies = [ "bit-vec", "bytes", "chrono", - "fallible-iterator", + "fallible-iterator 0.2.0", "postgres-protocol", "serde", "serde_json", "uuid", ] +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -3369,19 +3429,41 @@ dependencies = [ name = "prisma-fmt" version = "0.1.0" dependencies = [ + "build-utils", "colored", "dissimilar", "dmmf", "enumflags2", "expect-test", - "indoc 2.0.3", - "log", - "lsp-types", + "indoc 2.0.3", + "log", + "lsp-types", + "once_cell", + "psl", + "serde", + "serde_json", + "structopt", +] + +[[package]] +name = "prisma-metrics" +version = "0.1.0" +dependencies = [ + "derive_more", + "expect-test", + "futures", + "metrics", + "metrics-exporter-prometheus", + "metrics-util", "once_cell", - "psl", + "parking_lot", + "pin-project", "serde", "serde_json", - "structopt", + "tokio", + "tracing", + "tracing-futures", + "tracing-subscriber", ] [[package]] @@ -3590,16 +3672,19 @@ name = "quaint" version = "0.2.0-alpha.13" dependencies = [ "async-trait", + "async-tungstenite", "base64 0.12.3", "bigdecimal", "bit-vec", "byteorder", "bytes", - "cfg_aliases 0.1.1", + "cfg_aliases", "chrono", + "concat-idents", "connection-string", "crosstarget-utils", "either", + "enumflags2", "expect-test", "futures", "getrandom 0.2.11", @@ -3607,7 +3692,6 @@ dependencies = [ "indoc 0.3.6", "itertools 0.12.0", "lru-cache", - "metrics 0.18.1", "mobc", "mysql_async", "names 0.11.0", @@ -3618,8 +3702,10 @@ dependencies = [ "percent-encoding", "postgres-native-tls", "postgres-types", + "prisma-metrics", "quaint-test-macros", "quaint-test-setup", + "regex", "rusqlite", "serde", "serde_json", @@ -3628,11 +3714,12 @@ dependencies = [ "tiberius", "tokio", "tokio-postgres", - "tokio-util 0.6.10", + "tokio-util 0.7.8", "tracing", - "tracing-core", + "tracing-futures", "url", "uuid", + "ws_stream_tungstenite", ] [[package]] @@ -3661,16 +3748,15 @@ dependencies = [ [[package]] name = "quanta" -version = "0.9.3" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20afe714292d5e879d8b12740aa223c6a88f118af41870e8b6196e39a02238a8" +checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5" dependencies = [ "crossbeam-utils", "libc", - "mach", "once_cell", "raw-cpuid", - "wasi 0.10.2+wasi-snapshot-preview1", + "wasi 0.11.0+wasi-snapshot-preview1", "web-sys", "winapi", ] @@ -3689,6 +3775,7 @@ dependencies = [ "query-structure", "serde", "serde_json", + "telemetry", "thiserror", "user-facing-errors", "uuid", @@ -3705,6 +3792,7 @@ dependencies = [ "crossbeam-channel", "crosstarget-utils", "cuid", + "derive_more", "enumflags2", "futures", "indexmap 2.2.2", @@ -3713,13 +3801,14 @@ dependencies = [ "once_cell", "opentelemetry", "petgraph 0.4.13", + "prisma-metrics", "psl", "query-connector", - "query-engine-metrics", "query-structure", "schema", "serde", "serde_json", + "telemetry", "thiserror", "tokio", "tracing", @@ -3737,6 +3826,7 @@ dependencies = [ "anyhow", "async-trait", "base64 0.13.1", + "build-utils", "connection-string", "enumflags2", "graphql-parser", @@ -3745,17 +3835,18 @@ dependencies = [ "mongodb-query-connector", "opentelemetry", "opentelemetry-otlp", + "prisma-metrics", "psl", "quaint", "query-connector", "query-core", - "query-engine-metrics", "request-handlers", "serde", "serde_json", "serial_test", "sql-query-connector", "structopt", + "telemetry", "thiserror", "tokio", "tracing", @@ -3771,6 +3862,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "build-utils", "cbindgen", "chrono", "connection-string", @@ -3789,6 +3881,7 @@ dependencies = [ "serde", "serde_json", "sql-query-connector", + "telemetry", "thiserror", "tokio", "tracing", @@ -3808,12 +3901,13 @@ dependencies = [ "connection-string", "napi", "opentelemetry", + "prisma-metrics", "psl", "query-connector", "query-core", - "query-engine-metrics", "serde", "serde_json", + "telemetry", "thiserror", "tracing", "tracing-futures", @@ -3825,30 +3919,13 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "query-engine-metrics" -version = "0.1.0" -dependencies = [ - "expect-test", - "metrics 0.18.1", - "metrics-exporter-prometheus", - "metrics-util 0.12.1", - "once_cell", - "parking_lot 0.12.1", - "serde", - "serde_json", - "tokio", - "tracing", - "tracing-futures", - "tracing-subscriber", -] - [[package]] name = "query-engine-node-api" version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "build-utils", "connection-string", "driver-adapters", "futures", @@ -3856,17 +3933,18 @@ dependencies = [ "napi-build", "napi-derive", "opentelemetry", + "prisma-metrics", "psl", "quaint", "query-connector", "query-core", "query-engine-common", - "query-engine-metrics", "query-structure", "request-handlers", "serde", "serde_json", "sql-query-connector", + "telemetry", "thiserror", "tokio", "tracing", @@ -3892,9 +3970,9 @@ dependencies = [ "itertools 0.12.0", "once_cell", "paste", + "prisma-metrics", "prisma-value", "psl", - "query-engine-metrics", "query-test-macros", "query-tests-setup", "serde_json", @@ -3911,6 +3989,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "build-utils", "connection-string", "driver-adapters", "futures", @@ -3927,6 +4006,7 @@ dependencies = [ "serde-wasm-bindgen", "serde_json", "sql-query-connector", + "telemetry", "thiserror", "tokio", "tracing", @@ -3983,12 +4063,12 @@ dependencies = [ "nom", "once_cell", "parse-hyperlinks", + "prisma-metrics", "psl", "qe-setup", "quaint", "query-core", "query-engine", - "query-engine-metrics", "query-structure", "regex", "request-handlers", @@ -3996,6 +4076,7 @@ dependencies = [ "serde_json", "sql-query-connector", "strip-ansi-escapes", + "telemetry", "thiserror", "tokio", "tracing", @@ -4147,11 +4228,11 @@ dependencies = [ [[package]] name = "raw-cpuid" -version = "10.7.0" +version = "11.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +checksum = "1ab240315c661615f2ee9f0f2cd32d5a7343a84d5ebcccb99d46e6637565e7b0" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.4.0", ] [[package]] @@ -4187,29 +4268,29 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.2.16" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" +checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" dependencies = [ "bitflags 1.3.2", ] [[package]] name = "redox_syscall" -version = "0.3.5" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.4.0", ] [[package]] name = "regex" -version = "1.10.3" +version = "1.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" +checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" dependencies = [ - "aho-corasick 1.0.3", + "aho-corasick", "memchr", "regex-automata 0.4.5", "regex-syntax 0.8.2", @@ -4230,7 +4311,7 @@ version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd" dependencies = [ - "aho-corasick 1.0.3", + "aho-corasick", "memchr", "regex-syntax 0.8.2", ] @@ -4261,7 +4342,7 @@ name = "request-handlers" version = "0.1.0" dependencies = [ "bigdecimal", - "cfg_aliases 0.2.0", + "cfg_aliases", "codspeed-criterion-compat", "connection-string", "dmmf", @@ -4280,6 +4361,7 @@ dependencies = [ "serde", "serde_json", "sql-query-connector", + "telemetry", "thiserror", "tracing", "url", @@ -4298,7 +4380,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "hyper", "hyper-tls", @@ -4392,13 +4474,13 @@ dependencies = [ [[package]] name = "rusqlite" -version = "0.29.0" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "549b9d036d571d42e6e85d1c1425e2ac83491075078ca9a15be021c56b1641f2" +checksum = "b838eba278d213a8beaf485bd313fd580ca4505a00d5871caeb1457c55322cae" dependencies = [ "bitflags 2.4.0", "chrono", - "fallible-iterator", + "fallible-iterator 0.3.0", "fallible-streaming-iterator", "hashlink", "libsqlite3-sys", @@ -4434,32 +4516,23 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" -[[package]] -name = "rustc_version" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a" -dependencies = [ - "semver 0.9.0", -] - [[package]] name = "rustc_version" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" dependencies = [ - "semver 1.0.18", + "semver", ] [[package]] name = "rustc_version_runtime" -version = "0.2.1" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d31b7153270ebf48bf91c65ae5b0c00e749c4cfad505f66530ac74950249582f" +checksum = "2dd18cd2bae1820af0b6ad5e54f4a51d0f3fcc53b05f845675074efcc7af071d" dependencies = [ - "rustc_version 0.2.3", - "semver 0.9.0", + "rustc_version", + "semver", ] [[package]] @@ -4628,6 +4701,7 @@ version = "0.1.0" dependencies = [ "backtrace", "base64 0.13.1", + "build-utils", "connection-string", "expect-test", "indoc 2.0.3", @@ -4704,32 +4778,17 @@ dependencies = [ "libc", ] -[[package]] -name = "semver" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403" -dependencies = [ - "semver-parser", -] - [[package]] name = "semver" version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b0293b4b29daaf487284529cc2f5675b8e57c61f70167ba415a463651fd6a918" -[[package]] -name = "semver-parser" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" - [[package]] name = "serde" -version = "1.0.183" +version = "1.0.206" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32ac8da02677876d532745a130fc9d8e6edfa81a269b107c5b00829b91d8eb3c" +checksum = "5b3e4cd94123dd520a128bcd11e34d9e9e423e7e3e50425cb1b4b1e3549d0284" dependencies = [ "serde_derive", ] @@ -4756,13 +4815,13 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.183" +version = "1.0.206" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aafe972d60b0b9bee71a91b92fee2d4fb3c9d7e8f6b179aa99f27203d99a4816" +checksum = "fabfb6138d2383ea8208cf98ccf69cdfb1aff4088460681d84189aa259762f97" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.58", ] [[package]] @@ -4773,7 +4832,7 @@ checksum = "e578a843d40b4189a4d66bba51d7684f57da5bd7c304c64e14bd63efbef49509" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.58", ] [[package]] @@ -4796,7 +4855,7 @@ checksum = "3081f5ffbb02284dda55132aa26daecedd7372a42417bbbab6f14ab7d6bb9145" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.58", ] [[package]] @@ -4813,24 +4872,32 @@ dependencies = [ [[package]] name = "serde_with" -version = "1.14.0" +version = "3.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "678b5a069e50bf00ecd22d0cd8ddf7c236f68581b03db652061ed5eb13a312ff" +checksum = "69cecfa94848272156ea67b2b1a53f20fc7bc638c4a46d2f8abde08f05f4b857" dependencies = [ + "base64 0.22.1", + "chrono", + "hex", + "indexmap 1.9.3", + "indexmap 2.2.2", "serde", + "serde_derive", + "serde_json", "serde_with_macros", + "time", ] [[package]] name = "serde_with_macros" -version = "1.5.2" +version = "3.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e182d6ec6f05393cc0e5ed1bf81ad6db3a8feedf8ee515ecdd369809bcce8082" +checksum = "a8fee4991ef4f274617a51ad4af30519438dacb2f56ac773b08a1922ff743350" dependencies = [ - "darling 0.13.4", + "darling 0.20.10", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.58", ] [[package]] @@ -4843,7 +4910,7 @@ dependencies = [ "futures", "lazy_static", "log", - "parking_lot 0.12.1", + "parking_lot", "serial_test_derive", ] @@ -4855,7 +4922,7 @@ checksum = "91d129178576168c589c9ec973feedf7d3126c01ac2bf08795109aa35b69fb8f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.58", ] [[package]] @@ -4961,9 +5028,9 @@ checksum = "7bd3e3206899af3f8b12af284fafc038cc1dc2b41d1b89dd17297221c5d225de" [[package]] name = "sketches-ddsketch" -version = "0.1.3" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04d2ecae5fcf33b122e2e6bd520a57ccf152d2dde3b38c71039df1a6867264ee" +checksum = "85636c14b73d81f541e525f585c0a2109e6744e1565b5c1668e31c70c10ed65c" [[package]] name = "slab" @@ -4976,9 +5043,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.0" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "socket2" @@ -4992,12 +5059,12 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.3" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2538b18701741680e0322a2302176d3253a35388e2e62f172f64f4f16605f877" +checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" dependencies = [ "libc", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -5011,6 +5078,9 @@ name = "spin" version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] [[package]] name = "sql-ddl" @@ -5057,6 +5127,7 @@ dependencies = [ "indoc 2.0.3", "jsonrpc-core", "once_cell", + "paste", "pretty_assertions", "prisma-value", "psl", @@ -5098,6 +5169,7 @@ dependencies = [ "rand 0.8.5", "serde", "serde_json", + "telemetry", "thiserror", "tokio", "tracing", @@ -5116,6 +5188,7 @@ dependencies = [ "datamodel-renderer", "either", "enumflags2", + "expect-test", "indexmap 2.2.2", "indoc 2.0.3", "once_cell", @@ -5130,6 +5203,8 @@ dependencies = [ "sql-schema-describer", "sqlformat", "sqlparser", + "sqlx-core", + "sqlx-sqlite", "tokio", "tracing", "tracing-futures", @@ -5185,6 +5260,61 @@ dependencies = [ "log", ] +[[package]] +name = "sqlx-core" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a999083c1af5b5d6c071d34a708a19ba3e02106ad82ef7bbd69f5e48266b613b" +dependencies = [ + "atoi", + "byteorder", + "bytes", + "crossbeam-queue", + "either", + "event-listener", + "futures-channel", + "futures-core", + "futures-intrusive", + "futures-io", + "futures-util", + "hashbrown 0.14.5", + "hashlink", + "hex", + "indexmap 2.2.2", + "log", + "memchr", + "once_cell", + "paste", + "percent-encoding", + "smallvec", + "sqlformat", + "thiserror", + "tracing", + "url", +] + +[[package]] +name = "sqlx-sqlite" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b2cdd83c008a622d94499c0006d8ee5f821f36c89b7d625c900e5dc30b5c5ee" +dependencies = [ + "atoi", + "flume", + "futures-channel", + "futures-core", + "futures-executor", + "futures-intrusive", + "futures-util", + "libsqlite3-sys", + "log", + "percent-encoding", + "serde_urlencoded", + "sqlx-core", + "tracing", + "url", +] + [[package]] name = "static_assertions" version = "1.1.0" @@ -5228,6 +5358,12 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "structopt" version = "0.3.26" @@ -5281,9 +5417,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.48" +version = "2.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" +checksum = "44cfb93f38070beee36b3fef7d4f5a16f27751d94b187b666a5cc5e9b0d30687" dependencies = [ "proc-macro2", "quote", @@ -5314,6 +5450,36 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "telemetry" +version = "0.1.0" +dependencies = [ + "async-trait", + "crossbeam-channel", + "crosstarget-utils", + "cuid", + "derive_more", + "enumflags2", + "futures", + "indexmap 2.2.2", + "itertools 0.12.0", + "lru 0.7.8", + "once_cell", + "opentelemetry", + "prisma-metrics", + "psl", + "rand 0.8.5", + "serde", + "serde_json", + "thiserror", + "tokio", + "tracing", + "tracing-futures", + "tracing-opentelemetry", + "tracing-subscriber", + "uuid", +] + [[package]] name = "tempfile" version = "3.7.1" @@ -5342,6 +5508,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "build-utils", "colored", "dmmf", "enumflags2", @@ -5413,7 +5580,7 @@ checksum = "090198534930841fab3a5d1bb637cde49e339654e606195f8d9c76eeb081dc96" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.58", ] [[package]] @@ -5428,9 +5595,9 @@ dependencies = [ [[package]] name = "tiberius" -version = "0.11.7" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66303a42b7c5daffb95c10cd8f3007a9c29b3e90128cf42b3738f58102aa2516" +checksum = "a1446cb4198848d1562301a3340424b4f425ef79f35ef9ee034769a9dd92c10d" dependencies = [ "async-native-tls", "async-trait", @@ -5440,10 +5607,8 @@ dependencies = [ "bytes", "chrono", "connection-string", - "encoding", + "encoding_rs", "enumflags2", - "futures", - "futures-sink", "futures-util", "num-traits", "once_cell", @@ -5460,12 +5625,14 @@ dependencies = [ [[package]] name = "time" -version = "0.3.25" +version = "0.3.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0fdd63d58b18d663fbdf70e049f00a22c8e42be082203be7f26589213cd75ea" +checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" dependencies = [ "deranged", "itoa", + "num-conv", + "powerfmt", "serde", "time-core", "time-macros", @@ -5473,16 +5640,17 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.11" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb71511c991639bb078fd5bf97757e03914361c48100d52878b8e52b46fb92cd" +checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf" dependencies = [ + "num-conv", "time-core", ] @@ -5513,19 +5681,19 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.30.0" +version = "1.38.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d3ce25f50619af8b0aec2eb23deebe84249e19e2ddd393a6e16e3300a6dadfd" +checksum = "eb2caba9f80616f438e09748d5acda951967e1ea58508ef53d9c6402485a46df" dependencies = [ "backtrace", "bytes", "libc", "mio", "num_cpus", - "parking_lot 0.12.1", + "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2 0.5.3", + "socket2 0.5.7", "tokio-macros", "windows-sys 0.48.0", ] @@ -5542,13 +5710,13 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.1.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.58", ] [[package]] @@ -5563,25 +5731,27 @@ dependencies = [ [[package]] name = "tokio-postgres" -version = "0.7.7" -source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#a1a2dc6d9584deaf70a14293c428e7b6ca614d98" +version = "0.7.12" +source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#c62b9928d402685e152161907e8480603c29ef65" dependencies = [ "async-trait", "byteorder", "bytes", - "fallible-iterator", + "fallible-iterator 0.2.0", "futures-channel", "futures-util", "log", - "parking_lot 0.12.1", + "parking_lot", "percent-encoding", "phf", "pin-project-lite", "postgres-protocol", "postgres-types", - "socket2 0.5.3", + "rand 0.8.5", + "socket2 0.5.7", "tokio", "tokio-util 0.7.8", + "whoami", ] [[package]] @@ -5624,7 +5794,6 @@ checksum = "36943ee01a6d67977dd3f84a5a1d2efeb4ada3a1ae771cadfaa535d9d9fc6507" dependencies = [ "bytes", "futures-core", - "futures-io", "futures-sink", "log", "pin-project-lite", @@ -5668,7 +5837,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "hyper", "hyper-timeout", @@ -5753,7 +5922,7 @@ checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.58", ] [[package]] @@ -5843,51 +6012,6 @@ dependencies = [ "tracing-serde", ] -[[package]] -name = "trust-dns-proto" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c31f240f59877c3d4bb3b3ea0ec5a6a0cff07323580ff8c7a605cd7d08b255d" -dependencies = [ - "async-trait", - "cfg-if", - "data-encoding", - "enum-as-inner", - "futures-channel", - "futures-io", - "futures-util", - "idna 0.2.3", - "ipnet", - "lazy_static", - "log", - "rand 0.8.5", - "smallvec", - "thiserror", - "tinyvec", - "tokio", - "url", -] - -[[package]] -name = "trust-dns-resolver" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4ba72c2ea84515690c9fcef4c6c660bb9df3036ed1051686de84605b74fd558" -dependencies = [ - "cfg-if", - "futures-util", - "ipconfig", - "lazy_static", - "log", - "lru-cache", - "parking_lot 0.12.1", - "resolv-conf", - "smallvec", - "thiserror", - "tokio", - "trust-dns-proto", -] - [[package]] name = "try-lock" version = "0.2.4" @@ -5916,7 +6040,26 @@ dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.48", + "syn 2.0.58", +] + +[[package]] +name = "tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.1.0", + "httparse", + "log", + "native-tls", + "rand 0.8.5", + "sha1", + "thiserror", + "utf-8", ] [[package]] @@ -6060,6 +6203,12 @@ dependencies = [ "user-facing-error-macros", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8-width" version = "0.1.6" @@ -6074,12 +6223,13 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "uuid" -version = "1.4.1" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79daa5ed5740825c40b389c5e50312b9c86df53fccd33f281df655642b43869d" +checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" dependencies = [ "getrandom 0.2.11", "serde", + "wasm-bindgen", ] [[package]] @@ -6170,38 +6320,39 @@ checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" [[package]] name = "wasi" -version = "0.10.2+wasi-snapshot-preview1" +version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] -name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" +name = "wasite" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" [[package]] name = "wasm-bindgen" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" dependencies = [ "cfg-if", + "once_cell", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.58", "wasm-bindgen-shared", ] @@ -6219,9 +6370,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -6229,22 +6380,22 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.58", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" [[package]] name = "wasm-logger" @@ -6303,6 +6454,17 @@ dependencies = [ "once_cell", ] +[[package]] +name = "whoami" +version = "1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "372d5b87f58ec45c384ba03563b03544dc5fadc3983e434b286913f5b4a9bb6d" +dependencies = [ + "redox_syscall 0.5.7", + "wasite", + "web-sys", +] + [[package]] name = "widestring" version = "1.0.2" @@ -6579,6 +6741,26 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "ws_stream_tungstenite" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed39ff9f8b2eda91bf6390f9f49eee93d655489e15708e3bb638c1c4f07cecb4" +dependencies = [ + "async-tungstenite", + "async_io_stream", + "bitflags 2.4.0", + "futures-core", + "futures-io", + "futures-sink", + "futures-util", + "pharos", + "rustc_version", + "tokio", + "tracing", + "tungstenite", +] + [[package]] name = "wyz" version = "0.5.1" @@ -6620,5 +6802,5 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.58", ] diff --git a/Cargo.toml b/Cargo.toml index fbba97fcc427..df9f44eee292 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,6 @@ members = [ "query-engine/black-box-tests", "query-engine/dmmf", "query-engine/driver-adapters", - "query-engine/metrics", "query-engine/query-structure", "query-engine/query-engine", "query-engine/query-engine-node-api", @@ -38,10 +37,11 @@ members = [ [workspace.dependencies] async-trait = { version = "0.1.77" } enumflags2 = { version = "0.7", features = ["serde"] } +futures = "0.3" psl = { path = "./psl/psl" } serde_json = { version = "1", features = ["float_roundtrip", "preserve_order", "raw_value"] } serde = { version = "1", features = ["derive"] } -tokio = { version = "1.25", features = [ +tokio = { version = "1", features = [ "rt-multi-thread", "macros", "sync", @@ -51,30 +51,37 @@ tokio = { version = "1.25", features = [ "time", ] } chrono = { version = "0.4.38", features = ["serde"] } +derive_more = "0.99.17" user-facing-errors = { path = "./libs/user-facing-errors" } -uuid = { version = "1", features = ["serde", "v4"] } +uuid = { version = "1", features = ["serde", "v4", "v7", "js"] } indoc = "2.0.1" indexmap = { version = "2.2.2", features = ["serde"] } itertools = "0.12" connection-string = "0.2" -napi = { version = "2.15.1", default-features = false, features = [ - "napi8", +napi = { version = "2.16.13", default-features = false, features = [ + "napi9", "tokio_rt", "serde-json", ] } -napi-derive = "2.15.0" +napi-derive = "2.16.12" js-sys = { version = "0.3" } +pin-project = "1" rand = { version = "0.8" } +regex = { version = "1", features = ["std"] } serde_repr = { version = "0.1.17" } serde-wasm-bindgen = { version = "0.5" } tracing = { version = "0.1" } +tracing-futures = "0.2" tsify = { version = "0.4.5" } -wasm-bindgen = { version = "0.2.92" } +wasm-bindgen = { version = "0.2.93" } wasm-bindgen-futures = { version = "0.4" } wasm-rs-dbg = { version = "0.1.2", default-features = false, features = ["console-error"] } wasm-bindgen-test = { version = "0.3.0" } url = { version = "2.5.0" } +bson = { version = "2.11.0", features = ["chrono-0_4", "uuid-1"] } +mongodb = { git = "https://github.com/prisma/mongo-rust-driver.git", branch = "RUST-1994/happy-eyeballs" } + [workspace.dependencies.quaint] path = "quaint" diff --git a/Makefile b/Makefile index ec16c50b9dc2..7407f7d41fe3 100644 --- a/Makefile +++ b/Makefile @@ -67,6 +67,10 @@ build-qe-wasm-gz: build-qe-wasm gzip -knc $$provider/query_engine_bg.wasm > $$provider.gz; \ done; +integrate-qe-wasm: + cd query-engine/query-engine-wasm && \ + ./build.sh $(QE_WASM_VERSION) ../prisma/packages/client/node_modules/@prisma/query-engine-wasm + build-schema-wasm: @printf '%s\n' "🛠️ Building the Rust crate" cargo build --profile $(PROFILE) --target=wasm32-unknown-unknown -p prisma-schema-build diff --git a/README.md b/README.md index 4b792592be7f..84789c25f500 100644 --- a/README.md +++ b/README.md @@ -265,7 +265,10 @@ You can trigger releases from this repository to npm that can be used for testin #### Automated integration releases from this repository to npm -(Since July 2022). Any branch name starting with `integration/` will, first, run the full test suite in Buildkite `[Test] Prisma Engines` and, second, if passing, run the publish pipeline (build and upload engines to S3 & R2) +Any branch name starting with `integration/` will, first, run the full test suite in GH Actions and, second, run the release workflow (build and upload engines to S3 & R2). +To trigger the release on any other branch, you have two options: +- Either run [build-engines](https://github.com/prisma/prisma-engines/actions/workflows/build-engines.yml) workflow on a specified branch manually. +- Or add `[integration]` string anywhere in your commit messages/ The journey through the pipeline is the same as a commit on the `main` branch. - It will trigger [`prisma/engines-wrapper`](https://github.com/prisma/engines-wrapper) and publish a new [`@prisma/engines-version`](https://www.npmjs.com/package/@prisma/engines-version) npm package but on the `integration` tag. @@ -276,8 +279,8 @@ The journey through the pipeline is the same as a commit on the `main` branch. This end to end will take minimum ~1h20 to complete, but is completely automated :robot: Notes: -- in `prisma/prisma` repository, we do not run tests for `integration/` branches, it is much faster and also means that there is no risk of tests failing (e.g. flaky tests, snapshots) that would stop the publishing process. -- in `prisma/prisma-engines` the Buildkite test pipeline must first pass, then the engines will be built and uploaded to our storage via the Buildkite release pipeline. These 2 pipelines can fail for different reasons, it's recommended to keep an eye on them (check notifications in Slack) and restart jobs as needed. Finally, it will trigger [`prisma/engines-wrapper`](https://github.com/prisma/engines-wrapper). +- tests and publishing workflows are run in parallel in both `prisma/prisma-engines` and `prisma/prisma` repositories. So, it is possible that the engines would be published and only then test suite will +discover a defect. It is advised that to keep an eye on both test and publishing workflows. #### Manual integration releases from this repository to npm diff --git a/docker/planetscale_proxy/Dockerfile b/docker/planetscale_proxy/Dockerfile index 9d6cca2f5dd8..8c3ef94406e9 100644 --- a/docker/planetscale_proxy/Dockerfile +++ b/docker/planetscale_proxy/Dockerfile @@ -1,12 +1,14 @@ -FROM golang:1 +FROM ghcr.io/mattrobenolt/ps-http-sim:v0.0.9 AS planetscale-proxy -RUN apt update && apt install netcat-openbsd -y -RUN cd /go/src && git clone https://github.com/prisma/planetscale-proxy.git -RUN cd /go/src/planetscale-proxy && go install . +# ps-http-sim provides a barebones image with nothing but the static binary +# but we also rely on netcat being present. Alpine provides it as part of busybox. +FROM alpine:latest -ENTRYPOINT /go/bin/ps-http-sim \ - -http-addr=0.0.0.0 \ - -http-port=8085 \ +COPY --from=planetscale-proxy /ps-http-sim /ps-http-sim + +ENTRYPOINT /ps-http-sim \ + -listen-addr=0.0.0.0 \ + -listen-port=8085 \ -mysql-addr=$MYSQL_HOST \ -mysql-port=$MYSQL_PORT \ -mysql-idle-timeout=1s \ diff --git a/flake.lock b/flake.lock index cbdab3e8a915..c20225ca22ea 100644 --- a/flake.lock +++ b/flake.lock @@ -1,17 +1,12 @@ { "nodes": { "crane": { - "inputs": { - "nixpkgs": [ - "nixpkgs" - ] - }, "locked": { - "lastModified": 1715274763, - "narHash": "sha256-3Iv1PGHJn9sV3HO4FlOVaaztOxa9uGLfOmUWrH7v7+A=", + "lastModified": 1728776144, + "narHash": "sha256-fROVjMcKRoGHofDm8dY3uDUtCMwUICh/KjBFQnuBzfg=", "owner": "ipetkov", "repo": "crane", - "rev": "27025ab71bdca30e7ed0a16c88fd74c5970fc7f5", + "rev": "f876e3d905b922502f031aeec1a84490122254b7", "type": "github" }, "original": { @@ -27,11 +22,11 @@ ] }, "locked": { - "lastModified": 1714641030, - "narHash": "sha256-yzcRNDoyVP7+SCNX0wmuDju1NUCt8Dz9+lyUXEI0dbI=", + "lastModified": 1727826117, + "narHash": "sha256-K5ZLCyfO/Zj9mPFldf3iwS6oZStJcU4tSpiXTMYaaL0=", "owner": "hercules-ci", "repo": "flake-parts", - "rev": "e5d10a24b66c3ea8f150e47dfdb0416ab7c3390e", + "rev": "3d04084d54bedc3d6b8b736c70ef449225c361b1", "type": "github" }, "original": { @@ -40,26 +35,6 @@ "type": "github" } }, - "flake-utils": { - "inputs": { - "systems": [ - "systems" - ] - }, - "locked": { - "lastModified": 1710146030, - "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", - "owner": "numtide", - "repo": "flake-utils", - "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "flake-utils", - "type": "github" - } - }, "gitignore": { "inputs": { "nixpkgs": [ @@ -82,11 +57,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1715266358, - "narHash": "sha256-doPgfj+7FFe9rfzWo1siAV2mVCasW+Bh8I1cToAXEE4=", + "lastModified": 1728888510, + "narHash": "sha256-nsNdSldaAyu6PE3YUA+YQLqUDJh+gRbBooMMekZJwvI=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "f1010e0469db743d14519a1efd37e23f8513d714", + "rev": "a3c0b3b21515f74fd2665903d4ce6bc4dc81c77c", "type": "github" }, "original": { @@ -99,7 +74,6 @@ "inputs": { "crane": "crane", "flake-parts": "flake-parts", - "flake-utils": "flake-utils", "gitignore": "gitignore", "nixpkgs": "nixpkgs", "rust-overlay": "rust-overlay", @@ -108,19 +82,16 @@ }, "rust-overlay": { "inputs": { - "flake-utils": [ - "flake-utils" - ], "nixpkgs": [ "nixpkgs" ] }, "locked": { - "lastModified": 1715307487, - "narHash": "sha256-yuDAys3JuJmhQUQGMMsl3BDQNZUYZDw0eA71OVh9FeY=", + "lastModified": 1729184663, + "narHash": "sha256-uNyi5vQrzaLkt4jj6ZEOs4+4UqOAwP6jFG2s7LIDwIk=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "ec7a7caf50877bc32988c82653d6b3e6952a8c3f", + "rev": "16fb78d443c1970dda9a0bbb93070c9d8598a925", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index a76ab2edbea8..f55ed31a29d9 100644 --- a/flake.nix +++ b/flake.nix @@ -1,13 +1,6 @@ { inputs = { - crane = { - url = "github:ipetkov/crane"; - inputs.nixpkgs.follows = "nixpkgs"; - }; - flake-utils = { - url = "github:numtide/flake-utils"; - inputs.systems.follows = "systems"; - }; + crane.url = "github:ipetkov/crane"; flake-parts = { url = "github:hercules-ci/flake-parts"; inputs.nixpkgs-lib.follows = "nixpkgs"; @@ -19,7 +12,6 @@ rust-overlay = { url = "github:oxalica/rust-overlay"; inputs.nixpkgs.follows = "nixpkgs"; - inputs.flake-utils.follows = "flake-utils"; }; nixpkgs.url = "nixpkgs/nixos-unstable"; systems.url = "github:nix-systems/default"; diff --git a/libs/build-utils/Cargo.toml b/libs/build-utils/Cargo.toml new file mode 100644 index 000000000000..715b650505d4 --- /dev/null +++ b/libs/build-utils/Cargo.toml @@ -0,0 +1,6 @@ +[package] +name = "build-utils" +version = "0.1.0" +edition = "2021" + +[dependencies] diff --git a/libs/build-utils/src/lib.rs b/libs/build-utils/src/lib.rs new file mode 100644 index 000000000000..03294a997a32 --- /dev/null +++ b/libs/build-utils/src/lib.rs @@ -0,0 +1,23 @@ +use std::process::Command; + +/// Store the current git commit hash in the `GIT_HASH` variable in rustc env. +/// If the `GIT_HASH` environment variable is already set, this function does nothing. +pub fn store_git_commit_hash_in_env() { + if std::env::var("GIT_HASH").is_ok() { + return; + } + + let output = Command::new("git").args(["rev-parse", "HEAD"]).output().unwrap(); + + // Sanity check on the output. + if !output.status.success() { + panic!( + "Failed to get git commit hash.\nstderr: \n{}\nstdout {}\n", + String::from_utf8(output.stderr).unwrap_or_default(), + String::from_utf8(output.stdout).unwrap_or_default(), + ); + } + + let git_hash = String::from_utf8(output.stdout).unwrap(); + println!("cargo:rustc-env=GIT_HASH={git_hash}"); +} diff --git a/libs/crosstarget-utils/Cargo.toml b/libs/crosstarget-utils/Cargo.toml index 627efbf23c36..78d52dade2b0 100644 --- a/libs/crosstarget-utils/Cargo.toml +++ b/libs/crosstarget-utils/Cargo.toml @@ -6,14 +6,17 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -futures = "0.3" +derive_more.workspace = true +enumflags2.workspace = true +futures.workspace = true [target.'cfg(target_arch = "wasm32")'.dependencies] js-sys.workspace = true wasm-bindgen.workspace = true wasm-bindgen-futures.workspace = true -tokio = { version = "1.25", features = ["macros", "sync"] } -pin-project = "1" +tokio = { version = "1", features = ["macros", "sync"] } +pin-project.workspace = true [target.'cfg(not(target_arch = "wasm32"))'.dependencies] tokio.workspace = true +regex.workspace = true diff --git a/libs/crosstarget-utils/src/common.rs b/libs/crosstarget-utils/src/common.rs deleted file mode 100644 index 92a1d5094e89..000000000000 --- a/libs/crosstarget-utils/src/common.rs +++ /dev/null @@ -1,23 +0,0 @@ -use std::fmt::Display; - -#[derive(Debug)] -pub struct SpawnError; - -impl Display for SpawnError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Failed to spawn a future") - } -} - -impl std::error::Error for SpawnError {} - -#[derive(Debug)] -pub struct TimeoutError; - -impl Display for TimeoutError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Operation timed out") - } -} - -impl std::error::Error for TimeoutError {} diff --git a/libs/crosstarget-utils/src/common/mod.rs b/libs/crosstarget-utils/src/common/mod.rs new file mode 100644 index 000000000000..d1701cc3408f --- /dev/null +++ b/libs/crosstarget-utils/src/common/mod.rs @@ -0,0 +1,3 @@ +pub mod regex; +pub mod spawn; +pub mod timeout; diff --git a/libs/crosstarget-utils/src/common/regex.rs b/libs/crosstarget-utils/src/common/regex.rs new file mode 100644 index 000000000000..825d3e85d6c1 --- /dev/null +++ b/libs/crosstarget-utils/src/common/regex.rs @@ -0,0 +1,37 @@ +use derive_more::Display; + +#[derive(Debug, Display)] +#[display(fmt = "Regular expression error: {message}")] +pub struct RegExpError { + pub message: String, +} + +impl std::error::Error for RegExpError {} + +/// Flag modifiers for regular expressions. +#[enumflags2::bitflags] +#[derive(Copy, Clone, Debug, PartialEq)] +#[repr(u8)] +pub enum RegExpFlags { + IgnoreCase = 0b0001, + Multiline = 0b0010, +} + +impl RegExpFlags { + pub fn as_str(&self) -> &'static str { + match self { + Self::IgnoreCase => "i", + Self::Multiline => "m", + } + } +} + +pub trait RegExpCompat { + /// Searches for the first match of this regex in the haystack given, and if found, + /// returns not only the overall match but also the matches of each capture group in the regex. + /// If no match is found, then None is returned. + fn captures(&self, message: &str) -> Option>; + + /// Tests if the regex matches the input string. + fn test(&self, message: &str) -> bool; +} diff --git a/libs/crosstarget-utils/src/common/spawn.rs b/libs/crosstarget-utils/src/common/spawn.rs new file mode 100644 index 000000000000..77560452dbdc --- /dev/null +++ b/libs/crosstarget-utils/src/common/spawn.rs @@ -0,0 +1,8 @@ +use derive_more::Display; + +#[derive(Debug, Display)] +#[display(fmt = "Failed to spawn a future")] + +pub struct SpawnError; + +impl std::error::Error for SpawnError {} diff --git a/libs/crosstarget-utils/src/common/timeout.rs b/libs/crosstarget-utils/src/common/timeout.rs new file mode 100644 index 000000000000..829abaf0ec04 --- /dev/null +++ b/libs/crosstarget-utils/src/common/timeout.rs @@ -0,0 +1,7 @@ +use derive_more::Display; + +#[derive(Debug, Display)] +#[display(fmt = "Operation timed out")] +pub struct TimeoutError; + +impl std::error::Error for TimeoutError {} diff --git a/libs/crosstarget-utils/src/lib.rs b/libs/crosstarget-utils/src/lib.rs index a41d8dd0f9a6..1cfa25edeed3 100644 --- a/libs/crosstarget-utils/src/lib.rs +++ b/libs/crosstarget-utils/src/lib.rs @@ -9,4 +9,5 @@ mod native; #[cfg(not(target_arch = "wasm32"))] pub use crate::native::*; -pub use common::SpawnError; +pub use crate::common::regex::RegExpCompat; +pub use crate::common::spawn::SpawnError; diff --git a/libs/crosstarget-utils/src/native/mod.rs b/libs/crosstarget-utils/src/native/mod.rs index b19a356ff8ff..e3793a1de655 100644 --- a/libs/crosstarget-utils/src/native/mod.rs +++ b/libs/crosstarget-utils/src/native/mod.rs @@ -1,3 +1,4 @@ +pub mod regex; pub mod spawn; pub mod task; pub mod time; diff --git a/libs/crosstarget-utils/src/native/regex.rs b/libs/crosstarget-utils/src/native/regex.rs new file mode 100644 index 000000000000..59fe88899c63 --- /dev/null +++ b/libs/crosstarget-utils/src/native/regex.rs @@ -0,0 +1,41 @@ +use enumflags2::BitFlags; +use regex::{Regex as NativeRegex, RegexBuilder}; + +use crate::common::regex::{RegExpCompat, RegExpError, RegExpFlags}; + +pub struct RegExp { + inner: NativeRegex, +} + +impl RegExp { + pub fn new(pattern: &str, flags: BitFlags) -> Result { + let mut builder = RegexBuilder::new(pattern); + + if flags.contains(RegExpFlags::Multiline) { + builder.multi_line(true); + } + + if flags.contains(RegExpFlags::IgnoreCase) { + builder.case_insensitive(true); + } + + let inner = builder.build().map_err(|e| RegExpError { message: e.to_string() })?; + + Ok(Self { inner }) + } +} + +impl RegExpCompat for RegExp { + fn captures(&self, message: &str) -> Option> { + self.inner.captures(message).map(|captures| { + captures + .iter() + .flat_map(|capture| capture.map(|cap| cap.as_str().to_owned())) + .collect() + }) + } + + fn test(&self, message: &str) -> bool { + self.inner.is_match(message) + } +} diff --git a/libs/crosstarget-utils/src/native/spawn.rs b/libs/crosstarget-utils/src/native/spawn.rs index 70e4c3708f22..31971aa47c41 100644 --- a/libs/crosstarget-utils/src/native/spawn.rs +++ b/libs/crosstarget-utils/src/native/spawn.rs @@ -1,7 +1,7 @@ use futures::TryFutureExt; use std::future::Future; -use crate::common::SpawnError; +use crate::common::spawn::SpawnError; pub fn spawn_if_possible(future: F) -> impl Future> where diff --git a/libs/crosstarget-utils/src/native/time.rs b/libs/crosstarget-utils/src/native/time.rs index 3b154a27565c..c17cb07c5eb2 100644 --- a/libs/crosstarget-utils/src/native/time.rs +++ b/libs/crosstarget-utils/src/native/time.rs @@ -3,8 +3,9 @@ use std::{ time::{Duration, Instant}, }; -use crate::common::TimeoutError; +use crate::common::timeout::TimeoutError; +#[derive(Clone, Copy)] pub struct ElapsedTimeCounter { instant: Instant, } diff --git a/libs/crosstarget-utils/src/wasm/mod.rs b/libs/crosstarget-utils/src/wasm/mod.rs index b19a356ff8ff..e3793a1de655 100644 --- a/libs/crosstarget-utils/src/wasm/mod.rs +++ b/libs/crosstarget-utils/src/wasm/mod.rs @@ -1,3 +1,4 @@ +pub mod regex; pub mod spawn; pub mod task; pub mod time; diff --git a/libs/crosstarget-utils/src/wasm/regex.rs b/libs/crosstarget-utils/src/wasm/regex.rs new file mode 100644 index 000000000000..500f631282ea --- /dev/null +++ b/libs/crosstarget-utils/src/wasm/regex.rs @@ -0,0 +1,38 @@ +use enumflags2::BitFlags; +use js_sys::RegExp as JSRegExp; + +use crate::common::regex::{RegExpCompat, RegExpError, RegExpFlags}; + +pub struct RegExp { + inner: JSRegExp, +} + +impl RegExp { + pub fn new(pattern: &str, flags: BitFlags) -> Result { + let mut flags: String = flags.into_iter().map(|flag| flag.as_str()).collect(); + + // Global flag is implied in `regex::Regex`, so we match that behavior for consistency. + flags.push('g'); + + Ok(Self { + inner: JSRegExp::new(pattern, &flags), + }) + } +} + +impl RegExpCompat for RegExp { + fn captures(&self, message: &str) -> Option> { + self.inner.exec(message).map(|matches| { + // We keep the same number of captures as the number of groups in the regex pattern, + // but we guarantee that the captures are always strings. + matches + .iter() + .map(|match_value| match_value.try_into().ok().unwrap_or_default()) + .collect() + }) + } + + fn test(&self, input: &str) -> bool { + self.inner.test(input) + } +} diff --git a/libs/crosstarget-utils/src/wasm/spawn.rs b/libs/crosstarget-utils/src/wasm/spawn.rs index e27104c3b941..f9700d8a0071 100644 --- a/libs/crosstarget-utils/src/wasm/spawn.rs +++ b/libs/crosstarget-utils/src/wasm/spawn.rs @@ -4,7 +4,7 @@ use futures::TryFutureExt; use tokio::sync::oneshot; use wasm_bindgen_futures::spawn_local; -use crate::common::SpawnError; +use crate::common::spawn::SpawnError; pub fn spawn_if_possible(future: F) -> impl Future> where diff --git a/libs/crosstarget-utils/src/wasm/time.rs b/libs/crosstarget-utils/src/wasm/time.rs index 18f3394b7464..6c36a7b4d400 100644 --- a/libs/crosstarget-utils/src/wasm/time.rs +++ b/libs/crosstarget-utils/src/wasm/time.rs @@ -7,7 +7,7 @@ use std::time::Duration; use wasm_bindgen::prelude::*; use wasm_bindgen_futures::JsFuture; -use crate::common::TimeoutError; +use crate::common::timeout::TimeoutError; #[wasm_bindgen] extern "C" { @@ -21,6 +21,7 @@ extern "C" { } +#[derive(Clone, Copy)] pub struct ElapsedTimeCounter { start_time: f64, } diff --git a/query-engine/metrics/Cargo.toml b/libs/metrics/Cargo.toml similarity index 50% rename from query-engine/metrics/Cargo.toml rename to libs/metrics/Cargo.toml index 5593b246c093..916c464dda5c 100644 --- a/query-engine/metrics/Cargo.toml +++ b/libs/metrics/Cargo.toml @@ -1,19 +1,22 @@ [package] -name = "query-engine-metrics" +name = "prisma-metrics" version = "0.1.0" edition = "2021" [dependencies] -metrics = "0.18" -metrics-util = "0.12.1" -metrics-exporter-prometheus = "0.10.0" +futures.workspace = true +derive_more.workspace = true +metrics = "0.23.0" +metrics-util = "0.17.0" +metrics-exporter-prometheus = { version = "0.15.3", default-features = false } once_cell = "1.3" serde.workspace = true serde_json.workspace = true tracing.workspace = true -tracing-futures = "0.2" +tracing-futures.workspace = true tracing-subscriber = "0.3.11" parking_lot = "0.12" +pin-project.workspace = true [dev-dependencies] expect-test = "1" diff --git a/query-engine/metrics/src/common.rs b/libs/metrics/src/common.rs similarity index 100% rename from query-engine/metrics/src/common.rs rename to libs/metrics/src/common.rs diff --git a/query-engine/metrics/src/formatters.rs b/libs/metrics/src/formatters.rs similarity index 100% rename from query-engine/metrics/src/formatters.rs rename to libs/metrics/src/formatters.rs diff --git a/libs/metrics/src/guards.rs b/libs/metrics/src/guards.rs new file mode 100644 index 000000000000..331db1249904 --- /dev/null +++ b/libs/metrics/src/guards.rs @@ -0,0 +1,31 @@ +use std::sync::atomic::{AtomicBool, Ordering}; + +use crate::gauge; + +pub struct GaugeGuard { + name: &'static str, + decremented: AtomicBool, +} + +impl GaugeGuard { + pub fn increment(name: &'static str) -> Self { + gauge!(name).increment(1.0); + + Self { + name, + decremented: AtomicBool::new(false), + } + } + + pub fn decrement(&self) { + if !self.decremented.swap(true, Ordering::Relaxed) { + gauge!(self.name).decrement(1.0); + } + } +} + +impl Drop for GaugeGuard { + fn drop(&mut self) { + self.decrement(); + } +} diff --git a/libs/metrics/src/instrument.rs b/libs/metrics/src/instrument.rs new file mode 100644 index 000000000000..a2cb16de48f8 --- /dev/null +++ b/libs/metrics/src/instrument.rs @@ -0,0 +1,83 @@ +use std::{ + cell::RefCell, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::future::Either; +use pin_project::pin_project; + +use crate::MetricRecorder; + +thread_local! { + /// The current metric recorder temporarily set on the current thread while polling a future. + /// + /// See the description of `GLOBAL_RECORDER` in [`crate::recorder`] module for more + /// information. + static CURRENT_RECORDER: RefCell> = const { RefCell::new(None) }; +} + +/// Instruments a type with a metrics recorder. +/// +/// The instrumentation logic is currently only implemented for futures, but it could be extended +/// to support streams, sinks, and other types later if needed. Right now we only need it to be +/// able to set the initial recorder in the Node-API engine methods and forward the recorder to +/// spawned tokio tasks; in other words, to instrument the top-level future of each task. +pub trait WithMetricsInstrumentation: Sized { + /// Instruments the type with a [`MetricRecorder`]. + fn with_recorder(self, recorder: MetricRecorder) -> WithRecorder { + WithRecorder { inner: self, recorder } + } + + /// Instruments the type with an [`MetricRecorder`] if it is a `Some` or returns `self` as is + /// if the `recorder` is a `None`. + fn with_optional_recorder(self, recorder: Option) -> Either, Self> { + match recorder { + Some(recorder) => Either::Left(self.with_recorder(recorder)), + None => Either::Right(self), + } + } + + /// Instruments the type with the current [`MetricRecorder`] from the parent context on this + /// thread, or the default global recorder otherwise. If neither is set, then `self` is + /// returned as is. + fn with_current_recorder(self) -> Either, Self> { + CURRENT_RECORDER.with_borrow(|recorder| { + let recorder = recorder.clone().or_else(crate::recorder::global_recorder); + self.with_optional_recorder(recorder) + }) + } +} + +impl WithMetricsInstrumentation for T {} + +/// A type instrumented with a metric recorder. +/// +/// If `T` is a `Future`, then `WithRecorder` is also a `Future`. When polled, it temporarily +/// sets the local metric recorder for the duration of polling the inner future, and then restores +/// the previous recorder on the stack. +/// +/// Similar logic can be implemented for cases where `T` is another async primitive like a stream +/// or a sink, or any other type where such instrumentation makes sense (e.g. a function). +#[pin_project] +pub struct WithRecorder { + #[pin] + inner: T, + recorder: MetricRecorder, +} + +impl Future for WithRecorder { + type Output = T::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let prev_recorder = CURRENT_RECORDER.replace(Some(this.recorder.clone())); + + let poll = metrics::with_local_recorder(this.recorder, || this.inner.poll(cx)); + + CURRENT_RECORDER.set(prev_recorder); + + poll + } +} diff --git a/query-engine/metrics/src/lib.rs b/libs/metrics/src/lib.rs similarity index 78% rename from query-engine/metrics/src/lib.rs rename to libs/metrics/src/lib.rs index 1965b56cb076..43aeab775922 100644 --- a/query-engine/metrics/src/lib.rs +++ b/libs/metrics/src/lib.rs @@ -1,5 +1,8 @@ -//! Query Engine Metrics -//! This crate is responsible for capturing and recording metrics in the Query Engine. +//! # Prisma Metrics +//! +//! This crate is responsible for capturing and recording metrics in the Query Engine and its +//! dependencies. +//! //! Metrics is broken into two parts, `MetricsRecorder` and `MetricsRegistry`, and uses our tracing framework to communicate. //! An example best explains this system. //! When the engine boots up, the `MetricRegistry` is added to our tracing as a layer and The MetricRecorder is @@ -19,29 +22,23 @@ //! * At the moment, with the Histogram we only support one type of bucket which is a bucket for timings in milliseconds. //! -const METRIC_TARGET: &str = "qe_metrics"; -const METRIC_COUNTER: &str = "counter"; -const METRIC_GAUGE: &str = "gauge"; -const METRIC_HISTOGRAM: &str = "histogram"; -const METRIC_DESCRIPTION: &str = "description"; - mod common; mod formatters; +mod instrument; mod recorder; mod registry; +pub mod guards; + use once_cell::sync::Lazy; -use recorder::*; -pub use registry::MetricRegistry; use serde::Deserialize; use std::collections::HashMap; -use std::sync::Once; -pub extern crate metrics; -pub use metrics::{ - absolute_counter, decrement_gauge, describe_counter, describe_gauge, describe_histogram, gauge, histogram, - increment_counter, increment_gauge, -}; +pub use metrics::{self, counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram}; + +pub use instrument::*; +pub use recorder::MetricRecorder; +pub use registry::MetricRegistry; // Metrics that we emit from the engines, third party metrics emitted by libraries and that we rename are omitted. pub const PRISMA_CLIENT_QUERIES_TOTAL: &str = "prisma_client_queries_total"; // counter @@ -94,21 +91,8 @@ static METRIC_RENAMES: Lazy> ]) }); -pub fn setup() { - set_recorder(); - initialize_metrics(); -} - -static METRIC_RECORDER: Once = Once::new(); - -fn set_recorder() { - METRIC_RECORDER.call_once(|| { - metrics::set_boxed_recorder(Box::new(MetricRecorder)).unwrap(); - }); -} - /// Initialize metrics descriptions and values -pub fn initialize_metrics() { +pub(crate) fn initialize_metrics() { initialize_metrics_descriptions(); initialize_metrics_values(); } @@ -145,15 +129,15 @@ fn initialize_metrics_descriptions() { /// Histograms are excluded, as their initialization will alter the histogram values. /// (i.e. histograms don't have a neutral value, like counters or gauges) fn initialize_metrics_values() { - absolute_counter!(PRISMA_CLIENT_QUERIES_TOTAL, 0); - absolute_counter!(PRISMA_DATASOURCE_QUERIES_TOTAL, 0); - gauge!(PRISMA_CLIENT_QUERIES_ACTIVE, 0.0); - absolute_counter!(MOBC_POOL_CONNECTIONS_OPENED_TOTAL, 0); - absolute_counter!(MOBC_POOL_CONNECTIONS_CLOSED_TOTAL, 0); - gauge!(MOBC_POOL_CONNECTIONS_OPEN, 0.0); - gauge!(MOBC_POOL_CONNECTIONS_BUSY, 0.0); - gauge!(MOBC_POOL_CONNECTIONS_IDLE, 0.0); - gauge!(MOBC_POOL_WAIT_COUNT, 0.0); + counter!(PRISMA_CLIENT_QUERIES_TOTAL).absolute(0); + counter!(PRISMA_DATASOURCE_QUERIES_TOTAL).absolute(0); + gauge!(PRISMA_CLIENT_QUERIES_ACTIVE).set(0.0); + counter!(MOBC_POOL_CONNECTIONS_OPENED_TOTAL).absolute(0); + counter!(MOBC_POOL_CONNECTIONS_CLOSED_TOTAL).absolute(0); + gauge!(MOBC_POOL_CONNECTIONS_OPEN).set(0.0); + gauge!(MOBC_POOL_CONNECTIONS_BUSY).set(0.0); + gauge!(MOBC_POOL_CONNECTIONS_IDLE).set(0.0); + gauge!(MOBC_POOL_WAIT_COUNT).set(0.0); } // At the moment the histogram is only used for timings. So the bounds are hard coded here @@ -171,24 +155,16 @@ pub enum MetricFormat { #[cfg(test)] mod tests { use super::*; - use metrics::{ - absolute_counter, decrement_gauge, describe_counter, describe_gauge, describe_histogram, gauge, histogram, - increment_counter, increment_gauge, register_counter, register_gauge, register_histogram, - }; + use metrics::{describe_counter, describe_gauge, describe_histogram, gauge, histogram}; use serde_json::json; use std::collections::HashMap; use std::time::Duration; - use tracing::instrument::WithSubscriber; - use tracing::{trace, Dispatch}; - use tracing_subscriber::layer::SubscriberExt; + use tracing::trace; use once_cell::sync::Lazy; use tokio::runtime::Runtime; - static RT: Lazy = Lazy::new(|| { - set_recorder(); - Runtime::new().unwrap() - }); + static RT: Lazy = Lazy::new(|| Runtime::new().unwrap()); const TESTING_ACCEPT_LIST: &[&str] = &[ "test_counter", @@ -209,14 +185,14 @@ mod tests { fn test_counters() { RT.block_on(async { let metrics = MetricRegistry::new_with_accept_list(TESTING_ACCEPT_LIST.to_vec()); - let dispatch = Dispatch::new(tracing_subscriber::Registry::default().with(metrics.clone())); - async { - let counter1 = register_counter!("test_counter"); + let recorder = MetricRecorder::new(metrics.clone()); + async move { + let counter1 = counter!("test_counter"); counter1.increment(1); - increment_counter!("test_counter"); - increment_counter!("test_counter"); + counter!("test_counter").increment(1); + counter!("test_counter").increment(1); - increment_counter!("another_counter"); + counter!("another_counter").increment(1); let val = metrics.counter_value("test_counter").unwrap(); assert_eq!(val, 3); @@ -224,11 +200,11 @@ mod tests { let val2 = metrics.counter_value("another_counter").unwrap(); assert_eq!(val2, 1); - absolute_counter!("test_counter", 5); + counter!("test_counter").absolute(5); let val3 = metrics.counter_value("test_counter").unwrap(); assert_eq!(val3, 5); } - .with_subscriber(dispatch) + .with_recorder(recorder) .await; }); } @@ -237,13 +213,13 @@ mod tests { fn test_gauges() { RT.block_on(async { let metrics = MetricRegistry::new_with_accept_list(TESTING_ACCEPT_LIST.to_vec()); - let dispatch = Dispatch::new(tracing_subscriber::Registry::default().with(metrics.clone())); - async { - let gauge1 = register_gauge!("test_gauge"); + let recorder = MetricRecorder::new(metrics.clone()); + async move { + let gauge1 = gauge!("test_gauge"); gauge1.increment(1.0); - increment_gauge!("test_gauge", 1.0); - increment_gauge!("test_gauge", 1.0); - increment_gauge!("another_gauge", 1.0); + gauge!("test_gauge").increment(1.0); + gauge!("test_gauge").increment(1.0); + gauge!("another_gauge").increment(1.0); let val = metrics.gauge_value("test_gauge").unwrap(); assert_eq!(val, 3.0); @@ -253,15 +229,15 @@ mod tests { assert_eq!(None, metrics.counter_value("test_gauge")); - gauge!("test_gauge", 5.0); + gauge!("test_gauge").set(5.0); let val3 = metrics.gauge_value("test_gauge").unwrap(); assert_eq!(val3, 5.0); - decrement_gauge!("test_gauge", 2.0); + gauge!("test_gauge").decrement(2.0); let val4 = metrics.gauge_value("test_gauge").unwrap(); assert_eq!(val4, 3.0); } - .with_subscriber(dispatch) + .with_recorder(recorder) .await; }); } @@ -270,19 +246,19 @@ mod tests { fn test_no_panic_and_ignore_other_traces() { RT.block_on(async { let metrics = MetricRegistry::new_with_accept_list(TESTING_ACCEPT_LIST.to_vec()); - let dispatch = Dispatch::new(tracing_subscriber::Registry::default().with(metrics.clone())); - async { + let recorder = MetricRecorder::new(metrics.clone()); + async move { trace!("a fake trace"); - increment_gauge!("test_gauge", 1.0); - increment_counter!("test_counter"); + gauge!("test_gauge").set(1.0); + counter!("test_counter").increment(1); trace!("another fake trace"); assert_eq!(1.0, metrics.gauge_value("test_gauge").unwrap()); assert_eq!(1, metrics.counter_value("test_counter").unwrap()); } - .with_subscriber(dispatch) + .with_recorder(recorder) .await; }); } @@ -291,15 +267,15 @@ mod tests { fn test_ignore_non_accepted_metrics() { RT.block_on(async { let metrics = MetricRegistry::new_with_accept_list(TESTING_ACCEPT_LIST.to_vec()); - let dispatch = Dispatch::new(tracing_subscriber::Registry::default().with(metrics.clone())); - async { - increment_gauge!("not_accepted", 1.0); - increment_gauge!("test_gauge", 1.0); + let recorder = MetricRecorder::new(metrics.clone()); + async move { + gauge!("not_accepted").set(1.0); + gauge!("test_gauge").set(1.0); assert_eq!(1.0, metrics.gauge_value("test_gauge").unwrap()); assert_eq!(None, metrics.gauge_value("not_accepted")); } - .with_subscriber(dispatch) + .with_recorder(recorder) .await; }); } @@ -308,17 +284,17 @@ mod tests { fn test_histograms() { RT.block_on(async { let metrics = MetricRegistry::new_with_accept_list(TESTING_ACCEPT_LIST.to_vec()); - let dispatch = Dispatch::new(tracing_subscriber::Registry::default().with(metrics.clone())); - async { - let hist = register_histogram!("test_histogram"); + let recorder = MetricRecorder::new(metrics.clone()); + async move { + let hist = histogram!("test_histogram"); hist.record(Duration::from_millis(9)); - histogram!("test_histogram", Duration::from_millis(100)); - histogram!("test_histogram", Duration::from_millis(1)); + histogram!("test_histogram").record(Duration::from_millis(100)); + histogram!("test_histogram").record(Duration::from_millis(1)); - histogram!("test_histogram", Duration::from_millis(1999)); - histogram!("test_histogram", Duration::from_millis(3999)); - histogram!("test_histogram", Duration::from_millis(610)); + histogram!("test_histogram").record(Duration::from_millis(1999)); + histogram!("test_histogram").record(Duration::from_millis(3999)); + histogram!("test_histogram").record(Duration::from_millis(610)); let hist = metrics.histogram_values("test_histogram").unwrap(); let expected: Vec<(f64, u64)> = Vec::from([ @@ -336,7 +312,7 @@ mod tests { assert_eq!(hist.buckets(), expected); } - .with_subscriber(dispatch) + .with_recorder(recorder) .await; }); } @@ -345,8 +321,8 @@ mod tests { fn test_set_and_read_descriptions() { RT.block_on(async { let metrics = MetricRegistry::new_with_accept_list(TESTING_ACCEPT_LIST.to_vec()); - let dispatch = Dispatch::new(tracing_subscriber::Registry::default().with(metrics.clone())); - async { + let recorder = MetricRecorder::new(metrics.clone()); + async move { describe_counter!("test_counter", "This is a counter"); let descriptions = metrics.get_descriptions(); @@ -367,7 +343,7 @@ mod tests { let description = descriptions.get("test_histogram").unwrap(); assert_eq!("This is a hist", description); } - .with_subscriber(dispatch) + .with_recorder(recorder) .await; }); } @@ -376,8 +352,8 @@ mod tests { fn test_to_json() { RT.block_on(async { let metrics = MetricRegistry::new_with_accept_list(TESTING_ACCEPT_LIST.to_vec()); - let dispatch = Dispatch::new(tracing_subscriber::Registry::default().with(metrics.clone())); - async { + let recorder = MetricRecorder::new(metrics.clone()); + async move { let empty = json!({ "counters": [], "gauges": [], @@ -386,21 +362,21 @@ mod tests { assert_eq!(metrics.to_json(Default::default()), empty); - absolute_counter!("counter_1", 4, "label" => "one"); + counter!("counter_1", "label" => "one").absolute(4); describe_counter!("counter_2", "this is a description for counter 2"); - absolute_counter!("counter_2", 2, "label" => "one", "another_label" => "two"); + counter!("counter_2", "label" => "one", "another_label" => "two").absolute(2); describe_gauge!("gauge_1", "a description for gauge 1"); - gauge!("gauge_1", 7.0); - gauge!("gauge_2", 3.0, "label" => "three"); + gauge!("gauge_1").set(7.0); + gauge!("gauge_2", "label" => "three").set(3.0); describe_histogram!("histogram_1", "a description for histogram"); - let hist = register_histogram!("histogram_1", "label" => "one", "hist_two" => "two"); + let hist = histogram!("histogram_1", "label" => "one", "hist_two" => "two"); hist.record(Duration::from_millis(9)); - histogram!("histogram_2", Duration::from_millis(9)); - histogram!("histogram_2", Duration::from_millis(1000)); - histogram!("histogram_2", Duration::from_millis(40)); + histogram!("histogram_2").record(Duration::from_millis(9)); + histogram!("histogram_2").record(Duration::from_millis(1000)); + histogram!("histogram_2").record(Duration::from_millis(40)); let json = metrics.to_json(Default::default()); let expected = json!({ @@ -448,7 +424,7 @@ mod tests { assert_eq!(json, expected); } - .with_subscriber(dispatch) + .with_recorder(recorder) .await; }); } @@ -457,12 +433,12 @@ mod tests { fn test_global_and_metric_labels() { RT.block_on(async { let metrics = MetricRegistry::new_with_accept_list(TESTING_ACCEPT_LIST.to_vec()); - let dispatch = Dispatch::new(tracing_subscriber::Registry::default().with(metrics.clone())); - async { - let hist = register_histogram!("test_histogram", "label" => "one", "two" => "another"); + let recorder = MetricRecorder::new(metrics.clone()); + async move { + let hist = histogram!("test_histogram", "label" => "one", "two" => "another"); hist.record(Duration::from_millis(9)); - absolute_counter!("counter_1", 1); + counter!("counter_1").absolute(1); let mut global_labels: HashMap = HashMap::new(); global_labels.insert("global_one".to_string(), "one".to_string()); @@ -491,7 +467,7 @@ mod tests { }); assert_eq!(expected, json); } - .with_subscriber(dispatch) + .with_recorder(recorder) .await; }); } @@ -500,21 +476,21 @@ mod tests { fn test_prometheus_format() { RT.block_on(async { let metrics = MetricRegistry::new_with_accept_list(TESTING_ACCEPT_LIST.to_vec()); - let dispatch = Dispatch::new(tracing_subscriber::Registry::default().with(metrics.clone())); - async { - absolute_counter!("counter_1", 4, "label" => "one"); + let recorder = MetricRecorder::new(metrics.clone()); + async move { + counter!("counter_1", "label" => "one").absolute(4); describe_counter!("counter_2", "this is a description for counter 2"); - absolute_counter!("counter_2", 2, "label" => "one", "another_label" => "two"); + counter!("counter_2", "label" => "one", "another_label" => "two").absolute(2); describe_gauge!("gauge_1", "a description for gauge 1"); - gauge!("gauge_1", 7.0); - gauge!("gauge_2", 3.0, "label" => "three"); + gauge!("gauge_1").set(7.0); + gauge!("gauge_2", "label" => "three").set(3.0); describe_histogram!("histogram_1", "a description for histogram"); - let hist = register_histogram!("histogram_1", "label" => "one", "hist_two" => "two"); + let hist = histogram!("histogram_1", "label" => "one", "hist_two" => "two"); hist.record(Duration::from_millis(9)); - histogram!("histogram_2", Duration::from_millis(1000)); + histogram!("histogram_2").record(Duration::from_millis(1000)); let mut global_labels: HashMap = HashMap::new(); global_labels.insert("global_two".to_string(), "two".to_string()); @@ -574,7 +550,7 @@ mod tests { snapshot.assert_eq(&prometheus); } - .with_subscriber(dispatch) + .with_recorder(recorder) .await; }); } diff --git a/libs/metrics/src/recorder.rs b/libs/metrics/src/recorder.rs new file mode 100644 index 000000000000..a911a45d67e5 --- /dev/null +++ b/libs/metrics/src/recorder.rs @@ -0,0 +1,190 @@ +use std::sync::{Arc, OnceLock}; + +use derive_more::Display; +use metrics::{Counter, CounterFn, Gauge, GaugeFn, Histogram, HistogramFn, Key, Recorder, Unit}; +use metrics::{KeyName, Metadata, SharedString}; + +use crate::common::{MetricAction, MetricType}; +use crate::registry::MetricVisitor; +use crate::MetricRegistry; + +/// Default global metric recorder. +/// +/// `metrics` crate has the state on its own. It allows setting the global recorder, it allows +/// overriding it for a duration of an async closure, and it allows borrowing the current recorder +/// for a short while. We, however, can't use this in our async instrumentation because we need the +/// current recorder to be `Send + 'static` to be able to store it in a future that would be usable +/// in a work-stealing runtime, especially since we need to be able to instrument the futures +/// spawned as tasks. The solution to this is to maintain our own state in parallel. +/// +/// The APIs exposed by the crate guarantee that the state we modify on our side is updated on the +/// `metrics` side as well. Using `metrics::set_global_recorder` or `metrics::with_local_recorder` +/// in user code won't be detected by us but is safe and won't lead to any issues (even if the new +/// recorder isn't the [`MetricRecorder`] from this crate), we just won't know about any new local +/// recorders on the stack, and calling +/// [`crate::WithMetricsInstrumentation::with_current_recorder`] will re-use the last +/// [`MetricRecorder`] known to us. +static GLOBAL_RECORDER: OnceLock> = const { OnceLock::new() }; + +#[derive(Display, Debug)] +#[display(fmt = "global recorder can only be installed once")] +pub struct AlreadyInstalled; + +impl std::error::Error for AlreadyInstalled {} + +fn set_global_recorder(recorder: MetricRecorder) -> Result<(), AlreadyInstalled> { + GLOBAL_RECORDER.set(Some(recorder)).map_err(|_| AlreadyInstalled) +} + +pub(crate) fn global_recorder() -> Option { + GLOBAL_RECORDER.get()?.clone() +} + +/// Receives the metrics from the macros provided by the `metrics` crate and forwards them to +/// [`MetricRegistry`]. +/// +/// To provide an analogy, `MetricRecorder` to `MetricRegistry` is what `Dispatch` is to +/// `Subscriber` in `tracing`. Just like `Dispatch`, it acts like a handle to the registry and is +/// cheaply clonable with reference-counting semantics. +#[derive(Clone)] +pub struct MetricRecorder { + registry: MetricRegistry, +} + +impl MetricRecorder { + pub fn new(registry: MetricRegistry) -> Self { + Self { registry } + } + + /// Convenience method to call [`Self::init_prisma_metrics`] immediately after creating the + /// recorder. + pub fn with_initialized_prisma_metrics(self) -> Self { + self.init_prisma_metrics(); + self + } + + /// Initializes the default Prisma metrics by dispatching their descriptions and initial values + /// to the registry. + /// + /// Query engine needs this, but the metrics can also be used without this, especially in + /// tests. + pub fn init_prisma_metrics(&self) { + metrics::with_local_recorder(self, || { + super::initialize_metrics(); + }); + } + + /// Installs the metrics recorder globally, registering it both with the `metrics` crate and + /// our own instrumentation. + pub fn install_globally(&self) -> Result<(), AlreadyInstalled> { + set_global_recorder(self.clone())?; + metrics::set_global_recorder(self.clone()).map_err(|_| AlreadyInstalled) + } + + fn register_description(&self, name: KeyName, description: &str) { + self.record_in_registry(&MetricVisitor { + metric_type: MetricType::Description, + action: MetricAction::Description(description.to_owned()), + name: Key::from_name(name), + }); + } + + fn record_in_registry(&self, visitor: &MetricVisitor) { + self.registry.record(visitor); + } +} + +impl Recorder for MetricRecorder { + fn describe_counter(&self, key_name: KeyName, _unit: Option, description: SharedString) { + self.register_description(key_name, &description); + } + + fn describe_gauge(&self, key_name: KeyName, _unit: Option, description: SharedString) { + self.register_description(key_name, &description); + } + + fn describe_histogram(&self, key_name: KeyName, _unit: Option, description: SharedString) { + self.register_description(key_name, &description); + } + + fn register_counter(&self, key: &Key, _metadata: &Metadata<'_>) -> Counter { + Counter::from_arc(Arc::new(MetricHandle::new(key.clone(), self.registry.clone()))) + } + + fn register_gauge(&self, key: &Key, _metadata: &Metadata<'_>) -> Gauge { + Gauge::from_arc(Arc::new(MetricHandle::new(key.clone(), self.registry.clone()))) + } + + fn register_histogram(&self, key: &Key, _metadata: &Metadata<'_>) -> Histogram { + Histogram::from_arc(Arc::new(MetricHandle::new(key.clone(), self.registry.clone()))) + } +} + +pub(crate) struct MetricHandle { + key: Key, + registry: MetricRegistry, +} + +impl MetricHandle { + pub fn new(key: Key, registry: MetricRegistry) -> Self { + Self { key, registry } + } + + fn record_in_registry(&self, visitor: &MetricVisitor) { + self.registry.record(visitor); + } +} + +impl CounterFn for MetricHandle { + fn increment(&self, value: u64) { + self.record_in_registry(&MetricVisitor { + metric_type: MetricType::Counter, + action: MetricAction::Increment(value), + name: self.key.clone(), + }); + } + + fn absolute(&self, value: u64) { + self.record_in_registry(&MetricVisitor { + metric_type: MetricType::Counter, + action: MetricAction::Absolute(value), + name: self.key.clone(), + }); + } +} + +impl GaugeFn for MetricHandle { + fn increment(&self, value: f64) { + self.record_in_registry(&MetricVisitor { + metric_type: MetricType::Gauge, + action: MetricAction::GaugeInc(value), + name: self.key.clone(), + }); + } + + fn decrement(&self, value: f64) { + self.record_in_registry(&MetricVisitor { + metric_type: MetricType::Gauge, + action: MetricAction::GaugeDec(value), + name: self.key.clone(), + }); + } + + fn set(&self, value: f64) { + self.record_in_registry(&MetricVisitor { + metric_type: MetricType::Gauge, + action: MetricAction::GaugeSet(value), + name: self.key.clone(), + }); + } +} + +impl HistogramFn for MetricHandle { + fn record(&self, value: f64) { + self.record_in_registry(&MetricVisitor { + metric_type: MetricType::Histogram, + action: MetricAction::HistRecord(value), + name: self.key.clone(), + }); + } +} diff --git a/query-engine/metrics/src/registry.rs b/libs/metrics/src/registry.rs similarity index 66% rename from query-engine/metrics/src/registry.rs rename to libs/metrics/src/registry.rs index 6530edbe8764..f6b217fcda56 100644 --- a/query-engine/metrics/src/registry.rs +++ b/libs/metrics/src/registry.rs @@ -1,11 +1,7 @@ -use super::formatters::metrics_to_json; -use super::{ - common::{KeyLabels, Metric, MetricAction, MetricType, MetricValue, Snapshot}, - formatters::metrics_to_prometheus, -}; -use super::{ - ACCEPT_LIST, HISTOGRAM_BOUNDS, METRIC_COUNTER, METRIC_DESCRIPTION, METRIC_GAUGE, METRIC_HISTOGRAM, METRIC_TARGET, -}; +use std::collections::HashMap; +use std::fmt; +use std::sync::{atomic::Ordering, Arc}; + use metrics::{CounterFn, GaugeFn, HistogramFn, Key}; use metrics_util::{ registry::{GenerationalAtomicStorage, GenerationalStorage, Registry}, @@ -13,15 +9,13 @@ use metrics_util::{ }; use parking_lot::RwLock; use serde_json::Value; -use std::collections::HashMap; -use std::fmt; -use std::sync::atomic::Ordering; -use std::sync::Arc; -use tracing::{ - field::{Field, Visit}, - Subscriber, + +use super::formatters::metrics_to_json; +use super::{ + common::{Metric, MetricAction, MetricType, MetricValue, Snapshot}, + formatters::metrics_to_prometheus, }; -use tracing_subscriber::Layer; +use super::{ACCEPT_LIST, HISTOGRAM_BOUNDS}; struct Inner { descriptions: RwLock>, @@ -46,7 +40,7 @@ pub struct MetricRegistry { impl fmt::Debug for MetricRegistry { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Metric Registry") + write!(f, "MetricRegistry {{ .. }}") } } @@ -68,12 +62,14 @@ impl MetricRegistry { } } - fn record(&self, metric: &MetricVisitor) { - match metric.metric_type { - MetricType::Counter => self.handle_counter(metric), - MetricType::Gauge => self.handle_gauge(metric), - MetricType::Histogram => self.handle_histogram(metric), - MetricType::Description => self.handle_description(metric), + pub(crate) fn record(&self, metric: &MetricVisitor) { + if self.is_accepted_metric(metric) { + match metric.metric_type { + MetricType::Counter => self.handle_counter(metric), + MetricType::Gauge => self.handle_gauge(metric), + MetricType::Histogram => self.handle_histogram(metric), + MetricType::Description => self.handle_description(metric), + } } } @@ -223,80 +219,8 @@ impl MetricRegistry { } #[derive(Debug)] -struct MetricVisitor { - metric_type: MetricType, - action: MetricAction, - name: Key, -} - -impl MetricVisitor { - pub fn new() -> Self { - Self { - metric_type: MetricType::Description, - action: MetricAction::Absolute(0), - name: Key::from_name(""), - } - } -} - -impl Visit for MetricVisitor { - fn record_debug(&mut self, _field: &Field, _value: &dyn std::fmt::Debug) {} - - fn record_f64(&mut self, field: &Field, value: f64) { - match field.name() { - "gauge_inc" => self.action = MetricAction::GaugeInc(value), - "gauge_dec" => self.action = MetricAction::GaugeDec(value), - "gauge_set" => self.action = MetricAction::GaugeSet(value), - "hist_record" => self.action = MetricAction::HistRecord(value), - _ => (), - } - } - - fn record_i64(&mut self, field: &Field, value: i64) { - match field.name() { - "increment" => self.action = MetricAction::Increment(value as u64), - "absolute" => self.action = MetricAction::Absolute(value as u64), - _ => (), - } - } - - fn record_u64(&mut self, field: &Field, value: u64) { - match field.name() { - "increment" => self.action = MetricAction::Increment(value), - "absolute" => self.action = MetricAction::Absolute(value), - _ => (), - } - } - - fn record_str(&mut self, field: &Field, value: &str) { - match (field.name(), value) { - ("metric_type", METRIC_COUNTER) => self.metric_type = MetricType::Counter, - ("metric_type", METRIC_GAUGE) => self.metric_type = MetricType::Gauge, - ("metric_type", METRIC_HISTOGRAM) => self.metric_type = MetricType::Histogram, - ("metric_type", METRIC_DESCRIPTION) => self.metric_type = MetricType::Description, - ("name", _) => self.name = Key::from_name(value.to_string()), - ("key_labels", _) => { - let key_labels: KeyLabels = serde_json::from_str(value).unwrap(); - self.name = key_labels.into(); - } - (METRIC_DESCRIPTION, _) => self.action = MetricAction::Description(value.to_string()), - _ => (), - } - } -} - -// A tracing layer for receiving metric trace events and storing them in the registry. -impl Layer for MetricRegistry { - fn on_event(&self, event: &tracing::Event<'_>, _ctx: tracing_subscriber::layer::Context<'_, S>) { - if event.metadata().target() != METRIC_TARGET { - return; - } - - let mut visitor = MetricVisitor::new(); - event.record(&mut visitor); - - if self.is_accepted_metric(&visitor) { - self.record(&visitor); - } - } +pub(crate) struct MetricVisitor { + pub(crate) metric_type: MetricType, + pub(crate) action: MetricAction, + pub(crate) name: Key, } diff --git a/libs/mongodb-client/Cargo.toml b/libs/mongodb-client/Cargo.toml index 0144543ca055..dcfec7c21ac1 100644 --- a/libs/mongodb-client/Cargo.toml +++ b/libs/mongodb-client/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -mongodb = "2.8.0" +mongodb.workspace = true # Remove these when mongo opens up their connection string parsing percent-encoding = "2.0.0" diff --git a/libs/mongodb-client/src/lib.rs b/libs/mongodb-client/src/lib.rs index 584568e178f8..c5f6a16145ea 100644 --- a/libs/mongodb-client/src/lib.rs +++ b/libs/mongodb-client/src/lib.rs @@ -12,12 +12,12 @@ use mongodb::{ /// A wrapper to create a new MongoDB client. Please remove me when we do not /// need special setup anymore for this. pub async fn create(connection_string: impl AsRef) -> Result { - let mut options = if cfg!(target_os = "windows") { - ClientOptions::parse_with_resolver_config(connection_string, ResolverConfig::cloudflare()).await? - } else { - ClientOptions::parse(connection_string).await? - }; + let mut connection_string_parser = ClientOptions::parse(connection_string.as_ref()); + if cfg!(target_os = "windows") { + connection_string_parser = connection_string_parser.resolver_config(ResolverConfig::cloudflare()); + } + let mut options = connection_string_parser.await?; options.driver_info = Some(DriverInfo::builder().name("Prisma").build()); Ok(Client::with_options(options)?) diff --git a/libs/prisma-value/Cargo.toml b/libs/prisma-value/Cargo.toml index 1a0d28e06db3..9833b6ee104b 100644 --- a/libs/prisma-value/Cargo.toml +++ b/libs/prisma-value/Cargo.toml @@ -7,7 +7,7 @@ version = "0.1.0" base64 = "0.13" chrono.workspace = true once_cell = "1.3" -regex = "1.2" +regex.workspace = true bigdecimal = "0.3" serde.workspace = true serde_json.workspace = true diff --git a/libs/prisma-value/src/raw_json.rs b/libs/prisma-value/src/raw_json.rs index e0e3596b05d4..83db97658459 100644 --- a/libs/prisma-value/src/raw_json.rs +++ b/libs/prisma-value/src/raw_json.rs @@ -13,7 +13,7 @@ use serde_json::value::RawValue; /// directly because: /// 1. We need `Eq` implementation /// 2. `serde_json::value::RawValue::from_string` may error and we'd like to delay handling of that error to -/// serialization time +/// serialization time #[derive(Clone, Debug, PartialEq, Eq)] pub struct RawJson { value: String, diff --git a/libs/query-engine-common/Cargo.toml b/libs/query-engine-common/Cargo.toml index daf41ba50f66..258639d2d943 100644 --- a/libs/query-engine-common/Cargo.toml +++ b/libs/query-engine-common/Cargo.toml @@ -8,6 +8,7 @@ thiserror = "1" url.workspace = true query-connector = { path = "../../query-engine/connectors/query-connector" } query-core = { path = "../../query-engine/core" } +telemetry = { path = "../telemetry" } user-facing-errors = { path = "../user-facing-errors" } serde_json.workspace = true serde.workspace = true @@ -16,12 +17,12 @@ psl.workspace = true async-trait.workspace = true tracing.workspace = true tracing-subscriber = { version = "0.3" } -tracing-futures = "0.2" +tracing-futures.workspace = true tracing-opentelemetry = "0.17.3" opentelemetry = { version = "0.17" } [target.'cfg(all(not(target_arch = "wasm32")))'.dependencies] -query-engine-metrics = { path = "../../query-engine/metrics" } +prisma-metrics.path = "../metrics" napi.workspace = true [target.'cfg(target_arch = "wasm32")'.dependencies] diff --git a/libs/query-engine-common/src/engine.rs b/libs/query-engine-common/src/engine.rs index 5129fca185c4..91ddc2acb6df 100644 --- a/libs/query-engine-common/src/engine.rs +++ b/libs/query-engine-common/src/engine.rs @@ -59,7 +59,7 @@ pub struct EngineBuilder { pub struct ConnectedEngineNative { pub config_dir: PathBuf, pub env: HashMap, - pub metrics: Option, + pub metrics: Option, } /// Internal structure for querying and reconnecting with the engine. diff --git a/libs/query-engine-common/src/tracer.rs b/libs/query-engine-common/src/tracer.rs index 19d17cf13a05..256ec95c172f 100644 --- a/libs/query-engine-common/src/tracer.rs +++ b/libs/query-engine-common/src/tracer.rs @@ -8,7 +8,6 @@ use opentelemetry::{ }, trace::{TraceError, TracerProvider}, }; -use query_core::telemetry; use std::fmt::{self, Debug}; /// Pipeline builder diff --git a/libs/telemetry/Cargo.toml b/libs/telemetry/Cargo.toml new file mode 100644 index 000000000000..4b9f9b79dacc --- /dev/null +++ b/libs/telemetry/Cargo.toml @@ -0,0 +1,33 @@ +[package] +edition = "2021" +name = "telemetry" +version = "0.1.0" + +[features] +metrics = ["dep:prisma-metrics"] + +[dependencies] +async-trait.workspace = true +crossbeam-channel = "0.5.6" +psl.workspace = true +futures = "0.3" +indexmap.workspace = true +itertools.workspace = true +once_cell = "1" +opentelemetry = { version = "0.17.0", features = ["rt-tokio", "serialize"] } +rand.workspace = true +serde.workspace = true +serde_json.workspace = true +thiserror = "1.0" +tokio = { version = "1.0", features = ["macros", "time"] } +tracing = { workspace = true, features = ["attributes"] } +tracing-futures = "0.2" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +tracing-opentelemetry = "0.17.4" +uuid.workspace = true +cuid = { git = "https://github.com/prisma/cuid-rust", branch = "wasm32-support" } +crosstarget-utils = { path = "../crosstarget-utils" } +lru = "0.7.7" +enumflags2.workspace = true +derive_more = "0.99.17" +prisma-metrics = { path = "../metrics", optional = true } diff --git a/query-engine/core/src/telemetry/capturing/capturer.rs b/libs/telemetry/src/capturing/capturer.rs similarity index 95% rename from query-engine/core/src/telemetry/capturing/capturer.rs rename to libs/telemetry/src/capturing/capturer.rs index d0d9886acd2b..a978b0766c83 100644 --- a/query-engine/core/src/telemetry/capturing/capturer.rs +++ b/libs/telemetry/src/capturing/capturer.rs @@ -29,6 +29,20 @@ impl Capturer { Self::Disabled } + + pub async fn try_start_capturing(&self) { + if let Capturer::Enabled(capturer) = self { + capturer.start_capturing().await + } + } + + pub async fn try_fetch_captures(&self) -> Option { + if let Capturer::Enabled(capturer) = self { + capturer.fetch_captures().await + } else { + None + } + } } #[derive(Debug, Clone)] @@ -92,7 +106,7 @@ impl SpanProcessor for Processor { /// mongo / relational, the information to build this kind of log event is logged diffeerently in /// the server. /// - /// In the case of the of relational databaes --queried through sql_query_connector and eventually + /// In the case of the of relational database --queried through sql_query_connector and eventually /// through quaint, a trace span describes the query-- `TraceSpan::represents_query_event` /// determines if a span represents a query event. /// diff --git a/query-engine/core/src/telemetry/capturing/helpers.rs b/libs/telemetry/src/capturing/helpers.rs similarity index 100% rename from query-engine/core/src/telemetry/capturing/helpers.rs rename to libs/telemetry/src/capturing/helpers.rs diff --git a/query-engine/core/src/telemetry/capturing/mod.rs b/libs/telemetry/src/capturing/mod.rs similarity index 70% rename from query-engine/core/src/telemetry/capturing/mod.rs rename to libs/telemetry/src/capturing/mod.rs index bbdc6ae9a083..0fdf711afb47 100644 --- a/query-engine/core/src/telemetry/capturing/mod.rs +++ b/libs/telemetry/src/capturing/mod.rs @@ -4,148 +4,150 @@ //! The interaction diagram below (soorry width!) shows the different roles at play during telemetry //! capturing. A textual explanatation follows it. For the sake of example a server environment //! --the query-engine crate-- is assumed. -//! # ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ -//! # -//! # │ <> │ -//! # -//! # ╔═══════════════════════╗ │╔═══════════════╗ │ -//! # ║<>║ ║ <> ║ ╔════════════════╗ ╔═══════════════════╗ -//! # ┌───────────────────┐ ║ PROCESSOR ║ │║ Sender ║ ║ Storage ║│ ║ TRACER ║ -//! # │ Server │ ╚═══════════╦═══════════╝ ╚══════╦════════╝ ╚═══════╦════════╝ ╚═════════╦═════════╝ -//! # └─────────┬─────────┘ │ │ │ │ │ │ -//! # │ │ │ │ │ -//! # │ │ │ │ │ │ │ -//! # POST │ │ │ │ │ -//! # (body, headers)│ │ │ │ │ │ │ -//! # ──────────▶┌┴┐ │ │ │ │ -//! # ┌─┐ │ │new(headers)╔════════════╗ │ │ │ │ │ │ -//! # │1│ │ ├───────────▶║s: Settings ║ │ │ │ │ -//! # └─┘ │ │ ╚════════════╝ │ │ │ │ │ │ -//! # │ │ │ │ │ │ -//! # │ │ ╔═══════════════════╗ │ │ │ │ │ │ -//! # │ │ ║ Capturer::Enabled ║ │ │ │ │ ┌────────────┐ -//! # │ │ ╚═══════════════════╝ │ │ │ │ │ │ │<│ -//! # │ │ │ │ │ │ │ └──────┬─────┘ -//! # │ │ ┌─┐ new(trace_id, s) │ │ │ │ │ │ │ │ -//! # │ ├───┤2├───────────────────────▶│ │ │ │ │ │ -//! # │ │ └─┘ │ │ │ │ │ │ │ │ -//! # │ │ │ │ │ │ │ │ -//! # │ │ ┌─┐ start_capturing() │ start_capturing │ │ │ │ │ │ │ -//! # │ ├───┤3├───────────────────────▶│ (trace_id, s) │ │ │ │ │ -//! # │ │ └─┘ │ │ │ │ │ │ │ │ -//! # │ │ ├─────────────────────▶│ send(StartCapturing, │ │ │ │ -//! # │ │ │ │ trace_id)│ │ │ │ │ │ -//! # │ │ │ │── ── ── ── ── ── ── ─▶│ │ │ │ -//! # │ │ │ │ ┌─┐ │ │insert(trace_id, s) │ │ │ │ -//! # │ │ │ │ │4│ │────────────────────▶│ │ │ -//! # │ │ │ │ └─┘ │ │ │ │ ┌─┐ │ process_query │ -//! # │ │──────────────────────────────┼──────────────────────┼───────────────────────┼─────────────────────┼────────────┤5├──────┼──────────────────────────▶┌┴┐ -//! # │ │ │ │ │ │ │ │ └─┘ │ │ │ -//! # │ │ │ │ │ │ │ │ │ +//! # ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +//! # +//! # │ <> │ +//! # +//! # ╔═══════════════════════╗ │╔═══════════════╗ │ +//! # ║<>║ ║ <> ║ ╔════════════════╗ ╔═══════════════════╗ +//! # ┌───────────────────┐ ║ PROCESSOR ║ │║ Sender ║ ║ Storage ║│ ║ TRACER ║ +//! # │ Server │ ╚═══════════╦═══════════╝ ╚══════╦════════╝ ╚═══════╦════════╝ ╚═════════╦═════════╝ +//! # └─────────┬─────────┘ │ │ │ │ │ │ +//! # │ │ │ │ │ +//! # │ │ │ │ │ │ │ +//! # POST │ │ │ │ │ +//! # (body, headers)│ │ │ │ │ │ │ +//! # ──────────▶┌┴┐ │ │ │ │ +//! # ┌─┐ │ │new(headers)╔════════════╗ │ │ │ │ │ │ +//! # │1│ │ ├───────────▶║s: Settings ║ │ │ │ │ +//! # └─┘ │ │ ╚════════════╝ │ │ │ │ │ │ +//! # │ │ │ │ │ │ +//! # │ │ ╔═══════════════════╗ │ │ │ │ │ │ +//! # │ │ ║ Capturer::Enabled ║ │ │ │ │ ┌────────────┐ +//! # │ │ ╚═══════════════════╝ │ │ │ │ │ │ │<│ +//! # │ │ │ │ │ │ │ └──────┬─────┘ +//! # │ │ ┌─┐ new(trace_id, s) │ │ │ │ │ │ │ │ +//! # │ ├───┤2├───────────────────────▶│ │ │ │ │ │ +//! # │ │ └─┘ │ │ │ │ │ │ │ │ +//! # │ │ │ │ │ │ │ │ +//! # │ │ ┌─┐ start_capturing() │ start_capturing │ │ │ │ │ │ │ +//! # │ ├───┤3├───────────────────────▶│ (trace_id, s) │ │ │ │ │ +//! # │ │ └─┘ │ │ │ │ │ │ │ │ +//! # │ │ ├─────────────────────▶│ send(StartCapturing, │ │ │ │ +//! # │ │ │ │ trace_id)│ │ │ │ │ │ +//! # │ │ │ │── ── ── ── ── ── ── ─▶│ │ │ │ +//! # │ │ │ │ ┌─┐ │ │insert(trace_id, s) │ │ │ │ +//! # │ │ │ │ │4│ │────────────────────▶│ │ │ +//! # │ │ │ │ └─┘ │ │ │ │ ┌─┐ │ process_query │ +//! # │ │──────────────────────────────┼──────────────────────┼───────────────────────┼─────────────────────┼────────────┤5├──────┼──────────────────────────▶┌┴┐ +//! # │ │ │ │ │ │ │ │ └─┘ │ │ │ +//! # │ │ │ │ │ │ │ │ │ //! # │ │ │ │ │ │ │ │ │ │ │ ┌─────────────────────┐ //! # │ │ │ │ │ │ │ log! / span! ┌─┐ │ │ │ res: PrismaResponse │ //! # │ │ │ │ │ │ │ │ │◀─────────────────────┤6├──│ │ └──────────┬──────────┘ -//! # │ │ │ │ │ on_end(span_data)│ ┌─┐ │ └─┘ │ │ new │ -//! # │ │ │ │◀──────────────┼───────┼─────────────────────┼─────────┼──┤7├──────┤ │ │────────────▶│ -//! # │ │ │ │ send(SpanDataProcessed│ │ └─┘ │ │ │ │ -//! # │ │ │ │ , trace_id) │ append(trace_id, │ │ │ │ │ │ -//! # │ │ │ │── ── ── ── ── ── ── ─▶│ logs, traces) │ │ │ │ │ -//! # │ │ │ │ ┌─┐ │ ├────────────────────▶│ │ │ │ │ │ -//! # │ │ │ │ │8│ │ │ │ │ │ │ -//! # │ │ res: PrismaResponse │ ┌─┐ │ └─┘ │ │ │ │ │ │ │ │ -//! # │ │─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┼ ┤9├ ─ ─ ─ ─ ─ ─ ─ ─ ─│─ ─ ─ ─ ─ ─return ─ ─ ─│─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─│─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─│─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─└┬┘ │ -//! # │ │ ┌────┐ fetch_captures() │ └─┘ │ │ │ │ │ │ │ │ -//! # │ ├─┤ 10 ├──────────────────────▶│ fetch_captures │ │ │ │ │ │ -//! # │ │ └────┘ │ (trace_id) │ │ │ │ │ │ │ │ -//! # │ │ ├─────────────────────▶│ send(FetchCaptures, │ │ │ x │ -//! # │ │ │ │ trace_id) │ │ │ │ │ -//! # │ │ │ │── ── ── ── ── ── ── ─▶│ get logs/traces │ │ │ -//! # │ │ │ │ ┌────┐ │ ├─────────────────────▶ │ │ │ -//! # │ │ │ │ │ 11 │ │ │ │ │ -//! # │ │ │ │ └────┘ │ │◁ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─│ │ │ │ -//! # │ │ │ │ │ │ │ │ -//! # │ │ ◁ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─│ │ │ │ │ │ │ -//! # │ │ logs, traces │ │ │ │ │ │ -//! # │ │◁─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─│ │ │ │ │ │ │ │ -//! # │ │ x ┌────┐ │ │ │ │ res.set_extension(logs) │ -//! # │ ├───────────────────────────────────────┤ 12 ├────────┼───────────────┼───────┼─────────────────────┼─────────┼───────────┼──────────────────────────────────────────▶│ -//! # │ │ └────┘ │ │ │ │ res.set_extension(traces) │ -//! # │ ├─────────────────────────────────────────────────────┼───────────────┼───────┼─────────────────────┼─────────┼───────────┼──────────────────────────────────────────▶│ -//! # ◀ ─ ─ ─└┬┘ │ │ │ │ x -//! # json!(res) │ │ │ -//! # ┌────┐ │ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ -//! # │ 13 │ │ -//! # └────┘ -//! # -//! # ◀─────── call (pseudo-signatures) -//! # -//! # ◀─ ── ── async message passing (channels) -//! # -//! # ◁─ ─ ─ ─ return -//! # -//! +//! # │ │ │ │ │ on_end(span_data)│ ┌─┐ │ └─┘ │ │ new │ +//! # │ │ │ │◀──────────────┼───────┼─────────────────────┼─────────┼──┤7├──────┤ │ │────────────▶│ +//! # │ │ │ │ send(SpanDataProcessed│ │ └─┘ │ │ │ │ +//! # │ │ │ │ , trace_id) │ append(trace_id, │ │ │ │ │ │ +//! # │ │ │ │── ── ── ── ── ── ── ─▶│ logs, traces) │ │ │ │ │ +//! # │ │ │ │ ┌─┐ │ ├────────────────────▶│ │ │ │ │ │ +//! # │ │ │ │ │8│ │ │ │ │ │ │ +//! # │ │ res: PrismaResponse │ ┌─┐ │ └─┘ │ │ │ │ │ │ │ │ +//! # │ │─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┼ ┤9├ ─ ─ ─ ─ ─ ─ ─ ─ ─│─ ─ ─ ─ ─ ─return ─ ─ ─│─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─│─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─│─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─└┬┘ │ +//! # │ │ ┌────┐ fetch_captures() │ └─┘ │ │ │ │ │ │ │ │ +//! # │ ├─┤ 10 ├──────────────────────▶│ fetch_captures │ │ │ │ │ │ +//! # │ │ └────┘ │ (trace_id) │ │ │ │ │ │ │ │ +//! # │ │ ├─────────────────────▶│ send(FetchCaptures, │ │ │ x │ +//! # │ │ │ │ trace_id) │ │ │ │ │ +//! # │ │ │ │── ── ── ── ── ── ── ─▶│ get logs/traces │ │ │ +//! # │ │ │ │ ┌────┐ │ ├─────────────────────▶ │ │ │ +//! # │ │ │ │ │ 11 │ │ │ │ │ +//! # │ │ │ │ └────┘ │ │◁ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─│ │ │ │ +//! # │ │ │ │ │ │ │ │ +//! # │ │ ◁ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─│ │ │ │ │ │ │ +//! # │ │ logs, traces │ │ │ │ │ │ +//! # │ │◁─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─│ │ │ │ │ │ │ │ +//! # │ │ x ┌────┐ │ │ │ │ res.set_extension(logs) │ +//! # │ ├───────────────────────────────────────┤ 12 ├────────┼───────────────┼───────┼─────────────────────┼─────────┼───────────┼──────────────────────────────────────────▶│ +//! # │ │ └────┘ │ │ │ │ res.set_extension(traces) │ +//! # │ ├─────────────────────────────────────────────────────┼───────────────┼───────┼─────────────────────┼─────────┼───────────┼──────────────────────────────────────────▶│ +//! # ◀ ─ ─ ─└┬┘ │ │ │ │ x +//! # json!(res) │ │ │ +//! # ┌────┐ │ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +//! # │ 13 │ │ +//! # └────┘ +//! # +//! # ◀─────── call (pseudo-signatures) +//! # +//! # ◀─ ── ── async message passing (channels) +//! # +//! # ◁─ ─ ─ ─ return +//! # +//! //! In the diagram, you will see objects whose lifetime is static. The boxes for those have a double //! width margin. These are: -//! +//! //! - The `server` itself //! - The global `TRACER`, which handles `log!` and `span!` and uses the global `PROCESSOR` to -//! process the data constituting a trace `Span`s and log `Event`s +//! process the data constituting a trace `Span`s and log `Event`s //! - The global `PROCESSOR`, which manages the `Storage` set of data structures, holding logs, -//! traces (and capture settings) per request. -//! +//! traces (and capture settings) per request. +//! //! Then, through the request lifecycle, different objects are created and dropped: -//! +//! //! - When a request comes in, its headers are processed and a [`Settings`] object is built, this -//! object determines, for the request, how logging and tracing are going to be captured: if only -//! traces, logs, or both, and which log levels are going to be captured. +//! object determines, for the request, how logging and tracing are going to be captured: if only +//! traces, logs, or both, and which log levels are going to be captured. //! - Based on the settings, a new `Capturer` is created; a capturer is nothing but an exporter -//! wrapped to start capturing / fetch the captures for this particular request. +//! wrapped to start capturing / fetch the captures for this particular request. //! - An asynchronous task is spawned to own the storage of telemetry data without needing to share -//! memory accross threads. Communication with this task is done through channels. The `Sender` -//! part of the channel is kept in a global, so it can be cloned and used by a) the Capturer -//! (to start capturing / fetch the captures) or by the tracer's SpanProcessor, to extract -//! tracing and logging information that's eventually displayed to the user. -//! +//! memory accross threads. Communication with this task is done through channels. The `Sender` +//! part of the channel is kept in a global, so it can be cloned and used by a) the Capturer +//! (to start capturing / fetch the captures) or by the tracer's SpanProcessor, to extract +//! tracing and logging information that's eventually displayed to the user. +//! //! Then the capturing process works in this way: -//! +//! //! - The server receives a query **[1]** //! - It grabs the HTTP headers and builds a `Capture` object **[2]**, which is configured with the settings //! denoted by the `X-capture-telemetry` //! - Now the server tells the `Capturer` to start capturing all the logs and traces occurring on -//! the request **[3]** (denoted by a `trace_id`) The `trace_id` is either carried on the `traceparent` -//! header or implicitly created on the first span of the request. +//! the request **[3]** (denoted by a `trace_id`) The `trace_id` is either carried on the `traceparent` +//! header or implicitly created on the first span of the request. //! - The `Capturer` sends a message to the task owning the storage to start capturing **[4]**. //! The tasks creates a new entry in the storage for the given trace_id. Spans without a -//! corresponding trace_id in the storage are ignored. +//! corresponding trace_id in the storage are ignored. //! - The server dispatches the request and _Somewhere_ else in the code, it is processed **[5]**. //! - There the code logs events and emits traces asynchronously, as part of the processing **[6]** //! - Traces and Logs arrive at the `TRACER`, and get hydrated as SpanData in the `PROCESSOR` -//! **[7]**. +//! **[7]**. //! - This SpanData is sent through a channel to the task running in parallel, **[8]**. -//! The task transforms the SpanData into `TraceSpans` and `LogEvents` depending on the capture -//! settings and stores those spans and events in the storage. +//! The task transforms the SpanData into `TraceSpans` and `LogEvents` depending on the capture +//! settings and stores those spans and events in the storage. //! - When the code that dispatches the request is done it returns a `PrismaResponse` to the -//! server **[9]**. +//! server **[9]**. //! - Then the server asks the `PROCESSOR` to fetch the captures **[10]** //! - Like before, the `PROCESSOR` sends a message to the task running in parallel, -//! to fetch the captures from the `Storage` **[11]**. At that time, although -//! that's not represented in the diagram, the captures are deleted from the storage, thus -//! freeing any memory used for capturing during the request +//! to fetch the captures from the `Storage` **[11]**. At that time, although +//! that's not represented in the diagram, the captures are deleted from the storage, thus +//! freeing any memory used for capturing during the request //! - Finally, the server sets the `logs` and `traces` extensions in the `PrismaResponse`**[12]**, -//! it serializes the extended response in json format and returns it as an HTTP Response -//! blob **[13]**. +//! it serializes the extended response in json format and returns it as an HTTP Response +//! blob **[13]**. //! #![allow(unused_imports, dead_code)] pub use self::capturer::Capturer; pub use self::settings::Settings; -pub use tx_ext::TxTraceExt; use self::capturer::Processor; use once_cell::sync::Lazy; use opentelemetry::{global, sdk, trace}; use tracing::subscriber; use tracing_subscriber::{ - filter::filter_fn, layer::Layered, prelude::__tracing_subscriber_SubscriberExt, Layer, Registry, + filter::filter_fn, + layer::{Layered, SubscriberExt}, + registry::LookupSpan, + Layer, Registry, }; static PROCESSOR: Lazy = Lazy::new(Processor::default); @@ -159,12 +161,8 @@ pub fn capturer(trace_id: trace::TraceId, settings: Settings) -> Capturer { /// Adds a capturing layer to the given subscriber and installs the transformed subscriber as the /// global, default subscriber #[cfg(feature = "metrics")] -#[allow(clippy::type_complexity)] pub fn install_capturing_layer( - subscriber: Layered< - Option, - Layered + Send + Sync>, Registry>, - >, + subscriber: impl SubscriberExt + for<'a> LookupSpan<'a> + Send + Sync + 'static, log_queries: bool, ) { // set a trace context propagator, so that the trace context is propagated via the @@ -198,4 +196,3 @@ mod capturer; mod helpers; mod settings; pub mod storage; -mod tx_ext; diff --git a/query-engine/core/src/telemetry/capturing/settings.rs b/libs/telemetry/src/capturing/settings.rs similarity index 100% rename from query-engine/core/src/telemetry/capturing/settings.rs rename to libs/telemetry/src/capturing/settings.rs diff --git a/query-engine/core/src/telemetry/capturing/storage.rs b/libs/telemetry/src/capturing/storage.rs similarity index 92% rename from query-engine/core/src/telemetry/capturing/storage.rs rename to libs/telemetry/src/capturing/storage.rs index 9c2767169112..5c83affc85e3 100644 --- a/query-engine/core/src/telemetry/capturing/storage.rs +++ b/libs/telemetry/src/capturing/storage.rs @@ -1,5 +1,5 @@ use super::settings::Settings; -use crate::telemetry::models; +use crate::models; #[derive(Debug, Default)] pub struct Storage { diff --git a/libs/telemetry/src/helpers.rs b/libs/telemetry/src/helpers.rs new file mode 100644 index 000000000000..4a332e86af63 --- /dev/null +++ b/libs/telemetry/src/helpers.rs @@ -0,0 +1,178 @@ +use super::models::TraceSpan; +use derive_more::Display; +use once_cell::sync::Lazy; +use opentelemetry::propagation::Extractor; +use opentelemetry::sdk::export::trace::SpanData; +use opentelemetry::trace::{SpanId, TraceContextExt, TraceFlags, TraceId}; +use serde_json::{json, Value}; +use std::collections::HashMap; +use tracing::Metadata; +use tracing_subscriber::EnvFilter; + +pub static SHOW_ALL_TRACES: Lazy = Lazy::new(|| match std::env::var("PRISMA_SHOW_ALL_TRACES") { + Ok(enabled) => enabled.eq_ignore_ascii_case("true"), + Err(_) => false, +}); + +/// `TraceParent` is a remote span. It is identified by `trace_id` and `span_id`. +/// +/// By "remote" we mean that this span was not emitted in the current process. In real life, it is +/// either: +/// - Emitted by the JS part of the Prisma ORM. This is true both for Accelerate (where the Rust +/// part is deployed as a server) and for the ORM (where the Rust part is a shared library) +/// - Never emitted at all. This happens when the `TraceParent` is created artificially from `TxId` +/// (see `TxId::as_traceparent`). In this case, `TraceParent` is used only to correlate logs +/// from different transaction operations - it is never used as a part of the trace +#[derive(Display, Copy, Clone)] +// This conforms with https://www.w3.org/TR/trace-context/#traceparent-header-field-values. Accelerate +// relies on this behaviour. +#[display(fmt = "00-{trace_id:032x}-{span_id:016x}-{flags:02x}")] +pub struct TraceParent { + trace_id: TraceId, + span_id: SpanId, + flags: TraceFlags, +} + +impl TraceParent { + pub fn from_remote_context(context: &opentelemetry::Context) -> Option { + let span = context.span(); + let span_context = span.span_context(); + + if span_context.is_valid() { + Some(Self { + trace_id: span_context.trace_id(), + span_id: span_context.span_id(), + flags: span_context.trace_flags(), + }) + } else { + None + } + } + + // TODO(aqrln): remove this method once the log capturing doesn't rely on trace IDs anymore + #[deprecated = "this must only be used to create an artificial traceparent for log capturing when tracing is disabled on the client"] + pub fn new_random() -> Self { + Self { + trace_id: TraceId::from_bytes(rand::random()), + span_id: SpanId::from_bytes(rand::random()), + flags: TraceFlags::SAMPLED, + } + } + + pub fn trace_id(&self) -> TraceId { + self.trace_id + } + + pub fn sampled(&self) -> bool { + self.flags.is_sampled() + } + + /// Returns a remote `opentelemetry::Context`. By "remote" we mean that it wasn't emitted in the + /// current process. + pub fn to_remote_context(&self) -> opentelemetry::Context { + // This relies on the fact that global text map propagator was installed that + // can handle `traceparent` field (for example, `TraceContextPropagator`). + opentelemetry::global::get_text_map_propagator(|propagator| { + propagator.extract(&TraceParentExtractor::new(self)) + }) + } +} + +/// An extractor to use with `TraceContextPropagator`. It allows to avoid creating a full `HashMap` +/// to convert a `TraceParent` to a `Context`. +pub struct TraceParentExtractor(String); + +impl TraceParentExtractor { + pub fn new(traceparent: &TraceParent) -> Self { + Self(traceparent.to_string()) + } +} + +impl Extractor for TraceParentExtractor { + fn get(&self, key: &str) -> Option<&str> { + if key == "traceparent" { + Some(&self.0) + } else { + None + } + } + + fn keys(&self) -> Vec<&str> { + vec!["traceparent"] + } +} + +pub fn spans_to_json(spans: Vec) -> String { + let json_spans: Vec = spans.into_iter().map(|span| json!(TraceSpan::from(span))).collect(); + let span_result = json!({ + "span": true, + "spans": json_spans + }); + serde_json::to_string(&span_result).unwrap_or_default() +} + +pub fn restore_remote_context_from_json_str(serialized: &str) -> opentelemetry::Context { + // This relies on the fact that global text map propagator was installed that + // can handle `traceparent` field (for example, `TraceContextPropagator`). + let trace: HashMap = serde_json::from_str(serialized).unwrap_or_default(); + opentelemetry::global::get_text_map_propagator(|propagator| propagator.extract(&trace)) +} + +pub enum QueryEngineLogLevel { + FromEnv, + Override(String), +} + +impl QueryEngineLogLevel { + fn level(self) -> Option { + match self { + Self::FromEnv => std::env::var("QE_LOG_LEVEL").ok(), + Self::Override(l) => Some(l), + } + } +} + +#[rustfmt::skip] +pub fn env_filter(log_queries: bool, qe_log_level: QueryEngineLogLevel) -> EnvFilter { + let mut filter = EnvFilter::from_default_env() + .add_directive("tide=error".parse().unwrap()) + .add_directive("tonic=error".parse().unwrap()) + .add_directive("h2=error".parse().unwrap()) + .add_directive("hyper=error".parse().unwrap()) + .add_directive("tower=error".parse().unwrap()); + + if let Some(ref level) = qe_log_level.level() { + filter = filter + .add_directive(format!("query_engine={}", level).parse().unwrap()) + .add_directive(format!("query_core={}", level).parse().unwrap()) + .add_directive(format!("query_connector={}", level).parse().unwrap()) + .add_directive(format!("sql_query_connector={}", level).parse().unwrap()) + .add_directive(format!("mongodb_query_connector={}", level).parse().unwrap()); + } + + if log_queries { + filter = filter + .add_directive("quaint[{is_query}]=trace".parse().unwrap()) + .add_directive("mongodb_query_connector=debug".parse().unwrap()); + } + + filter +} + +pub fn user_facing_span_only_filter(meta: &Metadata<'_>) -> bool { + if !meta.is_span() { + return false; + } + + if *SHOW_ALL_TRACES { + return true; + } + + if meta.fields().iter().any(|f| f.name() == "user_facing") { + return true; + } + + // spans describing a quaint query. + // TODO: should this span be made user_facing in quaint? + meta.target() == "quaint::connector::metrics" && meta.name() == "quaint:query" +} diff --git a/query-engine/core/src/telemetry/mod.rs b/libs/telemetry/src/lib.rs similarity index 100% rename from query-engine/core/src/telemetry/mod.rs rename to libs/telemetry/src/lib.rs diff --git a/query-engine/core/src/telemetry/models.rs b/libs/telemetry/src/models.rs similarity index 93% rename from query-engine/core/src/telemetry/models.rs rename to libs/telemetry/src/models.rs index c1e9ff0158b9..275ec5e56930 100644 --- a/query-engine/core/src/telemetry/models.rs +++ b/libs/telemetry/src/models.rs @@ -7,7 +7,21 @@ use std::{ time::{Duration, SystemTime}, }; -const ACCEPT_ATTRIBUTES: &[&str] = &["db.statement", "itx_id", "db.type"]; +const ACCEPT_ATTRIBUTES: &[&str] = &[ + "db.system", + "db.statement", + "db.collection.name", + "db.operation.name", + "itx_id", +]; + +#[derive(Serialize, Debug, Clone, PartialEq, Eq)] +pub enum SpanKind { + #[serde(rename = "client")] + Client, + #[serde(rename = "internal")] + Internal, +} #[derive(Serialize, Debug, Clone, PartialEq, Eq)] pub struct TraceSpan { @@ -23,6 +37,7 @@ pub struct TraceSpan { pub(super) events: Vec, #[serde(skip_serializing_if = "Vec::is_empty")] pub(super) links: Vec, + pub(super) kind: SpanKind, } #[derive(Serialize, Debug, Clone, PartialEq, Eq)] @@ -39,6 +54,11 @@ impl TraceSpan { impl From for TraceSpan { fn from(span: SpanData) -> Self { + let kind = match span.span_kind { + opentelemetry::trace::SpanKind::Client => SpanKind::Client, + _ => SpanKind::Internal, + }; + let attributes: HashMap = span.attributes .iter() @@ -105,6 +125,7 @@ impl From for TraceSpan { attributes, links, events, + kind, } } } diff --git a/libs/test-cli/Cargo.toml b/libs/test-cli/Cargo.toml index 936ff3d9ee46..48c1f3170671 100644 --- a/libs/test-cli/Cargo.toml +++ b/libs/test-cli/Cargo.toml @@ -18,3 +18,6 @@ tracing.workspace = true tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-error = "0.2" async-trait.workspace = true + +[build-dependencies] +build-utils.path = "../build-utils" diff --git a/libs/test-cli/build.rs b/libs/test-cli/build.rs index 9bd10ecb9c58..33aded23a4a5 100644 --- a/libs/test-cli/build.rs +++ b/libs/test-cli/build.rs @@ -1,7 +1,3 @@ -use std::process::Command; - fn main() { - let output = Command::new("git").args(["rev-parse", "HEAD"]).output().unwrap(); - let git_hash = String::from_utf8(output.stdout).unwrap(); - println!("cargo:rustc-env=GIT_HASH={git_hash}"); + build_utils::store_git_commit_hash_in_env(); } diff --git a/libs/test-cli/src/main.rs b/libs/test-cli/src/main.rs index 89dbad852060..1341e9258ab8 100644 --- a/libs/test-cli/src/main.rs +++ b/libs/test-cli/src/main.rs @@ -4,6 +4,7 @@ mod diagnose_migration_history; use anyhow::Context; use colored::Colorize; +use psl::parse_configuration; use schema_connector::BoxFuture; use schema_core::json_rpc::types::*; use std::{fmt, fs::File, io::Read, str::FromStr, sync::Arc}; @@ -24,6 +25,17 @@ enum Command { #[structopt(long)] composite_type_depth: Option, }, + /// Parse SQL queries and returns type information. + IntrospectSql { + /// URL of the database to introspect. + #[structopt(long)] + url: Option, + /// Path to the schema file to introspect for. + #[structopt(long = "file-path")] + file_path: Option, + /// The SQL query to introspect. + query_file_path: String, + }, /// Generate DMMF from a schema, or directly from a database URL. Dmmf(DmmfCommand), /// Push a prisma schema directly to the database. @@ -184,25 +196,38 @@ async fn main() -> anyhow::Result<()> { Command::Dmmf(cmd) => generate_dmmf(&cmd).await?, Command::SchemaPush(cmd) => schema_push(&cmd).await?, Command::MigrateDiff(cmd) => migrate_diff(&cmd).await?, + Command::IntrospectSql { + url, + file_path, + query_file_path, + } => { + let schema = schema_from_args(url.as_deref(), file_path.as_deref())?; + let config = parse_configuration(&schema).unwrap(); + let api = schema_core::schema_api(Some(schema.clone()), None)?; + let query_str = std::fs::read_to_string(query_file_path)?; + + let res = api + .introspect_sql(IntrospectSqlParams { + url: config + .first_datasource() + .load_url(|key| std::env::var(key).ok()) + .unwrap(), + queries: vec![SqlQueryInput { + name: "query".to_string(), + source: query_str, + }], + }) + .await + .map_err(|err| anyhow::anyhow!("{err:?}"))?; + + println!("{}", serde_json::to_string_pretty(&res).unwrap()); + } Command::Introspect { url, file_path, composite_type_depth, } => { - if url.as_ref().xor(file_path.as_ref()).is_none() { - anyhow::bail!( - "{}", - "Exactly one of --url or --file-path must be provided".bold().red() - ); - } - - let schema = if let Some(file_path) = &file_path { - read_datamodel_from_file(file_path)? - } else if let Some(url) = &url { - minimal_schema_from_url(url)? - } else { - unreachable!() - }; + let schema = schema_from_args(url.as_deref(), file_path.as_deref())?; let base_directory_path = file_path .as_ref() @@ -292,6 +317,20 @@ async fn main() -> anyhow::Result<()> { Ok(()) } +fn schema_from_args(url: Option<&str>, file_path: Option<&str>) -> anyhow::Result { + if let Some(url) = url { + let schema = minimal_schema_from_url(url)?; + + Ok(schema) + } else if let Some(file_path) = file_path { + let schema = read_datamodel_from_file(file_path)?; + + Ok(schema) + } else { + anyhow::bail!("Please provide one of --url or --file-path") + } +} + fn read_datamodel_from_file(path: &str) -> std::io::Result { use std::path::Path; diff --git a/libs/test-macros/Cargo.toml b/libs/test-macros/Cargo.toml index 1d13b8029c05..eaeedc45a9b5 100644 --- a/libs/test-macros/Cargo.toml +++ b/libs/test-macros/Cargo.toml @@ -9,4 +9,4 @@ proc-macro = true [dependencies] proc-macro2 = "1.0.26" quote = "1.0.2" -syn = "1.0.5" +syn = { version = "1.0.5", features = ["full"] } diff --git a/libs/user-facing-errors/src/lib.rs b/libs/user-facing-errors/src/lib.rs index 7d7856831637..a1916e55162b 100644 --- a/libs/user-facing-errors/src/lib.rs +++ b/libs/user-facing-errors/src/lib.rs @@ -119,9 +119,9 @@ impl Error { } } - /// Construct a new UnknownError from a `PanicInfo` in a panic hook. `UnknownError`s created - /// with this constructor will have a proper, useful backtrace. - pub fn new_in_panic_hook(panic_info: &std::panic::PanicInfo<'_>) -> Self { + /// Construct a new UnknownError from a [`PanicHookInfo`] in a panic hook. [`UnknownError`]s + /// created with this constructor will have a proper, useful backtrace. + pub fn new_in_panic_hook(panic_info: &std::panic::PanicHookInfo<'_>) -> Self { let message = panic_info .payload() .downcast_ref::<&str>() diff --git a/libs/user-facing-errors/src/query_engine/mod.rs b/libs/user-facing-errors/src/query_engine/mod.rs index e42fbcb03f56..804ab2406533 100644 --- a/libs/user-facing-errors/src/query_engine/mod.rs +++ b/libs/user-facing-errors/src/query_engine/mod.rs @@ -68,10 +68,7 @@ pub struct UniqueKeyViolation { } #[derive(Debug, UserFacingError, Serialize)] -#[user_facing( - code = "P2003", - message = "Foreign key constraint failed on the field: `{field_name}`" -)] +#[user_facing(code = "P2003", message = "Foreign key constraint violated: `{field_name}`")] pub struct ForeignKeyViolation { /// Field name from one model from Prisma schema pub field_name: String, diff --git a/libs/user-facing-errors/src/schema_engine.rs b/libs/user-facing-errors/src/schema_engine.rs index 7329461ff2be..a3a81211c12c 100644 --- a/libs/user-facing-errors/src/schema_engine.rs +++ b/libs/user-facing-errors/src/schema_engine.rs @@ -15,25 +15,29 @@ pub struct DatabaseCreationFailed { code = "P3001", message = "Migration possible with destructive changes and possible data loss: {destructive_details}" )] +#[allow(dead_code)] pub struct DestructiveMigrationDetected { pub destructive_details: String, } +/// No longer used. #[derive(Debug, UserFacingError, Serialize)] #[user_facing( code = "P3002", message = "The attempted migration was rolled back: {database_error}" )] +#[allow(dead_code)] struct MigrationRollback { pub database_error: String, } -// No longer used. +/// No longer used. #[derive(Debug, SimpleUserFacingError)] #[user_facing( code = "P3003", message = "The format of migrations changed, the saved migrations are no longer valid. To solve this problem, please follow the steps at: https://pris.ly/d/migrate" )] +#[allow(dead_code)] pub struct DatabaseMigrationFormatChanged; #[derive(Debug, UserFacingError, Serialize)] diff --git a/nix/publish-engine-size.nix b/nix/publish-engine-size.nix index 11a63d7de7e8..7fe34f36d6c7 100644 --- a/nix/publish-engine-size.nix +++ b/nix/publish-engine-size.nix @@ -22,12 +22,15 @@ let craneLib = (flakeInputs.crane.mkLib pkgs).overrideToolchain rustToolchain; deps = craneLib.vendorCargoDeps { inherit src; }; libSuffix = stdenv.hostPlatform.extensions.sharedLibrary; + fakeGitHash = "0000000000000000000000000000000000000000"; in { packages.prisma-engines = stdenv.mkDerivation { name = "prisma-engines"; inherit src; + GIT_HASH = "${fakeGitHash}"; + buildInputs = [ pkgs.openssl.out ]; nativeBuildInputs = with pkgs; [ rustToolchain @@ -38,6 +41,7 @@ in ] ++ lib.optionals stdenv.isDarwin [ perl # required to build openssl darwin.apple_sdk.frameworks.Security + darwin.apple_sdk.frameworks.SystemConfiguration iconv ]; @@ -68,6 +72,8 @@ in inherit src; inherit (self'.packages.prisma-engines) buildInputs nativeBuildInputs configurePhase dontStrip; + GIT_HASH = "${fakeGitHash}"; + buildPhase = "cargo build --profile=${profile} --bin=test-cli"; installPhase = '' @@ -85,6 +91,8 @@ in inherit src; inherit (self'.packages.prisma-engines) buildInputs nativeBuildInputs configurePhase dontStrip; + GIT_HASH = "${fakeGitHash}"; + buildPhase = "cargo build --profile=${profile} --bin=query-engine"; installPhase = '' @@ -105,6 +113,8 @@ in inherit src; inherit (self'.packages.prisma-engines) buildInputs nativeBuildInputs configurePhase dontStrip; + GIT_HASH = "${fakeGitHash}"; + buildPhase = '' cargo build --profile=${profile} --bin=query-engine cargo build --profile=${profile} -p query-engine-node-api @@ -134,6 +144,8 @@ in inherit src; buildInputs = with pkgs; [ iconv ]; + GIT_HASH = "${fakeGitHash}"; + buildPhase = '' export HOME=$(mktemp -dt wasm-engine-home-XXXX) diff --git a/nix/shell.nix b/nix/shell.nix index 14d33f64abfc..f2767cbcf2ce 100644 --- a/nix/shell.nix +++ b/nix/shell.nix @@ -13,7 +13,7 @@ in nodejs_20 nodejs_20.pkgs.typescript-language-server - nodejs_20.pkgs.pnpm + pnpm_8 binaryen cargo-insta diff --git a/prisma-fmt/Cargo.toml b/prisma-fmt/Cargo.toml index 6778573f3a68..18b9a9042c6c 100644 --- a/prisma-fmt/Cargo.toml +++ b/prisma-fmt/Cargo.toml @@ -22,6 +22,9 @@ dissimilar = "1.0.3" once_cell = "1.9.0" expect-test = "1" +[build-dependencies] +build-utils.path = "../libs/build-utils" + [features] # sigh please don't ask :( vendored-openssl = [] diff --git a/prisma-fmt/build.rs b/prisma-fmt/build.rs index 2e8fe20c0503..33aded23a4a5 100644 --- a/prisma-fmt/build.rs +++ b/prisma-fmt/build.rs @@ -1,11 +1,3 @@ -use std::process::Command; - -fn store_git_commit_hash() { - let output = Command::new("git").args(["rev-parse", "HEAD"]).output().unwrap(); - let git_hash = String::from_utf8(output.stdout).unwrap(); - println!("cargo:rustc-env=GIT_HASH={git_hash}"); -} - fn main() { - store_git_commit_hash(); + build_utils::store_git_commit_hash_in_env(); } diff --git a/prisma-fmt/src/get_dmmf.rs b/prisma-fmt/src/get_dmmf.rs index 6f3f03aa4f18..d398b3f131b9 100644 --- a/prisma-fmt/src/get_dmmf.rs +++ b/prisma-fmt/src/get_dmmf.rs @@ -606,7 +606,7 @@ mod tests { "isNullable": false, "inputTypes": [ { - "type": "BRelationFilter", + "type": "BScalarRelationFilter", "namespace": "prisma", "location": "inputObjectTypes", "isList": false @@ -764,7 +764,7 @@ mod tests { "isNullable": false, "inputTypes": [ { - "type": "BRelationFilter", + "type": "BScalarRelationFilter", "namespace": "prisma", "location": "inputObjectTypes", "isList": false @@ -1037,7 +1037,7 @@ mod tests { "isNullable": true, "inputTypes": [ { - "type": "ANullableRelationFilter", + "type": "ANullableScalarRelationFilter", "namespace": "prisma", "location": "inputObjectTypes", "isList": false @@ -1174,7 +1174,7 @@ mod tests { "isNullable": true, "inputTypes": [ { - "type": "ANullableRelationFilter", + "type": "ANullableScalarRelationFilter", "namespace": "prisma", "location": "inputObjectTypes", "isList": false @@ -2037,7 +2037,7 @@ mod tests { ] }, { - "name": "BRelationFilter", + "name": "BScalarRelationFilter", "constraints": { "maxNumFields": null, "minNumFields": null @@ -2436,7 +2436,7 @@ mod tests { ] }, { - "name": "ANullableRelationFilter", + "name": "ANullableScalarRelationFilter", "constraints": { "maxNumFields": null, "minNumFields": null diff --git a/prisma-fmt/src/lib.rs b/prisma-fmt/src/lib.rs index b6b13c47838f..3ec5514313bd 100644 --- a/prisma-fmt/src/lib.rs +++ b/prisma-fmt/src/lib.rs @@ -136,14 +136,15 @@ pub fn hover(schema_files: String, params: &str) -> String { /// The two parameters are: /// - The [`SchemaFileInput`] to reformat, as a string. -/// - An LSP -/// [DocumentFormattingParams](https://github.com/microsoft/language-server-protocol/blob/gh-pages/_specifications/specification-3-16.md#textDocument_formatting) object, as JSON. +/// - An LSP [`DocumentFormattingParams`][1] object, as JSON. /// /// The function returns the formatted schema, as a string. /// If the schema or any of the provided parameters is invalid, the function returns the original schema. /// This function never panics. /// /// Of the DocumentFormattingParams, we only take into account tabSize, at the moment. +/// +/// [1]: https://github.com/microsoft/language-server-protocol/blob/gh-pages/_specifications/specification-3-16.md#textDocument_formatting pub fn format(datamodel: String, params: &str) -> String { let schema: SchemaFileInput = match serde_json::from_str(&datamodel) { Ok(params) => params, diff --git a/prisma-fmt/src/validate.rs b/prisma-fmt/src/validate.rs index 67b12c45ce25..d458d389793d 100644 --- a/prisma-fmt/src/validate.rs +++ b/prisma-fmt/src/validate.rs @@ -59,6 +59,28 @@ mod tests { use super::*; use expect_test::expect; + #[test] + fn validate_non_ascii_identifiers() { + let schema = r#" + datasource db { + provider = "postgresql" + url = env("DBURL") + } + + model Lööps { + id Int @id + läderlappen Boolean + } + "#; + + let request = json!({ + "prismaSchema": schema, + }); + + let response = validate(&request.to_string()); + assert!(response.is_ok()) + } + #[test] fn validate_invalid_schema_with_colors() { let schema = r#" diff --git a/prisma-fmt/tests/code_actions/test_api.rs b/prisma-fmt/tests/code_actions/test_api.rs index b09f517be9c5..95021c73cef7 100644 --- a/prisma-fmt/tests/code_actions/test_api.rs +++ b/prisma-fmt/tests/code_actions/test_api.rs @@ -90,10 +90,7 @@ pub(crate) fn test_scenario(scenario_name: &str) { .as_str() }; - let diagnostics = match parse_schema_diagnostics(&schema_files, initiating_file_name) { - Some(diagnostics) => diagnostics, - None => Vec::new(), - }; + let diagnostics = parse_schema_diagnostics(&schema_files, initiating_file_name).unwrap_or_default(); path.clear(); write!(path, "{SCENARIOS_PATH}/{scenario_name}/result.json").unwrap(); diff --git a/psl/parser-database/src/attributes/default.rs b/psl/parser-database/src/attributes/default.rs index e2be240f152c..d1f6b887fa22 100644 --- a/psl/parser-database/src/attributes/default.rs +++ b/psl/parser-database/src/attributes/default.rs @@ -196,11 +196,12 @@ fn validate_model_builtin_scalar_type_default( { validate_empty_function_args(funcname, &funcargs.arguments, accept, ctx) } - (ScalarType::String, ast::Expression::Function(funcname, funcargs, _)) - if funcname == FN_UUID || funcname == FN_CUID => - { + (ScalarType::String, ast::Expression::Function(funcname, funcargs, _)) if funcname == FN_CUID => { validate_empty_function_args(funcname, &funcargs.arguments, accept, ctx) } + (ScalarType::String, ast::Expression::Function(funcname, funcargs, _)) if funcname == FN_UUID => { + validate_uuid_args(&funcargs.arguments, accept, ctx) + } (ScalarType::String, ast::Expression::Function(funcname, funcargs, _)) if funcname == FN_NANOID => { validate_nanoid_args(&funcargs.arguments, accept, ctx) } @@ -242,11 +243,12 @@ fn validate_composite_builtin_scalar_type_default( ) { match (scalar_type, value) { // Functions - (ScalarType::String, ast::Expression::Function(funcname, funcargs, _)) - if funcname == FN_UUID || funcname == FN_CUID => - { + (ScalarType::String, ast::Expression::Function(funcname, funcargs, _)) if funcname == FN_CUID => { validate_empty_function_args(funcname, &funcargs.arguments, accept, ctx) } + (ScalarType::String, ast::Expression::Function(funcname, funcargs, _)) if funcname == FN_UUID => { + validate_uuid_args(&funcargs.arguments, accept, ctx) + } (ScalarType::DateTime, ast::Expression::Function(funcname, funcargs, _)) if funcname == FN_NOW => { validate_empty_function_args(FN_NOW, &funcargs.arguments, accept, ctx) } @@ -379,6 +381,24 @@ fn validate_dbgenerated_args(args: &[ast::Argument], accept: AcceptFn<'_>, ctx: } } +fn validate_uuid_args(args: &[ast::Argument], accept: AcceptFn<'_>, ctx: &mut Context<'_>) { + let mut bail = || ctx.push_attribute_validation_error("`uuid()` takes a single Int argument."); + + if args.len() > 1 { + bail() + } + + match args.first().map(|arg| &arg.value) { + Some(ast::Expression::NumericValue(val, _)) if ![4u8, 7u8].contains(&val.parse::().unwrap()) => { + ctx.push_attribute_validation_error( + "`uuid()` takes either no argument, or a single integer argument which is either 4 or 7.", + ); + } + None | Some(ast::Expression::NumericValue(_, _)) => accept(ctx), + _ => bail(), + } +} + fn validate_nanoid_args(args: &[ast::Argument], accept: AcceptFn<'_>, ctx: &mut Context<'_>) { let mut bail = || ctx.push_attribute_validation_error("`nanoid()` takes a single Int argument."); diff --git a/psl/parser-database/src/types.rs b/psl/parser-database/src/types.rs index 7d8a0c6a949f..e977c7342b5c 100644 --- a/psl/parser-database/src/types.rs +++ b/psl/parser-database/src/types.rs @@ -1468,7 +1468,8 @@ impl ScalarType { matches!(self, ScalarType::Bytes) } - pub(crate) fn try_from_str(s: &str, ignore_case: bool) -> Option { + /// Tries to parse a scalar type from a string. + pub fn try_from_str(s: &str, ignore_case: bool) -> Option { match ignore_case { true => match s.to_lowercase().as_str() { "int" => Some(ScalarType::Int), diff --git a/psl/parser-database/src/walkers/relation_field.rs b/psl/parser-database/src/walkers/relation_field.rs index 5e387480d0b7..7e44b2ca0df7 100644 --- a/psl/parser-database/src/walkers/relation_field.rs +++ b/psl/parser-database/src/walkers/relation_field.rs @@ -169,7 +169,7 @@ impl<'db> RelationFieldWalker<'db> { } /// The relation name. -#[derive(Debug, Clone, PartialOrd)] +#[derive(Debug, Clone)] pub enum RelationName<'db> { /// A relation name specified in the AST. Explicit(&'db str), @@ -201,6 +201,12 @@ impl<'db> Ord for RelationName<'db> { } } +impl<'db> PartialOrd for RelationName<'db> { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + impl<'db> std::hash::Hash for RelationName<'db> { fn hash(&self, state: &mut H) { match self { diff --git a/psl/psl-core/Cargo.toml b/psl/psl-core/Cargo.toml index ca1087939681..cd069d9bce37 100644 --- a/psl/psl-core/Cargo.toml +++ b/psl/psl-core/Cargo.toml @@ -22,7 +22,7 @@ chrono = { workspace = true } connection-string.workspace = true itertools.workspace = true once_cell = "1.3.1" -regex = "1.3.7" +regex.workspace = true serde.workspace = true serde_json.workspace = true enumflags2.workspace = true diff --git a/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs b/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs index fa1559c33c2b..65a0d929995f 100644 --- a/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs +++ b/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs @@ -318,7 +318,7 @@ impl Connector for PostgresDatamodelConnector { DoublePrecision => ScalarType::Float, // Decimal Decimal(_) => ScalarType::Decimal, - Money => ScalarType::Float, + Money => ScalarType::Decimal, // DateTime Timestamp(_) => ScalarType::DateTime, Timestamptz(_) => ScalarType::DateTime, @@ -463,7 +463,10 @@ impl Connector for PostgresDatamodelConnector { } fn validate_url(&self, url: &str) -> Result<(), String> { - if !url.starts_with("postgres://") && !url.starts_with("postgresql://") { + if !url.starts_with("postgres://") + && !url.starts_with("postgresql://") + && !url.starts_with("prisma+postgres://") + { return Err("must start with the protocol `postgresql://` or `postgres://`.".to_owned()); } diff --git a/psl/psl-core/src/common/preview_features.rs b/psl/psl-core/src/common/preview_features.rs index cadaa636395e..ea9b0eceea81 100644 --- a/psl/psl-core/src/common/preview_features.rs +++ b/psl/psl-core/src/common/preview_features.rs @@ -80,7 +80,9 @@ features!( RelationJoins, ReactNative, PrismaSchemaFolder, - OmitApi + OmitApi, + TypedSql, + StrictUndefinedChecks ); /// Generator preview features (alphabetically sorted) @@ -99,6 +101,7 @@ pub const ALL_PREVIEW_FEATURES: FeatureMap = FeatureMap { | RelationJoins | OmitApi | PrismaSchemaFolder + | StrictUndefinedChecks }), deprecated: enumflags2::make_bitflags!(PreviewFeature::{ AtomicNumberOperations @@ -133,7 +136,7 @@ pub const ALL_PREVIEW_FEATURES: FeatureMap = FeatureMap { | TransactionApi | UncheckedScalarInputs }), - hidden: enumflags2::make_bitflags!(PreviewFeature::{ReactNative}), + hidden: enumflags2::make_bitflags!(PreviewFeature::{ReactNative | TypedSql}), }; #[derive(Debug)] diff --git a/psl/psl-core/src/mcf.rs b/psl/psl-core/src/mcf.rs index c03edeb66993..75cde3bb5f01 100644 --- a/psl/psl-core/src/mcf.rs +++ b/psl/psl-core/src/mcf.rs @@ -8,7 +8,7 @@ pub use source::*; use serde::Serialize; pub fn config_to_mcf_json_value(mcf: &crate::Configuration, files: &Files) -> serde_json::Value { - serde_json::to_value(&model_to_serializable(mcf, files)).expect("Failed to render JSON.") + serde_json::to_value(model_to_serializable(mcf, files)).expect("Failed to render JSON.") } #[derive(Debug, Serialize)] diff --git a/psl/psl/tests/attributes/id_negative.rs b/psl/psl/tests/attributes/id_negative.rs index ce4a1ccbaa4d..bc3c9785118c 100644 --- a/psl/psl/tests/attributes/id_negative.rs +++ b/psl/psl/tests/attributes/id_negative.rs @@ -44,6 +44,26 @@ fn id_should_error_multiple_ids_are_provided() { expect_error(dml, &expectation) } +#[test] +fn id_should_error_on_invalid_uuid_version() { + let dml = indoc! {r#" + model Model { + id String @id @default(uuid(1)) + } + "#}; + + let expectation = expect![[r#" + error: Error parsing attribute "@default": `uuid()` takes either no argument, or a single integer argument which is either 4 or 7. + --> schema.prisma:2 +  |  +  1 | model Model { +  2 |  id String @id @default(uuid(1)) +  |  + "#]]; + + expect_error(dml, &expectation) +} + #[test] fn id_must_error_when_single_and_multi_field_id_is_used() { let dml = indoc! {r#" diff --git a/psl/psl/tests/attributes/id_positive.rs b/psl/psl/tests/attributes/id_positive.rs index 396c334f94fd..e81c1305c274 100644 --- a/psl/psl/tests/attributes/id_positive.rs +++ b/psl/psl/tests/attributes/id_positive.rs @@ -70,6 +70,45 @@ fn should_allow_string_ids_with_uuid() { model.assert_id_on_fields(&["id"]); } +#[test] +fn should_allow_string_ids_with_uuid_version_specified() { + let dml = indoc! {r#" + model ModelA { + id String @id @default(uuid(4)) + } + + model ModelB { + id String @id @default(uuid(7)) + } + "#}; + + let schema = psl::parse_schema(dml).unwrap(); + + { + let model = schema.assert_has_model("ModelA"); + + model + .assert_has_scalar_field("id") + .assert_scalar_type(ScalarType::String) + .assert_default_value() + .assert_uuid(); + + model.assert_id_on_fields(&["id"]); + } + + { + let model = schema.assert_has_model("ModelB"); + + model + .assert_has_scalar_field("id") + .assert_scalar_type(ScalarType::String) + .assert_default_value() + .assert_uuid(); + + model.assert_id_on_fields(&["id"]); + } +} + #[test] fn should_allow_string_ids_without_default() { let dml = indoc! {r#" diff --git a/psl/psl/tests/common/asserts.rs b/psl/psl/tests/common/asserts.rs index 4278f5cb77e5..e75df7c2cd7b 100644 --- a/psl/psl/tests/common/asserts.rs +++ b/psl/psl/tests/common/asserts.rs @@ -631,9 +631,7 @@ impl DefaultValueAssert for ast::Expression { #[track_caller] fn assert_uuid(&self) -> &Self { - assert!( - matches!(self, ast::Expression::Function(name, args, _) if name == "uuid" && args.arguments.is_empty()) - ); + assert!(matches!(self, ast::Expression::Function(name, _, _) if name == "uuid")); self } diff --git a/psl/psl/tests/config/generators.rs b/psl/psl/tests/config/generators.rs index 273de14c7443..20e8f8864402 100644 --- a/psl/psl/tests/config/generators.rs +++ b/psl/psl/tests/config/generators.rs @@ -258,7 +258,7 @@ fn nice_error_for_unknown_generator_preview_feature() { .unwrap_err(); let expectation = expect![[r#" - error: The preview feature "foo" is not known. Expected one of: deno, driverAdapters, fullTextIndex, fullTextSearch, metrics, multiSchema, nativeDistinct, postgresqlExtensions, tracing, views, relationJoins, prismaSchemaFolder, omitApi + error: The preview feature "foo" is not known. Expected one of: deno, driverAdapters, fullTextIndex, fullTextSearch, metrics, multiSchema, nativeDistinct, postgresqlExtensions, tracing, views, relationJoins, prismaSchemaFolder, omitApi, strictUndefinedChecks --> schema.prisma:3  |   2 |  provider = "prisma-client-js" diff --git a/psl/psl/tests/validation/enums/value_with_non_ascii_ident_should_not_error.prisma b/psl/psl/tests/validation/enums/value_with_non_ascii_ident_should_not_error.prisma new file mode 100644 index 000000000000..9adb651bde57 --- /dev/null +++ b/psl/psl/tests/validation/enums/value_with_non_ascii_ident_should_not_error.prisma @@ -0,0 +1,12 @@ +generator client { + provider = "prisma-client-js" +} + +datasource db { + provider = "postgresql" + url = env("DATABASE_URL") +} + +enum CatVariant { + lööps +} diff --git a/psl/psl/tests/validation/models/field_with_non_ascii_ident_should_not_error.prisma b/psl/psl/tests/validation/models/field_with_non_ascii_ident_should_not_error.prisma new file mode 100644 index 000000000000..f5752f9f2bfa --- /dev/null +++ b/psl/psl/tests/validation/models/field_with_non_ascii_ident_should_not_error.prisma @@ -0,0 +1,14 @@ +generator client { + provider = "prisma-client-js" +} + +datasource db { + provider = "postgresql" + url = env("DATABASE_URL") +} + +model A { + id Int @id @map("_id") + lööps String[] +} + diff --git a/psl/psl/tests/validation/models/model_with_non_ascii_ident_should_not_error.prisma b/psl/psl/tests/validation/models/model_with_non_ascii_ident_should_not_error.prisma new file mode 100644 index 000000000000..a2e4ca0c3af6 --- /dev/null +++ b/psl/psl/tests/validation/models/model_with_non_ascii_ident_should_not_error.prisma @@ -0,0 +1,12 @@ +generator client { + provider = "prisma-client-js" +} + +datasource db { + provider = "postgresql" + url = env("DATABASE_URL") +} + +model Lööp { + id Int @id +} diff --git a/psl/schema-ast/src/parser/datamodel.pest b/psl/schema-ast/src/parser/datamodel.pest index 62bf8459e00b..a0b84bd0d060 100644 --- a/psl/schema-ast/src/parser/datamodel.pest +++ b/psl/schema-ast/src/parser/datamodel.pest @@ -130,9 +130,11 @@ doc_content = @{ (!NEWLINE ~ ANY)* } // ###################################### // shared building blocks // ###################################### -identifier = @{ ASCII_ALPHANUMERIC ~ ( "_" | "-" | ASCII_ALPHANUMERIC)* } +unicode_alphanumeric = { LETTER | ASCII_DIGIT } +identifier = @{ unicode_alphanumeric ~ ( "_" | "-" | unicode_alphanumeric)* } path = @{ identifier ~ ("." ~ path?)* } + WHITESPACE = _{ SPACE_SEPARATOR | "\t" } // tabs are also whitespace NEWLINE = _{ "\n" | "\r\n" | "\r" } empty_lines = @{ (WHITESPACE* ~ NEWLINE)+ } diff --git a/quaint/.envrc b/quaint/.envrc index 226b0ae234c1..6c7b5116c89a 100644 --- a/quaint/.envrc +++ b/quaint/.envrc @@ -1,7 +1,7 @@ export TEST_MYSQL="mysql://root:prisma@localhost:3306/prisma" export TEST_MYSQL8="mysql://root:prisma@localhost:3307/prisma" export TEST_MYSQL_MARIADB="mysql://root:prisma@localhost:3308/prisma" -export TEST_PSQL="postgres://postgres:prisma@localhost:5432/postgres" +export TEST_PSQL="postgresql://postgres:prisma@localhost:5432/postgres" export TEST_CRDB="postgresql://prisma@127.0.0.1:26259/postgres" export TEST_MSSQL="jdbc:sqlserver://localhost:1433;database=master;user=SA;password=;trustServerCertificate=true" if command -v nix-shell &> /dev/null diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index d30e27408782..482f6f0a7626 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -51,6 +51,8 @@ postgresql-native = [ "bit-vec", "lru-cache", "byteorder", + "dep:ws_stream_tungstenite", + "dep:async-tungstenite" ] postgresql = [] @@ -70,15 +72,17 @@ fmt-sql = ["sqlformat"] connection-string = "0.2" percent-encoding = "2" tracing.workspace = true -tracing-core = "0.1" +tracing-futures.workspace = true async-trait.workspace = true thiserror = "1.0" num_cpus = "1.12" -metrics = "0.18" -futures = "0.3" +prisma-metrics.path = "../libs/metrics" +futures.workspace = true url.workspace = true hex = "0.4" itertools.workspace = true +regex.workspace = true +enumflags2.workspace = true either = { version = "1.6" } base64 = { version = "0.12.3" } @@ -88,11 +92,12 @@ serde_json.workspace = true native-tls = { version = "0.2", optional = true } bit-vec = { version = "0.6.1", optional = true } bytes = { version = "1.0", optional = true } -mobc = { version = "0.8", optional = true } +mobc = { version = "0.8.5", optional = true } serde = { version = "1.0" } sqlformat = { version = "0.2.3", optional = true } uuid.workspace = true crosstarget-utils = { path = "../libs/crosstarget-utils" } +concat-idents = "1.1.5" [dev-dependencies] once_cell = "1.3" @@ -102,13 +107,23 @@ paste = "1.0" serde = { version = "1.0", features = ["derive"] } quaint-test-macros = { path = "quaint-test-macros" } quaint-test-setup = { path = "quaint-test-setup" } -tokio = { version = "1.0", features = ["macros", "time"] } +tokio = { version = "1", features = ["macros", "time"] } expect-test = "1" [target.'cfg(target_arch = "wasm32")'.dependencies.getrandom] version = "0.2" features = ["js"] +[dependencies.ws_stream_tungstenite] +version = "0.14.0" +features = ["tokio_io"] +optional = true + +[dependencies.async-tungstenite] +version = "0.28.0" +features = ["tokio-runtime", "tokio-native-tls"] +optional = true + [dependencies.byteorder] default-features = false optional = true @@ -120,17 +135,17 @@ optional = true branch = "vendored-openssl" [dependencies.rusqlite] -version = "0.29" +version = "0.31" features = ["chrono", "column_decltype"] optional = true [target.'cfg(not(any(target_os = "macos", target_os = "ios")))'.dependencies.tiberius] -version = "0.11.6" +version = "0.12.3" optional = true features = ["sql-browser-tokio", "chrono", "bigdecimal"] [target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies.tiberius] -version = "0.11.2" +version = "0.12.3" optional = true default-features = false features = [ @@ -178,9 +193,9 @@ features = ["rt-multi-thread", "macros", "sync"] optional = true [dependencies.tokio-util] -version = "0.6" +version = "0.7" features = ["compat"] optional = true [build-dependencies] -cfg_aliases = "0.1.0" +cfg_aliases = "0.2.1" diff --git a/quaint/quaint-test-setup/Cargo.toml b/quaint/quaint-test-setup/Cargo.toml index eb47bd587e95..a1120af88104 100644 --- a/quaint/quaint-test-setup/Cargo.toml +++ b/quaint/quaint-test-setup/Cargo.toml @@ -9,5 +9,5 @@ once_cell = "1.3.1" bitflags = "1.2.1" async-trait.workspace = true names = "0.11" -tokio = { version = "1.0", features = ["rt-multi-thread"] } +tokio = { version = "1", features = ["rt-multi-thread"] } quaint = { path = "..", features = ["all-native", "pooled"] } diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index e5b0f760be0d..ad07b6698f38 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -9,8 +9,10 @@ //! implement the [Queryable](trait.Queryable.html) trait for generalized //! querying interface. +mod column_type; mod connection_info; +mod describe; pub mod external; pub mod metrics; #[cfg(native)] @@ -24,11 +26,13 @@ mod transaction; mod type_identifier; pub use self::result_set::*; +pub use column_type::*; pub use connection_info::*; #[cfg(native)] pub use native::*; +pub use describe::*; pub use external::*; pub use queryable::*; pub use transaction::*; diff --git a/quaint/src/connector/column_type.rs b/quaint/src/connector/column_type.rs new file mode 100644 index 000000000000..38fb3d786dc0 --- /dev/null +++ b/quaint/src/connector/column_type.rs @@ -0,0 +1,180 @@ +#[cfg(not(target_arch = "wasm32"))] +use super::TypeIdentifier; + +use crate::{Value, ValueType}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ColumnType { + Int32, + Int64, + Float, + Double, + Text, + Bytes, + Boolean, + Char, + Numeric, + Json, + Xml, + Uuid, + DateTime, + Date, + Time, + Enum, + + Int32Array, + Int64Array, + FloatArray, + DoubleArray, + TextArray, + CharArray, + BytesArray, + BooleanArray, + NumericArray, + JsonArray, + XmlArray, + UuidArray, + DateTimeArray, + DateArray, + TimeArray, + + Null, + + Unknown, +} + +impl ColumnType { + pub fn is_unknown(&self) -> bool { + matches!(self, ColumnType::Unknown) + } +} + +impl std::fmt::Display for ColumnType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ColumnType::Int32 => write!(f, "int"), + ColumnType::Int64 => write!(f, "bigint"), + ColumnType::Float => write!(f, "float"), + ColumnType::Double => write!(f, "double"), + ColumnType::Text => write!(f, "string"), + ColumnType::Enum => write!(f, "enum"), + ColumnType::Bytes => write!(f, "bytes"), + ColumnType::Boolean => write!(f, "bool"), + ColumnType::Char => write!(f, "char"), + ColumnType::Numeric => write!(f, "decimal"), + ColumnType::Json => write!(f, "json"), + ColumnType::Xml => write!(f, "xml"), + ColumnType::Uuid => write!(f, "uuid"), + ColumnType::DateTime => write!(f, "datetime"), + ColumnType::Date => write!(f, "date"), + ColumnType::Time => write!(f, "time"), + ColumnType::Int32Array => write!(f, "int-array"), + ColumnType::Int64Array => write!(f, "bigint-array"), + ColumnType::FloatArray => write!(f, "float-array"), + ColumnType::DoubleArray => write!(f, "double-array"), + ColumnType::TextArray => write!(f, "string-array"), + ColumnType::BytesArray => write!(f, "bytes-array"), + ColumnType::BooleanArray => write!(f, "bool-array"), + ColumnType::CharArray => write!(f, "char-array"), + ColumnType::NumericArray => write!(f, "decimal-array"), + ColumnType::JsonArray => write!(f, "json-array"), + ColumnType::XmlArray => write!(f, "xml-array"), + ColumnType::UuidArray => write!(f, "uuid-array"), + ColumnType::DateTimeArray => write!(f, "datetime-array"), + ColumnType::DateArray => write!(f, "date-array"), + ColumnType::TimeArray => write!(f, "time-array"), + + ColumnType::Null => write!(f, "null"), + ColumnType::Unknown => write!(f, "unknown"), + } + } +} + +impl From<&Value<'_>> for ColumnType { + fn from(value: &Value<'_>) -> Self { + Self::from(&value.typed) + } +} + +impl From<&ValueType<'_>> for ColumnType { + fn from(value: &ValueType) -> Self { + match value { + ValueType::Int32(_) => ColumnType::Int32, + ValueType::Int64(_) => ColumnType::Int64, + ValueType::Float(_) => ColumnType::Float, + ValueType::Double(_) => ColumnType::Double, + ValueType::Text(_) => ColumnType::Text, + ValueType::Enum(_, _) => ColumnType::Enum, + ValueType::EnumArray(_, _) => ColumnType::TextArray, + ValueType::Bytes(_) => ColumnType::Bytes, + ValueType::Boolean(_) => ColumnType::Boolean, + ValueType::Char(_) => ColumnType::Char, + ValueType::Numeric(_) => ColumnType::Numeric, + ValueType::Json(_) => ColumnType::Json, + ValueType::Xml(_) => ColumnType::Xml, + ValueType::Uuid(_) => ColumnType::Uuid, + ValueType::DateTime(_) => ColumnType::DateTime, + ValueType::Date(_) => ColumnType::Date, + ValueType::Time(_) => ColumnType::Time, + ValueType::Array(Some(vals)) if !vals.is_empty() => match &vals[0].typed { + ValueType::Int32(_) => ColumnType::Int32Array, + ValueType::Int64(_) => ColumnType::Int64Array, + ValueType::Float(_) => ColumnType::FloatArray, + ValueType::Double(_) => ColumnType::DoubleArray, + ValueType::Text(_) => ColumnType::TextArray, + ValueType::Enum(_, _) => ColumnType::TextArray, + ValueType::Bytes(_) => ColumnType::BytesArray, + ValueType::Boolean(_) => ColumnType::BooleanArray, + ValueType::Char(_) => ColumnType::CharArray, + ValueType::Numeric(_) => ColumnType::NumericArray, + ValueType::Json(_) => ColumnType::JsonArray, + ValueType::Xml(_) => ColumnType::TextArray, + ValueType::Uuid(_) => ColumnType::UuidArray, + ValueType::DateTime(_) => ColumnType::DateTimeArray, + ValueType::Date(_) => ColumnType::DateArray, + ValueType::Time(_) => ColumnType::TimeArray, + ValueType::Array(_) => ColumnType::Unknown, + ValueType::EnumArray(_, _) => ColumnType::Unknown, + }, + ValueType::Array(_) => ColumnType::Unknown, + } + } +} + +impl ColumnType { + #[cfg(not(target_arch = "wasm32"))] + pub(crate) fn from_type_identifier(value: T) -> Self + where + T: TypeIdentifier, + { + if value.is_bool() { + ColumnType::Boolean + } else if value.is_bytes() { + ColumnType::Bytes + } else if value.is_date() { + ColumnType::Date + } else if value.is_datetime() { + ColumnType::DateTime + } else if value.is_time() { + ColumnType::Time + } else if value.is_double() { + ColumnType::Double + } else if value.is_float() { + ColumnType::Float + } else if value.is_int32() { + ColumnType::Int32 + } else if value.is_int64() { + ColumnType::Int64 + } else if value.is_enum() { + ColumnType::Enum + } else if value.is_json() { + ColumnType::Json + } else if value.is_real() { + ColumnType::Numeric + } else if value.is_text() { + ColumnType::Text + } else { + ColumnType::Unknown + } + } +} diff --git a/quaint/src/connector/connection_info.rs b/quaint/src/connector/connection_info.rs index 7dd8a5b58257..80d63089e489 100644 --- a/quaint/src/connector/connection_info.rs +++ b/quaint/src/connector/connection_info.rs @@ -84,7 +84,7 @@ impl ConnectionInfo { } #[cfg(feature = "postgresql")] SqlFamily::Postgres => Ok(ConnectionInfo::Native(NativeConnectionInfo::Postgres( - PostgresUrl::new(url)?, + super::PostgresUrl::new_native(url)?, ))), #[allow(unreachable_patterns)] _ => unreachable!(), @@ -243,7 +243,7 @@ impl ConnectionInfo { pub fn pg_bouncer(&self) -> bool { match self { #[cfg(all(not(target_arch = "wasm32"), feature = "postgresql"))] - ConnectionInfo::Native(NativeConnectionInfo::Postgres(url)) => url.pg_bouncer(), + ConnectionInfo::Native(NativeConnectionInfo::Postgres(PostgresUrl::Native(url))) => url.pg_bouncer(), _ => false, } } diff --git a/quaint/src/connector/describe.rs b/quaint/src/connector/describe.rs new file mode 100644 index 000000000000..b719d594a56f --- /dev/null +++ b/quaint/src/connector/describe.rs @@ -0,0 +1,93 @@ +use std::borrow::Cow; + +use super::ColumnType; + +#[derive(Debug)] +pub struct DescribedQuery { + pub parameters: Vec, + pub columns: Vec, + pub enum_names: Option>, +} + +impl DescribedQuery { + pub fn param_enum_names(&self) -> Vec<&str> { + self.parameters.iter().filter_map(|p| p.enum_name.as_deref()).collect() + } +} + +#[derive(Debug)] +pub struct DescribedParameter { + pub name: String, + pub typ: ColumnType, + pub enum_name: Option, +} + +#[derive(Debug)] +pub struct DescribedColumn { + pub name: String, + pub typ: ColumnType, + pub nullable: bool, + pub enum_name: Option, +} + +impl DescribedParameter { + pub fn new_named<'a>(name: impl Into>, typ: impl Into) -> Self { + let name: Cow<'_, str> = name.into(); + + Self { + name: name.into_owned(), + typ: typ.into(), + enum_name: None, + } + } + + pub fn new_unnamed(idx: usize, typ: impl Into) -> Self { + Self { + name: format!("_{idx}"), + typ: typ.into(), + enum_name: None, + } + } + + pub fn with_enum_name(mut self, enum_name: Option) -> Self { + self.enum_name = enum_name; + self + } + + pub fn set_typ(mut self, typ: ColumnType) -> Self { + self.typ = typ; + self + } +} + +impl DescribedColumn { + pub fn new_named<'a>(name: impl Into>, typ: impl Into) -> Self { + let name: Cow<'_, str> = name.into(); + + Self { + name: name.into_owned(), + typ: typ.into(), + enum_name: None, + nullable: false, + } + } + + pub fn new_unnamed(idx: usize, typ: impl Into) -> Self { + Self { + name: format!("_{idx}"), + typ: typ.into(), + enum_name: None, + nullable: false, + } + } + + pub fn with_enum_name(mut self, enum_name: Option) -> Self { + self.enum_name = enum_name; + self + } + + pub fn is_nullable(mut self, nullable: bool) -> Self { + self.nullable = nullable; + self + } +} diff --git a/quaint/src/connector/metrics.rs b/quaint/src/connector/metrics.rs index a0c4ef426988..78fc7f99c724 100644 --- a/quaint/src/connector/metrics.rs +++ b/quaint/src/connector/metrics.rs @@ -1,15 +1,23 @@ +use prisma_metrics::{counter, histogram}; use tracing::{info_span, Instrument}; use crate::ast::{Params, Value}; use crosstarget_utils::time::ElapsedTimeCounter; use std::future::Future; -pub async fn query<'a, F, T, U>(tag: &'static str, query: &'a str, params: &'a [Value<'_>], f: F) -> crate::Result +pub async fn query<'a, F, T, U>( + tag: &'static str, + db_system_name: &'static str, + query: &'a str, + params: &'a [Value<'_>], + f: F, +) -> crate::Result where F: FnOnce() -> U + 'a, U: Future>, { - let span = info_span!("quaint:query", "db.statement" = %query); + let span = + info_span!("quaint:query", "db.system" = db_system_name, "db.statement" = %query, "otel.kind" = "client"); do_query(tag, query, params, f).instrument(span).await } @@ -46,9 +54,9 @@ where trace_query(query, params, result, &start); } - histogram!(format!("{tag}.query.time"), start.elapsed_time()); - histogram!("prisma_datasource_queries_duration_histogram_ms", start.elapsed_time()); - increment_counter!("prisma_datasource_queries_total"); + histogram!(format!("{tag}.query.time")).record(start.elapsed_time()); + histogram!("prisma_datasource_queries_duration_histogram_ms").record(start.elapsed_time()); + counter!("prisma_datasource_queries_total").increment(1); res } @@ -74,7 +82,7 @@ where result, ); - histogram!("pool.check_out", start.elapsed_time()); + histogram!("pool.check_out").record(start.elapsed_time()); res } diff --git a/quaint/src/connector/mssql/native/column_type.rs b/quaint/src/connector/mssql/native/column_type.rs new file mode 100644 index 000000000000..a133b883b34e --- /dev/null +++ b/quaint/src/connector/mssql/native/column_type.rs @@ -0,0 +1,43 @@ +use crate::connector::ColumnType; +use tiberius::{Column, ColumnType as MssqlColumnType}; + +impl From<&Column> for ColumnType { + fn from(value: &Column) -> Self { + match value.column_type() { + MssqlColumnType::Null => ColumnType::Unknown, + + MssqlColumnType::BigVarChar + | MssqlColumnType::BigChar + | MssqlColumnType::NVarchar + | MssqlColumnType::NChar + | MssqlColumnType::Text + | MssqlColumnType::NText => ColumnType::Text, + + MssqlColumnType::Xml => ColumnType::Xml, + + MssqlColumnType::Bit | MssqlColumnType::Bitn => ColumnType::Boolean, + MssqlColumnType::Int1 | MssqlColumnType::Int2 | MssqlColumnType::Int4 => ColumnType::Int32, + MssqlColumnType::Int8 | MssqlColumnType::Intn => ColumnType::Int64, + + MssqlColumnType::Datetime2 + | MssqlColumnType::Datetime4 + | MssqlColumnType::Datetime + | MssqlColumnType::Datetimen + | MssqlColumnType::DatetimeOffsetn => ColumnType::DateTime, + + MssqlColumnType::Float4 => ColumnType::Float, + MssqlColumnType::Float8 | MssqlColumnType::Money | MssqlColumnType::Money4 | MssqlColumnType::Floatn => { + ColumnType::Double + } + MssqlColumnType::Guid => ColumnType::Uuid, + MssqlColumnType::Decimaln | MssqlColumnType::Numericn => ColumnType::Numeric, + MssqlColumnType::Daten => ColumnType::Date, + MssqlColumnType::Timen => ColumnType::Time, + MssqlColumnType::BigVarBin | MssqlColumnType::BigBinary | MssqlColumnType::Image => ColumnType::Bytes, + + MssqlColumnType::Udt | MssqlColumnType::SSVariant => { + unreachable!("UDT and SSVariant types are not supported by Tiberius.") + } + } + } +} diff --git a/quaint/src/connector/mssql/native/conversion.rs b/quaint/src/connector/mssql/native/conversion.rs index c6f2b1f37f48..5d2eb2eb08b8 100644 --- a/quaint/src/connector/mssql/native/conversion.rs +++ b/quaint/src/connector/mssql/native/conversion.rs @@ -3,8 +3,7 @@ use crate::ast::{Value, ValueType}; use bigdecimal::BigDecimal; use std::{borrow::Cow, convert::TryFrom}; -use tiberius::ToSql; -use tiberius::{ColumnData, FromSql, IntoSql}; +use tiberius::{ColumnData, FromSql, IntoSql, ToSql}; impl<'a> IntoSql<'a> for &'a Value<'a> { fn into_sql(self) -> ColumnData<'a> { diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs index 124e14ac94d0..fe7751ddf373 100644 --- a/quaint/src/connector/mssql/native/mod.rs +++ b/quaint/src/connector/mssql/native/mod.rs @@ -1,15 +1,16 @@ //! Definitions for the MSSQL connector. //! This module is not compatible with wasm32-* targets. //! This module is only available with the `mssql-native` feature. +mod column_type; mod conversion; mod error; pub(crate) use crate::connector::mssql::MssqlUrl; -use crate::connector::{timeout, IsolationLevel, Transaction, TransactionOptions}; +use crate::connector::{timeout, DescribedQuery, IsolationLevel, Transaction, TransactionOptions}; use crate::{ ast::{Query, Value}, - connector::{metrics, queryable::*, DefaultTransaction, ResultSet}, + connector::{metrics, queryable::*, ColumnType as QuaintColumnType, DefaultTransaction, ResultSet}, visitor::{self, Visitor}, }; use async_trait::async_trait; @@ -29,6 +30,7 @@ use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt}; pub use tiberius; static SQL_SERVER_DEFAULT_ISOLATION: IsolationLevel = IsolationLevel::ReadCommitted; +const DB_SYSTEM_NAME: &str = "mssql"; #[async_trait] impl TransactionCapable for Mssql { @@ -129,7 +131,7 @@ impl Queryable for Mssql { } async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mssql.query_raw", sql, params, move || async move { + metrics::query("mssql.query_raw", DB_SYSTEM_NAME, sql, params, move || async move { let mut client = self.client.lock().await; let mut query = tiberius::Query::new(sql); @@ -144,6 +146,10 @@ impl Queryable for Mssql { Some(rows) => { let mut columns_set = false; let mut columns = Vec::new(); + + let mut types_set = false; + let mut types = Vec::new(); + let mut result_rows = Vec::with_capacity(rows.len()); for row in rows.into_iter() { @@ -152,6 +158,11 @@ impl Queryable for Mssql { columns_set = true; } + if !types_set { + types = row.columns().iter().map(QuaintColumnType::from).collect(); + types_set = true; + } + let mut values: Vec> = Vec::with_capacity(row.len()); for val in row.into_iter() { @@ -161,9 +172,9 @@ impl Queryable for Mssql { result_rows.push(values); } - Ok(ResultSet::new(columns, result_rows)) + Ok(ResultSet::new(columns, types, result_rows)) } - None => Ok(ResultSet::new(Vec::new(), Vec::new())), + None => Ok(ResultSet::new(Vec::new(), Vec::new(), Vec::new())), } }) .await @@ -173,13 +184,17 @@ impl Queryable for Mssql { self.query_raw(sql, params).await } + async fn describe_query(&self, _sql: &str) -> crate::Result { + unimplemented!("SQL Server does not support describe_query yet.") + } + async fn execute(&self, q: Query<'_>) -> crate::Result { let (sql, params) = visitor::Mssql::build(q)?; self.execute_raw(&sql, ¶ms[..]).await } async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mssql.execute_raw", sql, params, move || async move { + metrics::query("mssql.execute_raw", DB_SYSTEM_NAME, sql, params, move || async move { let mut query = tiberius::Query::new(sql); for param in params { @@ -199,7 +214,7 @@ impl Queryable for Mssql { } async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("mssql.raw_cmd", cmd, &[], move || async move { + metrics::query("mssql.raw_cmd", DB_SYSTEM_NAME, cmd, &[], move || async move { let mut client = self.client.lock().await; self.perform_io(client.simple_query(cmd)).await?.into_results().await?; Ok(()) diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index f18fd6a0b94a..064bc152d25f 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -1,6 +1,7 @@ //! Wasm-compatible definitions for the MySQL connector. //! This module is only available with the `mysql` feature. mod defaults; + pub(crate) mod error; pub(crate) mod url; diff --git a/quaint/src/connector/mysql/native/column_type.rs b/quaint/src/connector/mysql/native/column_type.rs new file mode 100644 index 000000000000..801cd697460a --- /dev/null +++ b/quaint/src/connector/mysql/native/column_type.rs @@ -0,0 +1,8 @@ +use crate::connector::ColumnType; +use mysql_async::Column as MysqlColumn; + +impl From<&MysqlColumn> for ColumnType { + fn from(value: &MysqlColumn) -> Self { + ColumnType::from_type_identifier(value) + } +} diff --git a/quaint/src/connector/mysql/native/conversion.rs b/quaint/src/connector/mysql/native/conversion.rs index cccb1dc3130a..1a2d065f03af 100644 --- a/quaint/src/connector/mysql/native/conversion.rs +++ b/quaint/src/connector/mysql/native/conversion.rs @@ -80,7 +80,7 @@ pub fn conv_params(params: &[Value<'_>]) -> crate::Result { } } -impl TypeIdentifier for my::Column { +impl TypeIdentifier for &my::Column { fn is_real(&self) -> bool { use ColumnType::*; @@ -175,14 +175,19 @@ impl TypeIdentifier for my::Column { fn is_bytes(&self) -> bool { use ColumnType::*; - let is_a_blob = matches!( + let is_bytes = matches!( self.column_type(), - MYSQL_TYPE_TINY_BLOB | MYSQL_TYPE_MEDIUM_BLOB | MYSQL_TYPE_LONG_BLOB | MYSQL_TYPE_BLOB + MYSQL_TYPE_TINY_BLOB + | MYSQL_TYPE_MEDIUM_BLOB + | MYSQL_TYPE_LONG_BLOB + | MYSQL_TYPE_BLOB + | MYSQL_TYPE_VAR_STRING + | MYSQL_TYPE_STRING ) && self.character_set() == 63; let is_bits = self.column_type() == MYSQL_TYPE_BIT && self.column_length() > 1; - is_a_blob || is_bits + is_bytes || is_bits } fn is_bool(&self) -> bool { @@ -268,6 +273,20 @@ impl TakeRow for my::Row { })?), my::Value::Float(f) => Value::from(f), my::Value::Double(f) => Value::from(f), + my::Value::Date(year, month, day, _, _, _, _) if column.is_date() => { + if day == 0 || month == 0 { + let msg = format!( + "The column `{}` contained an invalid datetime value with either day or month set to zero.", + column.name_str() + ); + let kind = ErrorKind::value_out_of_range(msg); + return Err(Error::builder(kind).build()); + } + + let date = NaiveDate::from_ymd_opt(year.into(), month.into(), day.into()).unwrap(); + + Value::date(date) + } my::Value::Date(year, month, day, hour, min, sec, micro) => { if day == 0 || month == 0 { let msg = format!( diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs index 4ffdfc88b4cf..2c8a757a48e7 100644 --- a/quaint/src/connector/mysql/native/mod.rs +++ b/quaint/src/connector/mysql/native/mod.rs @@ -1,11 +1,12 @@ //! Definitions for the MySQL connector. //! This module is not compatible with wasm32-* targets. //! This module is only available with the `mysql-native` feature. +mod column_type; mod conversion; mod error; pub(crate) use crate::connector::mysql::MysqlUrl; -use crate::connector::{timeout, IsolationLevel}; +use crate::connector::{timeout, ColumnType, DescribedColumn, DescribedParameter, DescribedQuery, IsolationLevel}; use crate::{ ast::{Query, Value}, @@ -15,6 +16,7 @@ use crate::{ }; use async_trait::async_trait; use lru_cache::LruCache; +use mysql_async::consts::ColumnFlags; use mysql_async::{ self as my, prelude::{Query as _, Queryable as _}, @@ -66,6 +68,8 @@ impl MysqlUrl { } } +const DB_SYSTEM_NAME: &str = "mysql"; + /// A connector interface for the MySQL database. #[derive(Debug)] pub struct Mysql { @@ -193,19 +197,44 @@ impl Queryable for Mysql { } async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mysql.query_raw", sql, params, move || async move { + metrics::query("mysql.query_raw", DB_SYSTEM_NAME, sql, params, move || async move { self.prepared(sql, |stmt| async move { let mut conn = self.conn.lock().await; let rows: Vec = conn.exec(&stmt, conversion::conv_params(params)?).await?; - let columns = stmt.columns().iter().map(|s| s.name_str().into_owned()).collect(); let last_id = conn.last_insert_id(); - let mut result_set = ResultSet::new(columns, Vec::new()); + + let mut result_rows = Vec::with_capacity(rows.len()); + let mut columns: Vec = Vec::new(); + let mut column_types: Vec = Vec::new(); + + let mut columns_set = false; for mut row in rows { - result_set.rows.push(row.take_result_row()?); + let row = row.take_result_row()?; + + if !columns_set { + for (idx, _) in row.iter().enumerate() { + let maybe_column = stmt.columns().get(idx); + // `mysql_async` does not return columns in `ResultSet` when a call to a stored procedure is done + // See https://github.com/prisma/prisma/issues/6173 + let column = maybe_column + .map(|col| col.name_str().into_owned()) + .unwrap_or_else(|| format!("f{idx}")); + let column_type = maybe_column.map(ColumnType::from).unwrap_or(ColumnType::Unknown); + + columns.push(column); + column_types.push(column_type); + } + + columns_set = true; + } + + result_rows.push(row); } + let mut result_set = ResultSet::new(columns, column_types, result_rows); + if let Some(id) = last_id { result_set.set_last_insert_id(id); }; @@ -221,13 +250,39 @@ impl Queryable for Mysql { self.query_raw(sql, params).await } + async fn describe_query(&self, sql: &str) -> crate::Result { + self.prepared(sql, |stmt| async move { + let columns = stmt + .columns() + .iter() + .map(|col| { + DescribedColumn::new_named(col.name_str(), col) + .is_nullable(!col.flags().contains(ColumnFlags::NOT_NULL_FLAG)) + }) + .collect(); + let parameters = stmt + .params() + .iter() + .enumerate() + .map(|(idx, col)| DescribedParameter::new_unnamed(idx, col)) + .collect(); + + Ok(DescribedQuery { + columns, + parameters, + enum_names: None, + }) + }) + .await + } + async fn execute(&self, q: Query<'_>) -> crate::Result { let (sql, params) = visitor::Mysql::build(q)?; self.execute_raw(&sql, ¶ms).await } async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mysql.execute_raw", sql, params, move || async move { + metrics::query("mysql.execute_raw", DB_SYSTEM_NAME, sql, params, move || async move { self.prepared(sql, |stmt| async move { let mut conn = self.conn.lock().await; conn.exec_drop(stmt, conversion::conv_params(params)?).await?; @@ -244,7 +299,7 @@ impl Queryable for Mysql { } async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("mysql.raw_cmd", cmd, &[], move || async move { + metrics::query("mysql.raw_cmd", DB_SYSTEM_NAME, cmd, &[], move || async move { self.perform_io(|| async move { let mut conn = self.conn.lock().await; let mut result = cmd.run(&mut *conn).await?; diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index 2ebb428f7dd6..e4d0b439f278 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -1,6 +1,7 @@ //! Wasm-compatible definitions for the PostgreSQL connector. //! This module is only available with the `postgresql` feature. mod defaults; + pub(crate) mod error; pub(crate) mod url; diff --git a/quaint/src/connector/postgres/error.rs b/quaint/src/connector/postgres/error.rs index 3dcc481eccba..0f19cf99e137 100644 --- a/quaint/src/connector/postgres/error.rs +++ b/quaint/src/connector/postgres/error.rs @@ -1,3 +1,5 @@ +use crosstarget_utils::{regex::RegExp, RegExpCompat}; +use enumflags2::BitFlags; use std::fmt::{Display, Formatter}; use crate::error::{DatabaseConstraint, Error, ErrorKind, Name}; @@ -28,6 +30,11 @@ impl Display for PostgresError { } } +fn extract_fk_constraint_name(message: &str) -> Option { + let re = RegExp::new(r#"foreign key constraint "([^"]+)""#, BitFlags::empty()).unwrap(); + re.captures(message).and_then(|caps| caps.get(1).cloned()) +} + impl From for Error { fn from(value: PostgresError) -> Self { match value.code.as_str() { @@ -89,12 +96,8 @@ impl From for Error { builder.build() } None => { - let constraint = value - .message - .split_whitespace() - .nth(10) - .and_then(|s| s.split('"').nth(1)) - .map(ToString::to_string) + // `value.message` looks like `update on table "Child" violates foreign key constraint "Child_parent_id_fkey"` + let constraint = extract_fk_constraint_name(value.message.as_str()) .map(DatabaseConstraint::Index) .unwrap_or(DatabaseConstraint::CannotParse); diff --git a/quaint/src/connector/postgres/native/column_type.rs b/quaint/src/connector/postgres/native/column_type.rs new file mode 100644 index 000000000000..ebb68497dacf --- /dev/null +++ b/quaint/src/connector/postgres/native/column_type.rs @@ -0,0 +1,135 @@ +use crate::connector::ColumnType; + +use std::borrow::Cow; +use tokio_postgres::types::{Kind as PostgresKind, Type as PostgresType}; + +macro_rules! create_pg_mapping { + ( + $($key:ident($typ: ty) => [$($value:ident),+]),* $(,)? + $([$pg_only_key:ident => $column_type_mapping:ident]),* + ) => { + // Generate PGColumnType enums + $( + concat_idents::concat_idents!(enum_name = PGColumnType, $key { + #[derive(Debug)] + #[allow(non_camel_case_types)] + #[allow(clippy::upper_case_acronyms)] + pub(crate) enum enum_name { + $($value,)* + } + }); + )* + + // Generate validators + $( + concat_idents::concat_idents!(struct_name = PGColumnValidator, $key { + #[derive(Debug)] + #[allow(non_camel_case_types)] + pub struct struct_name; + + impl struct_name { + #[inline] + #[allow(clippy::extra_unused_lifetimes)] + pub fn read<'a>(&self, val: $typ) -> $typ { + val + } + } + }); + )* + + pub(crate) enum PGColumnType { + $( + $key( + concat_idents::concat_idents!(variant = PGColumnType, $key, { variant }), + concat_idents::concat_idents!(enum_name = PGColumnValidator, $key, { enum_name }) + ), + )* + $($pg_only_key(concat_idents::concat_idents!(enum_name = PGColumnValidator, $column_type_mapping, { enum_name })),)* + } + + impl PGColumnType { + /// Takes a Postgres type and returns the corresponding ColumnType + #[deny(unreachable_patterns)] + pub(crate) fn from_pg_type(ty: &PostgresType) -> PGColumnType { + match ty { + $( + $( + &PostgresType::$value => PGColumnType::$key( + concat_idents::concat_idents!(variant = PGColumnType, $key, { variant::$value }), + concat_idents::concat_idents!(enum_name = PGColumnValidator, $key, { enum_name }), + ), + )* + )* + ref x => match x.kind() { + PostgresKind::Enum => PGColumnType::Enum(PGColumnValidatorText), + PostgresKind::Array(inner) => match inner.kind() { + PostgresKind::Enum => PGColumnType::EnumArray(PGColumnValidatorTextArray), + _ => PGColumnType::UnknownArray(PGColumnValidatorTextArray), + }, + _ => PGColumnType::Unknown(PGColumnValidatorText), + }, + } + } + } + + impl From for ColumnType { + fn from(ty: PGColumnType) -> ColumnType { + match ty { + $( + PGColumnType::$key(..) => ColumnType::$key, + )* + $( + PGColumnType::$pg_only_key(..) => ColumnType::$column_type_mapping, + )* + } + } + } + + impl From<&PostgresType> for ColumnType { + fn from(ty: &PostgresType) -> ColumnType { + PGColumnType::from_pg_type(&ty).into() + } + } + }; +} + +// Create a mapping between Postgres types and ColumnType and ensures there's a single source of truth. +// ColumnType() => [PostgresType(s)...] +create_pg_mapping! { + Boolean(Option) => [BOOL], + Int32(Option) => [INT2, INT4], + Int64(Option) => [INT8, OID], + Float(Option) => [FLOAT4], + Double(Option) => [FLOAT8], + Bytes(Option>) => [BYTEA], + Numeric(Option) => [NUMERIC, MONEY], + DateTime(Option>) => [TIMESTAMP, TIMESTAMPTZ], + Date(Option) => [DATE], + Time(Option) => [TIME, TIMETZ], + Text(Option>) => [INET, CIDR, BIT, VARBIT], + Uuid(Option) => [UUID], + Json(Option) => [JSON, JSONB], + Xml(Option>) => [XML], + Char(Option) => [CHAR], + + BooleanArray(impl Iterator>) => [BOOL_ARRAY], + Int32Array(impl Iterator>) => [INT2_ARRAY, INT4_ARRAY], + Int64Array(impl Iterator>) => [INT8_ARRAY, OID_ARRAY], + FloatArray(impl Iterator>) => [FLOAT4_ARRAY], + DoubleArray(impl Iterator>) => [FLOAT8_ARRAY], + BytesArray(impl Iterator>>) => [BYTEA_ARRAY], + NumericArray(impl Iterator>) => [NUMERIC_ARRAY, MONEY_ARRAY], + DateTimeArray(impl Iterator>>) => [TIMESTAMP_ARRAY, TIMESTAMPTZ_ARRAY], + DateArray(impl Iterator>) => [DATE_ARRAY], + TimeArray(impl Iterator>) => [TIME_ARRAY, TIMETZ_ARRAY], + TextArray(impl Iterator>>) => [TEXT_ARRAY, NAME_ARRAY, VARCHAR_ARRAY, INET_ARRAY, CIDR_ARRAY, BIT_ARRAY, VARBIT_ARRAY, XML_ARRAY], + UuidArray(impl Iterator>) => [UUID_ARRAY], + JsonArray(impl Iterator>) => [JSON_ARRAY, JSONB_ARRAY], + + // For the cases where the Postgres type is not directly mappable to ColumnType, use the following: + // [PGColumnType => ColumnType] + [Enum => Text], + [EnumArray => TextArray], + [UnknownArray => TextArray], + [Unknown => Text] +} diff --git a/quaint/src/connector/postgres/native/conversion.rs b/quaint/src/connector/postgres/native/conversion.rs index 4479eed69c69..fefe62a96d9a 100644 --- a/quaint/src/connector/postgres/native/conversion.rs +++ b/quaint/src/connector/postgres/native/conversion.rs @@ -4,8 +4,11 @@ use crate::{ ast::{Value, ValueType}, connector::queryable::{GetRow, ToColumnNames}, error::{Error, ErrorKind}, + prelude::EnumVariant, }; +use super::column_type::*; + use bigdecimal::{num_bigint::BigInt, BigDecimal, FromPrimitive, ToPrimitive}; use bit_vec::BitVec; use bytes::BytesMut; @@ -13,7 +16,7 @@ use chrono::{DateTime, NaiveDateTime, Utc}; pub(crate) use decimal::DecimalWrapper; use postgres_types::{FromSql, ToSql, WrongType}; -use std::{convert::TryFrom, error::Error as StdError}; +use std::{borrow::Cow, convert::TryFrom, error::Error as StdError}; use tokio_postgres::{ types::{self, IsNull, Kind, Type as PostgresType}, Row as PostgresRow, Statement as PostgresStatement, @@ -162,411 +165,528 @@ impl<'a> FromSql<'a> for NaiveMoney { impl GetRow for PostgresRow { fn get_result_row(&self) -> crate::Result>> { fn convert(row: &PostgresRow, i: usize) -> crate::Result> { - let result = match *row.columns()[i].type_() { - PostgresType::BOOL => ValueType::Boolean(row.try_get(i)?).into_value(), - PostgresType::INT2 => match row.try_get(i)? { - Some(val) => { - let val: i16 = val; - Value::int32(val) - } - None => Value::null_int32(), - }, - PostgresType::INT4 => match row.try_get(i)? { - Some(val) => { - let val: i32 = val; - Value::int32(val) - } - None => Value::null_int32(), - }, - PostgresType::INT8 => match row.try_get(i)? { - Some(val) => { - let val: i64 = val; - Value::int64(val) - } - None => Value::null_int64(), - }, - PostgresType::FLOAT4 => match row.try_get(i)? { - Some(val) => { - let val: f32 = val; - Value::float(val) - } - None => Value::null_float(), - }, - PostgresType::FLOAT8 => match row.try_get(i)? { - Some(val) => { - let val: f64 = val; - Value::double(val) - } - None => Value::null_double(), + let pg_ty = row.columns()[i].type_(); + let column_type = PGColumnType::from_pg_type(pg_ty); + + // This convoluted nested enum is macro-generated to ensure we have a single source of truth for + // the mapping between Postgres types and ColumnType. The macro is in `./column_type.rs`. + // PGColumnValidator are used to softly ensure that the correct `ValueType` variants are created. + // If you ever add a new type or change some mapping, please ensure you pass the data through `v.read()`. + let result = match column_type { + PGColumnType::Boolean(ty, v) => match ty { + PGColumnTypeBoolean::BOOL => ValueType::Boolean(v.read(row.try_get(i)?)), }, - PostgresType::BYTEA => match row.try_get(i)? { - Some(val) => { - let val: &[u8] = val; - Value::bytes(val.to_owned()) - } - None => Value::null_bytes(), - }, - PostgresType::BYTEA_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec>> = val; - let byteas = val.into_iter().map(|b| ValueType::Bytes(b.map(Into::into))); + PGColumnType::Int32(ty, v) => match ty { + PGColumnTypeInt32::INT2 => { + let val: Option = row.try_get(i)?; - Value::array(byteas) + ValueType::Int32(v.read(val.map(i32::from))) } - None => Value::null_array(), - }, - PostgresType::NUMERIC => { - let dw: Option = row.try_get(i)?; + PGColumnTypeInt32::INT4 => { + let val: Option = row.try_get(i)?; - ValueType::Numeric(dw.map(|dw| dw.0)).into_value() - } - PostgresType::MONEY => match row.try_get(i)? { - Some(val) => { - let val: NaiveMoney = val; - Value::numeric(val.0) + ValueType::Int32(v.read(val)) } - None => Value::null_numeric(), }, - PostgresType::TIMESTAMP => match row.try_get(i)? { - Some(val) => { - let ts: NaiveDateTime = val; - let dt = DateTime::::from_naive_utc_and_offset(ts, Utc); - Value::datetime(dt) + PGColumnType::Int64(ty, v) => match ty { + PGColumnTypeInt64::INT8 => { + let val = v.read(row.try_get(i)?); + + ValueType::Int64(val) } - None => Value::null_datetime(), - }, - PostgresType::TIMESTAMPTZ => match row.try_get(i)? { - Some(val) => { - let ts: DateTime = val; - Value::datetime(ts) + PGColumnTypeInt64::OID => { + let val: Option = row.try_get(i)?; + + ValueType::Int64(v.read(val.map(i64::from))) } - None => Value::null_datetime(), }, - PostgresType::DATE => match row.try_get(i)? { - Some(val) => Value::date(val), - None => Value::null_date(), + PGColumnType::Float(ty, v) => match ty { + PGColumnTypeFloat::FLOAT4 => ValueType::Float(v.read(row.try_get(i)?)), }, - PostgresType::TIME => match row.try_get(i)? { - Some(val) => Value::time(val), - None => Value::null_time(), + PGColumnType::Double(ty, v) => match ty { + PGColumnTypeDouble::FLOAT8 => ValueType::Double(v.read(row.try_get(i)?)), }, - PostgresType::TIMETZ => match row.try_get(i)? { - Some(val) => { - let time: TimeTz = val; - Value::time(time.0) + PGColumnType::Bytes(ty, v) => match ty { + PGColumnTypeBytes::BYTEA => { + let val: Option<&[u8]> = row.try_get(i)?; + let val = val.map(ToOwned::to_owned).map(Cow::Owned); + + ValueType::Bytes(v.read(val)) } - None => Value::null_time(), }, - PostgresType::UUID => match row.try_get(i)? { - Some(val) => { - let val: Uuid = val; - Value::uuid(val) + PGColumnType::Text(ty, v) => match ty { + PGColumnTypeText::INET | PGColumnTypeText::CIDR => { + let val: Option = row.try_get(i)?; + let val = val.map(|val| val.to_string()).map(Cow::from); + + ValueType::Text(v.read(val)) } - None => ValueType::Uuid(None).into_value(), - }, - PostgresType::UUID_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let val = val.into_iter().map(ValueType::Uuid); + PGColumnTypeText::VARBIT | PGColumnTypeText::BIT => { + let val: Option = row.try_get(i)?; + let val_str = val.map(|val| bits_to_string(&val)).transpose()?.map(Cow::Owned); - Value::array(val) + ValueType::Text(v.read(val_str)) } - None => Value::null_array(), }, - PostgresType::JSON | PostgresType::JSONB => ValueType::Json(row.try_get(i)?).into_value(), - PostgresType::INT2_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let ints = val.into_iter().map(|i| ValueType::Int32(i.map(|i| i as i32))); + PGColumnType::Char(ty, v) => match ty { + PGColumnTypeChar::CHAR => { + let val: Option = row.try_get(i)?; + let val = val.map(|val| (val as u8) as char); - Value::array(ints) + ValueType::Char(v.read(val)) } - None => Value::null_array(), }, - PostgresType::INT4_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let ints = val.into_iter().map(ValueType::Int32); + PGColumnType::Numeric(ty, v) => match ty { + PGColumnTypeNumeric::NUMERIC => { + let dw: Option = row.try_get(i)?; + let val = dw.map(|dw| dw.0); - Value::array(ints) + ValueType::Numeric(v.read(val)) } - None => Value::null_array(), - }, - PostgresType::INT8_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let ints = val.into_iter().map(ValueType::Int64); + PGColumnTypeNumeric::MONEY => { + let val: Option = row.try_get(i)?; - Value::array(ints) + ValueType::Numeric(v.read(val.map(|val| val.0))) } - None => Value::null_array(), }, - PostgresType::FLOAT4_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let floats = val.into_iter().map(ValueType::Float); + PGColumnType::DateTime(ty, v) => match ty { + PGColumnTypeDateTime::TIMESTAMP => { + let ts: Option = row.try_get(i)?; + let dt = ts.map(|ts| DateTime::::from_naive_utc_and_offset(ts, Utc)); - Value::array(floats) + ValueType::DateTime(v.read(dt)) } - None => Value::null_array(), - }, - PostgresType::FLOAT8_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let floats = val.into_iter().map(ValueType::Double); + PGColumnTypeDateTime::TIMESTAMPTZ => { + let ts: Option> = row.try_get(i)?; - Value::array(floats) + ValueType::DateTime(v.read(ts)) } - None => Value::null_array(), }, - PostgresType::BOOL_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let bools = val.into_iter().map(ValueType::Boolean); + PGColumnType::Date(ty, v) => match ty { + PGColumnTypeDate::DATE => ValueType::Date(v.read(row.try_get(i)?)), + }, + PGColumnType::Time(ty, v) => match ty { + PGColumnTypeTime::TIME => ValueType::Time(v.read(row.try_get(i)?)), + PGColumnTypeTime::TIMETZ => { + let val: Option = row.try_get(i)?; - Value::array(bools) + ValueType::Time(v.read(val.map(|val| val.0))) } - None => Value::null_array(), }, - PostgresType::TIMESTAMP_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - - let dates = val.into_iter().map(|dt| { - ValueType::DateTime(dt.map(|dt| DateTime::::from_naive_utc_and_offset(dt, Utc))) - }); + PGColumnType::Json(ty, v) => match ty { + PGColumnTypeJson::JSON | PGColumnTypeJson::JSONB => ValueType::Json(v.read(row.try_get(i)?)), + }, + PGColumnType::Xml(ty, v) => match ty { + PGColumnTypeXml::XML => { + let val: Option = row.try_get(i)?; - Value::array(dates) + ValueType::Xml(v.read(val.map(|val| Cow::Owned(val.0)))) } - None => Value::null_array(), }, - PostgresType::NUMERIC_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; + PGColumnType::Uuid(ty, v) => match ty { + PGColumnTypeUuid::UUID => ValueType::Uuid(v.read(row.try_get(i)?)), + }, + PGColumnType::Int32Array(ty, v) => match ty { + PGColumnTypeInt32Array::INT2_ARRAY => { + let vals: Option>> = row.try_get(i)?; - let decimals = val - .into_iter() - .map(|dec| ValueType::Numeric(dec.map(|dec| dec.0.to_string().parse().unwrap()))); + match vals { + Some(vals) => { + let ints = vals.into_iter().map(|val| val.map(i32::from)); - Value::array(decimals) + ValueType::Array(Some( + v.read(ints).map(ValueType::Int32).map(ValueType::into_value).collect(), + )) + } + None => ValueType::Array(None), + } + } + PGColumnTypeInt32Array::INT4_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::Int32) + .map(ValueType::into_value) + .collect(), + )), + None => ValueType::Array(None), + } } - None => Value::null_array(), }, - PostgresType::TEXT_ARRAY | PostgresType::NAME_ARRAY | PostgresType::VARCHAR_ARRAY => { - match row.try_get(i)? { - Some(val) => { - let strings: Vec> = val; - - Value::array(strings.into_iter().map(|s| s.map(|s| s.to_string()))) + PGColumnType::Int64Array(ty, v) => match ty { + PGColumnTypeInt64Array::INT8_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::Int64) + .map(ValueType::into_value) + .collect(), + )), + None => ValueType::Array(None), } - None => Value::null_array(), } - } - PostgresType::MONEY_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let nums = val.into_iter().map(|num| ValueType::Numeric(num.map(|num| num.0))); + PGColumnTypeInt64Array::OID_ARRAY => { + let vals: Option>> = row.try_get(i)?; - Value::array(nums) - } - None => Value::null_array(), - }, - PostgresType::OID_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let nums = val.into_iter().map(|oid| ValueType::Int64(oid.map(|oid| oid as i64))); + match vals { + Some(vals) => { + let oids = vals.into_iter().map(|oid| oid.map(i64::from)); - Value::array(nums) + ValueType::Array(Some( + v.read(oids).map(ValueType::Int64).map(ValueType::into_value).collect(), + )) + } + None => ValueType::Array(None), + } } - None => Value::null_array(), }, - PostgresType::TIMESTAMPTZ_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec>> = val; - let dates = val.into_iter().map(ValueType::DateTime); - - Value::array(dates) + PGColumnType::FloatArray(ty, v) => match ty { + PGColumnTypeFloatArray::FLOAT4_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::Float) + .map(ValueType::into_value) + .collect(), + )), + None => ValueType::Array(None), + } } - None => Value::null_array(), }, - PostgresType::DATE_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let dates = val.into_iter().map(ValueType::Date); - - Value::array(dates) + PGColumnType::DoubleArray(ty, v) => match ty { + PGColumnTypeDoubleArray::FLOAT8_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::Double) + .map(ValueType::into_value) + .collect(), + )), + None => ValueType::Array(None), + } } - None => Value::null_array(), }, - PostgresType::TIME_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let times = val.into_iter().map(ValueType::Time); - - Value::array(times) + PGColumnType::TextArray(ty, v) => match ty { + PGColumnTypeTextArray::TEXT_ARRAY + | PGColumnTypeTextArray::NAME_ARRAY + | PGColumnTypeTextArray::VARCHAR_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => { + let strings = vals.into_iter().map(|s| s.map(ToOwned::to_owned).map(Cow::Owned)); + + ValueType::Array(Some( + v.read(strings) + .map(ValueType::Text) + .map(ValueType::into_value) + .collect(), + )) + } + None => ValueType::Array(None), + } } - None => Value::null_array(), - }, - PostgresType::TIMETZ_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let timetzs = val.into_iter().map(|time| ValueType::Time(time.map(|time| time.0))); + PGColumnTypeTextArray::INET_ARRAY | PGColumnTypeTextArray::CIDR_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => { + let addrs = vals + .into_iter() + .map(|ip| ip.as_ref().map(ToString::to_string).map(Cow::Owned)); - Value::array(timetzs) + ValueType::Array(Some( + v.read(addrs).map(ValueType::Text).map(ValueType::into_value).collect(), + )) + } + None => ValueType::Array(None), + } } - None => Value::null_array(), - }, - PostgresType::JSON_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let jsons = val.into_iter().map(ValueType::Json); + PGColumnTypeTextArray::BIT_ARRAY | PGColumnTypeTextArray::VARBIT_ARRAY => { + let vals: Option>> = row.try_get(i)?; - Value::array(jsons) + match vals { + Some(vals) => { + let vals = vals + .into_iter() + .map(|bits| bits.map(|bits| bits_to_string(&bits).map(Cow::Owned)).transpose()) + .collect::>>()?; + + ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::Text) + .map(ValueType::into_value) + .collect(), + )) + } + None => ValueType::Array(None), + } } - None => Value::null_array(), - }, - PostgresType::JSONB_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let jsons = val.into_iter().map(ValueType::Json); + PGColumnTypeTextArray::XML_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => { + let xmls = vals.into_iter().map(|xml| xml.map(|xml| xml.0).map(Cow::Owned)); - Value::array(jsons) + ValueType::Array(Some( + v.read(xmls).map(ValueType::Text).map(ValueType::into_value).collect(), + )) + } + None => ValueType::Array(None), + } } - None => Value::null_array(), }, - PostgresType::OID => match row.try_get(i)? { - Some(val) => { - let val: u32 = val; - Value::int64(val) + PGColumnType::BytesArray(ty, v) => match ty { + PGColumnTypeBytesArray::BYTEA_ARRAY => { + let vals: Option>>> = row.try_get(i)?; + + match vals { + Some(vals) => ValueType::Array(Some( + v.read(vals.into_iter()) + .map(|b| b.map(Cow::Owned)) + .map(ValueType::Bytes) + .map(ValueType::into_value) + .collect(), + )), + None => ValueType::Array(None), + } } - None => Value::null_int64(), }, - PostgresType::CHAR => match row.try_get(i)? { - Some(val) => { - let val: i8 = val; - Value::character((val as u8) as char) + PGColumnType::BooleanArray(ty, v) => match ty { + PGColumnTypeBooleanArray::BOOL_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::Boolean) + .map(ValueType::into_value) + .collect(), + )), + None => ValueType::Array(None), + } } - None => Value::null_character(), }, - PostgresType::INET | PostgresType::CIDR => match row.try_get(i)? { - Some(val) => { - let val: std::net::IpAddr = val; - Value::text(val.to_string()) + PGColumnType::NumericArray(ty, v) => match ty { + PGColumnTypeNumericArray::NUMERIC_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => { + let decimals = vals.into_iter().map(|dec| dec.map(|dec| dec.0)); + + ValueType::Array(Some( + v.read(decimals.into_iter()) + .map(ValueType::Numeric) + .map(ValueType::into_value) + .collect(), + )) + } + None => ValueType::Array(None), + } } - None => Value::null_text(), - }, - PostgresType::INET_ARRAY | PostgresType::CIDR_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let addrs = val - .into_iter() - .map(|ip| ValueType::Text(ip.map(|ip| ip.to_string().into()))); - - Value::array(addrs) + PGColumnTypeNumericArray::MONEY_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => { + let nums = vals.into_iter().map(|num| num.map(|num| num.0)); + + ValueType::Array(Some( + v.read(nums.into_iter()) + .map(ValueType::Numeric) + .map(ValueType::into_value) + .collect(), + )) + } + None => ValueType::Array(None), + } } - None => Value::null_array(), }, - PostgresType::BIT | PostgresType::VARBIT => match row.try_get(i)? { - Some(val) => { - let val: BitVec = val; - Value::text(bits_to_string(&val)?) + PGColumnType::JsonArray(ty, v) => match ty { + PGColumnTypeJsonArray::JSON_ARRAY | PGColumnTypeJsonArray::JSONB_ARRAY => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::Json) + .map(ValueType::into_value) + .collect(), + )), + None => ValueType::Array(None), + } } - None => Value::null_text(), }, - PostgresType::BIT_ARRAY | PostgresType::VARBIT_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - val.into_iter() - .map(|bits| match bits { - Some(bits) => bits_to_string(&bits).map(|s| ValueType::Text(Some(s.into()))), - None => Ok(ValueType::Text(None)), - }) - .collect::>>() - .map(Value::array)? - } - None => Value::null_array(), + PGColumnType::UuidArray(ty, v) => match ty { + PGColumnTypeUuidArray::UUID_ARRAY => match row.try_get(i)? { + Some(vals) => { + let vals: Vec> = vals; + + ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::Uuid) + .map(ValueType::into_value) + .collect(), + )) + } + None => ValueType::Array(None), + }, }, - PostgresType::XML => match row.try_get(i)? { - Some(val) => { - let val: XmlString = val; - Value::xml(val.0) - } - None => Value::null_xml(), + PGColumnType::DateTimeArray(ty, v) => match ty { + PGColumnTypeDateTimeArray::TIMESTAMP_ARRAY => match row.try_get(i)? { + Some(vals) => { + let vals: Vec> = vals; + let dates = vals + .into_iter() + .map(|dt| dt.map(|dt| DateTime::::from_naive_utc_and_offset(dt, Utc))); + + ValueType::Array(Some( + v.read(dates) + .map(ValueType::DateTime) + .map(ValueType::into_value) + .collect(), + )) + } + None => ValueType::Array(None), + }, + PGColumnTypeDateTimeArray::TIMESTAMPTZ_ARRAY => match row.try_get(i)? { + Some(vals) => { + let vals: Vec>> = vals; + + ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::DateTime) + .map(ValueType::into_value) + .collect(), + )) + } + None => ValueType::Array(None), + }, }, - PostgresType::XML_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let xmls = val.into_iter().map(|xml| xml.map(|xml| xml.0)); - - Value::array(xmls) - } - None => Value::null_array(), + PGColumnType::DateArray(ty, v) => match ty { + PGColumnTypeDateArray::DATE_ARRAY => match row.try_get(i)? { + Some(vals) => { + let vals: Vec> = vals; + + ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::Date) + .map(ValueType::into_value) + .collect(), + )) + } + None => ValueType::Array(None), + }, }, - ref x => match x.kind() { - Kind::Enum => match row.try_get(i)? { + PGColumnType::TimeArray(ty, v) => match ty { + PGColumnTypeTimeArray::TIME_ARRAY => match row.try_get(i)? { + Some(vals) => { + let vals: Vec> = vals; + + ValueType::Array(Some( + v.read(vals.into_iter()) + .map(ValueType::Time) + .map(ValueType::into_value) + .collect(), + )) + } + None => ValueType::Array(None), + }, + PGColumnTypeTimeArray::TIMETZ_ARRAY => match row.try_get(i)? { Some(val) => { - let val: EnumString = val; - - Value::enum_variant(val.value) + let val: Vec> = val; + let timetzs = val.into_iter().map(|time| time.map(|time| time.0)); + + ValueType::Array(Some( + v.read(timetzs.into_iter()) + .map(ValueType::Time) + .map(ValueType::into_value) + .collect(), + )) } - None => Value::null_enum(), + None => ValueType::Array(None), }, - Kind::Array(inner) => match inner.kind() { - Kind::Enum => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let variants = val - .into_iter() - .map(|x| ValueType::Enum(x.map(|x| x.value.into()), None)); - - Ok(Value::array(variants)) - } - None => Ok(Value::null_array()), - }, - _ => match row.try_get(i) { - Ok(Some(val)) => { - let val: Vec> = val; - let strings = val.into_iter().map(|str| ValueType::Text(str.map(Into::into))); - - Ok(Value::array(strings)) - } - Ok(None) => Ok(Value::null_array()), - Err(err) => { - if err.source().map(|err| err.is::()).unwrap_or(false) { - let kind = ErrorKind::UnsupportedColumnType { - column_type: x.to_string(), - }; - - return Err(Error::builder(kind).build()); - } else { - Err(err) - } - } - }, - }?, - _ => match row.try_get(i) { - Ok(Some(val)) => { - let val: String = val; + }, + PGColumnType::EnumArray(v) => { + let vals: Option>> = row.try_get(i)?; + + match vals { + Some(vals) => { + let enums = vals.into_iter().map(|val| val.map(|val| Cow::Owned(val.value))); + + ValueType::Array(Some( + v.read(enums) + .map(|variant| ValueType::Enum(variant.map(EnumVariant::new), None)) + .map(ValueType::into_value) + .collect(), + )) + } + None => ValueType::Array(None), + } + } + PGColumnType::Enum(v) => { + let val: Option = row.try_get(i)?; + let enum_variant = v.read(val.map(|x| Cow::Owned(x.value))); - Ok(Value::text(val)) + ValueType::Enum(enum_variant.map(EnumVariant::new), None) + } + PGColumnType::UnknownArray(v) => match row.try_get(i) { + Ok(Some(vals)) => { + let vals: Vec> = vals; + let strings = vals.into_iter().map(|str| str.map(Cow::Owned)); + + Ok(ValueType::Array(Some( + v.read(strings.into_iter()) + .map(ValueType::Text) + .map(ValueType::into_value) + .collect(), + ))) + } + Ok(None) => Ok(ValueType::Array(None)), + Err(err) => { + if err.source().map(|err| err.is::()).unwrap_or(false) { + let kind = ErrorKind::UnsupportedColumnType { + column_type: pg_ty.to_string(), + }; + + return Err(Error::builder(kind).build()); + } else { + Err(err) } - Ok(None) => Ok(Value::from(ValueType::Text(None))), - Err(err) => { - if err.source().map(|err| err.is::()).unwrap_or(false) { - let kind = ErrorKind::UnsupportedColumnType { - column_type: x.to_string(), - }; - - return Err(Error::builder(kind).build()); - } else { - Err(err) - } + } + }?, + PGColumnType::Unknown(v) => match row.try_get(i) { + Ok(Some(val)) => { + let val: String = val; + + Ok(ValueType::Text(v.read(Some(Cow::Owned(val))))) + } + Ok(None) => Ok(ValueType::Text(None)), + Err(err) => { + if err.source().map(|err| err.is::()).unwrap_or(false) { + let kind = ErrorKind::UnsupportedColumnType { + column_type: pg_ty.to_string(), + }; + + return Err(Error::builder(kind).build()); + } else { + Err(err) } - }?, - }, + } + }?, }; - Ok(result) + Ok(result.into_value()) } let num_columns = self.columns().len(); diff --git a/quaint/src/connector/postgres/native/explain.rs b/quaint/src/connector/postgres/native/explain.rs new file mode 100644 index 000000000000..6a3e8594f8d7 --- /dev/null +++ b/quaint/src/connector/postgres/native/explain.rs @@ -0,0 +1,57 @@ +#[derive(serde::Deserialize, Debug)] +#[serde(untagged)] +pub(crate) enum Explain { + // NOTE: the returned JSON may not contain a `plan` field, for example, with `CALL` statements: + // https://github.com/launchbadge/sqlx/issues/1449 + // + // In this case, we should just fall back to assuming all is nullable. + // + // It may also contain additional fields we don't care about, which should not break parsing: + // https://github.com/launchbadge/sqlx/issues/2587 + // https://github.com/launchbadge/sqlx/issues/2622 + Plan { + #[serde(rename = "Plan")] + plan: Plan, + }, + + // This ensures that parsing never technically fails. + // + // We don't want to specifically expect `"Utility Statement"` because there might be other cases + // and we don't care unless it contains a query plan anyway. + Other(serde::de::IgnoredAny), +} + +#[derive(serde::Deserialize, Debug)] +pub(crate) struct Plan { + #[serde(rename = "Join Type")] + pub(crate) join_type: Option, + #[serde(rename = "Parent Relationship")] + pub(crate) parent_relation: Option, + #[serde(rename = "Output")] + pub(crate) output: Option>, + #[serde(rename = "Plans")] + pub(crate) plans: Option>, +} + +pub(crate) fn visit_plan(plan: &Plan, outputs: &[String], nullables: &mut Vec>) { + if let Some(plan_outputs) = &plan.output { + // all outputs of a Full Join must be marked nullable + // otherwise, all outputs of the inner half of an outer join must be marked nullable + if plan.join_type.as_deref() == Some("Full") || plan.parent_relation.as_deref() == Some("Inner") { + for output in plan_outputs { + if let Some(i) = outputs.iter().position(|o| o == output) { + // N.B. this may produce false positives but those don't cause runtime errors + nullables[i] = Some(true); + } + } + } + } + + if let Some(plans) = &plan.plans { + if let Some("Left") | Some("Right") = plan.join_type.as_deref() { + for plan in plans { + visit_plan(plan, outputs, nullables); + } + } + } +} diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index 85bd34066b90..eb6618ce9dc7 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -1,14 +1,20 @@ //! Definitions for the Postgres connector. //! This module is not compatible with wasm32-* targets. //! This module is only available with the `postgresql-native` feature. +pub(crate) mod column_type; mod conversion; mod error; +mod explain; +mod websocket; -pub(crate) use crate::connector::postgres::url::PostgresUrl; +pub(crate) use crate::connector::postgres::url::PostgresNativeUrl; use crate::connector::postgres::url::{Hidden, SslAcceptMode, SslParams}; -use crate::connector::{timeout, IsolationLevel, Transaction}; +use crate::connector::{ + timeout, ColumnType, DescribedColumn, DescribedParameter, DescribedQuery, IsolationLevel, Transaction, +}; use crate::error::NativeErrorKind; +use crate::ValueType; use crate::{ ast::{Query, Value}, connector::{metrics, queryable::*, ResultSet}, @@ -16,26 +22,33 @@ use crate::{ visitor::{self, Visitor}, }; use async_trait::async_trait; +use column_type::PGColumnType; use futures::{future::FutureExt, lock::Mutex}; use lru_cache::LruCache; use native_tls::{Certificate, Identity, TlsConnector}; use postgres_native_tls::MakeTlsConnector; +use postgres_types::{Kind as PostgresKind, Type as PostgresType}; +use prisma_metrics::WithMetricsInstrumentation; use std::hash::{DefaultHasher, Hash, Hasher}; use std::{ - borrow::Borrow, fmt::{Debug, Display}, fs, future::Future, sync::atomic::{AtomicBool, Ordering}, time::Duration, }; +use tokio::sync::OnceCell; use tokio_postgres::{config::ChannelBinding, Client, Config, Statement}; +use tracing_futures::WithSubscriber; +use websocket::connect_via_websocket; /// The underlying postgres driver. Only available with the `expose-drivers` /// Cargo feature. #[cfg(feature = "expose-drivers")] pub use tokio_postgres; +use super::PostgresWebSocketUrl; + struct PostgresClient(Client); impl Debug for PostgresClient { @@ -44,6 +57,9 @@ impl Debug for PostgresClient { } } +const DB_SYSTEM_NAME_POSTGRESQL: &str = "postgresql"; +const DB_SYSTEM_NAME_COCKROACHDB: &str = "cockroachdb"; + /// A connector interface for the PostgreSQL database. #[derive(Debug)] pub struct PostgreSql { @@ -52,6 +68,9 @@ pub struct PostgreSql { socket_timeout: Option, statement_cache: Mutex, is_healthy: AtomicBool, + is_cockroachdb: bool, + is_materialize: bool, + db_system_name: &'static str, } /// Key uniquely representing an SQL statement in the prepared statements cache. @@ -151,7 +170,7 @@ impl SslParams { } } -impl PostgresUrl { +impl PostgresNativeUrl { pub(crate) fn cache(&self) -> StatementCache { if self.query_params.pg_bouncer { StatementCache::new(0) @@ -188,8 +207,8 @@ impl PostgresUrl { pub(crate) fn to_config(&self) -> Config { let mut config = Config::new(); - config.user(self.username().borrow()); - config.password(self.password().borrow() as &str); + config.user(self.username().as_ref()); + config.password(self.password().as_ref()); config.host(self.host()); config.port(self.port()); config.dbname(self.dbname()); @@ -219,35 +238,24 @@ impl PostgresUrl { impl PostgreSql { /// Create a new connection to the database. - pub async fn new(url: PostgresUrl) -> crate::Result { + pub async fn new(url: PostgresNativeUrl, tls_manager: &MakeTlsConnectorManager) -> crate::Result { let config = url.to_config(); - let mut tls_builder = TlsConnector::builder(); - - { - let ssl_params = url.ssl_params(); - let auth = ssl_params.to_owned().into_auth().await?; - - if let Some(certificate) = auth.certificate.0 { - tls_builder.add_root_certificate(certificate); - } - - tls_builder.danger_accept_invalid_certs(auth.ssl_accept_mode == SslAcceptMode::AcceptInvalidCerts); - - if let Some(identity) = auth.identity.0 { - tls_builder.identity(identity); - } - } - - let tls = MakeTlsConnector::new(tls_builder.build()?); + let tls = tls_manager.get_connector().await?; let (client, conn) = timeout::connect(url.connect_timeout(), config.connect(tls)).await?; - tokio::spawn(conn.map(|r| match r { - Ok(_) => (), - Err(e) => { - tracing::error!("Error in PostgreSQL connection: {:?}", e); - } - })); + let is_cockroachdb = conn.parameter("crdb_version").is_some(); + let is_materialize = conn.parameter("mz_version").is_some(); + + tokio::spawn( + conn.map(|r| { + if let Err(e) = r { + tracing::error!("Error in PostgreSQL connection: {e:?}"); + } + }) + .with_current_subscriber() + .with_current_recorder(), + ); // On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. // We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. @@ -269,12 +277,37 @@ impl PostgreSql { } } + let db_system_name = if is_cockroachdb { + DB_SYSTEM_NAME_COCKROACHDB + } else { + DB_SYSTEM_NAME_POSTGRESQL + }; + Ok(Self { client: PostgresClient(client), socket_timeout: url.query_params.socket_timeout, pg_bouncer: url.query_params.pg_bouncer, statement_cache: Mutex::new(url.cache()), is_healthy: AtomicBool::new(true), + is_cockroachdb, + is_materialize, + db_system_name, + }) + } + + /// Create a new websocket connection to managed database + pub async fn new_with_websocket(url: PostgresWebSocketUrl) -> crate::Result { + let client = connect_via_websocket(url).await?; + + Ok(Self { + client: PostgresClient(client), + socket_timeout: None, + pg_bouncer: false, + statement_cache: Mutex::new(StatementCache::new(0)), + is_healthy: AtomicBool::new(true), + is_cockroachdb: false, + is_materialize: false, + db_system_name: DB_SYSTEM_NAME_POSTGRESQL, }) } @@ -348,6 +381,112 @@ impl PostgreSql { Ok(()) } } + + // All credits go to sqlx: https://github.com/launchbadge/sqlx/blob/a892ebc6e283f443145f92bbc7fce4ae44547331/sqlx-postgres/src/connection/describe.rs#L417 + pub(crate) async fn get_nullable_for_columns(&self, stmt: &Statement) -> crate::Result>> { + let columns = stmt.columns(); + + if columns.is_empty() { + return Ok(vec![]); + } + + let mut nullable_query = String::from("SELECT NOT pg_attribute.attnotnull as nullable FROM (VALUES "); + let mut args = Vec::with_capacity(columns.len() * 3); + + for (i, (column, bind)) in columns.iter().zip((1..).step_by(3)).enumerate() { + if !args.is_empty() { + nullable_query += ", "; + } + + nullable_query.push_str(&format!("(${}::int4, ${}::int8, ${}::int4)", bind, bind + 1, bind + 2)); + + args.push(Value::from(i as i32)); + args.push(ValueType::Int64(column.table_oid().map(i64::from)).into()); + args.push(ValueType::Int32(column.column_id().map(i32::from)).into()); + } + + nullable_query.push_str( + ") as col(idx, table_id, col_idx) \ + LEFT JOIN pg_catalog.pg_attribute \ + ON table_id IS NOT NULL \ + AND attrelid = table_id \ + AND attnum = col_idx \ + ORDER BY col.idx", + ); + + let nullable_query_result = self.query_raw(&nullable_query, &args).await?; + let mut nullables = Vec::with_capacity(nullable_query_result.len()); + + for row in nullable_query_result { + let nullable = row.at(0).and_then(|v| v.as_bool()); + + nullables.push(nullable) + } + + // If the server is CockroachDB or Materialize, skip this step (#1248). + if !self.is_cockroachdb && !self.is_materialize { + // patch up our null inference with data from EXPLAIN + let nullable_patch = self.nullables_from_explain(stmt).await?; + + for (nullable, patch) in nullables.iter_mut().zip(nullable_patch) { + *nullable = patch.or(*nullable); + } + } + + Ok(nullables) + } + + /// Infer nullability for columns of this statement using EXPLAIN VERBOSE. + /// + /// This currently only marks columns that are on the inner half of an outer join + /// and returns `None` for all others. + /// All credits go to sqlx: https://github.com/launchbadge/sqlx/blob/a892ebc6e283f443145f92bbc7fce4ae44547331/sqlx-postgres/src/connection/describe.rs#L482 + async fn nullables_from_explain(&self, stmt: &Statement) -> Result>, Error> { + use explain::{visit_plan, Explain, Plan}; + + let mut explain = format!("EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE {}", stmt.name()); + let params_len = stmt.params().len(); + let mut comma = false; + + if params_len > 0 { + explain += "("; + + // fill the arguments list with NULL, which should theoretically be valid + for _ in 0..params_len { + if comma { + explain += ", "; + } + + explain += "NULL"; + comma = true; + } + + explain += ")"; + } + + let explain_result = self.query_raw(&explain, &[]).await?.into_single()?; + let explains = explain_result + .into_single()? + .into_json() + .map(serde_json::from_value::<[Explain; 1]>) + .transpose()?; + let explain = explains.as_ref().and_then(|x| x.first()); + + let mut nullables = Vec::new(); + + if let Some(Explain::Plan { + plan: plan @ Plan { + output: Some(ref outputs), + .. + }, + }) = explain + { + nullables.resize(outputs.len(), None); + visit_plan(plan, outputs, &mut nullables); + } + + Ok(nullables) + } } // A SearchPath connection parameter (Display-impl) for connection initialization. @@ -400,61 +539,160 @@ impl Queryable for PostgreSql { async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { self.check_bind_variables_len(params)?; - metrics::query("postgres.query_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, &[]).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } + metrics::query( + "postgres.query_raw", + self.db_system_name, + sql, + params, + move || async move { + let stmt = self.fetch_cached(sql, &[]).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } - let rows = self - .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) - .await?; + let rows = self + .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) + .await?; - let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); + let col_types = stmt + .columns() + .iter() + .map(|c| PGColumnType::from_pg_type(c.type_())) + .map(ColumnType::from) + .collect::>(); + let mut result = ResultSet::new(stmt.to_column_names(), col_types, Vec::new()); - for row in rows { - result.rows.push(row.get_result_row()?); - } + for row in rows { + result.rows.push(row.get_result_row()?); + } - Ok(result) - }) + Ok(result) + }, + ) .await } async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { self.check_bind_variables_len(params)?; - metrics::query("postgres.query_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, params).await?; + metrics::query( + "postgres.query_raw", + self.db_system_name, + sql, + params, + move || async move { + let stmt = self.fetch_cached(sql, params).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; + let col_types = stmt + .columns() + .iter() + .map(|c| PGColumnType::from_pg_type(c.type_())) + .map(ColumnType::from) + .collect::>(); + let rows = self + .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) + .await?; - return Err(Error::builder(kind).build()); + let mut result = ResultSet::new(stmt.to_column_names(), col_types, Vec::new()); + + for row in rows { + result.rows.push(row.get_result_row()?); + } + + Ok(result) + }, + ) + .await + } + + async fn describe_query(&self, sql: &str) -> crate::Result { + let stmt = self.fetch_cached(sql, &[]).await?; + + let mut columns: Vec = Vec::with_capacity(stmt.columns().len()); + let mut parameters: Vec = Vec::with_capacity(stmt.params().len()); + + let enums_results = self + .query_raw("SELECT oid, typname FROM pg_type WHERE typtype = 'e';", &[]) + .await?; + + fn find_enum_by_oid(enums: &ResultSet, enum_oid: u32) -> Option<&str> { + enums.iter().find_map(|row| { + let oid = row.get("oid")?.as_i64()?; + let name = row.get("typname")?.as_str()?; + + if enum_oid == u32::try_from(oid).unwrap() { + Some(name) + } else { + None + } + }) + } + + fn resolve_type(ty: &PostgresType, enums: &ResultSet) -> (ColumnType, Option) { + let column_type = ColumnType::from(ty); + + match ty.kind() { + PostgresKind::Enum => { + let enum_name = find_enum_by_oid(enums, ty.oid()) + .unwrap_or_else(|| panic!("Could not find enum with oid {}", ty.oid())); + + (column_type, Some(enum_name.to_string())) + } + _ => (column_type, None), } + } + + let nullables = self.get_nullable_for_columns(&stmt).await?; - let rows = self - .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) - .await?; + for (idx, (col, nullable)) in stmt.columns().iter().zip(nullables).enumerate() { + let (typ, enum_name) = resolve_type(col.type_(), &enums_results); - let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); + if col.name() == "?column?" { + let kind = ErrorKind::QueryInvalidInput(format!("Invalid column name '?column?' for index {idx}. Your SQL query must explicitly alias that column name.")); - for row in rows { - result.rows.push(row.get_result_row()?); + return Err(Error::builder(kind).build()); } - Ok(result) + columns.push( + DescribedColumn::new_named(col.name(), typ) + .with_enum_name(enum_name) + // Make fields nullable by default if we can't infer nullability. + .is_nullable(nullable.unwrap_or(true)), + ); + } + + for param in stmt.params() { + let (typ, enum_name) = resolve_type(param, &enums_results); + + parameters.push(DescribedParameter::new_named(param.name(), typ).with_enum_name(enum_name)); + } + + let enum_names = enums_results + .into_iter() + .filter_map(|row| row.take("typname")) + .filter_map(|v| v.to_string()) + .collect::>(); + + Ok(DescribedQuery { + columns, + parameters, + enum_names: Some(enum_names), }) - .await } async fn execute(&self, q: Query<'_>) -> crate::Result { @@ -466,53 +704,65 @@ impl Queryable for PostgreSql { async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { self.check_bind_variables_len(params)?; - metrics::query("postgres.execute_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, &[]).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } + metrics::query( + "postgres.execute_raw", + self.db_system_name, + sql, + params, + move || async move { + let stmt = self.fetch_cached(sql, &[]).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } - let changes = self - .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) - .await?; + let changes = self + .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) + .await?; - Ok(changes) - }) + Ok(changes) + }, + ) .await } async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { self.check_bind_variables_len(params)?; - metrics::query("postgres.execute_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, params).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } + metrics::query( + "postgres.execute_raw", + self.db_system_name, + sql, + params, + move || async move { + let stmt = self.fetch_cached(sql, params).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } - let changes = self - .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) - .await?; + let changes = self + .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) + .await?; - Ok(changes) - }) + Ok(changes) + }, + ) .await } async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("postgres.raw_cmd", cmd, &[], move || async move { + metrics::query("postgres.raw_cmd", self.db_system_name, cmd, &[], move || async move { self.perform_io(self.client.0.simple_query(cmd)).await?; Ok(()) }) @@ -700,6 +950,48 @@ fn is_safe_identifier(ident: &str) -> bool { true } +pub struct MakeTlsConnectorManager { + url: PostgresNativeUrl, + connector: OnceCell, +} + +impl MakeTlsConnectorManager { + pub fn new(url: PostgresNativeUrl) -> Self { + MakeTlsConnectorManager { + url, + connector: OnceCell::new(), + } + } + + pub async fn get_connector(&self) -> crate::Result { + self.connector + .get_or_try_init(|| async { + let mut tls_builder = TlsConnector::builder(); + + { + let ssl_params = self.url.ssl_params(); + let auth = ssl_params.to_owned().into_auth().await?; + + if let Some(certificate) = auth.certificate.0 { + tls_builder.add_root_certificate(certificate); + } + + tls_builder.danger_accept_invalid_certs(auth.ssl_accept_mode == SslAcceptMode::AcceptInvalidCerts); + + if let Some(identity) = auth.identity.0 { + tls_builder.identity(identity); + } + } + + let tls_connector = MakeTlsConnector::new(tls_builder.build()?); + + Ok(tls_connector) + }) + .await + .cloned() + } +} + #[cfg(test)] mod tests { use super::*; @@ -715,10 +1007,12 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", schema_name); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Postgres); - let client = PostgreSql::new(pg_url).await.unwrap(); + let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); + + let client = PostgreSql::new(pg_url, &tls_manager).await.unwrap(); let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); let row = result_set.first().unwrap(); @@ -767,10 +1061,12 @@ mod tests { url.query_pairs_mut().append_pair("schema", schema_name); url.query_pairs_mut().append_pair("pbbouncer", "true"); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Postgres); - let client = PostgreSql::new(pg_url).await.unwrap(); + let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); + + let client = PostgreSql::new(pg_url, &tls_manager).await.unwrap(); let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); let row = result_set.first().unwrap(); @@ -818,10 +1114,12 @@ mod tests { let mut url = Url::parse(&CRDB_CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", schema_name); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Cockroach); - let client = PostgreSql::new(pg_url).await.unwrap(); + let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); + + let client = PostgreSql::new(pg_url, &tls_manager).await.unwrap(); let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); let row = result_set.first().unwrap(); @@ -869,10 +1167,12 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", schema_name); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Unknown); - let client = PostgreSql::new(pg_url).await.unwrap(); + let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); + + let client = PostgreSql::new(pg_url, &tls_manager).await.unwrap(); let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); let row = result_set.first().unwrap(); @@ -920,10 +1220,12 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", schema_name); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Unknown); - let client = PostgreSql::new(pg_url).await.unwrap(); + let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); + + let client = PostgreSql::new(pg_url, &tls_manager).await.unwrap(); let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); let row = result_set.first().unwrap(); diff --git a/quaint/src/connector/postgres/native/websocket.rs b/quaint/src/connector/postgres/native/websocket.rs new file mode 100644 index 000000000000..f278c9f099bc --- /dev/null +++ b/quaint/src/connector/postgres/native/websocket.rs @@ -0,0 +1,97 @@ +use std::str::FromStr; + +use async_tungstenite::{ + tokio::connect_async, + tungstenite::{ + self, + client::IntoClientRequest, + http::{HeaderMap, HeaderValue, StatusCode}, + Error as TungsteniteError, + }, +}; +use futures::FutureExt; +use postgres_native_tls::TlsConnector; +use prisma_metrics::WithMetricsInstrumentation; +use tokio_postgres::{Client, Config}; +use tracing_futures::WithSubscriber; +use ws_stream_tungstenite::WsStream; + +use crate::{ + connector::PostgresWebSocketUrl, + error::{self, Error, ErrorKind, Name, NativeErrorKind}, +}; + +const CONNECTION_PARAMS_HEADER: &str = "Prisma-Connection-Parameters"; +const HOST_HEADER: &str = "Prisma-Db-Host"; + +pub(crate) async fn connect_via_websocket(url: PostgresWebSocketUrl) -> crate::Result { + let db_name = url.overriden_db_name().map(ToOwned::to_owned); + let (ws_stream, response) = connect_async(url).await?; + + let connection_params = require_header_value(response.headers(), CONNECTION_PARAMS_HEADER)?; + let db_host = require_header_value(response.headers(), HOST_HEADER)?; + + let mut config = Config::from_str(connection_params)?; + if let Some(db_name) = db_name { + config.dbname(&db_name); + } + let ws_byte_stream = WsStream::new(ws_stream); + + let tls = TlsConnector::new(native_tls::TlsConnector::new()?, db_host); + let (client, connection) = config.connect_raw(ws_byte_stream, tls).await?; + tokio::spawn( + connection + .map(|r| { + if let Err(e) = r { + tracing::error!("Error in PostgreSQL WebSocket connection: {e:?}"); + } + }) + .with_current_subscriber() + .with_current_recorder(), + ); + Ok(client) +} + +fn require_header_value<'a>(headers: &'a HeaderMap, name: &str) -> crate::Result<&'a str> { + let Some(header) = headers.get(name) else { + let message = format!("Missing response header {name}"); + let error = Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(message.into()))).build(); + return Err(error); + }; + + let value = header.to_str().map_err(|inner| { + Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(Box::new(inner)))).build() + })?; + + Ok(value) +} + +impl IntoClientRequest for PostgresWebSocketUrl { + fn into_client_request(self) -> tungstenite::Result { + let mut request = self.url.to_string().into_client_request()?; + let bearer = format!("Bearer {}", self.api_key()); + let auth_header = HeaderValue::from_str(&bearer)?; + request.headers_mut().insert("Authorization", auth_header); + Ok(request) + } +} + +impl From for error::Error { + fn from(value: TungsteniteError) -> Self { + let builder = match value { + TungsteniteError::Tls(tls_error) => Error::builder(ErrorKind::Native(NativeErrorKind::TlsError { + message: tls_error.to_string(), + })), + + TungsteniteError::Http(response) if response.status() == StatusCode::UNAUTHORIZED => { + Error::builder(ErrorKind::DatabaseAccessDenied { + db_name: Name::Unavailable, + }) + } + + _ => Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(Box::new(value)))), + }; + + builder.build() + } +} diff --git a/quaint/src/connector/postgres/url.rs b/quaint/src/connector/postgres/url.rs index 844da48c8d66..096484cdc87a 100644 --- a/quaint/src/connector/postgres/url.rs +++ b/quaint/src/connector/postgres/url.rs @@ -63,16 +63,74 @@ impl PostgresFlavour { } } +#[derive(Debug, Clone)] +pub enum PostgresUrl { + Native(Box), + WebSocket(PostgresWebSocketUrl), +} + +impl PostgresUrl { + pub fn new_native(url: Url) -> Result { + Ok(Self::Native(Box::new(PostgresNativeUrl::new(url)?))) + } + + pub fn new_websocket(url: Url, api_key: String) -> Result { + Ok(Self::WebSocket(PostgresWebSocketUrl::new(url, api_key))) + } + + pub fn dbname(&self) -> &str { + match self { + Self::Native(url) => url.dbname(), + Self::WebSocket(url) => url.dbname(), + } + } + + pub fn host(&self) -> &str { + match self { + Self::Native(native_url) => native_url.host(), + Self::WebSocket(ws_url) => ws_url.host(), + } + } + + pub fn port(&self) -> u16 { + match self { + Self::Native(native_url) => native_url.port(), + Self::WebSocket(ws_url) => ws_url.port(), + } + } + + pub fn username(&self) -> Cow<'_, str> { + match self { + Self::Native(native_url) => native_url.username(), + Self::WebSocket(_) => Cow::Borrowed(""), + } + } + + pub fn schema(&self) -> &str { + match self { + Self::Native(native_url) => native_url.schema(), + Self::WebSocket(_) => "public", + } + } + + pub fn socket_timeout(&self) -> Option { + match self { + Self::Native(native_url) => native_url.socket_timeout(), + Self::WebSocket(_) => None, + } + } +} + /// Wraps a connection url and exposes the parsing logic used by Quaint, /// including default values. #[derive(Debug, Clone)] -pub struct PostgresUrl { +pub struct PostgresNativeUrl { pub(crate) url: Url, pub(crate) query_params: PostgresUrlQueryParams, pub(crate) flavour: PostgresFlavour, } -impl PostgresUrl { +impl PostgresNativeUrl { /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection /// parameters. pub fn new(url: Url) -> Result { @@ -431,6 +489,47 @@ pub(crate) struct PostgresUrlQueryParams { pub(crate) ssl_mode: SslMode, } +#[derive(Debug, Clone)] +pub struct PostgresWebSocketUrl { + pub(crate) url: Url, + pub(crate) api_key: String, + pub(crate) db_name: Option, +} + +impl PostgresWebSocketUrl { + pub fn new(url: Url, api_key: String) -> Self { + Self { + url, + api_key, + db_name: None, + } + } + + pub fn override_db_name(&mut self, name: String) { + self.db_name = Some(name) + } + + pub fn api_key(&self) -> &str { + &self.api_key + } + + pub fn dbname(&self) -> &str { + self.overriden_db_name().unwrap_or("postgres") + } + + pub fn overriden_db_name(&self) -> Option<&str> { + self.db_name.as_deref() + } + + pub fn host(&self) -> &str { + self.url.host_str().unwrap_or("localhost") + } + + pub fn port(&self) -> u16 { + self.url.port().unwrap_or(80) + } +} + #[cfg(test)] mod tests { use super::*; @@ -442,14 +541,15 @@ mod tests { #[test] fn should_parse_socket_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); + let url = PostgresNativeUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); assert_eq!("dbname", url.dbname()); assert_eq!("/var/run/psql.sock", url.host()); } #[test] fn should_parse_escaped_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); + let url = + PostgresNativeUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); assert_eq!("dbname", url.dbname()); assert_eq!("/var/run/postgresql", url.host()); } @@ -457,63 +557,69 @@ mod tests { #[test] fn should_allow_changing_of_cache_size() { let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap(); + PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()) + .unwrap(); assert_eq!(420, url.cache().capacity()); } #[test] fn should_have_default_cache_size() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + let url = PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); assert_eq!(100, url.cache().capacity()); } #[test] fn should_have_application_name() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()).unwrap(); + let url = PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()) + .unwrap(); assert_eq!(Some("test"), url.application_name()); } #[test] fn should_have_channel_binding() { let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()).unwrap(); + PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()) + .unwrap(); assert_eq!(ChannelBinding::Require, url.channel_binding()); } #[test] fn should_have_default_channel_binding() { let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()).unwrap(); + PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()) + .unwrap(); assert_eq!(ChannelBinding::Prefer, url.channel_binding()); - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + let url = PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); assert_eq!(ChannelBinding::Prefer, url.channel_binding()); } #[test] fn should_not_enable_caching_with_pgbouncer() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); + let url = + PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); assert_eq!(0, url.cache().capacity()); } #[test] fn should_parse_default_host() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); + let url = PostgresNativeUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); assert_eq!("dbname", url.dbname()); assert_eq!("localhost", url.host()); } #[test] fn should_parse_ipv6_host() { - let url = PostgresUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); + let url = + PostgresNativeUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); assert_eq!("2001:db8:1234::ffff", url.host()); } #[test] fn should_handle_options_field() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) - .unwrap(); + let url = + PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) + .unwrap(); assert_eq!("--cluster=my_cluster", url.options().unwrap()); } @@ -600,7 +706,7 @@ mod tests { url.query_pairs_mut().append_pair("schema", "hello"); url.query_pairs_mut().append_pair("pgbouncer", "true"); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Postgres); let config = pg_url.to_config(); @@ -616,7 +722,7 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", "hello"); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Postgres); let config = pg_url.to_config(); @@ -630,7 +736,7 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", "hello"); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Cockroach); let config = pg_url.to_config(); @@ -644,7 +750,7 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", "HeLLo"); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Cockroach); let config = pg_url.to_config(); diff --git a/quaint/src/connector/queryable.rs b/quaint/src/connector/queryable.rs index bea2184aa8d1..5f0fd54dad6b 100644 --- a/quaint/src/connector/queryable.rs +++ b/quaint/src/connector/queryable.rs @@ -1,4 +1,4 @@ -use super::{IsolationLevel, ResultSet, Transaction}; +use super::{DescribedQuery, IsolationLevel, ResultSet, Transaction}; use crate::ast::*; use async_trait::async_trait; @@ -57,6 +57,9 @@ pub trait Queryable: Send + Sync { /// parsing or normalization. async fn version(&self) -> crate::Result>; + /// Prepares a statement and returns type information. + async fn describe_query(&self, sql: &str) -> crate::Result; + /// Returns false, if connection is considered to not be in a working state. fn is_healthy(&self) -> bool; diff --git a/quaint/src/connector/result_set.rs b/quaint/src/connector/result_set.rs index 7592a85ea15c..7e7d045f9842 100644 --- a/quaint/src/connector/result_set.rs +++ b/quaint/src/connector/result_set.rs @@ -8,19 +8,23 @@ use crate::{ast::Value, error::*}; use serde_json::Map; use std::sync::Arc; +use super::ColumnType; + /// Encapsulates a set of results and their respective column names. #[derive(Debug, Default)] pub struct ResultSet { pub(crate) columns: Arc>, + pub(crate) types: Vec, pub(crate) rows: Vec>>, pub(crate) last_insert_id: Option, } impl ResultSet { /// Creates a new instance, bound to the given column names and result rows. - pub fn new(names: Vec, rows: Vec>>) -> Self { + pub fn new(names: Vec, types: Vec, rows: Vec>>) -> Self { Self { columns: Arc::new(names), + types, rows, last_insert_id: None, } @@ -61,6 +65,7 @@ impl ResultSet { self.rows.get(index).map(|row| ResultRowRef { columns: Arc::clone(&self.columns), values: row, + types: self.types.clone(), }) } @@ -75,9 +80,14 @@ impl ResultSet { pub fn iter(&self) -> ResultSetIterator<'_> { ResultSetIterator { columns: self.columns.clone(), + types: self.types.clone(), internal_iterator: self.rows.iter(), } } + + pub fn types(&self) -> &[ColumnType] { + &self.types + } } impl IntoIterator for ResultSet { @@ -87,6 +97,7 @@ impl IntoIterator for ResultSet { fn into_iter(self) -> Self::IntoIter { ResultSetIntoIterator { columns: self.columns, + types: self.types.clone(), internal_iterator: self.rows.into_iter(), } } @@ -96,6 +107,7 @@ impl IntoIterator for ResultSet { /// Might become lazy one day. pub struct ResultSetIntoIterator { pub(crate) columns: Arc>, + pub(crate) types: Vec, pub(crate) internal_iterator: std::vec::IntoIter>>, } @@ -107,6 +119,7 @@ impl Iterator for ResultSetIntoIterator { Some(row) => Some(ResultRow { columns: Arc::clone(&self.columns), values: row, + types: self.types.clone(), }), None => None, } @@ -115,6 +128,7 @@ impl Iterator for ResultSetIntoIterator { pub struct ResultSetIterator<'a> { pub(crate) columns: Arc>, + pub(crate) types: Vec, pub(crate) internal_iterator: std::slice::Iter<'a, Vec>>, } @@ -126,6 +140,7 @@ impl<'a> Iterator for ResultSetIterator<'a> { Some(row) => Some(ResultRowRef { columns: Arc::clone(&self.columns), values: row, + types: self.types.clone(), }), None => None, } diff --git a/quaint/src/connector/result_set/result_row.rs b/quaint/src/connector/result_set/result_row.rs index a6a5b55c3c62..58c389b294d4 100644 --- a/quaint/src/connector/result_set/result_row.rs +++ b/quaint/src/connector/result_set/result_row.rs @@ -1,5 +1,6 @@ use crate::{ ast::Value, + connector::ColumnType, error::{Error, ErrorKind}, }; use std::sync::Arc; @@ -9,6 +10,7 @@ use std::sync::Arc; #[derive(Debug, PartialEq)] pub struct ResultRow { pub(crate) columns: Arc>, + pub(crate) types: Vec, pub(crate) values: Vec>, } @@ -38,6 +40,7 @@ impl IntoIterator for ResultRow { #[derive(Debug, PartialEq)] pub struct ResultRowRef<'a> { pub(crate) columns: Arc>, + pub(crate) types: Vec, pub(crate) values: &'a Vec>, } @@ -59,17 +62,26 @@ impl ResultRow { } } - /// Take a value with the given column name from the row. Usage + /// Get a value with the given column name from the row. Usage /// documentation in [ResultRowRef](struct.ResultRowRef.html). pub fn get(&self, name: &str) -> Option<&Value<'static>> { self.columns.iter().position(|c| c == name).map(|idx| &self.values[idx]) } + /// Take a value with the given column name from the row. + pub fn take(mut self, name: &str) -> Option> { + self.columns + .iter() + .position(|c| c == name) + .map(|idx| self.values.remove(idx)) + } + /// Make a referring [ResultRowRef](struct.ResultRowRef.html). pub fn as_ref(&self) -> ResultRowRef { ResultRowRef { columns: Arc::clone(&self.columns), values: &self.values, + types: self.types.clone(), } } diff --git a/quaint/src/connector/sqlite.rs b/quaint/src/connector/sqlite.rs index 05f073d9c34e..3a30b38975d9 100644 --- a/quaint/src/connector/sqlite.rs +++ b/quaint/src/connector/sqlite.rs @@ -1,6 +1,7 @@ //! Wasm-compatible definitions for the SQLite connector. //! This module is only available with the `sqlite` feature. mod defaults; + pub(crate) mod error; mod ffi; pub(crate) mod params; diff --git a/quaint/src/connector/sqlite/native/column_type.rs b/quaint/src/connector/sqlite/native/column_type.rs new file mode 100644 index 000000000000..e8f1291a3a15 --- /dev/null +++ b/quaint/src/connector/sqlite/native/column_type.rs @@ -0,0 +1,14 @@ +use rusqlite::Column; + +use crate::connector::{ColumnType, TypeIdentifier}; + +impl From<&Column<'_>> for ColumnType { + fn from(value: &Column) -> Self { + if value.is_float() { + // Sqlite always returns Double for floats + ColumnType::Double + } else { + ColumnType::from_type_identifier(value) + } + } +} diff --git a/quaint/src/connector/sqlite/native/conversion.rs b/quaint/src/connector/sqlite/native/conversion.rs index afd0145fade8..b06be6487acd 100644 --- a/quaint/src/connector/sqlite/native/conversion.rs +++ b/quaint/src/connector/sqlite/native/conversion.rs @@ -16,7 +16,7 @@ use rusqlite::{ use chrono::TimeZone; -impl TypeIdentifier for Column<'_> { +impl TypeIdentifier for &Column<'_> { fn is_real(&self) -> bool { match self.decl_type() { Some(n) if n.starts_with("DECIMAL") => true, @@ -82,7 +82,6 @@ impl TypeIdentifier for Column<'_> { ) } - #[cfg(feature = "mysql")] fn is_time(&self) -> bool { false } @@ -119,12 +118,10 @@ impl TypeIdentifier for Column<'_> { matches!(self.decl_type(), Some("BOOLEAN") | Some("boolean")) } - #[cfg(feature = "mysql")] fn is_json(&self) -> bool { false } - #[cfg(feature = "mysql")] fn is_enum(&self) -> bool { false } @@ -146,8 +143,7 @@ impl<'a> GetRow for SqliteRow<'a> { c if c.is_int64() => Value::null_int64(), c if c.is_text() => Value::null_text(), c if c.is_bytes() => Value::null_bytes(), - c if c.is_float() => Value::null_float(), - c if c.is_double() => Value::null_double(), + c if c.is_float() || c.is_double() => Value::null_double(), c if c.is_real() => Value::null_numeric(), c if c.is_datetime() => Value::null_datetime(), c if c.is_date() => Value::null_date(), @@ -251,7 +247,9 @@ impl<'a> ToSql for Value<'a> { let value = match &self.typed { ValueType::Int32(integer) => integer.map(ToSqlOutput::from), ValueType::Int64(integer) => integer.map(ToSqlOutput::from), - ValueType::Float(float) => float.map(|f| f as f64).map(ToSqlOutput::from), + ValueType::Float(float) => { + float.map(|float| ToSqlOutput::from(float.to_string().parse::().expect("f32 is not a f64."))) + } ValueType::Double(double) => double.map(ToSqlOutput::from), ValueType::Text(cow) => cow.as_ref().map(|cow| ToSqlOutput::from(cow.as_ref())), ValueType::Enum(cow, _) => cow.as_ref().map(|cow| ToSqlOutput::from(cow.as_ref())), diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs index 58ef03799e2f..2d738a7f087f 100644 --- a/quaint/src/connector/sqlite/native/mod.rs +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -1,11 +1,12 @@ //! Definitions for the SQLite connector. //! This module is not compatible with wasm32-* targets. //! This module is only available with the `sqlite-native` feature. +mod column_type; mod conversion; mod error; -use crate::connector::sqlite::params::SqliteParams; use crate::connector::IsolationLevel; +use crate::connector::{sqlite::params::SqliteParams, ColumnType, DescribedQuery}; pub use rusqlite::{params_from_iter, version as sqlite_version}; @@ -28,6 +29,8 @@ pub struct Sqlite { pub(crate) client: Mutex, } +const DB_SYSTEM_NAME: &str = "sqlite"; + impl TryFrom<&str> for Sqlite { type Error = Error; @@ -99,13 +102,14 @@ impl Queryable for Sqlite { } async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("sqlite.query_raw", sql, params, move || async move { + metrics::query("sqlite.query_raw", DB_SYSTEM_NAME, sql, params, move || async move { let client = self.client.lock().await; let mut stmt = client.prepare_cached(sql)?; + let col_types = stmt.columns().iter().map(ColumnType::from).collect::>(); let mut rows = stmt.query(params_from_iter(params.iter()))?; - let mut result = ResultSet::new(rows.to_column_names(), Vec::new()); + let mut result = ResultSet::new(rows.to_column_names(), col_types, Vec::new()); while let Some(row) = rows.next()? { result.rows.push(row.get_result_row()?); @@ -122,13 +126,17 @@ impl Queryable for Sqlite { self.query_raw(sql, params).await } + async fn describe_query(&self, _sql: &str) -> crate::Result { + unimplemented!("SQLite describe_query is implemented in the schema engine.") + } + async fn execute(&self, q: Query<'_>) -> crate::Result { let (sql, params) = visitor::Sqlite::build(q)?; self.execute_raw(&sql, ¶ms).await } async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("sqlite.query_raw", sql, params, move || async move { + metrics::query("sqlite.query_raw", DB_SYSTEM_NAME, sql, params, move || async move { let client = self.client.lock().await; let mut stmt = client.prepare_cached(sql)?; let res = u64::try_from(stmt.execute(params_from_iter(params.iter()))?)?; @@ -143,7 +151,7 @@ impl Queryable for Sqlite { } async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("sqlite.raw_cmd", cmd, &[], move || async move { + metrics::query("sqlite.raw_cmd", DB_SYSTEM_NAME, cmd, &[], move || async move { let client = self.client.lock().await; client.execute_batch(cmd)?; Ok(()) diff --git a/quaint/src/connector/transaction.rs b/quaint/src/connector/transaction.rs index b7e91e97f6a8..599efe1d99fc 100644 --- a/quaint/src/connector/transaction.rs +++ b/quaint/src/connector/transaction.rs @@ -1,13 +1,13 @@ +use std::{fmt, str::FromStr}; + +use async_trait::async_trait; +use prisma_metrics::guards::GaugeGuard; + use super::*; use crate::{ ast::*, error::{Error, ErrorKind}, }; -use async_trait::async_trait; -use metrics::{decrement_gauge, increment_gauge}; -use std::{fmt, str::FromStr}; - -extern crate metrics as metrics; #[async_trait] pub trait Transaction: Queryable { @@ -36,6 +36,7 @@ pub(crate) struct TransactionOptions { /// transaction object will panic. pub struct DefaultTransaction<'a> { pub inner: &'a dyn Queryable, + gauge: GaugeGuard, } impl<'a> DefaultTransaction<'a> { @@ -44,7 +45,10 @@ impl<'a> DefaultTransaction<'a> { begin_stmt: &str, tx_opts: TransactionOptions, ) -> crate::Result> { - let this = Self { inner }; + let this = Self { + inner, + gauge: GaugeGuard::increment("prisma_client_queries_active"), + }; if tx_opts.isolation_first { if let Some(isolation) = tx_opts.isolation_level { @@ -62,7 +66,6 @@ impl<'a> DefaultTransaction<'a> { inner.server_reset_query(&this).await?; - increment_gauge!("prisma_client_queries_active", 1.0); Ok(this) } } @@ -71,7 +74,7 @@ impl<'a> DefaultTransaction<'a> { impl<'a> Transaction for DefaultTransaction<'a> { /// Commit the changes to the database and consume the transaction. async fn commit(&self) -> crate::Result<()> { - decrement_gauge!("prisma_client_queries_active", 1.0); + self.gauge.decrement(); self.inner.raw_cmd("COMMIT").await?; Ok(()) @@ -79,7 +82,7 @@ impl<'a> Transaction for DefaultTransaction<'a> { /// Rolls back the changes to the database. async fn rollback(&self) -> crate::Result<()> { - decrement_gauge!("prisma_client_queries_active", 1.0); + self.gauge.decrement(); self.inner.raw_cmd("ROLLBACK").await?; Ok(()) @@ -108,6 +111,10 @@ impl<'a> Queryable for DefaultTransaction<'a> { self.inner.query_raw_typed(sql, params).await } + async fn describe_query(&self, sql: &str) -> crate::Result { + self.inner.describe_query(sql).await + } + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { self.inner.execute_raw(sql, params).await } diff --git a/quaint/src/connector/type_identifier.rs b/quaint/src/connector/type_identifier.rs index ce27ea89a404..9fcc46f61c1c 100644 --- a/quaint/src/connector/type_identifier.rs +++ b/quaint/src/connector/type_identifier.rs @@ -5,15 +5,12 @@ pub(crate) trait TypeIdentifier { fn is_int32(&self) -> bool; fn is_int64(&self) -> bool; fn is_datetime(&self) -> bool; - #[cfg(feature = "mysql")] fn is_time(&self) -> bool; fn is_date(&self) -> bool; fn is_text(&self) -> bool; fn is_bytes(&self) -> bool; fn is_bool(&self) -> bool; - #[cfg(feature = "mysql")] fn is_json(&self) -> bool; - #[cfg(feature = "mysql")] fn is_enum(&self) -> bool; fn is_null(&self) -> bool; } diff --git a/quaint/src/lib.rs b/quaint/src/lib.rs index 45c2a10a1698..ab73ef7e66aa 100644 --- a/quaint/src/lib.rs +++ b/quaint/src/lib.rs @@ -110,11 +110,8 @@ compile_error!("one of 'sqlite', 'postgresql', 'mysql' or 'mssql' features must #[macro_use] mod macros; -#[macro_use] -extern crate metrics; - -pub extern crate bigdecimal; -pub extern crate chrono; +pub use bigdecimal; +pub use chrono; pub mod ast; pub mod connector; diff --git a/quaint/src/macros.rs b/quaint/src/macros.rs index cfb52bc0c6e1..9921359669c9 100644 --- a/quaint/src/macros.rs +++ b/quaint/src/macros.rs @@ -1,3 +1,5 @@ +use crate::{connector::ColumnType, Value}; + /// Convert given set of tuples into `Values`. /// /// ```rust @@ -173,7 +175,7 @@ macro_rules! expression { /// A test-generator to test types in the defined database. #[cfg(test)] macro_rules! test_type { - ($name:ident($db:ident, $sql_type:literal, $(($input:expr, $output:expr)),+ $(,)?)) => { + ($name:ident($db:ident, $sql_type:literal, $col_type:expr, $(($input:expr, $output:expr)),+ $(,)?)) => { paste::item! { #[test] fn [< test_type_ $name >] () -> crate::Result<()> { @@ -198,7 +200,9 @@ macro_rules! test_type { let select = Select::from_table(&table).column("value").order_by("id".descend()); let res = setup.conn().select(select).await?.into_single()?; + assert_eq!($col_type, res.types[0]); assert_eq!(Some(&output), res.at(0)); + assert_matching_value_and_column_type(&$col_type, res.at(0).unwrap()); )+ Result::<(), crate::error::Error>::Ok(()) @@ -209,7 +213,7 @@ macro_rules! test_type { } }; - ($name:ident($db:ident, $sql_type:literal, $($value:expr),+ $(,)?)) => { + ($name:ident($db:ident, $sql_type:literal, $col_type:expr, $($value:expr),+ $(,)?)) => { paste::item! { #[test] fn [< test_type_ $name >] () -> crate::Result<()> { @@ -232,7 +236,9 @@ macro_rules! test_type { let select = Select::from_table(&table).column("value").order_by("id".descend()); let res = setup.conn().select(select).await?.into_single()?; + assert_eq!($col_type, res.types[0]); assert_eq!(Some(&value), res.at(0)); + assert_matching_value_and_column_type(&$col_type, &value); )+ Result::<(), crate::error::Error>::Ok(()) @@ -243,3 +249,12 @@ macro_rules! test_type { } }; } + +#[allow(dead_code)] +pub(crate) fn assert_matching_value_and_column_type(col_type: &ColumnType, value: &Value) { + let inferred_column_type = ColumnType::from(&value.typed); + + if !inferred_column_type.is_unknown() { + assert_eq!(col_type, &inferred_column_type); + } +} diff --git a/quaint/src/pooled.rs b/quaint/src/pooled.rs index 3e7e58c05e52..389005ab7bd3 100644 --- a/quaint/src/pooled.rs +++ b/quaint/src/pooled.rs @@ -274,7 +274,7 @@ impl Builder { /// pool. /// /// - Defaults to `false`, meaning connections are never tested on - /// `check_out`. + /// `check_out`. /// /// [`check_out`]: struct.Quaint.html#method.check_out pub fn test_on_check_out(&mut self, test_on_check_out: bool) { @@ -307,12 +307,14 @@ impl Builder { /// - Defaults to `PostgresFlavour::Unknown`. #[cfg(feature = "postgresql-native")] pub fn set_postgres_flavour(&mut self, flavour: crate::connector::PostgresFlavour) { - use crate::connector::NativeConnectionInfo; - if let ConnectionInfo::Native(NativeConnectionInfo::Postgres(ref mut url)) = self.connection_info { + use crate::connector::{NativeConnectionInfo, PostgresUrl}; + if let ConnectionInfo::Native(NativeConnectionInfo::Postgres(PostgresUrl::Native(ref mut url))) = + self.connection_info + { url.set_flavour(flavour); } - if let QuaintManager::Postgres { ref mut url } = self.manager { + if let QuaintManager::Postgres { ref mut url, .. } = self.manager { url.set_flavour(flavour); } } @@ -415,13 +417,14 @@ impl Quaint { } #[cfg(feature = "postgresql")] s if s.starts_with("postgres") || s.starts_with("postgresql") => { - let url = crate::connector::PostgresUrl::new(url::Url::parse(s)?)?; + let url = crate::connector::PostgresNativeUrl::new(url::Url::parse(s)?)?; let connection_limit = url.connection_limit(); let pool_timeout = url.pool_timeout(); let max_connection_lifetime = url.max_connection_lifetime(); let max_idle_connection_lifetime = url.max_idle_connection_lifetime(); - let manager = QuaintManager::Postgres { url }; + let tls_manager = crate::connector::MakeTlsConnectorManager::new(url.clone()); + let manager = QuaintManager::Postgres { url, tls_manager }; let mut builder = Builder::new(s, manager)?; if let Some(limit) = connection_limit { diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index 73441b7609ba..bf4d50eeea87 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -1,16 +1,21 @@ +use std::future::Future; + +use async_trait::async_trait; +use mobc::{Connection as MobcPooled, Manager}; +use prisma_metrics::WithMetricsInstrumentation; +use tracing_futures::WithSubscriber; + #[cfg(feature = "mssql-native")] use crate::connector::MssqlUrl; #[cfg(feature = "mysql-native")] use crate::connector::MysqlUrl; #[cfg(feature = "postgresql-native")] -use crate::connector::PostgresUrl; +use crate::connector::{MakeTlsConnectorManager, PostgresNativeUrl}; use crate::{ ast, connector::{self, impl_default_TransactionCapable, IsolationLevel, Queryable, Transaction, TransactionCapable}, error::Error, }; -use async_trait::async_trait; -use mobc::{Connection as MobcPooled, Manager}; /// A connection from the pool. Implements /// [Queryable](connector/trait.Queryable.html). @@ -34,6 +39,10 @@ impl Queryable for PooledConnection { self.inner.query_raw_typed(sql, params).await } + async fn describe_query(&self, sql: &str) -> crate::Result { + self.inner.describe_query(sql).await + } + async fn execute(&self, q: ast::Query<'_>) -> crate::Result { self.inner.execute(q).await } @@ -81,7 +90,10 @@ pub enum QuaintManager { Mysql { url: MysqlUrl }, #[cfg(feature = "postgresql")] - Postgres { url: PostgresUrl }, + Postgres { + url: PostgresNativeUrl, + tls_manager: MakeTlsConnectorManager, + }, #[cfg(feature = "sqlite")] Sqlite { url: String, db_name: String }, @@ -113,9 +125,9 @@ impl Manager for QuaintManager { } #[cfg(feature = "postgresql-native")] - QuaintManager::Postgres { url } => { + QuaintManager::Postgres { url, tls_manager } => { use crate::connector::PostgreSql; - Ok(Box::new(PostgreSql::new(url.clone()).await?) as Self::Connection) + Ok(Box::new(PostgreSql::new(url.clone(), tls_manager).await?) as Self::Connection) } #[cfg(feature = "mssql-native")] @@ -139,6 +151,14 @@ impl Manager for QuaintManager { fn validate(&self, conn: &mut Self::Connection) -> bool { conn.is_healthy() } + + fn spawn_task(&self, task: T) + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + tokio::spawn(task.with_current_subscriber().with_current_recorder()); + } } #[cfg(test)] diff --git a/quaint/src/prelude.rs b/quaint/src/prelude.rs index 1fe867ccd4cc..2984e49fdc94 100644 --- a/quaint/src/prelude.rs +++ b/quaint/src/prelude.rs @@ -1,7 +1,7 @@ //! A "prelude" for users of the `quaint` crate. pub use crate::ast::*; pub use crate::connector::{ - ConnectionInfo, DefaultTransaction, ExternalConnectionInfo, Queryable, ResultRow, ResultSet, SqlFamily, + ColumnType, ConnectionInfo, DefaultTransaction, ExternalConnectionInfo, Queryable, ResultRow, ResultSet, SqlFamily, TransactionCapable, }; pub use crate::{col, val, values}; diff --git a/quaint/src/single.rs b/quaint/src/single.rs index 6e76a8003f09..13be8c4bc857 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -148,8 +148,9 @@ impl Quaint { } #[cfg(feature = "postgresql-native")] s if s.starts_with("postgres") || s.starts_with("postgresql") => { - let url = connector::PostgresUrl::new(url::Url::parse(s)?)?; - let psql = connector::PostgreSql::new(url).await?; + let url = connector::PostgresNativeUrl::new(url::Url::parse(s)?)?; + let tls_manager = connector::MakeTlsConnectorManager::new(url.clone()); + let psql = connector::PostgreSql::new(url, &tls_manager).await?; Arc::new(psql) as Arc } #[cfg(feature = "mssql-native")] @@ -209,6 +210,10 @@ impl Queryable for Quaint { self.inner.query_raw_typed(sql, params).await } + async fn describe_query(&self, sql: &str) -> crate::Result { + self.inner.describe_query(sql).await + } + async fn execute(&self, q: ast::Query<'_>) -> crate::Result { self.inner.execute(q).await } diff --git a/quaint/src/tests/query.rs b/quaint/src/tests/query.rs index 06bebe1a9601..6e83297a9a75 100644 --- a/quaint/src/tests/query.rs +++ b/quaint/src/tests/query.rs @@ -736,7 +736,7 @@ async fn returning_update(api: &mut dyn TestApi) -> crate::Result<()> { Ok(()) } -#[cfg(all(feature = "mssql", feature = "bigdecimal"))] +#[cfg(feature = "mssql")] #[test_each_connector(tags("mssql"))] async fn returning_decimal_insert_with_type_defs(api: &mut dyn TestApi) -> crate::Result<()> { use bigdecimal::BigDecimal; @@ -1388,15 +1388,13 @@ async fn unsigned_integers_are_handled(api: &mut dyn TestApi) -> crate::Result<( .create_temp_table("id int4 auto_increment primary key, big bigint unsigned") .await?; - let insert = Insert::multi_into(&table, ["big"]) - .values((2,)) - .values((std::i64::MAX,)); + let insert = Insert::multi_into(&table, ["big"]).values((2,)).values((i64::MAX,)); api.conn().insert(insert.into()).await?; let select = Select::from_table(&table).column("big").order_by("id"); let roundtripped = api.conn().select(select).await?; - let expected = &[2, std::i64::MAX]; + let expected = &[2, i64::MAX]; let actual: Vec = roundtripped .into_iter() .map(|row| row.at(0).unwrap().as_i64().unwrap()) diff --git a/quaint/src/tests/query/error.rs b/quaint/src/tests/query/error.rs index 69c57332b6d3..399866bd4a3b 100644 --- a/quaint/src/tests/query/error.rs +++ b/quaint/src/tests/query/error.rs @@ -162,7 +162,7 @@ async fn int_unsigned_negative_value_out_of_range(api: &mut dyn TestApi) -> crat // Value too big { - let insert = Insert::multi_into(&table, ["big"]).values((std::i64::MAX,)); + let insert = Insert::multi_into(&table, ["big"]).values((i64::MAX,)); let result = api.conn().insert(insert.into()).await; assert!(matches!(result.unwrap_err().kind(), ErrorKind::ValueOutOfRange { .. })); diff --git a/quaint/src/tests/types/mssql.rs b/quaint/src/tests/types/mssql.rs index ac404dd8af38..bd3ce6555a69 100644 --- a/quaint/src/tests/types/mssql.rs +++ b/quaint/src/tests/types/mssql.rs @@ -2,11 +2,14 @@ mod bigdecimal; -use crate::tests::test_api::*; +use crate::macros::assert_matching_value_and_column_type; +use crate::{connector::ColumnType, tests::test_api::*}; +use std::str::FromStr; test_type!(nvarchar_limited( mssql, "NVARCHAR(10)", + ColumnType::Text, Value::null_text(), Value::text("foobar"), Value::text("余"), @@ -15,6 +18,7 @@ test_type!(nvarchar_limited( test_type!(nvarchar_max( mssql, "NVARCHAR(max)", + ColumnType::Text, Value::null_text(), Value::text("foobar"), Value::text("余"), @@ -24,6 +28,7 @@ test_type!(nvarchar_max( test_type!(ntext( mssql, "NTEXT", + ColumnType::Text, Value::null_text(), Value::text("foobar"), Value::text("余"), @@ -32,6 +37,7 @@ test_type!(ntext( test_type!(varchar_limited( mssql, "VARCHAR(10)", + ColumnType::Text, Value::null_text(), Value::text("foobar"), )); @@ -39,15 +45,23 @@ test_type!(varchar_limited( test_type!(varchar_max( mssql, "VARCHAR(max)", + ColumnType::Text, Value::null_text(), Value::text("foobar"), )); -test_type!(text(mssql, "TEXT", Value::null_text(), Value::text("foobar"))); +test_type!(text( + mssql, + "TEXT", + ColumnType::Text, + Value::null_text(), + Value::text("foobar") +)); test_type!(tinyint( mssql, "tinyint", + ColumnType::Int32, Value::null_int32(), Value::int32(u8::MIN), Value::int32(u8::MAX), @@ -56,6 +70,7 @@ test_type!(tinyint( test_type!(smallint( mssql, "smallint", + ColumnType::Int32, Value::null_int32(), Value::int32(i16::MIN), Value::int32(i16::MAX), @@ -64,6 +79,7 @@ test_type!(smallint( test_type!(int( mssql, "int", + ColumnType::Int32, Value::null_int32(), Value::int32(i32::MIN), Value::int32(i32::MAX), @@ -72,27 +88,48 @@ test_type!(int( test_type!(bigint( mssql, "bigint", + ColumnType::Int64, Value::null_int64(), Value::int64(i64::MIN), Value::int64(i64::MAX), )); -test_type!(float_24(mssql, "float(24)", Value::null_float(), Value::float(1.23456),)); +test_type!(float_24( + mssql, + "float(24)", + ColumnType::Float, + Value::null_float(), + Value::float(1.23456), +)); -test_type!(real(mssql, "real", Value::null_float(), Value::float(1.123456))); +test_type!(real( + mssql, + "real", + ColumnType::Float, + Value::null_float(), + Value::float(1.123456) +)); test_type!(float_53( mssql, "float(53)", + ColumnType::Double, Value::null_double(), Value::double(1.1234567891) )); -test_type!(money(mssql, "money", Value::null_double(), Value::double(3.14))); +test_type!(money( + mssql, + "money", + ColumnType::Double, + Value::null_double(), + Value::double(3.14) +)); test_type!(smallmoney( mssql, "smallmoney", + ColumnType::Double, Value::null_double(), Value::double(3.14) )); @@ -100,6 +137,7 @@ test_type!(smallmoney( test_type!(boolean( mssql, "bit", + ColumnType::Boolean, Value::null_boolean(), Value::boolean(true), Value::boolean(false), @@ -108,6 +146,7 @@ test_type!(boolean( test_type!(binary( mssql, "binary(8)", + ColumnType::Bytes, Value::null_bytes(), Value::bytes(b"DEADBEEF".to_vec()), )); @@ -115,6 +154,7 @@ test_type!(binary( test_type!(varbinary( mssql, "varbinary(8)", + ColumnType::Bytes, Value::null_bytes(), Value::bytes(b"DEADBEEF".to_vec()), )); @@ -122,6 +162,7 @@ test_type!(varbinary( test_type!(image( mssql, "image", + ColumnType::Bytes, Value::null_bytes(), Value::bytes(b"DEADBEEF".to_vec()), )); @@ -129,6 +170,7 @@ test_type!(image( test_type!(date( mssql, "date", + ColumnType::Date, Value::null_date(), Value::date(chrono::NaiveDate::from_ymd_opt(2020, 4, 20).unwrap()) )); @@ -136,26 +178,67 @@ test_type!(date( test_type!(time( mssql, "time", + ColumnType::Time, Value::null_time(), Value::time(chrono::NaiveTime::from_hms_opt(16, 20, 00).unwrap()) )); -test_type!(datetime2(mssql, "datetime2", Value::null_datetime(), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:00Z").unwrap(); - Value::datetime(dt.with_timezone(&chrono::Utc)) -})); +test_type!(datetime2( + mssql, + "datetime2", + ColumnType::DateTime, + Value::null_datetime(), + { + let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:00Z").unwrap(); + Value::datetime(dt.with_timezone(&chrono::Utc)) + } +)); -test_type!(datetime(mssql, "datetime", Value::null_datetime(), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); - Value::datetime(dt.with_timezone(&chrono::Utc)) -})); +test_type!(datetime( + mssql, + "datetime", + ColumnType::DateTime, + Value::null_datetime(), + { + let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); + Value::datetime(dt.with_timezone(&chrono::Utc)) + } +)); + +test_type!(datetimeoffset( + mssql, + "datetimeoffset", + ColumnType::DateTime, + Value::null_datetime(), + { + let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); + Value::datetime(dt.with_timezone(&chrono::Utc)) + } +)); + +test_type!(smalldatetime( + mssql, + "smalldatetime", + ColumnType::DateTime, + Value::null_datetime(), + { + let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:00Z").unwrap(); + Value::datetime(dt.with_timezone(&chrono::Utc)) + } +)); -test_type!(datetimeoffset(mssql, "datetimeoffset", Value::null_datetime(), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); - Value::datetime(dt.with_timezone(&chrono::Utc)) -})); +test_type!(uuid( + mssql, + "uniqueidentifier", + ColumnType::Uuid, + Value::null_uuid(), + Value::uuid(uuid::Uuid::from_str("936DA01F-9ABD-4D9D-80C7-02AF85C822A8").unwrap()) +)); -test_type!(smalldatetime(mssql, "smalldatetime", Value::null_datetime(), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:00Z").unwrap(); - Value::datetime(dt.with_timezone(&chrono::Utc)) -})); +test_type!(xml( + mssql, + "xml", + ColumnType::Xml, + Value::null_xml(), + Value::xml("bar"), +)); diff --git a/quaint/src/tests/types/mssql/bigdecimal.rs b/quaint/src/tests/types/mssql/bigdecimal.rs index 8fe3761624d2..2a2ce02350d3 100644 --- a/quaint/src/tests/types/mssql/bigdecimal.rs +++ b/quaint/src/tests/types/mssql/bigdecimal.rs @@ -1,10 +1,12 @@ use super::*; -use crate::bigdecimal::BigDecimal; +use crate::macros::assert_matching_value_and_column_type; +use crate::{bigdecimal::BigDecimal, connector::ColumnType}; use std::str::FromStr; test_type!(numeric( mssql, "numeric(10,2)", + ColumnType::Numeric, Value::null_numeric(), Value::numeric(BigDecimal::from_str("3.14")?) )); @@ -12,6 +14,7 @@ test_type!(numeric( test_type!(numeric_10_2( mssql, "numeric(10,2)", + ColumnType::Numeric, ( Value::numeric(BigDecimal::from_str("3950.123456")?), Value::numeric(BigDecimal::from_str("3950.12")?) @@ -21,6 +24,7 @@ test_type!(numeric_10_2( test_type!(numeric_35_6( mssql, "numeric(35, 6)", + ColumnType::Numeric, ( Value::numeric(BigDecimal::from_str("3950")?), Value::numeric(BigDecimal::from_str("3950.000000")?) @@ -102,6 +106,7 @@ test_type!(numeric_35_6( test_type!(numeric_35_2( mssql, "numeric(35, 2)", + ColumnType::Numeric, ( Value::numeric(BigDecimal::from_str("3950.123456")?), Value::numeric(BigDecimal::from_str("3950.12")?) @@ -115,18 +120,21 @@ test_type!(numeric_35_2( test_type!(numeric_4_0( mssql, "numeric(4, 0)", + ColumnType::Numeric, Value::numeric(BigDecimal::from_str("3950")?) )); test_type!(numeric_35_0( mssql, "numeric(35, 0)", + ColumnType::Numeric, Value::numeric(BigDecimal::from_str("79228162514264337593543950335")?), )); test_type!(numeric_35_1( mssql, "numeric(35, 1)", + ColumnType::Numeric, ( Value::numeric(BigDecimal::from_str("79228162514264337593543950335")?), Value::numeric(BigDecimal::from_str("79228162514264337593543950335.0")?) @@ -141,12 +149,14 @@ test_type!(numeric_35_1( test_type!(numeric_38_6( mssql, "numeric(38, 6)", + ColumnType::Numeric, Value::numeric(BigDecimal::from_str("9343234567898765456789043634999.345678")?), )); test_type!(money( mssql, "money", + ColumnType::Double, (Value::null_numeric(), Value::null_double()), (Value::numeric(BigDecimal::from_str("3.14")?), Value::double(3.14)) )); @@ -154,6 +164,7 @@ test_type!(money( test_type!(smallmoney( mssql, "smallmoney", + ColumnType::Double, (Value::null_numeric(), Value::null_double()), (Value::numeric(BigDecimal::from_str("3.14")?), Value::double(3.14)) )); @@ -161,6 +172,7 @@ test_type!(smallmoney( test_type!(float_24( mssql, "float(24)", + ColumnType::Float, (Value::null_numeric(), Value::null_float()), ( Value::numeric(BigDecimal::from_str("1.123456")?), @@ -171,6 +183,7 @@ test_type!(float_24( test_type!(real( mssql, "real", + ColumnType::Float, (Value::null_numeric(), Value::null_float()), ( Value::numeric(BigDecimal::from_str("1.123456")?), @@ -181,6 +194,7 @@ test_type!(real( test_type!(float_53( mssql, "float(53)", + ColumnType::Double, (Value::null_numeric(), Value::null_double()), ( Value::numeric(BigDecimal::from_str("1.123456789012345")?), diff --git a/quaint/src/tests/types/mysql.rs b/quaint/src/tests/types/mysql.rs index ade4e5d2a1f2..77444378735f 100644 --- a/quaint/src/tests/types/mysql.rs +++ b/quaint/src/tests/types/mysql.rs @@ -1,14 +1,15 @@ #![allow(clippy::approx_constant)] -use crate::tests::test_api::*; - use std::str::FromStr; use crate::bigdecimal::BigDecimal; +use crate::macros::assert_matching_value_and_column_type; +use crate::{connector::ColumnType, tests::test_api::*}; test_type!(tinyint( mysql, "tinyint(4)", + ColumnType::Int32, Value::null_int32(), Value::int32(i8::MIN), Value::int32(i8::MAX) @@ -17,6 +18,7 @@ test_type!(tinyint( test_type!(tinyint1( mysql, "tinyint(1)", + ColumnType::Int32, Value::int32(-1), Value::int32(1), Value::int32(0) @@ -25,6 +27,7 @@ test_type!(tinyint1( test_type!(tinyint_unsigned( mysql, "tinyint(4) unsigned", + ColumnType::Int32, Value::null_int32(), Value::int32(0), Value::int32(255) @@ -33,6 +36,7 @@ test_type!(tinyint_unsigned( test_type!(year( mysql, "year", + ColumnType::Int32, Value::null_int32(), Value::int32(1984), Value::int32(2049) @@ -41,6 +45,7 @@ test_type!(year( test_type!(smallint( mysql, "smallint", + ColumnType::Int32, Value::null_int32(), Value::int32(i16::MIN), Value::int32(i16::MAX) @@ -49,6 +54,7 @@ test_type!(smallint( test_type!(smallint_unsigned( mysql, "smallint unsigned", + ColumnType::Int32, Value::null_int32(), Value::int32(0), Value::int32(65535) @@ -57,6 +63,7 @@ test_type!(smallint_unsigned( test_type!(mediumint( mysql, "mediumint", + ColumnType::Int32, Value::null_int32(), Value::int32(-8388608), Value::int32(8388607) @@ -65,6 +72,7 @@ test_type!(mediumint( test_type!(mediumint_unsigned( mysql, "mediumint unsigned", + ColumnType::Int64, Value::null_int64(), Value::int64(0), Value::int64(16777215) @@ -73,6 +81,7 @@ test_type!(mediumint_unsigned( test_type!(int( mysql, "int", + ColumnType::Int32, Value::null_int32(), Value::int32(i32::MIN), Value::int32(i32::MAX) @@ -81,6 +90,7 @@ test_type!(int( test_type!(int_unsigned( mysql, "int unsigned", + ColumnType::Int64, Value::null_int64(), Value::int64(0), Value::int64(2173158296i64), @@ -90,6 +100,7 @@ test_type!(int_unsigned( test_type!(int_unsigned_not_null( mysql, "int unsigned not null", + ColumnType::Int64, Value::int64(0), Value::int64(2173158296i64), Value::int64(4294967295i64) @@ -98,6 +109,7 @@ test_type!(int_unsigned_not_null( test_type!(bigint( mysql, "bigint", + ColumnType::Int64, Value::null_int64(), Value::int64(i64::MIN), Value::int64(i64::MAX) @@ -106,6 +118,7 @@ test_type!(bigint( test_type!(decimal( mysql, "decimal(10,2)", + ColumnType::Numeric, Value::null_numeric(), Value::numeric(bigdecimal::BigDecimal::from_str("3.14").unwrap()) )); @@ -114,6 +127,7 @@ test_type!(decimal( test_type!(decimal_65_6( mysql, "decimal(65, 6)", + ColumnType::Numeric, Value::numeric(BigDecimal::from_str( "93431006223456789876545678909876545678903434334567834369999.345678" )?), @@ -122,6 +136,7 @@ test_type!(decimal_65_6( test_type!(float_decimal( mysql, "float", + ColumnType::Float, (Value::null_numeric(), Value::null_float()), ( Value::numeric(bigdecimal::BigDecimal::from_str("3.14").unwrap()), @@ -132,6 +147,7 @@ test_type!(float_decimal( test_type!(double_decimal( mysql, "double", + ColumnType::Double, (Value::null_numeric(), Value::null_double()), ( Value::numeric(bigdecimal::BigDecimal::from_str("3.14").unwrap()), @@ -142,6 +158,7 @@ test_type!(double_decimal( test_type!(bit1( mysql, "bit(1)", + ColumnType::Boolean, (Value::null_bytes(), Value::null_boolean()), (Value::int32(0), Value::boolean(false)), (Value::int32(1), Value::boolean(true)), @@ -150,28 +167,77 @@ test_type!(bit1( test_type!(bit64( mysql, "bit(64)", + ColumnType::Bytes, Value::null_bytes(), Value::bytes(vec![0, 0, 0, 0, 0, 6, 107, 58]) )); -test_type!(char(mysql, "char(255)", Value::null_text(), Value::text("foobar"))); -test_type!(float(mysql, "float", Value::null_float(), Value::float(1.12345),)); -test_type!(double(mysql, "double", Value::null_double(), Value::double(1.12314124))); +test_type!(char( + mysql, + "char(255)", + ColumnType::Text, + Value::null_text(), + Value::text("foobar") +)); +test_type!(float( + mysql, + "float", + ColumnType::Float, + Value::null_float(), + Value::float(1.12345), +)); +test_type!(double( + mysql, + "double", + ColumnType::Double, + Value::null_double(), + Value::double(1.12314124) +)); test_type!(varchar( mysql, "varchar(255)", + ColumnType::Text, + Value::null_text(), + Value::text("foobar") +)); +test_type!(tinytext( + mysql, + "tinytext", + ColumnType::Text, + Value::null_text(), + Value::text("foobar") +)); +test_type!(text( + mysql, + "text", + ColumnType::Text, + Value::null_text(), + Value::text("foobar") +)); +test_type!(longtext( + mysql, + "longtext", + ColumnType::Text, Value::null_text(), Value::text("foobar") )); -test_type!(tinytext(mysql, "tinytext", Value::null_text(), Value::text("foobar"))); -test_type!(text(mysql, "text", Value::null_text(), Value::text("foobar"))); -test_type!(longtext(mysql, "longtext", Value::null_text(), Value::text("foobar"))); -test_type!(binary(mysql, "binary(5)", Value::bytes(vec![1, 2, 3, 0, 0]))); -test_type!(varbinary(mysql, "varbinary(255)", Value::bytes(vec![1, 2, 3]))); +test_type!(binary( + mysql, + "binary(5)", + ColumnType::Bytes, + Value::bytes(vec![1, 2, 3, 0, 0]) +)); +test_type!(varbinary( + mysql, + "varbinary(255)", + ColumnType::Bytes, + Value::bytes(vec![1, 2, 3]) +)); test_type!(mediumtext( mysql, "mediumtext", + ColumnType::Text, Value::null_text(), Value::text("foobar") )); @@ -179,6 +245,7 @@ test_type!(mediumtext( test_type!(tinyblob( mysql, "tinyblob", + ColumnType::Bytes, Value::null_bytes(), Value::bytes(vec![1, 2, 3]) )); @@ -186,6 +253,7 @@ test_type!(tinyblob( test_type!(mediumblob( mysql, "mediumblob", + ColumnType::Bytes, Value::null_bytes(), Value::bytes(vec![1, 2, 3]) )); @@ -193,15 +261,23 @@ test_type!(mediumblob( test_type!(longblob( mysql, "longblob", + ColumnType::Bytes, Value::null_bytes(), Value::bytes(vec![1, 2, 3]) )); -test_type!(blob(mysql, "blob", Value::null_bytes(), Value::bytes(vec![1, 2, 3]))); +test_type!(blob( + mysql, + "blob", + ColumnType::Bytes, + Value::null_bytes(), + Value::bytes(vec![1, 2, 3]) +)); test_type!(enum( mysql, "enum('pollicle_dogs','jellicle_cats')", + ColumnType::Enum, Value::null_enum(), Value::enum_variant("jellicle_cats"), Value::enum_variant("pollicle_dogs") @@ -210,28 +286,50 @@ test_type!(enum( test_type!(json( mysql, "json", + ColumnType::Json, Value::null_json(), Value::json(serde_json::json!({"this": "is", "a": "json", "number": 2})) )); -test_type!(date(mysql, "date", Value::null_date(), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-04-20T00:00:00Z").unwrap(); - Value::datetime(dt.with_timezone(&chrono::Utc)) -})); +test_type!(date( + mysql, + "date", + ColumnType::Date, + (Value::null_date(), Value::null_date()), + ( + Value::date(chrono::NaiveDate::from_ymd_opt(2020, 4, 20).unwrap()), + Value::date(chrono::NaiveDate::from_ymd_opt(2020, 4, 20).unwrap()) + ), + ( + Value::datetime( + chrono::DateTime::parse_from_rfc3339("2020-04-20T00:00:00Z") + .unwrap() + .with_timezone(&chrono::Utc) + ), + Value::date(chrono::NaiveDate::from_ymd_opt(2020, 4, 20).unwrap()) + ) +)); test_type!(time( mysql, "time", + ColumnType::Time, Value::null_time(), Value::time(chrono::NaiveTime::from_hms_opt(16, 20, 00).unwrap()) )); -test_type!(datetime(mysql, "datetime", Value::null_datetime(), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); - Value::datetime(dt.with_timezone(&chrono::Utc)) -})); +test_type!(datetime( + mysql, + "datetime", + ColumnType::DateTime, + Value::null_datetime(), + { + let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); + Value::datetime(dt.with_timezone(&chrono::Utc)) + } +)); -test_type!(timestamp(mysql, "timestamp", { +test_type!(timestamp(mysql, "timestamp", ColumnType::DateTime, { let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); Value::datetime(dt.with_timezone(&chrono::Utc)) })); diff --git a/quaint/src/tests/types/postgres.rs b/quaint/src/tests/types/postgres.rs index d69a8dbb3424..99b3d67bf6e3 100644 --- a/quaint/src/tests/types/postgres.rs +++ b/quaint/src/tests/types/postgres.rs @@ -1,11 +1,13 @@ mod bigdecimal; -use crate::tests::test_api::*; +use crate::macros::assert_matching_value_and_column_type; +use crate::{connector::ColumnType, tests::test_api::*}; use std::str::FromStr; test_type!(boolean( postgresql, "boolean", + ColumnType::Boolean, Value::null_boolean(), Value::boolean(true), Value::boolean(false), @@ -14,6 +16,7 @@ test_type!(boolean( test_type!(boolean_array( postgresql, "boolean[]", + ColumnType::BooleanArray, Value::null_array(), Value::array(vec![ Value::boolean(true), @@ -26,6 +29,7 @@ test_type!(boolean_array( test_type!(int2( postgresql, "int2", + ColumnType::Int32, Value::null_int32(), Value::int32(i16::MIN), Value::int32(i16::MAX), @@ -34,6 +38,7 @@ test_type!(int2( test_type!(int2_with_int64( postgresql, "int2", + ColumnType::Int32, (Value::null_int64(), Value::null_int32()), (Value::int64(i16::MIN), Value::int32(i16::MIN)), (Value::int64(i16::MAX), Value::int32(i16::MAX)) @@ -42,6 +47,7 @@ test_type!(int2_with_int64( test_type!(int2_array( postgresql, "int2[]", + ColumnType::Int32Array, Value::null_array(), Value::array(vec![ Value::int32(1), @@ -54,6 +60,7 @@ test_type!(int2_array( test_type!(int2_array_with_i64( postgresql, "int2[]", + ColumnType::Int32Array, ( Value::array(vec![ Value::int64(i16::MIN), @@ -71,6 +78,7 @@ test_type!(int2_array_with_i64( test_type!(int4( postgresql, "int4", + ColumnType::Int32, Value::null_int32(), Value::int32(i32::MIN), Value::int32(i32::MAX), @@ -79,6 +87,7 @@ test_type!(int4( test_type!(int4_with_i64( postgresql, "int4", + ColumnType::Int32, (Value::null_int64(), Value::null_int32()), (Value::int64(i32::MIN), Value::int32(i32::MIN)), (Value::int64(i32::MAX), Value::int32(i32::MAX)) @@ -87,6 +96,7 @@ test_type!(int4_with_i64( test_type!(int4_array( postgresql, "int4[]", + ColumnType::Int32Array, Value::null_array(), Value::array(vec![ Value::int32(i32::MIN), @@ -98,6 +108,7 @@ test_type!(int4_array( test_type!(int4_array_with_i64( postgresql, "int4[]", + ColumnType::Int32Array, ( Value::array(vec![ Value::int64(i32::MIN), @@ -115,6 +126,7 @@ test_type!(int4_array_with_i64( test_type!(int8( postgresql, "int8", + ColumnType::Int64, Value::null_int64(), Value::int64(i64::MIN), Value::int64(i64::MAX), @@ -123,6 +135,7 @@ test_type!(int8( test_type!(int8_array( postgresql, "int8[]", + ColumnType::Int64Array, Value::null_array(), Value::array(vec![ Value::int64(1), @@ -132,11 +145,18 @@ test_type!(int8_array( ]), )); -test_type!(float4(postgresql, "float4", Value::null_float(), Value::float(1.234))); +test_type!(float4( + postgresql, + "float4", + ColumnType::Float, + Value::null_float(), + Value::float(1.234) +)); test_type!(float4_array( postgresql, "float4[]", + ColumnType::FloatArray, Value::null_array(), Value::array(vec![Value::float(1.1234), Value::float(4.321), Value::null_float()]) )); @@ -144,6 +164,7 @@ test_type!(float4_array( test_type!(float8( postgresql, "float8", + ColumnType::Double, Value::null_double(), Value::double(1.12345764), )); @@ -151,6 +172,7 @@ test_type!(float8( test_type!(float8_array( postgresql, "float8[]", + ColumnType::DoubleArray, Value::null_array(), Value::array(vec![Value::double(1.1234), Value::double(4.321), Value::null_double()]) )); @@ -160,6 +182,7 @@ test_type!(float8_array( test_type!(oid_with_i32( postgresql, "oid", + ColumnType::Int64, (Value::null_int32(), Value::null_int64()), (Value::int32(i32::MAX), Value::int64(i32::MAX)), (Value::int32(u32::MIN as i32), Value::int64(u32::MIN)), @@ -168,6 +191,7 @@ test_type!(oid_with_i32( test_type!(oid_with_i64( postgresql, "oid", + ColumnType::Int64, Value::null_int64(), Value::int64(u32::MAX), Value::int64(u32::MIN), @@ -176,6 +200,7 @@ test_type!(oid_with_i64( test_type!(oid_array( postgresql, "oid[]", + ColumnType::Int64Array, Value::null_array(), Value::array(vec![ Value::int64(1), @@ -188,6 +213,7 @@ test_type!(oid_array( test_type!(serial2( postgresql, "serial2", + ColumnType::Int32, Value::int32(i16::MIN), Value::int32(i16::MAX), )); @@ -195,6 +221,7 @@ test_type!(serial2( test_type!(serial4( postgresql, "serial4", + ColumnType::Int32, Value::int32(i32::MIN), Value::int32(i32::MAX), )); @@ -202,15 +229,23 @@ test_type!(serial4( test_type!(serial8( postgresql, "serial8", + ColumnType::Int64, Value::int64(i64::MIN), Value::int64(i64::MAX), )); -test_type!(char(postgresql, "char(6)", Value::null_text(), Value::text("foobar"))); +test_type!(char( + postgresql, + "char(6)", + ColumnType::Text, + Value::null_text(), + Value::text("foobar") +)); test_type!(char_array( postgresql, "char(6)[]", + ColumnType::TextArray, Value::null_array(), Value::array(vec![Value::text("foobar"), Value::text("omgwtf"), Value::null_text()]) )); @@ -218,6 +253,7 @@ test_type!(char_array( test_type!(varchar( postgresql, "varchar(255)", + ColumnType::Text, Value::null_text(), Value::text("foobar") )); @@ -225,24 +261,39 @@ test_type!(varchar( test_type!(varchar_array( postgresql, "varchar(255)[]", + ColumnType::TextArray, Value::null_array(), Value::array(vec![Value::text("foobar"), Value::text("omgwtf"), Value::null_text()]) )); -test_type!(text(postgresql, "text", Value::null_text(), Value::text("foobar"))); +test_type!(text( + postgresql, + "text", + ColumnType::Text, + Value::null_text(), + Value::text("foobar") +)); test_type!(text_array( postgresql, "text[]", + ColumnType::TextArray, Value::null_array(), Value::array(vec![Value::text("foobar"), Value::text("omgwtf"), Value::null_text()]) )); -test_type!(bit(postgresql, "bit(4)", Value::null_text(), Value::text("1001"))); +test_type!(bit( + postgresql, + "bit(4)", + ColumnType::Text, + Value::null_text(), + Value::text("1001") +)); test_type!(bit_array( postgresql, "bit(4)[]", + ColumnType::TextArray, Value::null_array(), Value::array(vec![Value::text("1001"), Value::text("0110"), Value::null_text()]) )); @@ -250,6 +301,7 @@ test_type!(bit_array( test_type!(varbit( postgresql, "varbit(20)", + ColumnType::Text, Value::null_text(), Value::text("001010101") )); @@ -257,6 +309,7 @@ test_type!(varbit( test_type!(varbit_array( postgresql, "varbit(20)[]", + ColumnType::TextArray, Value::null_array(), Value::array(vec![ Value::text("001010101"), @@ -265,11 +318,18 @@ test_type!(varbit_array( ]) )); -test_type!(inet(postgresql, "inet", Value::null_text(), Value::text("127.0.0.1"))); +test_type!(inet( + postgresql, + "inet", + ColumnType::Text, + Value::null_text(), + Value::text("127.0.0.1") +)); test_type!(inet_array( postgresql, "inet[]", + ColumnType::TextArray, Value::null_array(), Value::array(vec![ Value::text("127.0.0.1"), @@ -281,6 +341,7 @@ test_type!(inet_array( test_type!(json( postgresql, "json", + ColumnType::Json, Value::null_json(), Value::json(serde_json::json!({"foo": "bar"})) )); @@ -288,6 +349,7 @@ test_type!(json( test_type!(json_array( postgresql, "json[]", + ColumnType::JsonArray, Value::null_array(), Value::array(vec![ Value::json(serde_json::json!({"foo": "bar"})), @@ -299,6 +361,7 @@ test_type!(json_array( test_type!(jsonb( postgresql, "jsonb", + ColumnType::Json, Value::null_json(), Value::json(serde_json::json!({"foo": "bar"})) )); @@ -306,6 +369,7 @@ test_type!(jsonb( test_type!(jsonb_array( postgresql, "jsonb[]", + ColumnType::JsonArray, Value::null_array(), Value::array(vec![ Value::json(serde_json::json!({"foo": "bar"})), @@ -314,11 +378,18 @@ test_type!(jsonb_array( ]) )); -test_type!(xml(postgresql, "xml", Value::null_xml(), Value::xml("1",))); +test_type!(xml( + postgresql, + "xml", + ColumnType::Xml, + Value::null_xml(), + Value::xml("1",) +)); test_type!(xml_array( postgresql, "xml[]", + ColumnType::TextArray, Value::null_array(), Value::array(vec![ Value::text("1"), @@ -330,6 +401,7 @@ test_type!(xml_array( test_type!(uuid( postgresql, "uuid", + ColumnType::Uuid, Value::null_uuid(), Value::uuid(uuid::Uuid::from_str("936DA01F-9ABD-4D9D-80C7-02AF85C822A8").unwrap()) )); @@ -337,6 +409,7 @@ test_type!(uuid( test_type!(uuid_array( postgresql, "uuid[]", + ColumnType::UuidArray, Value::null_array(), Value::array(vec![ Value::uuid(uuid::Uuid::from_str("936DA01F-9ABD-4D9D-80C7-02AF85C822A8").unwrap()), @@ -347,6 +420,7 @@ test_type!(uuid_array( test_type!(date( postgresql, "date", + ColumnType::Date, Value::null_date(), Value::date(chrono::NaiveDate::from_ymd_opt(2020, 4, 20).unwrap()) )); @@ -354,6 +428,7 @@ test_type!(date( test_type!(date_array( postgresql, "date[]", + ColumnType::DateArray, Value::null_array(), Value::array(vec![ Value::date(chrono::NaiveDate::from_ymd_opt(2020, 4, 20).unwrap()), @@ -364,6 +439,7 @@ test_type!(date_array( test_type!(time( postgresql, "time", + ColumnType::Time, Value::null_time(), Value::time(chrono::NaiveTime::from_hms_opt(16, 20, 00).unwrap()) )); @@ -371,6 +447,7 @@ test_type!(time( test_type!(time_array( postgresql, "time[]", + ColumnType::TimeArray, Value::null_array(), Value::array(vec![ Value::time(chrono::NaiveTime::from_hms_opt(16, 20, 00).unwrap()), @@ -378,37 +455,62 @@ test_type!(time_array( ]) )); -test_type!(timestamp(postgresql, "timestamp", Value::null_datetime(), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); - Value::datetime(dt.with_timezone(&chrono::Utc)) -})); +test_type!(timestamp( + postgresql, + "timestamp", + ColumnType::DateTime, + Value::null_datetime(), + { + let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); + Value::datetime(dt.with_timezone(&chrono::Utc)) + } +)); -test_type!(timestamp_array(postgresql, "timestamp[]", Value::null_array(), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); +test_type!(timestamp_array( + postgresql, + "timestamp[]", + ColumnType::DateTimeArray, + Value::null_array(), + { + let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); - Value::array(vec![ - Value::datetime(dt.with_timezone(&chrono::Utc)), - Value::null_datetime(), - ]) -})); + Value::array(vec![ + Value::datetime(dt.with_timezone(&chrono::Utc)), + Value::null_datetime(), + ]) + } +)); -test_type!(timestamptz(postgresql, "timestamptz", Value::null_datetime(), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); - Value::datetime(dt.with_timezone(&chrono::Utc)) -})); +test_type!(timestamptz( + postgresql, + "timestamptz", + ColumnType::DateTime, + Value::null_datetime(), + { + let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); + Value::datetime(dt.with_timezone(&chrono::Utc)) + } +)); -test_type!(timestamptz_array(postgresql, "timestamptz[]", Value::null_array(), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); +test_type!(timestamptz_array( + postgresql, + "timestamptz[]", + ColumnType::DateTimeArray, + Value::null_array(), + { + let dt = chrono::DateTime::parse_from_rfc3339("2020-02-27T19:10:22Z").unwrap(); - Value::array(vec![ - Value::datetime(dt.with_timezone(&chrono::Utc)), - Value::null_datetime(), - ]) -})); + Value::array(vec![ + Value::datetime(dt.with_timezone(&chrono::Utc)), + Value::null_datetime(), + ]) + } +)); test_type!(bytea( postgresql, "bytea", + ColumnType::Bytes, Value::null_bytes(), Value::bytes(b"DEADBEEF".to_vec()) )); @@ -416,6 +518,7 @@ test_type!(bytea( test_type!(bytea_array( postgresql, "bytea[]", + ColumnType::BytesArray, Value::null_array(), Value::array(vec![ Value::bytes(b"DEADBEEF".to_vec()), diff --git a/quaint/src/tests/types/postgres/bigdecimal.rs b/quaint/src/tests/types/postgres/bigdecimal.rs index 894b2c967629..1f8dcd663164 100644 --- a/quaint/src/tests/types/postgres/bigdecimal.rs +++ b/quaint/src/tests/types/postgres/bigdecimal.rs @@ -1,9 +1,11 @@ use super::*; use crate::bigdecimal::BigDecimal; +use crate::macros::assert_matching_value_and_column_type; test_type!(decimal( postgresql, "decimal(10,2)", + ColumnType::Numeric, Value::null_numeric(), Value::numeric(BigDecimal::from_str("3.14")?) )); @@ -11,6 +13,7 @@ test_type!(decimal( test_type!(decimal_10_2( postgresql, "decimal(10, 2)", + ColumnType::Numeric, ( Value::numeric(BigDecimal::from_str("3950.123456")?), Value::numeric(BigDecimal::from_str("3950.12")?) @@ -20,6 +23,7 @@ test_type!(decimal_10_2( test_type!(decimal_35_6( postgresql, "decimal(35, 6)", + ColumnType::Numeric, ( Value::numeric(BigDecimal::from_str("3950")?), Value::numeric(BigDecimal::from_str("3950.000000")?) @@ -101,6 +105,7 @@ test_type!(decimal_35_6( test_type!(decimal_35_2( postgresql, "decimal(35, 2)", + ColumnType::Numeric, ( Value::numeric(BigDecimal::from_str("3950.123456")?), Value::numeric(BigDecimal::from_str("3950.12")?) @@ -114,12 +119,14 @@ test_type!(decimal_35_2( test_type!(decimal_4_0( postgresql, "decimal(4, 0)", + ColumnType::Numeric, Value::numeric(BigDecimal::from_str("3950")?) )); test_type!(decimal_65_30( postgresql, "decimal(65, 30)", + ColumnType::Numeric, ( Value::numeric(BigDecimal::from_str("1.2")?), Value::numeric(BigDecimal::from_str("1.2000000000000000000000000000")?) @@ -133,6 +140,7 @@ test_type!(decimal_65_30( test_type!(decimal_65_34( postgresql, "decimal(65, 34)", + ColumnType::Numeric, ( Value::numeric(BigDecimal::from_str("3.1415926535897932384626433832795028")?), Value::numeric(BigDecimal::from_str("3.1415926535897932384626433832795028")?) @@ -150,12 +158,14 @@ test_type!(decimal_65_34( test_type!(decimal_35_0( postgresql, "decimal(35, 0)", + ColumnType::Numeric, Value::numeric(BigDecimal::from_str("79228162514264337593543950335")?), )); test_type!(decimal_35_1( postgresql, "decimal(35, 1)", + ColumnType::Numeric, ( Value::numeric(BigDecimal::from_str("79228162514264337593543950335")?), Value::numeric(BigDecimal::from_str("79228162514264337593543950335.0")?) @@ -169,6 +179,7 @@ test_type!(decimal_35_1( test_type!(decimal_128_6( postgresql, "decimal(128, 6)", + ColumnType::Numeric, Value::numeric(BigDecimal::from_str( "93431006223456789876545678909876545678903434369343100622345678987654567890987654567890343436999999100622345678343699999910.345678" )?), @@ -177,6 +188,7 @@ test_type!(decimal_128_6( test_type!(decimal_array( postgresql, "decimal(10,2)[]", + ColumnType::NumericArray, Value::null_array(), Value::array(vec![BigDecimal::from_str("3.14")?, BigDecimal::from_str("5.12")?]) )); @@ -184,6 +196,7 @@ test_type!(decimal_array( test_type!(money( postgresql, "money", + ColumnType::Numeric, Value::null_numeric(), Value::numeric(BigDecimal::from_str("1.12")?) )); @@ -191,6 +204,7 @@ test_type!(money( test_type!(money_array( postgresql, "money[]", + ColumnType::NumericArray, Value::null_array(), Value::array(vec![BigDecimal::from_str("1.12")?, BigDecimal::from_str("1.12")?]) )); @@ -198,6 +212,7 @@ test_type!(money_array( test_type!(float4( postgresql, "float4", + ColumnType::Float, (Value::null_numeric(), Value::null_float()), ( Value::numeric(BigDecimal::from_str("1.123456")?), @@ -208,6 +223,7 @@ test_type!(float4( test_type!(float8( postgresql, "float8", + ColumnType::Double, (Value::null_numeric(), Value::null_double()), ( Value::numeric(BigDecimal::from_str("1.123456")?), diff --git a/quaint/src/tests/types/sqlite.rs b/quaint/src/tests/types/sqlite.rs index c4950e748697..647f7217c83c 100644 --- a/quaint/src/tests/types/sqlite.rs +++ b/quaint/src/tests/types/sqlite.rs @@ -1,5 +1,7 @@ #![allow(clippy::approx_constant)] +use crate::connector::ColumnType; +use crate::macros::assert_matching_value_and_column_type; use crate::tests::test_api::sqlite_test_api; use crate::tests::test_api::TestApi; use crate::{ast::*, connector::Queryable}; @@ -9,6 +11,7 @@ use std::str::FromStr; test_type!(integer( sqlite, "INTEGER", + ColumnType::Int32, Value::null_int32(), Value::int32(i8::MIN), Value::int32(i8::MAX), @@ -21,17 +24,25 @@ test_type!(integer( test_type!(big_int( sqlite, "BIGINT", + ColumnType::Int64, Value::null_int64(), Value::int64(i64::MIN), Value::int64(i64::MAX), )); -test_type!(real(sqlite, "REAL", Value::null_double(), Value::double(1.12345))); +test_type!(real( + sqlite, + "REAL", + ColumnType::Double, + Value::null_double(), + Value::double(1.12345) +)); test_type!(float_decimal( sqlite, "FLOAT", - (Value::null_numeric(), Value::null_float()), + ColumnType::Double, + (Value::null_numeric(), Value::null_double()), ( Value::numeric(bigdecimal::BigDecimal::from_str("3.14").unwrap()), Value::double(3.14) @@ -41,6 +52,7 @@ test_type!(float_decimal( test_type!(double_decimal( sqlite, "DOUBLE", + ColumnType::Double, (Value::null_numeric(), Value::null_double()), ( Value::numeric(bigdecimal::BigDecimal::from_str("3.14").unwrap()), @@ -48,27 +60,44 @@ test_type!(double_decimal( ) )); -test_type!(text(sqlite, "TEXT", Value::null_text(), Value::text("foobar huhuu"))); +test_type!(text( + sqlite, + "TEXT", + ColumnType::Text, + Value::null_text(), + Value::text("foobar huhuu") +)); test_type!(blob( sqlite, "BLOB", + ColumnType::Bytes, Value::null_bytes(), Value::bytes(b"DEADBEEF".to_vec()) )); -test_type!(float(sqlite, "FLOAT", Value::null_float(), Value::double(1.23))); +test_type!(float( + sqlite, + "FLOAT", + ColumnType::Double, + (Value::null_float(), Value::null_double()), + (Value::null_double(), Value::null_double()), + (Value::float(1.23456), Value::double(1.23456)), + (Value::double(1.2312313213), Value::double(1.2312313213)) +)); test_type!(double( sqlite, "DOUBLE", + ColumnType::Double, Value::null_double(), - Value::double(1.2312313213) + Value::double(1.2312313213), )); test_type!(boolean( sqlite, "BOOLEAN", + ColumnType::Boolean, Value::null_boolean(), Value::boolean(true), Value::boolean(false) @@ -77,6 +106,7 @@ test_type!(boolean( test_type!(date( sqlite, "DATE", + ColumnType::Date, Value::null_date(), Value::date(chrono::NaiveDate::from_ymd_opt(1984, 1, 1).unwrap()) )); @@ -84,6 +114,7 @@ test_type!(date( test_type!(datetime( sqlite, "DATETIME", + ColumnType::DateTime, Value::null_datetime(), Value::datetime(chrono::DateTime::from_str("2020-07-29T09:23:44.458Z").unwrap()) )); @@ -104,6 +135,7 @@ async fn test_type_text_datetime_rfc3339(api: &mut dyn TestApi) -> crate::Result let res = api.conn().select(select).await?.into_single()?; assert_eq!(Some(&Value::datetime(dt)), res.at(0)); + assert_matching_value_and_column_type(&res.types[0], res.at(0).unwrap()); Ok(()) } @@ -125,7 +157,9 @@ async fn test_type_text_datetime_rfc2822(api: &mut dyn TestApi) -> crate::Result let select = Select::from_table(&table).column("value").order_by("id".descend()); let res = api.conn().select(select).await?.into_single()?; + assert_eq!(ColumnType::DateTime, res.types[0]); assert_eq!(Some(&Value::datetime(dt)), res.at(0)); + assert_matching_value_and_column_type(&res.types[0], res.at(0).unwrap()); Ok(()) } @@ -147,7 +181,9 @@ async fn test_type_text_datetime_custom(api: &mut dyn TestApi) -> crate::Result< let naive = chrono::NaiveDateTime::parse_from_str("2020-04-20 16:20:00", "%Y-%m-%d %H:%M:%S").unwrap(); let expected = chrono::DateTime::from_naive_utc_and_offset(naive, chrono::Utc); + assert_eq!(ColumnType::DateTime, res.types[0]); assert_eq!(Some(&Value::datetime(expected)), res.at(0)); + assert_matching_value_and_column_type(&res.types[0], res.at(0).unwrap()); Ok(()) } diff --git a/query-engine/query-engine-node-api/tmp b/quaint/src/tests/types/utils.rs similarity index 100% rename from query-engine/query-engine-node-api/tmp rename to quaint/src/tests/types/utils.rs diff --git a/query-engine/black-box-tests/Cargo.toml b/query-engine/black-box-tests/Cargo.toml index c5f88c844dc7..e08ebb962e1d 100644 --- a/query-engine/black-box-tests/Cargo.toml +++ b/query-engine/black-box-tests/Cargo.toml @@ -14,5 +14,5 @@ tokio.workspace = true user-facing-errors.workspace = true insta = "1.7.1" enumflags2.workspace = true -query-engine-metrics = {path = "../metrics"} -regex = "1.9.3" +prisma-metrics.path = "../../libs/metrics" +regex.workspace = true diff --git a/query-engine/black-box-tests/tests/helpers/mod.rs b/query-engine/black-box-tests/tests/helpers/mod.rs index 95dc53313b8d..efebea03ac3f 100644 --- a/query-engine/black-box-tests/tests/helpers/mod.rs +++ b/query-engine/black-box-tests/tests/helpers/mod.rs @@ -52,10 +52,7 @@ pub(crate) fn query_engine_cmd(dml: &str) -> (process::Command, String) { cmd.env_clear(); let port = generate_free_port(); - cmd.env("PRISMA_DML", dml) - .arg("--port") - .arg(&port.to_string()) - .arg("-g"); + cmd.env("PRISMA_DML", dml).arg("--port").arg(port.to_string()).arg("-g"); (cmd, format!("http://0.0.0.0:{}", port)) } diff --git a/query-engine/black-box-tests/tests/metrics/smoke_tests.rs b/query-engine/black-box-tests/tests/metrics/smoke_tests.rs index b21f22265f98..91c3e719456f 100644 --- a/query-engine/black-box-tests/tests/metrics/smoke_tests.rs +++ b/query-engine/black-box-tests/tests/metrics/smoke_tests.rs @@ -122,7 +122,7 @@ mod smoke_tests { assert_eq!(metrics.matches("# TYPE prisma_datasource_queries_duration_histogram_ms histogram").count(), 1); // Check that exist as many metrics as being accepted - let accepted_metric_count = query_engine_metrics::ACCEPT_LIST.len(); + let accepted_metric_count = prisma_metrics::ACCEPT_LIST.len(); let displayed_metric_count = metrics.matches("# TYPE").count(); let non_prisma_metric_count = displayed_metric_count - metrics.matches("# TYPE prisma").count(); diff --git a/query-engine/connector-test-kit-rs/README.md b/query-engine/connector-test-kit-rs/README.md index d896358d06ed..ef8396f48045 100644 --- a/query-engine/connector-test-kit-rs/README.md +++ b/query-engine/connector-test-kit-rs/README.md @@ -83,7 +83,7 @@ drivers the code that actually communicates with the databases. See [`adapter-*` To run tests through a driver adapters, you should also configure the following environment variables: * `DRIVER_ADAPTER`: tells the test executor to use a particular driver adapter. Set to `neon`, `planetscale` or any other supported adapter. -* `DRIVER_ADAPTER_CONFIG`: a json string with the configuration for the driver adapter. This is adapter specific. See the [github workflow for driver adapter tests](.github/workflows/query-engine-driver-adapters.yml) for examples on how to configure the driver adapters. +* `DRIVER_ADAPTER_CONFIG`: a json string with the configuration for the driver adapter. This is adapter specific. See the [GitHub workflow for driver adapter tests](.github/workflows/query-engine-driver-adapters.yml) for examples on how to configure the driver adapters. * `ENGINE`: can be used to run either `wasm` or `napi` or `c-abi` version of the engine. Example: @@ -339,7 +339,7 @@ run_query!( **Accepting a snapshot update will replace, directly in your code, the expected output in the assertion.** -If you dislike the interactive view, you can also run `cargo insta accept` to automatically accept all snapshots and then use your git diff to check if everything is as intented. +If you dislike the interactive view, you can also run `cargo insta accept` to automatically accept all snapshots and then use your git diff to check if everything is as intended. ##### Without `cargo-insta` diff --git a/query-engine/connector-test-kit-rs/qe-setup/Cargo.toml b/query-engine/connector-test-kit-rs/qe-setup/Cargo.toml index b3b75f294fcc..dc225369c87b 100644 --- a/query-engine/connector-test-kit-rs/qe-setup/Cargo.toml +++ b/query-engine/connector-test-kit-rs/qe-setup/Cargo.toml @@ -16,6 +16,6 @@ serde_json.workspace = true tokio.workspace = true connection-string = "*" -mongodb = "2.8.0" +mongodb.workspace = true url.workspace = true once_cell = "1.17.0" diff --git a/query-engine/connector-test-kit-rs/qe-setup/src/cockroachdb.rs b/query-engine/connector-test-kit-rs/qe-setup/src/cockroachdb.rs index 0d876d6b4dcf..c901b9ef3887 100644 --- a/query-engine/connector-test-kit-rs/qe-setup/src/cockroachdb.rs +++ b/query-engine/connector-test-kit-rs/qe-setup/src/cockroachdb.rs @@ -5,7 +5,7 @@ use url::Url; pub(crate) async fn cockroach_setup(url: String, prisma_schema: &str) -> ConnectorResult<()> { let mut parsed_url = Url::parse(&url).map_err(ConnectorError::url_parse_error)?; - let mut quaint_url = quaint::connector::PostgresUrl::new(parsed_url.clone()).unwrap(); + let mut quaint_url = quaint::connector::PostgresNativeUrl::new(parsed_url.clone()).unwrap(); quaint_url.set_flavour(PostgresFlavour::Cockroach); let db_name = quaint_url.dbname(); diff --git a/query-engine/connector-test-kit-rs/qe-setup/src/lib.rs b/query-engine/connector-test-kit-rs/qe-setup/src/lib.rs index 530717cc94db..17d9ec06ab5e 100644 --- a/query-engine/connector-test-kit-rs/qe-setup/src/lib.rs +++ b/query-engine/connector-test-kit-rs/qe-setup/src/lib.rs @@ -65,14 +65,11 @@ fn parse_configuration(datamodel: &str) -> ConnectorResult<(Datasource, String, /// (rather than just the Schema Engine), this function will call [`ExternalInitializer::init_with_migration`]. /// Otherwise, it will call [`ExternalInitializer::init`], and then proceed with the standard /// setup based on the Schema Engine. -pub async fn setup_external<'a, EI>( +pub async fn setup_external<'a>( driver_adapter: DriverAdapter, - initializer: EI, + initializer: impl ExternalInitializer<'a>, db_schemas: &[&str], -) -> ConnectorResult -where - EI: ExternalInitializer<'a> + ?Sized, -{ +) -> ConnectorResult { let prisma_schema = initializer.datamodel(); let (source, url, _preview_features) = parse_configuration(prisma_schema)?; diff --git a/query-engine/connector-test-kit-rs/qe-setup/src/mongodb.rs b/query-engine/connector-test-kit-rs/qe-setup/src/mongodb.rs index ebaf97f8f31d..7b66af53503d 100644 --- a/query-engine/connector-test-kit-rs/qe-setup/src/mongodb.rs +++ b/query-engine/connector-test-kit-rs/qe-setup/src/mongodb.rs @@ -10,11 +10,12 @@ pub(crate) async fn mongo_setup(schema: &str, url: &str) -> ConnectorResult<()> client .database(&db_name) - .drop(Some( + .drop() + .with_options( mongodb::options::DropDatabaseOptions::builder() .write_concern(mongodb::options::WriteConcern::builder().journal(true).build()) .build(), - )) + ) .await .unwrap(); @@ -24,7 +25,7 @@ pub(crate) async fn mongo_setup(schema: &str, url: &str) -> ConnectorResult<()> for model in parsed_schema.db.walk_models() { client .database(&db_name) - .create_collection(model.database_name(), None) + .create_collection(model.database_name()) .await .unwrap(); } diff --git a/query-engine/connector-test-kit-rs/qe-setup/src/postgres.rs b/query-engine/connector-test-kit-rs/qe-setup/src/postgres.rs index 6bbba8564cae..536f51eb4834 100644 --- a/query-engine/connector-test-kit-rs/qe-setup/src/postgres.rs +++ b/query-engine/connector-test-kit-rs/qe-setup/src/postgres.rs @@ -5,7 +5,7 @@ use url::Url; pub(crate) async fn postgres_setup(url: String, prisma_schema: &str, db_schemas: &[&str]) -> ConnectorResult<()> { let mut parsed_url = Url::parse(&url).map_err(ConnectorError::url_parse_error)?; - let mut quaint_url = quaint::connector::PostgresUrl::new(parsed_url.clone()).unwrap(); + let mut quaint_url = quaint::connector::PostgresNativeUrl::new(parsed_url.clone()).unwrap(); quaint_url.set_flavour(PostgresFlavour::Postgres); let (db_name, schema) = (quaint_url.dbname(), quaint_url.schema()); diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/Cargo.toml b/query-engine/connector-test-kit-rs/query-engine-tests/Cargo.toml index 46d1d4b845fd..60cfbca7ca18 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/Cargo.toml +++ b/query-engine/connector-test-kit-rs/query-engine-tests/Cargo.toml @@ -11,7 +11,7 @@ query-test-macros = { path = "../query-test-macros" } query-tests-setup = { path = "../query-tests-setup" } indoc.workspace = true tracing.workspace = true -tracing-futures = "0.2" +tracing-futures.workspace = true colored = "2" chrono.workspace = true psl.workspace = true @@ -20,9 +20,9 @@ uuid.workspace = true tokio.workspace = true user-facing-errors.workspace = true prisma-value = { path = "../../../libs/prisma-value" } -query-engine-metrics = { path = "../../metrics"} +prisma-metrics.path = "../../../libs/metrics" once_cell = "1.15.0" -futures = "0.3" +futures.workspace = true paste = "1.0.14" [dev-dependencies] diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/src/utils/querying.rs b/query-engine/connector-test-kit-rs/query-engine-tests/src/utils/querying.rs index 268fddef976b..ce64623645bd 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/src/utils/querying.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/src/utils/querying.rs @@ -34,6 +34,28 @@ macro_rules! match_connector_result { }; } +#[macro_export] +macro_rules! assert_connector_error { + ($runner:expr, $q:expr, $code:expr, $( $($matcher:pat_param)|+ $( if $pred:expr )? => $msg:expr ),*) => { + use query_tests_setup::*; + use query_tests_setup::ConnectorVersion::*; + + let connector = $runner.connector_version(); + + let mut results = match &connector { + $( + $( $matcher )|+ $( if $pred )? => $msg.to_string() + ),* + }; + + if results.len() == 0 { + panic!("No assertion failure defined for connector {connector}."); + } + + $runner.query($q).await?.assert_failure($code, Some(results)); + }; +} + #[macro_export] macro_rules! is_one_of { ($result:expr, $potential_results:expr) => { diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs index da0db0a0e70a..100828697046 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs @@ -3,6 +3,8 @@ use std::borrow::Cow; #[test_suite(schema(generic), exclude(Sqlite("cfd1")))] mod interactive_tx { + use std::time::{Duration, Instant}; + use query_engine_tests::*; use tokio::time; @@ -213,7 +215,7 @@ mod interactive_tx { Ok(()) } - #[connector_test(exclude(Vitess("planetscale.js.wasm"), Sqlite("cfd1")))] + #[connector_test(exclude(Sqlite("cfd1")))] async fn batch_queries_failure(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. let tx_id = runner.start_tx(5000, 5000, None).await?; @@ -231,7 +233,9 @@ mod interactive_tx { let batch_results = runner.batch(queries, false, None).await?; batch_results.assert_failure(2002, None); + let now = Instant::now(); let res = runner.commit_tx(tx_id.clone()).await?; + assert!(now.elapsed() <= Duration::from_millis(5000)); if matches!(runner.connector_version(), ConnectorVersion::MongoDb(_)) { assert!(res.is_err()); @@ -256,7 +260,7 @@ mod interactive_tx { Ok(()) } - #[connector_test(exclude(Vitess("planetscale.js.wasm")))] + #[connector_test] async fn tx_expiration_failure_cycle(mut runner: Runner) -> TestResult<()> { // Tx expires after one seconds. let tx_id = runner.start_tx(5000, 1000, None).await?; @@ -570,13 +574,13 @@ mod interactive_tx { #[test_suite(schema(generic), exclude(Sqlite("cfd1")))] mod itx_isolation { + use std::sync::Arc; + use query_engine_tests::*; + use tokio::task::JoinSet; // All (SQL) connectors support serializable. - // However, there's a bug in the PlanetScale driver adapter: - // "Transaction characteristics can't be changed while a transaction is in progress - // (errno 1568) (sqlstate 25001) during query: SET TRANSACTION ISOLATION LEVEL SERIALIZABLE" - #[connector_test(exclude(MongoDb, Vitess("planetscale.js", "planetscale.js.wasm"), Sqlite("cfd1")))] + #[connector_test(exclude(MongoDb, Sqlite("cfd1")))] async fn basic_serializable(mut runner: Runner) -> TestResult<()> { let tx_id = runner.start_tx(5000, 5000, Some("Serializable".to_owned())).await?; runner.set_active_tx(tx_id.clone()); @@ -598,9 +602,7 @@ mod itx_isolation { Ok(()) } - // On PlanetScale, this fails with: - // `InteractiveTransactionError("Error in connector: Error querying the database: Server error: `ERROR 25001 (1568): Transaction characteristics can't be changed while a transaction is in progress'")` - #[connector_test(exclude(MongoDb, Vitess("planetscale.js", "planetscale.js.wasm"), Sqlite("cfd1")))] + #[connector_test(exclude(MongoDb, Sqlite("cfd1")))] async fn casing_doesnt_matter(mut runner: Runner) -> TestResult<()> { let tx_id = runner.start_tx(5000, 5000, Some("sErIaLiZaBlE".to_owned())).await?; runner.set_active_tx(tx_id.clone()); @@ -654,4 +656,45 @@ mod itx_isolation { Ok(()) } + + #[connector_test(exclude(Sqlite))] + async fn high_concurrency(runner: Runner) -> TestResult<()> { + let runner = Arc::new(runner); + let mut set = JoinSet::>::new(); + + for i in 1..=20 { + set.spawn({ + let runner = Arc::clone(&runner); + async move { + let tx_id = runner.start_tx(5000, 5000, None).await?; + + runner + .query_in_tx( + &tx_id, + format!( + r#"mutation {{ + createOneTestModel( + data: {{ + id: {i} + }} + ) {{ id }} + }}"# + ), + ) + .await? + .assert_success(); + + runner.commit_tx(tx_id).await?.expect("commit must succeed"); + + Ok(()) + } + }); + } + + while let Some(handle) = set.join_next().await { + handle.expect("task panicked or canceled")?; + } + + Ok(()) + } } diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs index 2a1cf89e9d3b..323f162a2111 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs @@ -9,9 +9,7 @@ use query_engine_tests::test_suite; ) )] mod metrics { - use query_engine_metrics::{ - PRISMA_CLIENT_QUERIES_ACTIVE, PRISMA_CLIENT_QUERIES_TOTAL, PRISMA_DATASOURCE_QUERIES_TOTAL, - }; + use prisma_metrics::{PRISMA_CLIENT_QUERIES_ACTIVE, PRISMA_CLIENT_QUERIES_TOTAL, PRISMA_DATASOURCE_QUERIES_TOTAL}; use query_engine_tests::ConnectorVersion::*; use query_engine_tests::*; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_delete/set_default.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_delete/set_default.rs index 131dbcf89591..9e347d476059 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_delete/set_default.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_delete/set_default.rs @@ -80,7 +80,7 @@ mod one2one_req { runner, "mutation { deleteOneParent(where: { id: 1 }) { id }}", 2003, - "Foreign key constraint failed on the field" + "Foreign key constraint violated" ); Ok(()) @@ -187,7 +187,7 @@ mod one2one_opt { runner, "mutation { deleteOneParent(where: { id: 1 }) { id }}", 2003, - "Foreign key constraint failed on the field" + "Foreign key constraint violated" ); Ok(()) @@ -293,7 +293,7 @@ mod one2many_req { runner, "mutation { deleteOneParent(where: { id: 1 }) { id }}", 2003, - "Foreign key constraint failed on the field" + "Foreign key constraint violated" ); Ok(()) @@ -397,7 +397,7 @@ mod one2many_opt { runner, "mutation { deleteOneParent(where: { id: 1 }) { id }}", 2003, - "Foreign key constraint failed on the field" + "Foreign key constraint violated" ); Ok(()) diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/cascade.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/cascade.rs index 99cd190e161a..1da34e776286 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/cascade.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/cascade.rs @@ -33,13 +33,7 @@ mod one2one_req { schema.to_owned() } - #[connector_test(schema(required), exclude(Sqlite("cfd1")))] - /// On D1, this fails with: - /// - /// ```diff - /// - {"data":{"updateManyParent":{"count":1}}} - /// + {"data":{"updateManyParent":{"count":2}}} - /// ``` + #[connector_test(schema(required))] async fn update_parent_cascade(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { @@ -176,14 +170,7 @@ mod one2one_opt { schema.to_owned() } - #[connector_test(schema(optional), exclude(Sqlite("cfd1")))] - // Updating the parent updates the child FK as well. - // On D1, this fails with: - // - // ```diff - // - {"data":{"updateManyParent":{"count":1}}} - // + {"data":{"updateManyParent":{"count":2}}} - // ``` + #[connector_test(schema(optional))] async fn update_parent_cascade(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/restrict.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/restrict.rs index 99c3c0b094dc..f8bff4d30f06 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/restrict.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/restrict.rs @@ -255,13 +255,7 @@ mod one2many_req { Ok(()) } - #[connector_test(exclude(Sqlite("cfd1")))] - /// Updating the parent succeeds if no child is connected or if the linking fields aren't part of the update payload. - /// - /// ```diff - /// - {"data":{"updateManyParent":{"count":1}}} - /// + {"data":{"updateManyParent":{"count":2}}} - /// ``` + #[connector_test] async fn update_parent(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; run_query!( @@ -383,13 +377,7 @@ mod one2many_opt { Ok(()) } - #[connector_test(exclude(Sqlite("cfd1")))] - /// Updating the parent succeeds if no child is connected or if the linking fields aren't part of the update payload. - /// - /// ```diff - /// - {"data":{"updateManyParent":{"count":1}}} - /// + {"data":{"updateManyParent":{"count":2}}} - /// ``` + #[connector_test] async fn update_parent(runner: Runner) -> TestResult<()> { create_test_data(&runner).await?; run_query!( diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/set_default.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/set_default.rs index 99c2ffb63a5e..73ee612bbfc0 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/set_default.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/set_default.rs @@ -79,7 +79,7 @@ mod one2one_req { &runner, r#"mutation { updateOneParent(where: { id: 1 }, data: { uniq: "u1" }) { id }}"#, 2003, - "Foreign key constraint failed on the field" + "Foreign key constraint violated" ); Ok(()) @@ -182,7 +182,7 @@ mod one2one_opt { &runner, r#"mutation { updateOneParent(where: { id: 1 } data: { uniq: "u1" }) { id }}"#, 2003, - "Foreign key constraint failed on the field" + "Foreign key constraint violated" ); Ok(()) @@ -287,7 +287,7 @@ mod one2many_req { &runner, r#"mutation { updateOneParent(where: { id: 1 }, data: { uniq: "u1" }) { id }}"#, 2003, - "Foreign key constraint failed on the field" + "Foreign key constraint violated" ); Ok(()) @@ -390,7 +390,7 @@ mod one2many_opt { &runner, r#"mutation { updateOneParent(where: { id: 1 }, data: { uniq: "u1" }) { id }}"#, 2003, - "Foreign key constraint failed on the field" + "Foreign key constraint violated" ); Ok(()) diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/set_null.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/set_null.rs index 8ef0ab0d1e8c..dcb4083dbba9 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/set_null.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/set_null.rs @@ -65,13 +65,7 @@ mod one2one_opt { Ok(()) } - #[connector_test(exclude(Sqlite("cfd1")))] - // On D1, this fails with: - // - // ```diff - // - {"data":{"updateManyParent":{"count":1}}} - // + {"data":{"updateManyParent":{"count":2}}} - // ``` + #[connector_test] async fn update_many_parent(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { createOneParent(data: { id: 1, uniq: "1", child: { create: { id: 1 }}}) { id }}"#), @@ -457,13 +451,7 @@ mod one2many_opt { Ok(()) } - #[connector_test(exclude(Sqlite("cfd1")))] - // On D1, this fails with: - // - // ```diff - // - {"data":{"updateManyParent":{"count":1}}} - // + {"data":{"updateManyParent":{"count":2}}} - // ``` + #[connector_test] async fn update_many_parent(runner: Runner) -> TestResult<()> { insta::assert_snapshot!( run_query!(&runner, r#"mutation { createOneParent(data: { id: 1, uniq: "1", children: { create: { id: 1 }}}) { id }}"#), diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/mod.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/mod.rs index 4b014fa53f69..6ab70f2975d3 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/mod.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/mod.rs @@ -1,6 +1,7 @@ mod max_integer; mod prisma_10098; mod prisma_10935; +mod prisma_11750; mod prisma_11789; mod prisma_12572; mod prisma_12929; @@ -27,6 +28,7 @@ mod prisma_21901; mod prisma_22007; mod prisma_22298; mod prisma_22971; +mod prisma_24072; mod prisma_5952; mod prisma_6173; mod prisma_7010; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_11750.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_11750.rs new file mode 100644 index 000000000000..907aae408bf9 --- /dev/null +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_11750.rs @@ -0,0 +1,127 @@ +use query_engine_tests::*; + +/// Regression test for . +/// +/// See also and +/// . +/// +/// This is a port of the TypeScript test from the client test suite. +/// +/// The test creates a user and then tries to update the same row in multiple concurrent +/// transactions. We don't assert that most operations succeed and merely log the errors happening +/// during update or commit, as those are expected to happen. We do fail the test if creating the +/// user fails, or if we fail to start a transaction, as those operations are expected to succeed. +/// +/// What we really test here, though, is that the query engine must not deadlock (leading to the +/// test never finishing). +/// +/// Some providers are skipped because these concurrent conflicting transactions cause troubles on +/// the database side and failures to start new transactions. +/// +/// For an example of an equivalent test that passes on all databases where the transactions don't +/// conflict and don't cause issues on the database side, see the `high_concurrency` test in the +/// `new::interactive_tx::interactive_tx` test suite. +#[test_suite(schema(user), exclude(Sqlite, MySql(8), SqlServer))] +mod prisma_11750 { + use std::sync::Arc; + use tokio::task::JoinSet; + + #[connector_test] + async fn test_itx_concurrent_updates_single_thread(runner: Runner) -> TestResult<()> { + let runner = Arc::new(runner); + + create_user(&runner, 1, "x").await?; + + for _ in 0..10 { + tokio::try_join!( + update_user(Arc::clone(&runner), "a"), + update_user(Arc::clone(&runner), "b"), + update_user(Arc::clone(&runner), "c"), + update_user(Arc::clone(&runner), "d"), + update_user(Arc::clone(&runner), "e"), + update_user(Arc::clone(&runner), "f"), + update_user(Arc::clone(&runner), "g"), + update_user(Arc::clone(&runner), "h"), + update_user(Arc::clone(&runner), "i"), + update_user(Arc::clone(&runner), "j"), + )?; + } + + create_user(&runner, 2, "y").await?; + + Ok(()) + } + + #[connector_test] + async fn test_itx_concurrent_updates_multi_thread(runner: Runner) -> TestResult<()> { + let runner = Arc::new(runner); + + create_user(&runner, 1, "x").await?; + + for _ in 0..10 { + let mut set = JoinSet::new(); + + for email in ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] { + set.spawn(update_user(Arc::clone(&runner), email)); + } + + while let Some(handle) = set.join_next().await { + handle.expect("task panicked or canceled")?; + } + } + + create_user(&runner, 2, "y").await?; + + Ok(()) + } + + async fn create_user(runner: &Runner, id: u32, email: &str) -> TestResult<()> { + run_query!( + &runner, + format!( + r#"mutation {{ + createOneUser( + data: {{ + id: {id}, + first_name: "{email}", + last_name: "{email}", + email: "{email}" + }} + ) {{ id }} + }}"# + ) + ); + + Ok(()) + } + + async fn update_user(runner: Arc, new_email: &str) -> TestResult<()> { + let tx_id = runner.start_tx(2000, 25, None).await?; + + let result = runner + .query_in_tx( + &tx_id, + format!( + r#"mutation {{ + updateOneUser( + where: {{ id: 1 }}, + data: {{ email: "{new_email}" }} + ) {{ id }} + }}"# + ), + ) + .await; + + if let Err(err) = result { + tracing::error!(%err, "query error"); + } + + let result = runner.commit_tx(tx_id).await?; + + if let Err(err) = result { + tracing::error!(?err, "commit error"); + } + + Ok(()) + } +} diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs index e026a90016bd..7ed3cb9a8598 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs @@ -4,6 +4,7 @@ //! actors to allow test to continue even if one query is blocking. use indoc::indoc; +use prisma_metrics::{MetricRecorder, WithMetricsInstrumentation}; use query_engine_tests::{ query_core::TxId, render_test_datamodel, setup_metrics, test_tracing_subscriber, LogEmit, QueryResult, Runner, TestError, TestLogCapture, TestResult, WithSubscriber, CONFIG, ENV_LOG_LEVEL, @@ -50,13 +51,12 @@ impl Actor { /// Spawns a new query engine to the runtime. pub async fn spawn() -> TestResult { let (log_capture, log_tx) = TestLogCapture::new(); - async fn with_logs(fut: impl Future, log_tx: LogEmit) -> T { - fut.with_subscriber(test_tracing_subscriber( - ENV_LOG_LEVEL.to_string(), - setup_metrics(), - log_tx, - )) - .await + let (metrics, recorder) = setup_metrics(); + + async fn with_observability(fut: impl Future, log_tx: LogEmit, recorder: MetricRecorder) -> T { + fut.with_subscriber(test_tracing_subscriber(ENV_LOG_LEVEL.to_string(), log_tx)) + .with_recorder(recorder) + .await } let (query_sender, mut query_receiver) = mpsc::channel(100); @@ -73,21 +73,24 @@ impl Actor { Some("READ COMMITTED"), ); - let mut runner = Runner::load(datamodel, &[], version, tag, None, setup_metrics(), log_capture).await?; + let mut runner = Runner::load(datamodel, &[], version, tag, None, metrics, log_capture).await?; tokio::spawn(async move { while let Some(message) = query_receiver.recv().await { match message { Message::Query(query) => { - let result = with_logs(runner.query(query), log_tx.clone()).await; + let result = with_observability(runner.query(query), log_tx.clone(), recorder.clone()).await; response_sender.send(Response::Query(result)).await.unwrap(); } Message::BeginTransaction => { - let response = with_logs(runner.start_tx(10000, 10000, None), log_tx.clone()).await; + let response = + with_observability(runner.start_tx(10000, 10000, None), log_tx.clone(), recorder.clone()) + .await; response_sender.send(Response::Tx(response)).await.unwrap(); } Message::RollbackTransaction(tx_id) => { - let response = with_logs(runner.rollback_tx(tx_id), log_tx.clone()).await?; + let response = + with_observability(runner.rollback_tx(tx_id), log_tx.clone(), recorder.clone()).await?; response_sender.send(Response::Rollback(response)).await.unwrap(); } Message::SetActiveTx(tx_id) => { diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_24072.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_24072.rs new file mode 100644 index 000000000000..14d402ee6af6 --- /dev/null +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_24072.rs @@ -0,0 +1,54 @@ +use indoc::indoc; +use query_engine_tests::*; + +// Skip databases that don't support `onDelete: SetDefault` +#[test_suite( + schema(schema), + exclude( + MongoDb, + MySql(5.6), + MySql(5.7), + Vitess("planetscale.js"), + Vitess("planetscale.js.wasm") + ) +)] +mod prisma_24072 { + fn schema() -> String { + let schema = indoc! { + r#"model Parent { + #id(id, Int, @id) + child Child? + } + + model Child { + #id(id, Int, @id) + parent_id Int? @default(2) @unique + parent Parent? @relation(fields: [parent_id], references: [id], onDelete: NoAction) + }"# + }; + + schema.to_owned() + } + + // Deleting the parent without cascading to the child should fail with an explicitly named constraint violation, + // without any "(not available)" names. + #[connector_test] + async fn test_24072(runner: Runner) -> TestResult<()> { + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneParent(data: { id: 1, child: { create: { id: 1 }}}) { id }}"#), + @r###"{"data":{"createOneParent":{"id":1}}}"### + ); + + assert_connector_error!( + &runner, + "mutation { deleteOneParent(where: { id: 1 }) { id }}", + 2003, + CockroachDb(_) | Postgres(_) | SqlServer(_) | Vitess(_) => "Foreign key constraint violated: `Child_parent_id_fkey (index)`", + MySql(_) => "Foreign key constraint violated: `parent_id`", + Sqlite(_) => "Foreign key constraint violated: `foreign key`", + _ => "Foreign key constraint violated" + ); + + Ok(()) + } +} diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_6173.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_6173.rs index 3aa1e9b0e2d2..e5644fc42d42 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_6173.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_6173.rs @@ -19,13 +19,13 @@ mod query_raw { let res = run_query_json!( &runner, r#" - mutation { - queryRaw( - query: "BEGIN NOT ATOMIC\n INSERT INTO Test VALUES(FLOOR(RAND()*1000));\n SELECT * FROM Test;\n END", - parameters: "[]" - ) - } - "# + mutation { + queryRaw( + query: "BEGIN NOT ATOMIC\n INSERT INTO Test VALUES(FLOOR(RAND()*1000));\n SELECT * FROM Test;\n END", + parameters: "[]" + ) + } + "# ); // fmt_execute_raw cannot run this query, doing it directly instead insta::assert_json_snapshot!(res, @@ -53,4 +53,34 @@ mod query_raw { Ok(()) } + + #[connector_test(only(MySQL("mariadb")))] + async fn mysql_call_2(runner: Runner) -> TestResult<()> { + let res = run_query_json!( + &runner, + r#" + mutation { + queryRaw( + query: "BEGIN NOT ATOMIC\n INSERT INTO Test VALUES(FLOOR(RAND()*1000));\n SELECT * FROM Test WHERE 1=0;\n END", + parameters: "[]" + ) + } + "# + ); + + insta::assert_json_snapshot!(res, + @r###" + { + "data": { + "queryRaw": { + "columns": [], + "types": [], + "rows": [] + } + } + } + "###); + + Ok(()) + } } diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/group_by_having.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/group_by_having.rs index 545c44cfe41c..15d11967178e 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/group_by_having.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/group_by_having.rs @@ -52,13 +52,7 @@ mod aggr_group_by_having { Ok(()) } - #[connector_test(exclude(Sqlite("cfd1")))] - // On D1, this fails with: - // - // ```diff - // - {"data":{"groupByTestModel":[{"string":"group1","_count":{"int":2}}]}} - // + {"data":{"groupByTestModel":[]}} - // ``` + #[connector_test] async fn having_count_scalar_filter(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, int: 1, string: "group1" }"#).await?; create_row(&runner, r#"{ id: 2, int: 2, string: "group1" }"#).await?; @@ -133,13 +127,7 @@ mod aggr_group_by_having { Ok(()) } - #[connector_test(exclude(Sqlite("cfd1")))] - // On D1, this fails with: - // - // ```diff - // - {"data":{"groupByTestModel":[{"string":"group1","_sum":{"float":16.0,"int":16}}]}} - // + {"data":{"groupByTestModel":[]}} - // ``` + #[connector_test] async fn having_sum_scalar_filter(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, float: 10, int: 10, string: "group1" }"#).await?; create_row(&runner, r#"{ id: 2, float: 6, int: 6, string: "group1" }"#).await?; @@ -208,13 +196,7 @@ mod aggr_group_by_having { Ok(()) } - #[connector_test(exclude(Sqlite("cfd1")))] - // On D1, this fails with: - // - // ```diff - // - {"data":{"groupByTestModel":[{"string":"group1","_min":{"float":0.0,"int":0}},{"string":"group2","_min":{"float":0.0,"int":0}}]}} - // + {"data":{"groupByTestModel":[]}} - // ``` + #[connector_test] async fn having_min_scalar_filter(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, float: 10, int: 10, string: "group1" }"#).await?; create_row(&runner, r#"{ id: 2, float: 0, int: 0, string: "group1" }"#).await?; @@ -282,13 +264,7 @@ mod aggr_group_by_having { Ok(()) } - #[connector_test(exclude(Sqlite("cfd1")))] - // On D1, this fails with: - // - // ```diff - // - {"data":{"groupByTestModel":[{"string":"group1","_max":{"float":10.0,"int":10}},{"string":"group2","_max":{"float":10.0,"int":10}}]}} - // + {"data":{"groupByTestModel":[]}} - // ``` + #[connector_test] async fn having_max_scalar_filter(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, float: 10, int: 10, string: "group1" }"#).await?; create_row(&runner, r#"{ id: 2, float: 0, int: 0, string: "group1" }"#).await?; @@ -356,13 +332,7 @@ mod aggr_group_by_having { Ok(()) } - #[connector_test(exclude(Sqlite("cfd1")))] - // On D1, this fails with: - // - // ```diff - // - {"data":{"groupByTestModel":[{"string":"group1","_count":{"string":2}}]}} - // + {"data":{"groupByTestModel":[]}} - // ``` + #[connector_test] async fn having_count_non_numerical_field(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, float: 10, int: 10, string: "group1" }"#).await?; create_row(&runner, r#"{ id: 2, float: 0, int: 0, string: "group1" }"#).await?; @@ -380,16 +350,7 @@ mod aggr_group_by_having { Ok(()) } - #[connector_test(exclude(Sqlite("cfd1")))] - // On D1, this panics with: - // - // ``` - // assertion `left == right` failed: Query result: {"data":{"groupByTestModel":[]}} is not part of the expected results: ["{\"data\":{\"groupByTestModel\":[{\"string\":\"group1\"},{\"string\":\"group2\"}]}}", "{\"data\":{\"groupByTestModel\":[{\"string\":\"group2\"},{\"string\":\"group1\"}]}}"] for connector SQLite (cfd1) - // left: false - // right: true - // note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace - // FAILED - // ``` + #[connector_test] async fn having_without_aggr_sel(runner: Runner) -> TestResult<()> { create_row(&runner, r#"{ id: 1, float: 10, int: 10, string: "group1" }"#).await?; create_row(&runner, r#"{ id: 2, float: 0, int: 0, string: "group1" }"#).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/many_count_relation.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/many_count_relation.rs index 94fc36388af7..b757eb4d13c8 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/many_count_relation.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/many_count_relation.rs @@ -900,7 +900,7 @@ mod many_count_rel { // Nullable counts should be COALESCE'd to 0. insta::assert_snapshot!( run_query!(&runner, r#"{ - findManyPost { + findManyPost(orderBy: {id: "asc"}) { _count { comments } } } diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batching/transactional_batch.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batching/transactional_batch.rs index 6ebe42d0b089..cf0a769d3543 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batching/transactional_batch.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batching/transactional_batch.rs @@ -22,6 +22,12 @@ mod transactional { model ModelC { #id(id, Int, @id) } + + model User { + #id(id, Int, @id) + email String @unique + name String @unique + } "# }; @@ -44,6 +50,29 @@ mod transactional { Ok(()) } + #[connector_test()] + async fn two_query_for_batch(runner: Runner) -> TestResult<()> { + let queries = vec![ + r#"mutation { createOneUser(data: { id: 1, email: "test@test.com", name: "page" }) { id } }"#.to_string(), + ]; + + runner.batch(queries, true, None).await?; + + let queries2 = vec![ + r#"query { findUniqueUser(where: { email: "test@test.com" }) { id } }"#.to_string(), + r#"query { findUniqueUser(where: { name: "test" }) { id } }"#.to_string(), + ]; + + let batch_results = runner.batch(queries2, true, None).await?; + + insta::assert_snapshot!( + batch_results.to_string(), + @r###"{"batchResult":[{"data":{"findUniqueUser":{"id":1}}},{"data":{"findUniqueUser":null}}]}"### + ); + + Ok(()) + } + #[connector_test(exclude(Sqlite("cfd1")))] // On D1, this fails with: // @@ -116,9 +145,7 @@ mod transactional { Ok(()) } - // On PlanetScale, this fails with: - // "Error in connector: Error querying the database: Server error: `ERROR 25001 (1568): Transaction characteristics can't be changed while a transaction is in progress'"" - #[connector_test(exclude(MongoDb, Vitess("planetscale.js", "planetscale.js.wasm")))] + #[connector_test(exclude(MongoDb))] async fn valid_isolation_level(runner: Runner) -> TestResult<()> { let queries = vec![r#"mutation { createOneModelB(data: { id: 1 }) { id }}"#.to_string()]; @@ -176,7 +203,7 @@ mod transactional { let batch_results = runner.batch(queries, true, None).await?; insta::assert_snapshot!( batch_results.to_string(), - @r###"{"batchResult":[{"data":{"createOneModelB":{"id":1}}},{"data":{"executeRaw":1}},{"data":{"queryRaw":{"columns":["id"],"types":[],"rows":[]}}}]}"### + @r###"{"batchResult":[{"data":{"createOneModelB":{"id":1}}},{"data":{"executeRaw":1}},{"data":{"queryRaw":{"columns":["id"],"types":["int"],"rows":[]}}}]}"### ); Ok(()) diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/data_types/json.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/data_types/json.rs index 362dfe5c3019..93f4abccf5db 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/data_types/json.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/data_types/json.rs @@ -216,6 +216,53 @@ mod json { Ok(()) } + fn schema_json_list() -> String { + let schema = indoc! { + r#"model TestModel { + #id(id, Int, @id) + + child Child? + } + + model Child { + #id(id, Int, @id) + json_list Json[] + + testId Int? @unique + test TestModel? @relation(fields: [testId], references: [id]) + }"# + }; + + schema.to_owned() + } + + #[connector_test( + schema(schema_json_list), + capabilities(Json, ScalarLists), + exclude(Mysql(5.6), CockroachDb) + )] + async fn json_list(runner: Runner) -> TestResult<()> { + create_row( + &runner, + r#"{ id: 1, child: { create: { id: 1, json_list: ["1", "2"] } } }"#, + ) + .await?; + create_row(&runner, r#"{ id: 2, child: { create: { id: 2, json_list: ["{}"] } } }"#).await?; + create_row( + &runner, + r#"{ id: 3, child: { create: { id: 3, json_list: ["\"hello\"", "\"world\""] } } }"#, + ) + .await?; + create_row(&runner, r#"{ id: 4, child: { create: { id: 4 } } }"#).await?; + + insta::assert_snapshot!( + run_query!(&runner, r#"{ findManyTestModel { child { json_list } } }"#), + @r###"{"data":{"findManyTestModel":[{"child":{"json_list":["1","2"]}},{"child":{"json_list":["{}"]}},{"child":{"json_list":["\"hello\"","\"world\""]}},{"child":{"json_list":[]}}]}}"### + ); + + Ok(()) + } + async fn create_test_data(runner: &Runner) -> TestResult<()> { create_row(runner, r#"{ id: 1, json: "{}" }"#).await?; create_row(runner, r#"{ id: 2, json: "{\"a\":\"b\"}" }"#).await?; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/many_relation.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/many_relation.rs index 2f50edbe2628..30e390785275 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/many_relation.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/many_relation.rs @@ -514,6 +514,316 @@ mod many_relation { Ok(()) } + fn schema_25103() -> String { + let schema = indoc! { + r#"model Contact { + #id(id, String, @id) + identities Identity[] + } + + model Identity { + #id(id, String, @id) + contactId String + contact Contact @relation(fields: [contactId], references: [id]) + subscriptions Subscription[] + } + + model Subscription { + #id(id, String, @id) + identityId String + audienceId String + optedOutAt DateTime? + audience Audience @relation(fields: [audienceId], references: [id]) + identity Identity @relation(fields: [identityId], references: [id]) + } + + model Audience { + #id(id, String, @id) + deletedAt DateTime? + subscriptions Subscription[] + }"# + }; + + schema.to_owned() + } + + // Regression test for https://github.com/prisma/prisma/issues/25103 + // SQL Server excluded because the m2m fragment does not support onUpdate/onDelete args which are needed. + #[connector_test(schema(schema_25103), exclude(SqlServer))] + async fn prisma_25103(runner: Runner) -> TestResult<()> { + // Create some sample audiences + run_query!( + &runner, + r#"mutation { + createOneAudience(data: { + id: "audience1", + deletedAt: null + }) { + id + }}"# + ); + run_query!( + &runner, + r#"mutation { + createOneAudience(data: { + id: "audience2", + deletedAt: null + }) { + id + }}"# + ); + // Create a contact with identities and subscriptions + insta::assert_snapshot!( + run_query!( + &runner, + r#"mutation { + createOneContact(data: { + id: "contact1", + identities: { + create: [ + { + id: "identity1", + subscriptions: { + create: [ + { + id: "subscription1", + audienceId: "audience1", + optedOutAt: null + }, + { + id: "subscription2", + audienceId: "audience2", + optedOutAt: null + } + ] + } + } + ] + } + }) { + id, + identities (orderBy: { id: asc }) { + id, + subscriptions (orderBy: { id: asc }) { + id, + audienceId + } + } + }}"# + ), + @r###"{"data":{"createOneContact":{"id":"contact1","identities":[{"id":"identity1","subscriptions":[{"id":"subscription1","audienceId":"audience1"},{"id":"subscription2","audienceId":"audience2"}]}]}}}"### + ); + // Find contacts that include identities whose subscriptions have `optedOutAt = null` and include audiences with `deletedAt = null`` + insta::assert_snapshot!( + run_query!( + &runner, + r#"query { + findManyContact(orderBy: { id: asc }) { + id, + identities(orderBy: { id: asc }) { + id, + subscriptions(orderBy: { id: asc }, where: { optedOutAt: null, audience: { deletedAt: null } }) { + id, + identityId, + audience { + id, + deletedAt + } + } + } + } + }"# + ), + @r###"{"data":{"findManyContact":[{"id":"contact1","identities":[{"id":"identity1","subscriptions":[{"id":"subscription1","identityId":"identity1","audience":{"id":"audience1","deletedAt":null}},{"id":"subscription2","identityId":"identity1","audience":{"id":"audience2","deletedAt":null}}]}]}]}}"### + ); + + Ok(()) + } + + fn schema_25104() -> String { + let schema = indoc! { + r#" + model A { + #id(id, String, @id) + bs B[] + } + + model B { + #id(id, String, @id) + a A @relation(fields: [aId], references: [id]) + aId String + + cs C[] + } + + model C { + #id(id, String, @id) + name String + bs B[] + } + "# + }; + + schema.to_owned() + } + + #[connector_test(schema(schema_25104), exclude(MongoDb))] + async fn prisma_25104(runner: Runner) -> TestResult<()> { + insta::assert_snapshot!( + run_query!( + &runner, + r#" + query { + findManyA { + bs(where: { + cs: { + every: { + name: { equals: "a" } + } + } + }) { + id + } + } + } + "# + ), + @r###"{"data":{"findManyA":[]}}"### + ); + + Ok(()) + } + + fn schema_23742() -> String { + let schema = indoc! { + r#"model Top { + #id(id, Int, @id) + + middleId Int? + middle Middle? @relation(fields: [middleId], references: [id]) + + #m2m(bottoms, Bottom[], id, Int) + } + + model Middle { + #id(id, Int, @id) + bottoms Bottom[] + + tops Top[] + } + + model Bottom { + #id(id, Int, @id) + + middleId Int? + middle Middle? @relation(fields: [middleId], references: [id]) + + #m2m(tops, Top[], id, Int) + }"# + }; + + schema.to_owned() + } + + // Regression test for https://github.com/prisma/prisma/issues/23742 + // SQL Server excluded because the m2m fragment does not support onUpdate/onDelete args which are needed. + #[connector_test(schema(schema_23742), exclude(SqlServer))] + async fn prisma_23742(runner: Runner) -> TestResult<()> { + run_query!( + &runner, + r#"mutation { + createOneTop(data: { + id: 1, + middle: { create: { id: 1, bottoms: { create: { id: 1, tops: { create: { id: 2 } } } } } } + }) { + id + }}"# + ); + + insta::assert_snapshot!( + run_query!(&runner, r#"{ + findUniqueTop(where: { id: 1 }) { + middle { + bottoms( + where: { tops: { some: { id: 2 } } } + ) { + id + } + } + } + } + "#), + @r###"{"data":{"findUniqueTop":{"middle":{"bottoms":[{"id":1}]}}}}"### + ); + + Ok(()) + } + + fn schema_nested_some_filter_m2m_different_pk() -> String { + let schema = indoc! { + r#" + model Top { + #id(topId, Int, @id) + + relatedMiddleId Int? + middle Middle? @relation(fields: [relatedMiddleId], references: [middleId]) + + #m2m(bottoms, Bottom[], bottomId, Int) + } + + model Middle { + #id(middleId, Int, @id) + + bottoms Bottom[] + tops Top[] + } + + model Bottom { + #id(bottomId, Int, @id) + + relatedMiddleId Int? + middle Middle? @relation(fields: [relatedMiddleId], references: [middleId]) + + #m2m(tops, Top[], topId, Int) + } + "# + }; + + schema.to_owned() + } + + #[connector_test(schema(schema_nested_some_filter_m2m_different_pk), exclude(SqlServer))] + async fn nested_some_filter_m2m_different_pk(runner: Runner) -> TestResult<()> { + run_query!( + &runner, + r#"mutation { + createOneTop(data: { + topId: 1, + middle: { create: { middleId: 1, bottoms: { create: { bottomId: 1, tops: { create: { topId: 2 } } } } } } + }) { + topId + }}"# + ); + + insta::assert_snapshot!( + run_query!(&runner, r#"{ + findUniqueTop(where: { topId: 1 }) { + middle { + bottoms( + where: { tops: { some: { topId: 2 } } } + ) { + bottomId + } + } + } + } + "#), + @r###"{"data":{"findUniqueTop":{"middle":{"bottoms":[{"bottomId":1}]}}}}"### + ); + + Ok(()) + } + async fn test_data(runner: &Runner) -> TestResult<()> { runner .query(indoc! { r#" diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/scalar_list.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/scalar_list.rs index 20b1e853b64d..8cd4c59af670 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/scalar_list.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/scalar_list.rs @@ -254,12 +254,12 @@ mod scalar_list { "types": [ "int", "string-array", - "unknown-array", - "unknown-array", - "unknown-array", - "unknown-array", - "unknown-array", - "unknown-array" + "int-array", + "bigint-array", + "double-array", + "bytes-array", + "bool-array", + "datetime-array" ], "rows": [ [ @@ -332,13 +332,13 @@ mod scalar_list { ], "types": [ "int", - "unknown-array", - "unknown-array", - "unknown-array", - "unknown-array", - "unknown-array", - "unknown-array", - "unknown-array" + "string-array", + "int-array", + "bigint-array", + "double-array", + "bytes-array", + "bool-array", + "datetime-array" ], "rows": [ [ diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/typed_output.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/typed_output.rs index fa7b8d64692d..980392054bfd 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/typed_output.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/raw/sql/typed_output.rs @@ -440,6 +440,40 @@ mod typed_output { Ok(()) } + #[connector_test(schema(generic), only(Mysql))] + async fn unknown_type_mysql(runner: Runner) -> TestResult<()> { + insta::assert_snapshot!( + run_query!(&runner, fmt_query_raw(r#"SELECT POINT(1, 1);"#, vec![])), + @r###"{"data":{"queryRaw":{"columns":["POINT(1, 1)"],"types":["bytes"],"rows":[["AAAAAAEBAAAAAAAAAAAA8D8AAAAAAADwPw=="]]}}}"### + ); + + Ok(()) + } + + #[connector_test(schema(generic), only(Postgres))] + async fn unknown_type_pg(runner: Runner) -> TestResult<()> { + assert_error!( + &runner, + fmt_query_raw(r#"SELECT POINT(1, 1);"#, vec![]), + 2010, + "Failed to deserialize column of type 'point'" + ); + + Ok(()) + } + + #[connector_test(schema(generic), only(SqlServer))] + async fn unknown_type_mssql(runner: Runner) -> TestResult<()> { + assert_error!( + &runner, + fmt_query_raw(r#"SELECT geometry::Parse('POINT(3 4 7 2.5)');"#, vec![]), + 2010, + "not yet implemented for Udt" + ); + + Ok(()) + } + async fn create_row(runner: &Runner, data: &str) -> TestResult<()> { runner .query(format!("mutation {{ createOneTestModel(data: {data}) {{ id }} }}")) diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/ids/uuid_create_graphql.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/ids/uuid_create_graphql.rs index 5efd17555673..afacaa5292b5 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/ids/uuid_create_graphql.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/ids/uuid_create_graphql.rs @@ -68,4 +68,66 @@ mod uuid_create_graphql { Ok(()) } + + fn schema_3() -> String { + let schema = indoc! { + r#"model Todo { + #id(id, String, @id, @default(uuid(7))) + title String + }"# + }; + + schema.to_owned() + } + + // "Creating an item with an id field of model UUIDv7 and retrieving it" should "work" + #[connector_test(schema(schema_3))] + async fn create_uuid_v7_and_retrieve_it_should_work(runner: Runner) -> TestResult<()> { + let res = run_query_json!( + &runner, + r#"mutation { + createOneTodo(data: { title: "the title" }){ + id + } + }"# + ); + + let uuid = match &res["data"]["createOneTodo"]["id"] { + serde_json::Value::String(str) => str, + _ => unreachable!(), + }; + + // Validate that this is a valid UUIDv7 value + { + let uuid = uuid::Uuid::parse_str(uuid.as_str()).expect("Expected valid UUID but couldn't parse it."); + assert_eq!( + uuid.get_version().expect("Expected UUIDv7 but got something else."), + uuid::Version::SortRand + ); + } + + // Test findMany + let res = run_query_json!( + &runner, + r#"query { findManyTodo(where: { title: "the title" }) { id }}"# + ); + if let serde_json::Value::String(str) = &res["data"]["findManyTodo"][0]["id"] { + assert_eq!(str, uuid); + } else { + panic!("Expected UUID but got something else."); + } + + // Test findUnique + let res = run_query_json!( + &runner, + format!(r#"query {{ findUniqueTodo(where: {{ id: "{}" }}) {{ id }} }}"#, uuid) + ); + if let serde_json::Value::String(str) = &res["data"]["findUniqueTodo"]["id"] { + assert_eq!(str, uuid); + } else { + panic!("Expected UUID but got something else."); + } + + Ok(()) + } } diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/nested_mutations/not_using_schema_base/nested_create_many.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/nested_mutations/not_using_schema_base/nested_create_many.rs index 821b99f9fce8..30a671f61def 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/nested_mutations/not_using_schema_base/nested_create_many.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/nested_mutations/not_using_schema_base/nested_create_many.rs @@ -140,7 +140,7 @@ mod nested_create_many { // Each DB allows a certain amount of params per single query, and a certain number of rows. // We create 1000 nested records. // "Nested createMany" should "allow creating a large number of records (horizontal partitioning check)" - #[connector_test(exclude(Sqlite("cfd1")))] + #[connector_test] async fn allow_create_large_number_records(runner: Runner) -> TestResult<()> { let records: Vec<_> = (1..=1000).map(|i| format!(r#"{{ id: {i}, str1: "{i}" }}"#)).collect(); diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many.rs index f1f80eb93f23..61996fd993d4 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many.rs @@ -355,7 +355,7 @@ mod create_many { // LibSQL & co are ignored because they don't support metrics #[connector_test(schema(schema_7), only(Sqlite("3")))] async fn create_many_by_shape_counter_1(runner: Runner) -> TestResult<()> { - use query_engine_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; + use prisma_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; // Generated queries: // INSERT INTO `main`.`Test` (`opt`, `req`) VALUES (null, ?), (?, ?) params=[1,2,2] @@ -397,7 +397,7 @@ mod create_many { // LibSQL & co are ignored because they don't support metrics #[connector_test(schema(schema_7), only(Sqlite("3")))] async fn create_many_by_shape_counter_2(runner: Runner) -> TestResult<()> { - use query_engine_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; + use prisma_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; // Generated queries: // INSERT INTO `main`.`Test` ( `opt_default_static`, `req_default_static`, `opt`, `req` ) VALUES (?, ?, null, ?), (?, ?, null, ?), (?, ?, null, ?) params=[1,1,1,2,1,2,1,3,3] @@ -436,7 +436,7 @@ mod create_many { // LibSQL & co are ignored because they don't support metrics #[connector_test(schema(schema_7), only(Sqlite("3")))] async fn create_many_by_shape_counter_3(runner: Runner) -> TestResult<()> { - use query_engine_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; + use prisma_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; // Generated queries: // INSERT INTO `main`.`Test` ( `req_default_static`, `req`, `opt_default`, `opt_default_static` ) VALUES (?, ?, ?, ?) params=[1,6,3,1] diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many_and_return.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many_and_return.rs index a55efb4e0cc9..9b92d99404b3 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many_and_return.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many_and_return.rs @@ -650,7 +650,7 @@ mod create_many { // LibSQL & co are ignored because they don't support metrics #[connector_test(schema(schema_7), only(Sqlite("3")))] async fn create_many_by_shape_counter_1(runner: Runner) -> TestResult<()> { - use query_engine_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; + use prisma_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; // Generated queries: // INSERT INTO `main`.`Test` (`opt`, `req`) VALUES (null, ?), (?, ?) params=[1,2,2] @@ -692,7 +692,7 @@ mod create_many { // LibSQL & co are ignored because they don't support metrics #[connector_test(schema(schema_7), only(Sqlite("3")))] async fn create_many_by_shape_counter_2(runner: Runner) -> TestResult<()> { - use query_engine_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; + use prisma_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; // Generated queries: // INSERT INTO `main`.`Test` ( `opt_default_static`, `req_default_static`, `opt`, `req` ) VALUES (?, ?, null, ?), (?, ?, null, ?), (?, ?, null, ?) params=[1,1,1,2,1,2,1,3,3] @@ -731,7 +731,7 @@ mod create_many { // LibSQL & co are ignored because they don't support metrics #[connector_test(schema(schema_7), only(Sqlite("3")))] async fn create_many_by_shape_counter_3(runner: Runner) -> TestResult<()> { - use query_engine_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; + use prisma_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; // Generated queries: // INSERT INTO `main`.`Test` ( `req_default_static`, `req`, `opt_default`, `opt_default_static` ) VALUES (?, ?, ?, ?) params=[1,6,3,1] diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/delete_many_relations.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/delete_many_relations.rs index ec9508347a6e..6bd3eab692ea 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/delete_many_relations.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/delete_many_relations.rs @@ -16,7 +16,7 @@ mod delete_many_rels { // On D1, this fails with: // // ```diff - // - {"data":{"deleteManyParent":{"count":1}}} + // - {"data":{"deleteManyParent":{"count":2}}} // + {"data":{"deleteManyParent":{"count":3}}} // ``` async fn p1_c1(runner: &Runner, _t: &DatamodelWithParams) -> TestResult<()> { diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml b/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml index cd8abc07331c..b2016e602b9c 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml +++ b/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml @@ -15,14 +15,15 @@ sql-query-connector = { path = "../../connectors/sql-query-connector" } query-engine = { path = "../../query-engine" } psl.workspace = true user-facing-errors = { path = "../../../libs/user-facing-errors" } +telemetry = { path = "../../../libs/telemetry" } thiserror = "1.0" async-trait.workspace = true nom = "7.1" itertools.workspace = true -regex = "1" +regex.workspace = true serde.workspace = true tracing.workspace = true -tracing-futures = "0.2" +tracing-futures.workspace = true tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } tracing-error = "0.2" colored = "2" @@ -30,7 +31,7 @@ indoc.workspace = true enumflags2.workspace = true hyper = { version = "0.14", features = ["full"] } indexmap.workspace = true -query-engine-metrics = { path = "../../metrics" } +prisma-metrics.path = "../../../libs/metrics" quaint.workspace = true jsonrpc-core = "17" insta = "1.7.1" diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/lib.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/lib.rs index e94c14c6c574..87e9241fb469 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/lib.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/lib.rs @@ -22,10 +22,9 @@ pub use templating::*; use colored::Colorize; use once_cell::sync::Lazy; +use prisma_metrics::{MetricRecorder, MetricRegistry, WithMetricsInstrumentation}; use psl::datamodel_connector::ConnectorCapabilities; -use query_engine_metrics::MetricRegistry; use std::future::Future; -use std::sync::Once; use tokio::runtime::Builder; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tracing_futures::WithSubscriber; @@ -61,14 +60,10 @@ fn run_with_tokio>(fut: F) -> O { .block_on(fut) } -static METRIC_RECORDER: Once = Once::new(); - -pub fn setup_metrics() -> MetricRegistry { +pub fn setup_metrics() -> (MetricRegistry, MetricRecorder) { let metrics = MetricRegistry::new(); - METRIC_RECORDER.call_once(|| { - query_engine_metrics::setup(); - }); - metrics + let recorder = MetricRecorder::new(metrics.clone()).with_initialized_prisma_metrics(); + (metrics, recorder) } /// Taken from Reddit. Enables taking an async function pointer which takes references as param @@ -161,8 +156,7 @@ fn run_relation_link_test_impl( let datamodel = render_test_datamodel(&test_db_name, template, &[], None, Default::default(), Default::default(), None); let (connector_tag, version) = CONFIG.test_connector().unwrap(); - let metrics = setup_metrics(); - let metrics_for_subscriber = metrics.clone(); + let (metrics, recorder) = setup_metrics(); let (log_capture, log_tx) = TestLogCapture::new(); run_with_tokio( @@ -176,9 +170,8 @@ fn run_relation_link_test_impl( test_fn(&runner, &dm).with_subscriber(test_tracing_subscriber( ENV_LOG_LEVEL.to_string(), - metrics_for_subscriber, log_tx, - )) + )).with_recorder(recorder) .await.unwrap(); teardown_project(&datamodel, Default::default(), runner.schema_id()) @@ -275,8 +268,7 @@ fn run_connector_test_impl( None, ); let (connector_tag, version) = CONFIG.test_connector().unwrap(); - let metrics = crate::setup_metrics(); - let metrics_for_subscriber = metrics.clone(); + let (metrics, recorder) = crate::setup_metrics(); let (log_capture, log_tx) = TestLogCapture::new(); @@ -297,11 +289,8 @@ fn run_connector_test_impl( let schema_id = runner.schema_id(); if let Err(err) = test_fn(runner) - .with_subscriber(test_tracing_subscriber( - ENV_LOG_LEVEL.to_string(), - metrics_for_subscriber, - log_tx, - )) + .with_subscriber(test_tracing_subscriber(ENV_LOG_LEVEL.to_string(), log_tx)) + .with_recorder(recorder) .await { panic!("💥 Test failed due to an error: {err:?}"); diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/logging.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/logging.rs index 5520075e6d30..d95867c39b89 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/logging.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/logging.rs @@ -1,12 +1,11 @@ -use query_core::telemetry::helpers as telemetry_helpers; -use query_engine_metrics::MetricRegistry; +use telemetry::helpers as telemetry_helpers; use tracing::Subscriber; use tracing_error::ErrorLayer; use tracing_subscriber::{prelude::*, Layer}; use crate::LogEmit; -pub fn test_tracing_subscriber(log_config: String, metrics: MetricRegistry, log_tx: LogEmit) -> impl Subscriber { +pub fn test_tracing_subscriber(log_config: String, log_tx: LogEmit) -> impl Subscriber { let filter = telemetry_helpers::env_filter(true, telemetry_helpers::QueryEngineLogLevel::Override(log_config)); let fmt_layer = tracing_subscriber::fmt::layer() @@ -15,7 +14,6 @@ pub fn test_tracing_subscriber(log_config: String, metrics: MetricRegistry, log_ tracing_subscriber::registry() .with(fmt_layer.boxed()) - .with(metrics.boxed()) .with(ErrorLayer::default()) } diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs index e5808ace7fcc..de8ee9bd33be 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs @@ -8,13 +8,13 @@ use crate::{ ENGINE_PROTOCOL, }; use colored::Colorize; +use prisma_metrics::MetricRegistry; use query_core::{ protocol::EngineProtocol, relation_load_strategy, schema::{self, QuerySchemaRef}, QueryExecutor, TransactionOptions, TxId, }; -use query_engine_metrics::MetricRegistry; use request_handlers::{ BatchTransactionOption, ConnectorKind, GraphqlBody, JsonBatchQuery, JsonBody, JsonSingleQuery, MultiQuery, RequestBody, RequestHandler, @@ -306,7 +306,15 @@ impl Runner { }) } - pub async fn query(&self, query: T) -> TestResult + pub async fn query(&self, query: impl Into) -> TestResult { + self.query_with_maybe_tx_id(self.current_tx_id.as_ref(), query).await + } + + pub async fn query_in_tx(&self, tx_id: &TxId, query: impl Into) -> TestResult { + self.query_with_maybe_tx_id(Some(tx_id), query).await + } + + async fn query_with_maybe_tx_id(&self, tx_id: Option<&TxId>, query: T) -> TestResult where T: Into, { @@ -316,7 +324,7 @@ impl Runner { RunnerExecutor::Builtin(e) => e, RunnerExecutor::External(external) => match JsonRequest::from_graphql(&query, self.query_schema()) { Ok(json_query) => { - let mut response = external.query(json_query, self.current_tx_id.as_ref()).await?; + let mut response = external.query(json_query, tx_id).await?; response.detag(); return Ok(response); } @@ -353,7 +361,7 @@ impl Runner { } }; - let response = handler.handle(request_body, self.current_tx_id.clone(), None).await; + let response = handler.handle(request_body, tx_id.cloned(), None).await; let result: QueryResult = match self.protocol { EngineProtocol::Json => JsonResponse::from_graphql(response).into(), diff --git a/query-engine/connectors/mongodb-query-connector/Cargo.toml b/query-engine/connectors/mongodb-query-connector/Cargo.toml index 8c801f803550..0c4cefcce845 100644 --- a/query-engine/connectors/mongodb-query-connector/Cargo.toml +++ b/query-engine/connectors/mongodb-query-connector/Cargo.toml @@ -7,23 +7,22 @@ version = "0.1.0" anyhow = "1.0" async-trait.workspace = true bigdecimal = "0.3" -# bson = {version = "1.1.0", features = ["decimal128"]} -futures = "0.3" +futures.workspace = true itertools.workspace = true -mongodb = "2.8.0" -bson = { version = "2.4.0", features = ["chrono-0_4", "uuid-1"] } +mongodb.workspace = true +bson.workspace = true rand.workspace = true -regex = "1" +regex.workspace = true serde_json.workspace = true thiserror = "1.0" tokio.workspace = true tracing.workspace = true -tracing-futures = "0.2" +tracing-futures.workspace = true uuid.workspace = true indexmap.workspace = true -query-engine-metrics = { path = "../../metrics" } +prisma-metrics.path = "../../../libs/metrics" cuid = { git = "https://github.com/prisma/cuid-rust", branch = "wasm32-support" } -derive_more = "0.99.17" +derive_more.workspace = true [dependencies.query-structure] path = "../../query-structure" @@ -38,6 +37,9 @@ path = "../query-connector" [dependencies.prisma-value] path = "../../../libs/prisma-value" +[dependencies.telemetry] +path = "../../../libs/telemetry" + [dependencies.chrono] features = ["serde"] version = "0.4" diff --git a/query-engine/connectors/mongodb-query-connector/src/cursor.rs b/query-engine/connectors/mongodb-query-connector/src/cursor.rs index 1aaa22ef6b19..012200067f99 100644 --- a/query-engine/connectors/mongodb-query-connector/src/cursor.rs +++ b/query-engine/connectors/mongodb-query-connector/src/cursor.rs @@ -1,5 +1,5 @@ use crate::{orderby::OrderByData, IntoBson}; -use mongodb::bson::{doc, Document}; +use bson::{doc, Document}; use query_structure::{OrderBy, SelectionResult, SortOrder}; #[derive(Debug, Clone)] diff --git a/query-engine/connectors/mongodb-query-connector/src/error.rs b/query-engine/connectors/mongodb-query-connector/src/error.rs index a93350168387..8fcfaaaf041b 100644 --- a/query-engine/connectors/mongodb-query-connector/src/error.rs +++ b/query-engine/connectors/mongodb-query-connector/src/error.rs @@ -181,6 +181,33 @@ fn driver_error_to_connector_error(err: DriverError) -> ConnectorError { }, mongodb::error::ErrorKind::BulkWrite(err) => { + let errors = err + .write_errors + .iter() + .map(|(index, err)| match err.code { + 11000 => unique_violation_error(err.message.as_str()), + _ => ErrorKind::RawDatabaseError { + code: err.code.to_string(), + message: format!("Bulk write error on write index '{}': {}", index, err.message), + }, + }) + .chain(err.write_concern_errors.iter().map(|err| match err.code { + 11000 => unique_violation_error(err.message.as_str()), + _ => ErrorKind::RawDatabaseError { + code: err.code.to_string(), + message: format!("Bulk write concern error: {}", err.message), + }, + })) + .collect_vec(); + + if errors.len() == 1 { + ConnectorError::from_kind(errors.into_iter().next().unwrap()) + } else { + ConnectorError::from_kind(ErrorKind::MultiError(MultiError { errors })) + } + } + + mongodb::error::ErrorKind::InsertMany(err) => { let mut errors = match err.write_errors { Some(ref errors) => errors .iter() diff --git a/query-engine/connectors/mongodb-query-connector/src/filter.rs b/query-engine/connectors/mongodb-query-connector/src/filter.rs index 479b1c93ffe2..b29fe6f97344 100644 --- a/query-engine/connectors/mongodb-query-connector/src/filter.rs +++ b/query-engine/connectors/mongodb-query-connector/src/filter.rs @@ -1,5 +1,5 @@ use crate::{constants::group_by, error::MongoError, join::JoinStage, query_builder::AggregationType, IntoBson}; -use mongodb::bson::{doc, Bson, Document}; +use bson::{doc, Bson, Document}; use query_structure::*; #[derive(Debug, Clone)] diff --git a/query-engine/connectors/mongodb-query-connector/src/interface/connection.rs b/query-engine/connectors/mongodb-query-connector/src/interface/connection.rs index 9f9232517528..72b0c6a3afb0 100644 --- a/query-engine/connectors/mongodb-query-connector/src/interface/connection.rs +++ b/query-engine/connectors/mongodb-query-connector/src/interface/connection.rs @@ -11,6 +11,7 @@ use connector_interface::{ use mongodb::{ClientSession, Database}; use query_structure::{prelude::*, RelationLoadStrategy, SelectionResult}; use std::collections::HashMap; +use telemetry::helpers::TraceParent; pub struct MongoDbConnection { /// The session to use for operations. @@ -57,7 +58,7 @@ impl WriteOperations for MongoDbConnection { args: WriteArgs, // The field selection on a create is never used on MongoDB as it cannot return more than the ID. _selected_fields: FieldSelection, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(write::create_record(&self.database, &mut self.session, model, args)).await } @@ -67,7 +68,7 @@ impl WriteOperations for MongoDbConnection { model: &Model, args: Vec, skip_duplicates: bool, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(write::create_records( &self.database, @@ -85,7 +86,7 @@ impl WriteOperations for MongoDbConnection { _args: Vec, _skip_duplicates: bool, _selected_fields: FieldSelection, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { unimplemented!() } @@ -95,7 +96,7 @@ impl WriteOperations for MongoDbConnection { model: &Model, record_filter: connector_interface::RecordFilter, args: WriteArgs, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(async move { let result = write::update_records( @@ -119,7 +120,7 @@ impl WriteOperations for MongoDbConnection { record_filter: connector_interface::RecordFilter, args: WriteArgs, selected_fields: Option, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result> { catch(async move { let result = write::update_records( @@ -150,7 +151,7 @@ impl WriteOperations for MongoDbConnection { &mut self, model: &Model, record_filter: connector_interface::RecordFilter, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(write::delete_records( &self.database, @@ -166,7 +167,7 @@ impl WriteOperations for MongoDbConnection { model: &Model, record_filter: connector_interface::RecordFilter, selected_fields: FieldSelection, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(write::delete_record( &self.database, @@ -183,7 +184,7 @@ impl WriteOperations for MongoDbConnection { field: &RelationFieldRef, parent_id: &SelectionResult, child_ids: &[SelectionResult], - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result<()> { catch(write::m2m_connect( &self.database, @@ -200,7 +201,7 @@ impl WriteOperations for MongoDbConnection { field: &RelationFieldRef, parent_id: &SelectionResult, child_ids: &[SelectionResult], - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result<()> { catch(write::m2m_disconnect( &self.database, @@ -235,7 +236,7 @@ impl WriteOperations for MongoDbConnection { async fn native_upsert_record( &mut self, _upsert: connector_interface::NativeUpsert, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { unimplemented!("Native upsert is not currently supported.") } @@ -249,7 +250,7 @@ impl ReadOperations for MongoDbConnection { filter: &query_structure::Filter, selected_fields: &FieldSelection, _relation_load_strategy: RelationLoadStrategy, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result> { catch(read::get_single_record( &self.database, @@ -267,7 +268,7 @@ impl ReadOperations for MongoDbConnection { query_arguments: query_structure::QueryArguments, selected_fields: &FieldSelection, _relation_load_strategy: RelationLoadStrategy, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(read::get_many_records( &self.database, @@ -283,7 +284,7 @@ impl ReadOperations for MongoDbConnection { &mut self, from_field: &RelationFieldRef, from_record_ids: &[SelectionResult], - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result> { catch(read::get_related_m2m_record_ids( &self.database, @@ -301,7 +302,7 @@ impl ReadOperations for MongoDbConnection { selections: Vec, group_by: Vec, having: Option, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result> { catch(aggregate::aggregate( &self.database, diff --git a/query-engine/connectors/mongodb-query-connector/src/interface/mod.rs b/query-engine/connectors/mongodb-query-connector/src/interface/mod.rs index 620d7628182f..126424af3d8d 100644 --- a/query-engine/connectors/mongodb-query-connector/src/interface/mod.rs +++ b/query-engine/connectors/mongodb-query-connector/src/interface/mod.rs @@ -60,7 +60,7 @@ impl Connector for MongoDb { ) -> connector_interface::Result> { let session = self .client - .start_session(None) + .start_session() .await .map_err(|err| MongoError::from(err).into_connector_error())?; diff --git a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs index e0933e7e840e..6045d06b442d 100644 --- a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs +++ b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs @@ -1,16 +1,20 @@ +use std::collections::HashMap; + +use connector_interface::{ConnectionLike, ReadOperations, Transaction, UpdateType, WriteOperations}; +use mongodb::options::{Acknowledgment, ReadConcern, TransactionOptions, WriteConcern}; +use prisma_metrics::{guards::GaugeGuard, PRISMA_CLIENT_QUERIES_ACTIVE}; +use query_structure::{RelationLoadStrategy, SelectionResult}; +use telemetry::helpers::TraceParent; + use super::*; use crate::{ error::MongoError, root_queries::{aggregate, read, write}, }; -use connector_interface::{ConnectionLike, ReadOperations, Transaction, UpdateType, WriteOperations}; -use mongodb::options::{Acknowledgment, ReadConcern, TransactionOptions, WriteConcern}; -use query_engine_metrics::{decrement_gauge, increment_gauge, metrics, PRISMA_CLIENT_QUERIES_ACTIVE}; -use query_structure::{RelationLoadStrategy, SelectionResult}; -use std::collections::HashMap; pub struct MongoDbTransaction<'conn> { connection: &'conn mut MongoDbConnection, + gauge: GaugeGuard, } impl<'conn> ConnectionLike for MongoDbTransaction<'conn> {} @@ -26,20 +30,22 @@ impl<'conn> MongoDbTransaction<'conn> { connection .session - .start_transaction(options) + .start_transaction() + .with_options(options) .await .map_err(|err| MongoError::from(err).into_connector_error())?; - increment_gauge!(PRISMA_CLIENT_QUERIES_ACTIVE, 1.0); - - Ok(Self { connection }) + Ok(Self { + connection, + gauge: GaugeGuard::increment(PRISMA_CLIENT_QUERIES_ACTIVE), + }) } } #[async_trait] impl<'conn> Transaction for MongoDbTransaction<'conn> { async fn commit(&mut self) -> connector_interface::Result<()> { - decrement_gauge!(PRISMA_CLIENT_QUERIES_ACTIVE, 1.0); + self.gauge.decrement(); utils::commit_with_retry(&mut self.connection.session) .await @@ -49,7 +55,7 @@ impl<'conn> Transaction for MongoDbTransaction<'conn> { } async fn rollback(&mut self) -> connector_interface::Result<()> { - decrement_gauge!(PRISMA_CLIENT_QUERIES_ACTIVE, 1.0); + self.gauge.decrement(); self.connection .session @@ -77,7 +83,7 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { args: connector_interface::WriteArgs, // The field selection on a create is never used on MongoDB as it cannot return more than the ID. _selected_fields: FieldSelection, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(write::create_record( &self.connection.database, @@ -93,7 +99,7 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { model: &Model, args: Vec, skip_duplicates: bool, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(write::create_records( &self.connection.database, @@ -111,7 +117,7 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { _args: Vec, _skip_duplicates: bool, _selected_fields: FieldSelection, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { unimplemented!() } @@ -121,7 +127,7 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { model: &Model, record_filter: connector_interface::RecordFilter, args: connector_interface::WriteArgs, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(async move { let result = write::update_records( @@ -144,7 +150,7 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { record_filter: connector_interface::RecordFilter, args: connector_interface::WriteArgs, selected_fields: Option, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result> { catch(async move { let result = write::update_records( @@ -174,7 +180,7 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { &mut self, model: &Model, record_filter: connector_interface::RecordFilter, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(write::delete_records( &self.connection.database, @@ -190,7 +196,7 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { model: &Model, record_filter: connector_interface::RecordFilter, selected_fields: FieldSelection, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(write::delete_record( &self.connection.database, @@ -205,7 +211,7 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { async fn native_upsert_record( &mut self, _upsert: connector_interface::NativeUpsert, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { unimplemented!("Native upsert is not currently supported.") } @@ -215,7 +221,7 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { field: &RelationFieldRef, parent_id: &SelectionResult, child_ids: &[SelectionResult], - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result<()> { catch(write::m2m_connect( &self.connection.database, @@ -232,7 +238,7 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { field: &RelationFieldRef, parent_id: &SelectionResult, child_ids: &[SelectionResult], - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result<()> { catch(write::m2m_disconnect( &self.connection.database, @@ -278,7 +284,7 @@ impl<'conn> ReadOperations for MongoDbTransaction<'conn> { filter: &query_structure::Filter, selected_fields: &FieldSelection, _relation_load_strategy: RelationLoadStrategy, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result> { catch(read::get_single_record( &self.connection.database, @@ -296,7 +302,7 @@ impl<'conn> ReadOperations for MongoDbTransaction<'conn> { query_arguments: query_structure::QueryArguments, selected_fields: &FieldSelection, _relation_load_strategy: RelationLoadStrategy, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(read::get_many_records( &self.connection.database, @@ -312,7 +318,7 @@ impl<'conn> ReadOperations for MongoDbTransaction<'conn> { &mut self, from_field: &RelationFieldRef, from_record_ids: &[SelectionResult], - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result> { catch(read::get_related_m2m_record_ids( &self.connection.database, @@ -330,7 +336,7 @@ impl<'conn> ReadOperations for MongoDbTransaction<'conn> { selections: Vec, group_by: Vec, having: Option, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result> { catch(aggregate::aggregate( &self.connection.database, diff --git a/query-engine/connectors/mongodb-query-connector/src/interface/utils.rs b/query-engine/connectors/mongodb-query-connector/src/interface/utils.rs index 6b089ff700fd..21e35604045c 100644 --- a/query-engine/connectors/mongodb-query-connector/src/interface/utils.rs +++ b/query-engine/connectors/mongodb-query-connector/src/interface/utils.rs @@ -1,7 +1,7 @@ use std::time::{Duration, Instant}; use mongodb::{ - error::{Result, TRANSIENT_TRANSACTION_ERROR, UNKNOWN_TRANSACTION_COMMIT_RESULT}, + error::{CommandError, ErrorKind, Result, TRANSIENT_TRANSACTION_ERROR, UNKNOWN_TRANSACTION_COMMIT_RESULT}, ClientSession, }; @@ -14,9 +14,15 @@ pub async fn commit_with_retry(session: &mut ClientSession) -> Result<()> { let timeout = Instant::now(); while let Err(err) = session.commit_transaction().await { - if (err.contains_label(UNKNOWN_TRANSACTION_COMMIT_RESULT) || err.contains_label(TRANSIENT_TRANSACTION_ERROR)) - && timeout.elapsed() < MAX_TX_TIMEOUT_COMMIT_RETRY_LIMIT - { + // For some reason, MongoDB adds `TRANSIENT_TRANSACTION_ERROR` to errors about aborted + // transactions. Since transaction will not become less aborted in the future, we handle + // this case separately. + let is_aborted = matches!(err.kind.as_ref(), ErrorKind::Command(CommandError { code: 251, .. })); + let is_in_unknown_state = err.contains_label(UNKNOWN_TRANSACTION_COMMIT_RESULT); + let is_transient = err.contains_label(TRANSIENT_TRANSACTION_ERROR); + let is_retryable = !is_aborted && (is_in_unknown_state || is_transient); + + if is_retryable && timeout.elapsed() < MAX_TX_TIMEOUT_COMMIT_RETRY_LIMIT { tokio::time::sleep(TX_RETRY_BACKOFF).await; continue; } else { diff --git a/query-engine/connectors/mongodb-query-connector/src/join.rs b/query-engine/connectors/mongodb-query-connector/src/join.rs index 24c8abe2fba9..501d9042dbe8 100644 --- a/query-engine/connectors/mongodb-query-connector/src/join.rs +++ b/query-engine/connectors/mongodb-query-connector/src/join.rs @@ -1,5 +1,5 @@ use crate::filter::MongoFilter; -use mongodb::bson::{doc, Document}; +use bson::{doc, Document}; use query_structure::{walkers, RelationFieldRef, ScalarFieldRef}; /// A join stage describes a tree of joins and nested joins to be performed on a collection. diff --git a/query-engine/connectors/mongodb-query-connector/src/lib.rs b/query-engine/connectors/mongodb-query-connector/src/lib.rs index 3903acaf8371..c3af1d6f253f 100644 --- a/query-engine/connectors/mongodb-query-connector/src/lib.rs +++ b/query-engine/connectors/mongodb-query-connector/src/lib.rs @@ -14,11 +14,10 @@ mod query_strings; mod root_queries; mod value; +use bson::Bson; +use bson::Document; use error::MongoError; -use mongodb::{ - bson::{Bson, Document}, - ClientSession, SessionCursor, -}; +use mongodb::{ClientSession, SessionCursor}; pub use interface::*; diff --git a/query-engine/connectors/mongodb-query-connector/src/orderby.rs b/query-engine/connectors/mongodb-query-connector/src/orderby.rs index 6248cbc8ab41..eeac8619a374 100644 --- a/query-engine/connectors/mongodb-query-connector/src/orderby.rs +++ b/query-engine/connectors/mongodb-query-connector/src/orderby.rs @@ -1,6 +1,6 @@ use crate::join::JoinStage; +use bson::{doc, Document}; use itertools::Itertools; -use mongodb::bson::{doc, Document}; use query_structure::{OrderBy, OrderByHop, OrderByToManyAggregation, SortOrder}; use std::{fmt::Display, iter}; diff --git a/query-engine/connectors/mongodb-query-connector/src/projection.rs b/query-engine/connectors/mongodb-query-connector/src/projection.rs index 96948fc0e027..af80992cbf29 100644 --- a/query-engine/connectors/mongodb-query-connector/src/projection.rs +++ b/query-engine/connectors/mongodb-query-connector/src/projection.rs @@ -1,5 +1,5 @@ use crate::IntoBson; -use mongodb::bson::{Bson, Document}; +use bson::{Bson, Document}; use query_structure::{FieldSelection, SelectedField}; /// Used as projection document for Mongo queries. diff --git a/query-engine/connectors/mongodb-query-connector/src/query_builder/group_by_builder.rs b/query-engine/connectors/mongodb-query-connector/src/query_builder/group_by_builder.rs index f5ac3659f1b5..c40c3ee8d0dc 100644 --- a/query-engine/connectors/mongodb-query-connector/src/query_builder/group_by_builder.rs +++ b/query-engine/connectors/mongodb-query-connector/src/query_builder/group_by_builder.rs @@ -1,7 +1,7 @@ use crate::constants::*; +use bson::{doc, Bson, Document}; use connector_interface::AggregationSelection; -use mongodb::bson::{doc, Bson, Document}; use query_structure::{AggregationFilter, Filter, ScalarFieldRef}; use std::collections::HashSet; diff --git a/query-engine/connectors/mongodb-query-connector/src/query_builder/read_query_builder.rs b/query-engine/connectors/mongodb-query-connector/src/query_builder/read_query_builder.rs index 27185de5c917..bfe48d5f851c 100644 --- a/query-engine/connectors/mongodb-query-connector/src/query_builder/read_query_builder.rs +++ b/query-engine/connectors/mongodb-query-connector/src/query_builder/read_query_builder.rs @@ -10,15 +10,13 @@ use crate::{ root_queries::observing, vacuum_cursor, BsonTransform, IntoBson, }; +use bson::{doc, Document}; use connector_interface::AggregationSelection; use itertools::Itertools; -use mongodb::{ - bson::{doc, Document}, - options::AggregateOptions, - ClientSession, Collection, -}; +use mongodb::{options::AggregateOptions, ClientSession, Collection}; use query_structure::{FieldSelection, Filter, Model, QueryArguments, ScalarFieldRef, VirtualSelection}; use std::convert::TryFrom; +use std::future::IntoFuture; // Mongo Driver broke usage of the simple API, can't be used by us anymore. // As such the read query will always be based on aggregation pipeline @@ -37,7 +35,11 @@ impl ReadQuery { let opts = AggregateOptions::builder().allow_disk_use(true).build(); let query_string_builder = Aggregate::new(&self.stages, on_collection.name()); let cursor = observing(&query_string_builder, || { - on_collection.aggregate_with_session(self.stages.clone(), opts, with_session) + on_collection + .aggregate(self.stages.clone()) + .with_options(opts) + .session(&mut *with_session) + .into_future() }) .await?; diff --git a/query-engine/connectors/mongodb-query-connector/src/query_strings.rs b/query-engine/connectors/mongodb-query-connector/src/query_strings.rs index bdccd09849c8..764f22fe080e 100644 --- a/query-engine/connectors/mongodb-query-connector/src/query_strings.rs +++ b/query-engine/connectors/mongodb-query-connector/src/query_strings.rs @@ -5,8 +5,8 @@ //! There is a struct for each different type of query to generate. Each of them implement the //! QueryStringBuilder trait, which is dynamically dispatched to a specific query string builder by //! `root_queries::observing` +use bson::{Bson, Document}; use derive_more::Constructor; -use mongodb::bson::{Bson, Document}; use std::fmt::Write; pub(crate) trait QueryString: Sync + Send { diff --git a/query-engine/connectors/mongodb-query-connector/src/root_queries/aggregate.rs b/query-engine/connectors/mongodb-query-connector/src/root_queries/aggregate.rs index 05ff57053e95..797e34127f8a 100644 --- a/query-engine/connectors/mongodb-query-connector/src/root_queries/aggregate.rs +++ b/query-engine/connectors/mongodb-query-connector/src/root_queries/aggregate.rs @@ -108,7 +108,7 @@ fn to_aggregation_rows( for field in fields { let meta = selection_meta.get(field.db_name()).unwrap(); - let bson = doc.remove(&format!("count_{}", field.db_name())).unwrap(); + let bson = doc.remove(format!("count_{}", field.db_name())).unwrap(); let field_val = value_from_bson(bson, meta)?; row.push(AggregationResult::Count(Some(field.clone()), field_val)); @@ -117,7 +117,7 @@ fn to_aggregation_rows( AggregationSelection::Average(fields) => { for field in fields { let meta = selection_meta.get(field.db_name()).unwrap(); - let bson = doc.remove(&format!("avg_{}", field.db_name())).unwrap(); + let bson = doc.remove(format!("avg_{}", field.db_name())).unwrap(); let field_val = value_from_bson(bson, meta)?; row.push(AggregationResult::Average(field.clone(), field_val)); @@ -126,7 +126,7 @@ fn to_aggregation_rows( AggregationSelection::Sum(fields) => { for field in fields { let meta = selection_meta.get(field.db_name()).unwrap(); - let bson = doc.remove(&format!("sum_{}", field.db_name())).unwrap(); + let bson = doc.remove(format!("sum_{}", field.db_name())).unwrap(); let field_val = value_from_bson(bson, meta)?; row.push(AggregationResult::Sum(field.clone(), field_val)); @@ -135,7 +135,7 @@ fn to_aggregation_rows( AggregationSelection::Min(fields) => { for field in fields { let meta = selection_meta.get(field.db_name()).unwrap(); - let bson = doc.remove(&format!("min_{}", field.db_name())).unwrap(); + let bson = doc.remove(format!("min_{}", field.db_name())).unwrap(); let field_val = value_from_bson(bson, meta)?; row.push(AggregationResult::Min(field.clone(), field_val)); @@ -144,7 +144,7 @@ fn to_aggregation_rows( AggregationSelection::Max(fields) => { for field in fields { let meta = selection_meta.get(field.db_name()).unwrap(); - let bson = doc.remove(&format!("max_{}", field.db_name())).unwrap(); + let bson = doc.remove(format!("max_{}", field.db_name())).unwrap(); let field_val = value_from_bson(bson, meta)?; row.push(AggregationResult::Max(field.clone(), field_val)); diff --git a/query-engine/connectors/mongodb-query-connector/src/root_queries/mod.rs b/query-engine/connectors/mongodb-query-connector/src/root_queries/mod.rs index dd2962d192e8..85099be0a694 100644 --- a/query-engine/connectors/mongodb-query-connector/src/root_queries/mod.rs +++ b/query-engine/connectors/mongodb-query-connector/src/root_queries/mod.rs @@ -10,16 +10,19 @@ use crate::query_strings::QueryString; use crate::{ error::DecorateErrorWithFieldInformationExtension, output_meta::OutputMetaMapping, value::value_from_bson, }; +use bson::Bson; +use bson::Document; use futures::Future; -use mongodb::bson::Bson; -use mongodb::bson::Document; -use query_engine_metrics::{ - histogram, increment_counter, metrics, PRISMA_DATASOURCE_QUERIES_DURATION_HISTOGRAM_MS, - PRISMA_DATASOURCE_QUERIES_TOTAL, +use prisma_metrics::{ + counter, histogram, PRISMA_DATASOURCE_QUERIES_DURATION_HISTOGRAM_MS, PRISMA_DATASOURCE_QUERIES_TOTAL, }; use query_structure::*; +use std::sync::Arc; use std::time::Instant; -use tracing::debug; +use tracing::{debug, info_span}; +use tracing_futures::Instrument; + +const DB_SYSTEM_NAME: &str = "mongodb"; /// Transforms a document to a `Record`, fields ordered as defined in `fields`. fn document_to_record(mut doc: Document, fields: &[String], meta_mapping: &OutputMetaMapping) -> crate::Result { @@ -59,19 +62,34 @@ where F: FnOnce() -> U + 'a, U: Future>, { + // TODO: build the string lazily in the Display impl so it doesn't have to be built if neither + // logs nor traces are enabled. This is tricky because whatever we store in the span has to be + // 'static, and all `QueryString` implementations aren't, so this requires some refactoring. + let query_string: Arc = builder.build().into(); + + let span = info_span!( + "prisma:engine:db_query", + user_facing = true, + "db.system" = DB_SYSTEM_NAME, + "db.statement" = %Arc::clone(&query_string), + "db.operation.name" = builder.query_type(), + "otel.kind" = "client" + ); + + if let Some(coll) = builder.collection() { + span.record("db.collection.name", coll); + } + let start = Instant::now(); - let res = f().await; + let res = f().instrument(span).await; let elapsed = start.elapsed().as_millis() as f64; - histogram!(PRISMA_DATASOURCE_QUERIES_DURATION_HISTOGRAM_MS, elapsed); - increment_counter!(PRISMA_DATASOURCE_QUERIES_TOTAL); + histogram!(PRISMA_DATASOURCE_QUERIES_DURATION_HISTOGRAM_MS).record(elapsed); + counter!(PRISMA_DATASOURCE_QUERIES_TOTAL).increment(1); - // TODO: emit tracing event only when "debug" level query logs are enabled. // TODO prisma/team-orm#136: fix log subscription. - let query_string = builder.build(); // NOTE: `params` is a part of the interface for query logs. - let params: Vec = vec![]; - debug!(target: "mongodb_query_connector::query", item_type = "query", is_query = true, query = %query_string, params = ?params, duration_ms = elapsed); + debug!(target: "mongodb_query_connector::query", item_type = "query", is_query = true, query = %query_string, params = %"[]", duration_ms = elapsed); res } diff --git a/query-engine/connectors/mongodb-query-connector/src/root_queries/read.rs b/query-engine/connectors/mongodb-query-connector/src/root_queries/read.rs index 9fc322312265..acec7c57ead1 100644 --- a/query-engine/connectors/mongodb-query-connector/src/root_queries/read.rs +++ b/query-engine/connectors/mongodb-query-connector/src/root_queries/read.rs @@ -5,7 +5,7 @@ use crate::{ }; use mongodb::{bson::doc, options::FindOptions, ClientSession, Database}; use query_structure::*; -use tracing::{info_span, Instrument}; +use std::future::IntoFuture; /// Finds a single record. Joins are not required at the moment because the selector is always a unique one. pub async fn get_single_record<'conn>( @@ -17,12 +17,6 @@ pub async fn get_single_record<'conn>( ) -> crate::Result> { let coll = database.collection(model.db_name()); - let span = info_span!( - "prisma:engine:db_query", - user_facing = true, - "db.statement" = &format_args!("db.{}.findOne(*)", coll.name()) - ); - let meta_mapping = output_meta::from_selected_fields(selected_fields); let query_arguments: QueryArguments = (model.clone(), filter.clone()).into(); let query = MongoReadQueryBuilder::from_args(query_arguments)? @@ -30,7 +24,7 @@ pub async fn get_single_record<'conn>( .with_virtual_fields(selected_fields.virtuals())? .build()?; - let docs = query.execute(coll, session).instrument(span).await?; + let docs = query.execute(coll, session).await?; if docs.is_empty() { Ok(None) @@ -59,12 +53,6 @@ pub async fn get_many_records<'conn>( ) -> crate::Result { let coll = database.collection(model.db_name()); - let span = info_span!( - "prisma:engine:db_query", - user_facing = true, - "db.statement" = &format_args!("db.{}.findMany(*)", coll.name()) - ); - let reverse_order = query_arguments.take.map(|t| t < 0).unwrap_or(false); let field_names: Vec<_> = selected_fields.db_names().collect(); @@ -80,7 +68,7 @@ pub async fn get_many_records<'conn>( .with_virtual_fields(selected_fields.virtuals())? .build()?; - let docs = query.execute(coll, session).instrument(span).await?; + let docs = query.execute(coll, session).await?; for doc in docs { let record = document_to_record(doc, &field_names, &meta_mapping)?; records.push(record) @@ -126,7 +114,10 @@ pub async fn get_related_m2m_record_ids<'conn>( let find_options = FindOptions::builder().projection(projection.clone()).build(); let cursor = observing(&query_string_builder, || { - coll.find_with_session(filter.clone(), Some(find_options), session) + coll.find(filter.clone()) + .with_options(find_options) + .session(&mut *session) + .into_future() }) .await?; let docs = vacuum_cursor(cursor, session).await?; diff --git a/query-engine/connectors/mongodb-query-connector/src/root_queries/update/expression.rs b/query-engine/connectors/mongodb-query-connector/src/root_queries/update/expression.rs index b92a0cf94d56..f1be40e57416 100644 --- a/query-engine/connectors/mongodb-query-connector/src/root_queries/update/expression.rs +++ b/query-engine/connectors/mongodb-query-connector/src/root_queries/update/expression.rs @@ -1,8 +1,8 @@ use super::{into_expression::IntoUpdateExpression, operation}; +use bson::{doc, Bson, Document}; use connector_interface::FieldPath; use indexmap::IndexMap; -use mongodb::bson::{doc, Bson, Document}; /// `UpdateExpression` is an intermediary AST that's used to represent MongoDB expressions. /// It is meant to be transformed into `BSON`. diff --git a/query-engine/connectors/mongodb-query-connector/src/root_queries/update/into_bson.rs b/query-engine/connectors/mongodb-query-connector/src/root_queries/update/into_bson.rs index b4bba0f89857..da8dbf31ab56 100644 --- a/query-engine/connectors/mongodb-query-connector/src/root_queries/update/into_bson.rs +++ b/query-engine/connectors/mongodb-query-connector/src/root_queries/update/into_bson.rs @@ -1,8 +1,8 @@ use super::expression::*; use crate::IntoBson; +use bson::{doc, Bson, Document}; use itertools::Itertools; -use mongodb::bson::{doc, Bson, Document}; impl IntoBson for Set { fn into_bson(self) -> crate::Result { diff --git a/query-engine/connectors/mongodb-query-connector/src/root_queries/update/into_expression.rs b/query-engine/connectors/mongodb-query-connector/src/root_queries/update/into_expression.rs index 359b9f06dadf..3fdf1925fa9c 100644 --- a/query-engine/connectors/mongodb-query-connector/src/root_queries/update/into_expression.rs +++ b/query-engine/connectors/mongodb-query-connector/src/root_queries/update/into_expression.rs @@ -1,8 +1,8 @@ use super::{expression::*, operation::*}; use crate::{filter, IntoBson}; +use bson::{doc, Bson}; use itertools::Itertools; -use mongodb::bson::{doc, Bson}; pub(crate) trait IntoUpdateExpression { fn into_update_expression(self) -> crate::Result; diff --git a/query-engine/connectors/mongodb-query-connector/src/root_queries/update/into_operation.rs b/query-engine/connectors/mongodb-query-connector/src/root_queries/update/into_operation.rs index 01ff5abcbd13..52ffbd70e338 100644 --- a/query-engine/connectors/mongodb-query-connector/src/root_queries/update/into_operation.rs +++ b/query-engine/connectors/mongodb-query-connector/src/root_queries/update/into_operation.rs @@ -1,8 +1,8 @@ use super::operation::*; use crate::*; +use bson::doc; use connector_interface::{CompositeWriteOperation, FieldPath, ScalarWriteOperation, WriteOperation}; -use mongodb::bson::doc; use query_structure::{Field, PrismaValue}; pub(crate) trait IntoUpdateOperation { diff --git a/query-engine/connectors/mongodb-query-connector/src/root_queries/update/mod.rs b/query-engine/connectors/mongodb-query-connector/src/root_queries/update/mod.rs index d91100d037b0..481ce7a95d29 100644 --- a/query-engine/connectors/mongodb-query-connector/src/root_queries/update/mod.rs +++ b/query-engine/connectors/mongodb-query-connector/src/root_queries/update/mod.rs @@ -7,10 +7,10 @@ mod operation; use super::*; use crate::*; +use bson::Document; use connector_interface::{FieldPath, WriteOperation}; use into_expression::IntoUpdateExpressions; use into_operation::IntoUpdateOperation; -use mongodb::bson::Document; pub(crate) trait IntoUpdateDocumentExtension { fn into_update_docs(self, field: &Field, path: FieldPath) -> crate::Result>; diff --git a/query-engine/connectors/mongodb-query-connector/src/root_queries/update/operation.rs b/query-engine/connectors/mongodb-query-connector/src/root_queries/update/operation.rs index 0fa814d81af7..62502f03a4aa 100644 --- a/query-engine/connectors/mongodb-query-connector/src/root_queries/update/operation.rs +++ b/query-engine/connectors/mongodb-query-connector/src/root_queries/update/operation.rs @@ -1,6 +1,6 @@ use super::{expression, into_expression::IntoUpdateExpression}; +use bson::{doc, Document}; use connector_interface::FieldPath; -use mongodb::bson::{doc, Document}; use query_structure::Filter; /// `UpdateOperation` is an intermediary AST used to perform preliminary transformations from a `WriteOperation`. diff --git a/query-engine/connectors/mongodb-query-connector/src/root_queries/write.rs b/query-engine/connectors/mongodb-query-connector/src/root_queries/write.rs index f66057d1ff22..2564b56e3717 100644 --- a/query-engine/connectors/mongodb-query-connector/src/root_queries/write.rs +++ b/query-engine/connectors/mongodb-query-connector/src/root_queries/write.rs @@ -16,8 +16,8 @@ use mongodb::{ ClientSession, Collection, Database, }; use query_structure::{Model, PrismaValue, SelectionResult}; +use std::future::IntoFuture; use std::{collections::HashMap, convert::TryInto}; -use tracing::{info_span, Instrument}; use update::IntoUpdateDocumentExtension; /// Create a single record to the database resulting in a @@ -30,12 +30,6 @@ pub async fn create_record<'conn>( ) -> crate::Result { let coll = database.collection::(model.db_name()); - let span = info_span!( - "prisma:engine:db_query", - user_facing = true, - "db.statement" = &format_args!("db.{}.insertOne(*)", coll.name()) - ); - let id_field = pick_singular_id(model); // Fields to write to the document. @@ -65,9 +59,7 @@ pub async fn create_record<'conn>( } let query_builder = InsertOne::new(&doc, coll.name()); - let insert_result = observing(&query_builder, || coll.insert_one_with_session(&doc, None, session)) - .instrument(span) - .await?; + let insert_result = observing(&query_builder, || coll.insert_one(&doc).session(session).into_future()).await?; let id_value = value_from_bson(insert_result.inserted_id, &id_meta)?; Ok(SingleRecord { @@ -85,12 +77,6 @@ pub async fn create_records<'conn>( ) -> crate::Result { let coll = database.collection::(model.db_name()); - let span = info_span!( - "prisma:engine:db_query", - user_facing = true, - "db.statement" = &format_args!("db.{}.insertMany(*)", coll.name()) - ); - let num_records = args.len(); let fields: Vec<_> = model.fields().non_relational(); @@ -123,14 +109,25 @@ pub async fn create_records<'conn>( let query_string_builder = InsertMany::new(&docs, coll.name(), ordered); let docs_iter = docs.iter(); let insert = observing(&query_string_builder, || { - coll.insert_many_with_session(docs_iter, options, session) - }) - .instrument(span); + coll.insert_many(docs_iter) + .with_options(options) + .session(session) + .into_future() + }); match insert.await { Ok(insert_result) => Ok(insert_result.inserted_ids.len()), Err(err) if skip_duplicates => match err.kind.as_ref() { - ErrorKind::BulkWrite(ref failure) => match failure.write_errors { + ErrorKind::BulkWrite(ref failure) => { + let errs = &failure.write_errors; + if !errs.iter().any(|(_, err)| err.code != 11000) { + Ok(num_records - errs.len()) + } else { + Err(err.into()) + } + } + + ErrorKind::InsertMany(ref failure) => match failure.write_errors { Some(ref errs) if !errs.iter().any(|err| err.code != 11000) => Ok(num_records - errs.len()), _ => Err(err.into()), }, @@ -171,19 +168,13 @@ pub async fn update_records<'conn>( .collect::>>()? } else { let filter = MongoFilterVisitor::new(FilterPrefix::default(), false).visit(record_filter.filter)?; - find_ids(database, coll.clone(), session, model, filter).await? + find_ids(coll.clone(), session, model, filter).await? }; if ids.is_empty() { return Ok(vec![]); } - let span = info_span!( - "prisma:engine:db_query", - user_facing = true, - "db.statement" = &format_args!("db.{}.updateMany(*)", coll.name()) - ); - let filter = doc! { id_field.db_name(): { "$in": ids.clone() } }; let fields: Vec<_> = model .fields() @@ -205,9 +196,10 @@ pub async fn update_records<'conn>( if !update_docs.is_empty() { let query_string_builder = UpdateMany::new(&filter, &update_docs, coll.name()); let res = observing(&query_string_builder, || { - coll.update_many_with_session(filter.clone(), update_docs.clone(), None, session) + coll.update_many(filter.clone(), update_docs.clone()) + .session(session) + .into_future() }) - .instrument(span) .await?; // It's important we check the `matched_count` and not the `modified_count` here. @@ -251,25 +243,18 @@ pub async fn delete_records<'conn>( .collect::>>()? } else { let filter = MongoFilterVisitor::new(FilterPrefix::default(), false).visit(record_filter.filter)?; - find_ids(database, coll.clone(), session, model, filter).await? + find_ids(coll.clone(), session, model, filter).await? }; if ids.is_empty() { return Ok(0); } - let span = info_span!( - "prisma:engine:db_query", - user_facing = true, - "db.statement" = &format_args!("db.{}.deleteMany(*)", coll.name()) - ); - let filter = doc! { id_field.db_name(): { "$in": ids } }; let query_string_builder = DeleteMany::new(&filter, coll.name()); let delete_result = observing(&query_string_builder, || { - coll.delete_many_with_session(filter.clone(), None, session) + coll.delete_many(filter.clone()).session(session).into_future() }) - .instrument(span) .await?; Ok(delete_result.deleted_count as usize) @@ -297,16 +282,10 @@ pub async fn delete_record<'conn>( "$expr": filter, }; - let span = info_span!( - "prisma:engine:db_query", - user_facing = true, - "db.statement" = &format_args!("db.{}.findAndModify(*)", coll.name()) - ); let query_string_builder = DeleteOne::new(&filter, coll.name()); let document = observing(&query_string_builder, || { - coll.find_one_and_delete_with_session(filter.clone(), None, session) + coll.find_one_and_delete(filter.clone()).session(session).into_future() }) - .instrument(span) .await? .ok_or(MongoError::RecordDoesNotExist { cause: "Record to delete does not exist.".to_owned(), @@ -320,20 +299,11 @@ pub async fn delete_record<'conn>( /// Retrives document ids based on the given filter. async fn find_ids( - database: &Database, collection: Collection, session: &mut ClientSession, model: &Model, filter: MongoFilter, ) -> crate::Result> { - let coll = database.collection::(model.db_name()); - - let span = info_span!( - "prisma:engine:db_query", - user_facing = true, - "db.statement" = &format_args!("db.{}.findMany(*)", coll.name()) - ); - let id_field = model.primary_identifier(); let mut builder = MongoReadQueryBuilder::new(model.clone()); @@ -348,7 +318,7 @@ async fn find_ids( let builder = builder.with_model_projection(id_field)?; let query = builder.build()?; - let docs = query.execute(collection, session).instrument(span).await?; + let docs = query.execute(collection, session).await?; let ids = docs.into_iter().map(|mut doc| doc.remove("_id").unwrap()).collect(); Ok(ids) @@ -394,7 +364,10 @@ pub async fn m2m_connect<'conn>( let query_string_builder = UpdateOne::new(&parent_filter, &parent_update, parent_coll.name()); observing(&query_string_builder, || { - parent_coll.update_one_with_session(parent_filter.clone(), parent_update.clone(), None, session) + parent_coll + .update_one(parent_filter.clone(), parent_update.clone()) + .session(&mut *session) + .into_future() }) .await?; @@ -415,7 +388,10 @@ pub async fn m2m_connect<'conn>( let child_updates = vec![child_update.clone()]; let query_string_builder = UpdateMany::new(&child_filter, &child_updates, child_coll.name()); observing(&query_string_builder, || { - child_coll.update_many_with_session(child_filter.clone(), child_update.clone(), None, session) + child_coll + .update_many(child_filter.clone(), child_update.clone()) + .session(&mut *session) + .into_future() }) .await?; @@ -460,7 +436,10 @@ pub async fn m2m_disconnect<'conn>( // First update the parent and remove all child IDs to the m:n scalar field. let query_string_builder = UpdateOne::new(&parent_filter, &parent_update, parent_coll.name()); observing(&query_string_builder, || { - parent_coll.update_one_with_session(parent_filter.clone(), parent_update.clone(), None, session) + parent_coll + .update_one(parent_filter.clone(), parent_update.clone()) + .session(&mut *session) + .into_future() }) .await?; @@ -482,7 +461,10 @@ pub async fn m2m_disconnect<'conn>( let child_updates = vec![child_update.clone()]; let query_string_builder = UpdateMany::new(&child_filter, &child_updates, child_coll.name()); observing(&query_string_builder, || { - child_coll.update_many_with_session(child_filter.clone(), child_update, None, session) + child_coll + .update_many(child_filter.clone(), child_update) + .session(session) + .into_future() }) .await?; @@ -506,13 +488,6 @@ pub async fn query_raw<'conn>( inputs: HashMap, query_type: Option, ) -> crate::Result { - let db_statement = get_raw_db_statement(&query_type, &model, database); - let span = info_span!( - "prisma:engine:db_query", - user_facing = true, - "db.statement" = &&db_statement.as_str() - ); - let mongo_command = MongoCommand::from_raw_query(model, inputs, query_type)?; async { @@ -520,7 +495,7 @@ pub async fn query_raw<'conn>( MongoCommand::Raw { cmd } => { let query_string_builder = RunCommand::new(&cmd); let mut result = observing(&query_string_builder, || { - database.run_command_with_session(cmd.clone(), None, session) + database.run_command(cmd.clone()).session(session).into_future() }) .await?; @@ -547,7 +522,10 @@ pub async fn query_raw<'conn>( .unwrap_or_default(); let query_string_builder = Find::new(&unwrapped_filter, &projection, coll.name()); let cursor = observing(&query_string_builder, || { - coll.find_with_session(filter, options, session) + coll.find(filter.unwrap_or_default()) + .with_options(options) + .session(&mut *session) + .into_future() }) .await?; @@ -556,7 +534,10 @@ pub async fn query_raw<'conn>( MongoOperation::Aggregate(pipeline, options) => { let query_string_builder = Aggregate::new(&pipeline, coll.name()); let cursor = observing(&query_string_builder, || { - coll.aggregate_with_session(pipeline.clone(), options, session) + coll.aggregate(pipeline.clone()) + .with_options(options) + .session(&mut *session) + .into_future() }) .await?; @@ -568,17 +549,5 @@ pub async fn query_raw<'conn>( Ok(RawJson::try_new(json_result)?) } - .instrument(span) .await } - -fn get_raw_db_statement(query_type: &Option, model: &Option<&Model>, database: &Database) -> String { - match (query_type.as_deref(), model) { - (Some("findRaw"), Some(m)) => format!("db.{}.findRaw(*)", database.collection::(m.db_name()).name()), - (Some("aggregateRaw"), Some(m)) => format!( - "db.{}.aggregateRaw(*)", - database.collection::(m.db_name()).name() - ), - _ => "db.runCommandRaw(*)".to_string(), - } -} diff --git a/query-engine/connectors/mongodb-query-connector/src/value.rs b/query-engine/connectors/mongodb-query-connector/src/value.rs index a9f79e941f7e..abbfaf7108f3 100644 --- a/query-engine/connectors/mongodb-query-connector/src/value.rs +++ b/query-engine/connectors/mongodb-query-connector/src/value.rs @@ -4,9 +4,9 @@ use crate::{ IntoBson, MongoError, }; use bigdecimal::{BigDecimal, FromPrimitive, ToPrimitive}; +use bson::{oid::ObjectId, spec::BinarySubtype, Binary, Bson, Document, Timestamp}; use chrono::{TimeZone, Utc}; use itertools::Itertools; -use mongodb::bson::{oid::ObjectId, spec::BinarySubtype, Binary, Bson, Document, Timestamp}; use psl::builtin_connectors::MongoDbType; use query_structure::{ CompositeFieldRef, Field, PrismaValue, RelationFieldRef, ScalarFieldRef, SelectedField, TypeIdentifier, diff --git a/query-engine/connectors/query-connector/Cargo.toml b/query-engine/connectors/query-connector/Cargo.toml index 52555d256baa..125be0895492 100644 --- a/query-engine/connectors/query-connector/Cargo.toml +++ b/query-engine/connectors/query-connector/Cargo.toml @@ -7,7 +7,7 @@ version = "0.1.0" anyhow = "1.0" async-trait.workspace = true chrono.workspace = true -futures = "0.3" +futures.workspace = true itertools.workspace = true query-structure = {path = "../../query-structure"} prisma-value = {path = "../../../libs/prisma-value"} @@ -17,3 +17,4 @@ thiserror = "1.0" user-facing-errors = {path = "../../../libs/user-facing-errors", features = ["sql"]} uuid.workspace = true indexmap.workspace = true +telemetry = {path = "../../../libs/telemetry"} diff --git a/query-engine/connectors/query-connector/src/interface.rs b/query-engine/connectors/query-connector/src/interface.rs index cbdafcaeeee3..05e8f1e1098f 100644 --- a/query-engine/connectors/query-connector/src/interface.rs +++ b/query-engine/connectors/query-connector/src/interface.rs @@ -3,13 +3,17 @@ use async_trait::async_trait; use prisma_value::PrismaValue; use query_structure::{ast::FieldArity, *}; use std::collections::HashMap; +use telemetry::helpers::TraceParent; #[async_trait] pub trait Connector { /// Returns a connection to a data source. async fn get_connection(&self) -> crate::Result>; - /// Returns the name of the connector. + /// Returns the database system name, as per the OTEL spec. + /// Reference: + /// - https://opentelemetry.io/docs/specs/semconv/database/sql/ + /// - https://opentelemetry.io/docs/specs/semconv/database/mongodb/ fn name(&self) -> &'static str; /// Returns whether a connector should retry an entire transaction when that transaction failed during its execution @@ -194,7 +198,7 @@ pub trait ReadOperations { filter: &Filter, selected_fields: &FieldSelection, relation_load_strategy: RelationLoadStrategy, - trace_id: Option, + traceparent: Option, ) -> crate::Result>; /// Gets multiple records from the database. @@ -209,7 +213,7 @@ pub trait ReadOperations { query_arguments: QueryArguments, selected_fields: &FieldSelection, relation_load_strategy: RelationLoadStrategy, - trace_id: Option, + traceparent: Option, ) -> crate::Result; /// Retrieves pairs of IDs that belong together from a intermediate join @@ -223,7 +227,7 @@ pub trait ReadOperations { &mut self, from_field: &RelationFieldRef, from_record_ids: &[SelectionResult], - trace_id: Option, + traceparent: Option, ) -> crate::Result>; /// Aggregates records for a specific model based on the given selections. @@ -238,7 +242,7 @@ pub trait ReadOperations { selections: Vec, group_by: Vec, having: Option, - trace_id: Option, + traceparent: Option, ) -> crate::Result>; } @@ -250,7 +254,7 @@ pub trait WriteOperations { model: &Model, args: WriteArgs, selected_fields: FieldSelection, - trace_id: Option, + traceparent: Option, ) -> crate::Result; /// Inserts many records at once into the database. @@ -259,7 +263,7 @@ pub trait WriteOperations { model: &Model, args: Vec, skip_duplicates: bool, - trace_id: Option, + traceparent: Option, ) -> crate::Result; /// Inserts many records at once into the database and returns their @@ -272,7 +276,7 @@ pub trait WriteOperations { args: Vec, skip_duplicates: bool, selected_fields: FieldSelection, - trace_id: Option, + traceparent: Option, ) -> crate::Result; /// Update records in the `Model` with the given `WriteArgs` filtered by the @@ -282,7 +286,7 @@ pub trait WriteOperations { model: &Model, record_filter: RecordFilter, args: WriteArgs, - trace_id: Option, + traceparent: Option, ) -> crate::Result; /// Update record in the `Model` with the given `WriteArgs` filtered by the @@ -293,7 +297,7 @@ pub trait WriteOperations { record_filter: RecordFilter, args: WriteArgs, selected_fields: Option, - trace_id: Option, + traceparent: Option, ) -> crate::Result>; /// Native upsert @@ -301,7 +305,7 @@ pub trait WriteOperations { async fn native_upsert_record( &mut self, upsert: NativeUpsert, - trace_id: Option, + traceparent: Option, ) -> crate::Result; /// Delete records in the `Model` with the given `Filter`. @@ -309,7 +313,7 @@ pub trait WriteOperations { &mut self, model: &Model, record_filter: RecordFilter, - trace_id: Option, + traceparent: Option, ) -> crate::Result; /// Delete single record in the `Model` with the given `Filter` and returns @@ -321,7 +325,7 @@ pub trait WriteOperations { model: &Model, record_filter: RecordFilter, selected_fields: FieldSelection, - trace_id: Option, + traceparent: Option, ) -> crate::Result; // We plan to remove the methods below in the future. We want emulate them with the ones above. Those should suffice. @@ -332,7 +336,7 @@ pub trait WriteOperations { field: &RelationFieldRef, parent_id: &SelectionResult, child_ids: &[SelectionResult], - trace_id: Option, + traceparent: Option, ) -> crate::Result<()>; /// Disconnect the children from the parent (m2m relation only). @@ -341,7 +345,7 @@ pub trait WriteOperations { field: &RelationFieldRef, parent_id: &SelectionResult, child_ids: &[SelectionResult], - trace_id: Option, + traceparent: Option, ) -> crate::Result<()>; /// Execute the raw query in the database as-is. diff --git a/query-engine/connectors/sql-query-connector/Cargo.toml b/query-engine/connectors/sql-query-connector/Cargo.toml index 5fbe41370202..2e3e0fe2fe5c 100644 --- a/query-engine/connectors/sql-query-connector/Cargo.toml +++ b/query-engine/connectors/sql-query-connector/Cargo.toml @@ -38,15 +38,15 @@ psl.workspace = true anyhow = "1.0" async-trait.workspace = true bigdecimal = "0.3" -futures = "0.3" +futures.workspace = true itertools.workspace = true once_cell = "1.3" rand.workspace = true serde_json.workspace = true thiserror = "1.0" -tokio = { version = "1.0", features = ["macros", "time"] } +tokio = { version = "1", features = ["macros", "time"] } tracing = { workspace = true, features = ["log"] } -tracing-futures = "0.2" +tracing-futures.workspace = true uuid.workspace = true opentelemetry = { version = "0.17", features = ["tokio"] } tracing-opentelemetry = "0.17.3" @@ -66,6 +66,9 @@ path = "../../query-structure" [dependencies.prisma-value] path = "../../../libs/prisma-value" +[dependencies.telemetry] +path = "../../../libs/telemetry" + [dependencies.chrono] features = ["serde"] version = "0.4" diff --git a/query-engine/connectors/sql-query-connector/src/context.rs b/query-engine/connectors/sql-query-connector/src/context.rs index 5b2887451ef1..2519439b13f1 100644 --- a/query-engine/connectors/sql-query-connector/src/context.rs +++ b/query-engine/connectors/sql-query-connector/src/context.rs @@ -1,8 +1,9 @@ use quaint::prelude::ConnectionInfo; +use telemetry::helpers::TraceParent; pub(super) struct Context<'a> { connection_info: &'a ConnectionInfo, - pub(crate) trace_id: Option<&'a str>, + pub(crate) traceparent: Option, /// Maximum rows allowed at once for an insert query. /// None is unlimited. pub(crate) max_insert_rows: Option, @@ -12,13 +13,13 @@ pub(super) struct Context<'a> { } impl<'a> Context<'a> { - pub(crate) fn new(connection_info: &'a ConnectionInfo, trace_id: Option<&'a str>) -> Self { + pub(crate) fn new(connection_info: &'a ConnectionInfo, traceparent: Option) -> Self { let max_insert_rows = connection_info.max_insert_rows(); let max_bind_values = connection_info.max_bind_values(); Context { connection_info, - trace_id, + traceparent, max_insert_rows, max_bind_values: Some(max_bind_values), } diff --git a/query-engine/connectors/sql-query-connector/src/database/connection.rs b/query-engine/connectors/sql-query-connector/src/database/connection.rs index f928fcacdfa5..1222f0425ea0 100644 --- a/query-engine/connectors/sql-query-connector/src/database/connection.rs +++ b/query-engine/connectors/sql-query-connector/src/database/connection.rs @@ -15,6 +15,7 @@ use quaint::{ }; use query_structure::{prelude::*, Filter, QueryArguments, RelationLoadStrategy, SelectionResult}; use std::{collections::HashMap, str::FromStr}; +use telemetry::helpers::TraceParent; pub(crate) struct SqlConnection { inner: C, @@ -89,10 +90,10 @@ where filter: &Filter, selected_fields: &FieldSelection, relation_load_strategy: RelationLoadStrategy, - trace_id: Option, + traceparent: Option, ) -> connector::Result> { // [Composites] todo: FieldSelection -> ModelProjection conversion - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, read::get_single_record( @@ -113,9 +114,9 @@ where query_arguments: QueryArguments, selected_fields: &FieldSelection, relation_load_strategy: RelationLoadStrategy, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, read::get_many_records( @@ -134,9 +135,9 @@ where &mut self, from_field: &RelationFieldRef, from_record_ids: &[SelectionResult], - trace_id: Option, + traceparent: Option, ) -> connector::Result> { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, read::get_related_m2m_record_ids(&self.inner, from_field, from_record_ids, &ctx), @@ -151,9 +152,9 @@ where selections: Vec, group_by: Vec, having: Option, - trace_id: Option, + traceparent: Option, ) -> connector::Result> { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, read::aggregate(&self.inner, model, query_arguments, selections, group_by, having, &ctx), @@ -172,9 +173,9 @@ where model: &Model, args: WriteArgs, selected_fields: FieldSelection, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::create_record( @@ -194,9 +195,9 @@ where model: &Model, args: Vec, skip_duplicates: bool, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::create_records_count(&self.inner, model, args, skip_duplicates, &ctx), @@ -210,9 +211,9 @@ where args: Vec, skip_duplicates: bool, selected_fields: FieldSelection, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::create_records_returning(&self.inner, model, args, skip_duplicates, selected_fields, &ctx), @@ -225,9 +226,9 @@ where model: &Model, record_filter: RecordFilter, args: WriteArgs, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::update_records(&self.inner, model, record_filter, args, &ctx), @@ -241,9 +242,9 @@ where record_filter: RecordFilter, args: WriteArgs, selected_fields: Option, - trace_id: Option, + traceparent: Option, ) -> connector::Result> { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::update_record(&self.inner, model, record_filter, args, selected_fields, &ctx), @@ -255,9 +256,9 @@ where &mut self, model: &Model, record_filter: RecordFilter, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::delete_records(&self.inner, model, record_filter, &ctx), @@ -270,9 +271,9 @@ where model: &Model, record_filter: RecordFilter, selected_fields: FieldSelection, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::delete_record(&self.inner, model, record_filter, selected_fields, &ctx), @@ -283,9 +284,9 @@ where async fn native_upsert_record( &mut self, upsert: connector_interface::NativeUpsert, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch(&self.connection_info, upsert::native_upsert(&self.inner, upsert, &ctx)).await } @@ -294,9 +295,9 @@ where field: &RelationFieldRef, parent_id: &SelectionResult, child_ids: &[SelectionResult], - trace_id: Option, + traceparent: Option, ) -> connector::Result<()> { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::m2m_connect(&self.inner, field, parent_id, child_ids, &ctx), @@ -309,9 +310,9 @@ where field: &RelationFieldRef, parent_id: &SelectionResult, child_ids: &[SelectionResult], - trace_id: Option, + traceparent: Option, ) -> connector::Result<()> { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::m2m_disconnect(&self.inner, field, parent_id, child_ids, &ctx), diff --git a/query-engine/connectors/sql-query-connector/src/database/js.rs b/query-engine/connectors/sql-query-connector/src/database/js.rs index 9badc8659738..40ca0caa0025 100644 --- a/query-engine/connectors/sql-query-connector/src/database/js.rs +++ b/query-engine/connectors/sql-query-connector/src/database/js.rs @@ -48,7 +48,16 @@ impl Connector for Js { } fn name(&self) -> &'static str { - "js" + match self.connection_info.sql_family() { + #[cfg(feature = "postgresql")] + SqlFamily::Postgres => "postgresql", + #[cfg(feature = "mysql")] + SqlFamily::Mysql => "mysql", + #[cfg(feature = "sqlite")] + SqlFamily::Sqlite => "sqlite", + #[cfg(feature = "mssql")] + SqlFamily::Mssql => "mssql", + } } fn should_retry_on_transient_error(&self) -> bool { @@ -90,6 +99,10 @@ impl QuaintQueryable for DriverAdapter { self.connector.query_raw_typed(sql, params).await } + async fn describe_query(&self, sql: &str) -> quaint::Result { + self.connector.describe_query(sql).await + } + async fn execute(&self, q: Query<'_>) -> quaint::Result { self.connector.execute(q).await } diff --git a/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs b/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs index 701fa9a1a0c5..9e59e4f232c7 100644 --- a/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs @@ -13,6 +13,7 @@ pub struct PostgreSql { pool: Quaint, connection_info: ConnectionInfo, features: psl::PreviewFeatures, + flavour: PostgresFlavour, } impl PostgreSql { @@ -60,6 +61,7 @@ impl FromSource for PostgreSql { pool, connection_info, features, + flavour, }) } } @@ -76,7 +78,10 @@ impl Connector for PostgreSql { } fn name(&self) -> &'static str { - "postgres" + match self.flavour { + PostgresFlavour::Postgres | PostgresFlavour::Unknown => "postgresql", + PostgresFlavour::Cockroach => "cockroachdb", + } } fn should_retry_on_transient_error(&self) -> bool { diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/read/coerce.rs b/query-engine/connectors/sql-query-connector/src/database/operations/read/coerce.rs index ba3dd50e1768..87fbe8ce86ef 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/read/coerce.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/read/coerce.rs @@ -120,10 +120,23 @@ fn coerce_json_relation_to_pv(value: serde_json::Value, rs: &RelationSelection) } pub(crate) fn coerce_json_scalar_to_pv(value: serde_json::Value, sf: &ScalarField) -> crate::Result { - if sf.type_identifier().is_json() { + if sf.type_identifier().is_json() && !sf.is_list() { return Ok(PrismaValue::Json(serde_json::to_string(&value)?)); } + if sf.type_identifier().is_json() && sf.is_list() { + return match value { + serde_json::Value::Null => Ok(PrismaValue::List(vec![])), + serde_json::Value::Array(values) => Ok(PrismaValue::List( + values + .iter() + .map(|v| PrismaValue::Json(serde_json::to_string(v).unwrap())) + .collect(), + )), + _ => unreachable!("Invalid JSON value for JSON list field."), + }; + } + match value { serde_json::Value::Null => { if sf.is_list() { diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/update.rs b/query-engine/connectors/sql-query-connector/src/database/operations/update.rs index 54e04651d2f4..0dc8081f97d2 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/update.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/update.rs @@ -9,7 +9,6 @@ use crate::{Context, QueryExt, Queryable}; use connector_interface::*; use itertools::Itertools; use query_structure::*; -use std::usize; /// Performs an update with an explicit selection set. /// This function is called for connectors that supports the `UpdateReturning` capability. diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs index b56118b1c698..137bff50ca58 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs @@ -18,7 +18,6 @@ use std::borrow::Cow; use std::{ collections::{HashMap, HashSet}, ops::Deref, - usize, }; use user_facing_errors::query_engine::DatabaseConstraint; @@ -73,7 +72,7 @@ async fn generate_id( // db generate values only if needed if need_select { - let pk_select = id_select.add_trace_id(ctx.trace_id); + let pk_select = id_select.add_traceparent(ctx.traceparent); let pk_result = conn.query(pk_select.into()).await?; let result = try_convert(&(id_field.into()), pk_result)?; diff --git a/query-engine/connectors/sql-query-connector/src/database/transaction.rs b/query-engine/connectors/sql-query-connector/src/database/transaction.rs index 263c541f6b42..387b18f63ee2 100644 --- a/query-engine/connectors/sql-query-connector/src/database/transaction.rs +++ b/query-engine/connectors/sql-query-connector/src/database/transaction.rs @@ -10,6 +10,7 @@ use prisma_value::PrismaValue; use quaint::prelude::ConnectionInfo; use query_structure::{prelude::*, Filter, QueryArguments, RelationLoadStrategy, SelectionResult}; use std::collections::HashMap; +use telemetry::helpers::TraceParent; pub struct SqlConnectorTransaction<'tx> { inner: Box, @@ -73,9 +74,9 @@ impl<'tx> ReadOperations for SqlConnectorTransaction<'tx> { filter: &Filter, selected_fields: &FieldSelection, relation_load_strategy: RelationLoadStrategy, - trace_id: Option, + traceparent: Option, ) -> connector::Result> { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, read::get_single_record( @@ -96,9 +97,9 @@ impl<'tx> ReadOperations for SqlConnectorTransaction<'tx> { query_arguments: QueryArguments, selected_fields: &FieldSelection, relation_load_strategy: RelationLoadStrategy, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, read::get_many_records( @@ -117,9 +118,9 @@ impl<'tx> ReadOperations for SqlConnectorTransaction<'tx> { &mut self, from_field: &RelationFieldRef, from_record_ids: &[SelectionResult], - trace_id: Option, + traceparent: Option, ) -> connector::Result> { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch(&self.connection_info, async { read::get_related_m2m_record_ids(self.inner.as_queryable(), from_field, from_record_ids, &ctx).await }) @@ -133,9 +134,9 @@ impl<'tx> ReadOperations for SqlConnectorTransaction<'tx> { selections: Vec, group_by: Vec, having: Option, - trace_id: Option, + traceparent: Option, ) -> connector::Result> { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, read::aggregate( @@ -159,9 +160,9 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { model: &Model, args: WriteArgs, selected_fields: FieldSelection, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::create_record( @@ -181,9 +182,9 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { model: &Model, args: Vec, skip_duplicates: bool, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::create_records_count(self.inner.as_queryable(), model, args, skip_duplicates, &ctx), @@ -197,9 +198,9 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { args: Vec, skip_duplicates: bool, selected_fields: FieldSelection, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::create_records_returning( @@ -219,9 +220,9 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { model: &Model, record_filter: RecordFilter, args: WriteArgs, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::update_records(self.inner.as_queryable(), model, record_filter, args, &ctx), @@ -235,9 +236,9 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { record_filter: RecordFilter, args: WriteArgs, selected_fields: Option, - trace_id: Option, + traceparent: Option, ) -> connector::Result> { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::update_record( @@ -256,10 +257,10 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { &mut self, model: &Model, record_filter: RecordFilter, - trace_id: Option, + traceparent: Option, ) -> connector::Result { catch(&self.connection_info, async { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); write::delete_records(self.inner.as_queryable(), model, record_filter, &ctx).await }) .await @@ -270,9 +271,9 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { model: &Model, record_filter: RecordFilter, selected_fields: FieldSelection, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::delete_record(self.inner.as_queryable(), model, record_filter, selected_fields, &ctx), @@ -283,10 +284,10 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { async fn native_upsert_record( &mut self, upsert: connector_interface::NativeUpsert, - trace_id: Option, + traceparent: Option, ) -> connector::Result { catch(&self.connection_info, async { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); upsert::native_upsert(self.inner.as_queryable(), upsert, &ctx).await }) .await @@ -297,10 +298,10 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { field: &RelationFieldRef, parent_id: &SelectionResult, child_ids: &[SelectionResult], - trace_id: Option, + traceparent: Option, ) -> connector::Result<()> { catch(&self.connection_info, async { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); write::m2m_connect(self.inner.as_queryable(), field, parent_id, child_ids, &ctx).await }) .await @@ -311,9 +312,9 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { field: &RelationFieldRef, parent_id: &SelectionResult, child_ids: &[SelectionResult], - trace_id: Option, + traceparent: Option, ) -> connector::Result<()> { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::m2m_disconnect(self.inner.as_queryable(), field, parent_id, child_ids, &ctx), diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/read.rs b/query-engine/connectors/sql-query-connector/src/query_builder/read.rs index 6a0572ecc0da..84323e0f52b2 100644 --- a/query-engine/connectors/sql-query-connector/src/query_builder/read.rs +++ b/query-engine/connectors/sql-query-connector/src/query_builder/read.rs @@ -100,7 +100,7 @@ impl SelectDefinition for QueryArguments { .so_that(conditions) .offset(skip as usize) .append_trace(&Span::current()) - .add_trace_id(ctx.trace_id); + .add_traceparent(ctx.traceparent); let select_ast = order_by_definitions .iter() @@ -137,7 +137,7 @@ where let (select, additional_selection_set) = query.into_select(model, virtual_selections, ctx); let select = columns.fold(select, |acc, col| acc.column(col)); - let select = select.append_trace(&Span::current()).add_trace_id(ctx.trace_id); + let select = select.append_trace(&Span::current()).add_traceparent(ctx.traceparent); additional_selection_set .into_iter() @@ -183,7 +183,7 @@ pub(crate) fn aggregate( selections.iter().fold( Select::from_table(sub_table) .append_trace(&Span::current()) - .add_trace_id(ctx.trace_id), + .add_traceparent(ctx.traceparent), |select, next_op| match next_op { AggregationSelection::Field(field) => select.column( Column::from(field.db_name().to_owned()) @@ -269,7 +269,9 @@ pub(crate) fn group_by_aggregate( }); let grouped = group_by.into_iter().fold( - select_query.append_trace(&Span::current()).add_trace_id(ctx.trace_id), + select_query + .append_trace(&Span::current()) + .add_traceparent(ctx.traceparent), |query, field| query.group_by(field.as_column(ctx)), ); diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/select/mod.rs b/query-engine/connectors/sql-query-connector/src/query_builder/select/mod.rs index 945c3fc2e51e..9c0139c6cd81 100644 --- a/query-engine/connectors/sql-query-connector/src/query_builder/select/mod.rs +++ b/query-engine/connectors/sql-query-connector/src/query_builder/select/mod.rs @@ -654,11 +654,7 @@ fn extract_filter_scalars(f: &Filter) -> Vec { } fn join_fields(rf: &RelationField) -> Vec { - if rf.is_inlined_on_enclosing_model() { - rf.scalar_fields() - } else { - rf.related_field().referenced_fields() - } + rf.linking_fields().as_scalar_fields().unwrap_or_default() } fn join_alias_name(rf: &RelationField) -> String { diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/write.rs b/query-engine/connectors/sql-query-connector/src/query_builder/write.rs index c089f0834dcb..c07e3600e149 100644 --- a/query-engine/connectors/sql-query-connector/src/query_builder/write.rs +++ b/query-engine/connectors/sql-query-connector/src/query_builder/write.rs @@ -34,7 +34,7 @@ pub(crate) fn create_record( Insert::from(insert) .returning(selected_fields.as_columns(ctx).map(|c| c.set_is_selected(true))) .append_trace(&Span::current()) - .add_trace_id(ctx.trace_id) + .add_traceparent(ctx.traceparent) } /// `INSERT` new records into the database based on the given write arguments, @@ -84,7 +84,7 @@ pub(crate) fn create_records_nonempty( let insert = Insert::multi_into(model.as_table(ctx), columns); let insert = values.into_iter().fold(insert, |stmt, values| stmt.values(values)); let insert: Insert = insert.into(); - let mut insert = insert.append_trace(&Span::current()).add_trace_id(ctx.trace_id); + let mut insert = insert.append_trace(&Span::current()).add_traceparent(ctx.traceparent); if let Some(selected_fields) = selected_fields { insert = insert.returning(projection_into_columns(selected_fields, ctx)); @@ -105,7 +105,7 @@ pub(crate) fn create_records_empty( ctx: &Context<'_>, ) -> Insert<'static> { let insert: Insert<'static> = Insert::single_into(model.as_table(ctx)).into(); - let mut insert = insert.append_trace(&Span::current()).add_trace_id(ctx.trace_id); + let mut insert = insert.append_trace(&Span::current()).add_traceparent(ctx.traceparent); if let Some(selected_fields) = selected_fields { insert = insert.returning(projection_into_columns(selected_fields, ctx)); @@ -175,7 +175,7 @@ pub(crate) fn build_update_and_set_query( acc.set(name, value) }); - let query = query.append_trace(&Span::current()).add_trace_id(ctx.trace_id); + let query = query.append_trace(&Span::current()).add_traceparent(ctx.traceparent); let query = if let Some(selected_fields) = selected_fields { query.returning(selected_fields.as_columns(ctx).map(|c| c.set_is_selected(true))) @@ -222,7 +222,7 @@ pub(crate) fn delete_returning( .so_that(filter) .returning(projection_into_columns(selected_fields, ctx)) .append_trace(&Span::current()) - .add_trace_id(ctx.trace_id) + .add_traceparent(ctx.traceparent) .into() } @@ -234,7 +234,7 @@ pub(crate) fn delete_many_from_filter( Delete::from_table(model.as_table(ctx)) .so_that(filter_condition) .append_trace(&Span::current()) - .add_trace_id(ctx.trace_id) + .add_traceparent(ctx.traceparent) .into() } @@ -301,5 +301,5 @@ pub(crate) fn delete_relation_table_records( Delete::from_table(relation.as_table(ctx)) .so_that(parent_id_criteria.and(child_id_criteria)) .append_trace(&Span::current()) - .add_trace_id(ctx.trace_id) + .add_traceparent(ctx.traceparent) } diff --git a/query-engine/connectors/sql-query-connector/src/query_ext.rs b/query-engine/connectors/sql-query-connector/src/query_ext.rs index c0d511f9e6d8..a843f4a1525c 100644 --- a/query-engine/connectors/sql-query-connector/src/query_ext.rs +++ b/query-engine/connectors/sql-query-connector/src/query_ext.rs @@ -25,18 +25,18 @@ impl QueryExt for Q { idents: &[ColumnMetadata<'_>], ctx: &Context<'_>, ) -> crate::Result> { - let span = info_span!("filter read query"); + let span = info_span!("prisma:engine:filter_read_query"); let otel_ctx = span.context(); let span_ref = otel_ctx.span(); let span_ctx = span_ref.span_context(); - let q = match (q, ctx.trace_id) { + let q = match (q, ctx.traceparent) { (Query::Select(x), _) if span_ctx.trace_flags() == TraceFlags::SAMPLED => { Query::Select(Box::from(x.comment(trace_parent_to_string(span_ctx)))) } // This is part of the required changes to pass a traceid - (Query::Select(x), trace_id) => Query::Select(Box::from(x.add_trace_id(trace_id))), + (Query::Select(x), traceparent) => Query::Select(Box::from(x.add_traceparent(traceparent))), (q, _) => q, }; @@ -119,7 +119,7 @@ impl QueryExt for Q { let select = Select::from_table(model.as_table(ctx)) .columns(id_cols) .append_trace(&Span::current()) - .add_trace_id(ctx.trace_id) + .add_traceparent(ctx.traceparent) .so_that(condition); self.select_ids(select, model_id, ctx).await diff --git a/query-engine/connectors/sql-query-connector/src/ser_raw.rs b/query-engine/connectors/sql-query-connector/src/ser_raw.rs index 3c80b91d34ee..bbb23735704f 100644 --- a/query-engine/connectors/sql-query-connector/src/ser_raw.rs +++ b/query-engine/connectors/sql-query-connector/src/ser_raw.rs @@ -1,5 +1,5 @@ use quaint::{ - connector::{ResultRowRef, ResultSet}, + connector::{ColumnType, ResultRowRef, ResultSet}, Value, ValueType, }; use serde::{ser::*, Serialize, Serializer}; @@ -9,7 +9,7 @@ pub struct SerializedResultSet(pub ResultSet); #[derive(Debug, Serialize)] struct InnerSerializedResultSet<'a> { columns: SerializedColumns<'a>, - types: &'a SerializedTypes, + types: SerializedTypes<'a>, rows: SerializedRows<'a>, } @@ -22,7 +22,7 @@ impl serde::Serialize for SerializedResultSet { InnerSerializedResultSet { columns: SerializedColumns(this), - types: &SerializedTypes::new(this), + types: SerializedTypes(this), rows: SerializedRows(this), } .serialize(serializer) @@ -39,75 +39,65 @@ impl<'a> Serialize for SerializedColumns<'a> { { let this = &self.0; - if this.is_empty() { - return this.columns().serialize(serializer); - } - - let first_row = this.first().unwrap(); - - let mut seq = serializer.serialize_seq(Some(first_row.len()))?; + this.columns().serialize(serializer) + } +} - for (idx, _) in first_row.iter().enumerate() { - if let Some(column_name) = this.columns().get(idx) { - seq.serialize_element(column_name)?; - } else { - // `query_raw` does not return column names in `ResultSet` when a call to a stored procedure is done - // See https://github.com/prisma/prisma/issues/6173 - seq.serialize_element(&format!("f{idx}"))?; +#[derive(Debug)] +struct SerializedTypes<'a>(&'a ResultSet); + +impl<'a> SerializedTypes<'a> { + fn infer_unknown_column_types(&self) -> Vec { + let rows = self.0; + + let mut types = rows.types().to_owned(); + // Find all the unknown column types to avoid unnecessary iterations. + let unknown_indexes = rows + .types() + .iter() + .enumerate() + .filter_map(|(idx, ty)| match ty.is_unknown() { + true => Some(idx), + false => None, + }); + + for unknown_idx in unknown_indexes { + // While quaint already infers `ColumnType`s from the database, it can still have ColumnType::Unknown. + // In this case, we try to infer the types from the actual response data. + for row in self.0.iter() { + let current_type = types[unknown_idx]; + let inferred_type = ColumnType::from(&row[unknown_idx]); + + if current_type.is_unknown() && !inferred_type.is_unknown() { + types[unknown_idx] = inferred_type; + break; + } } } - seq.end() + if !self.0.is_empty() { + // Client doesn't know how to handle unknown types. + assert!(!types.contains(&ColumnType::Unknown)); + } + + types } } -#[derive(Debug, Serialize)] -#[serde(transparent)] -struct SerializedTypes(Vec); - -impl SerializedTypes { - fn new(rows: &ResultSet) -> Self { - if rows.is_empty() { - return Self(Vec::with_capacity(0)); - } +impl Serialize for SerializedTypes<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let types = self.infer_unknown_column_types(); - let row_len = rows.first().unwrap().len(); - let mut types = vec![SerializedValueType::Unknown; row_len]; - let mut types_found = 0; - - // This attempts to infer types based on `quaint::Value` present in the rows. - // We need to go through every row because because empty and null arrays don't encode their inner type. - // In the best case scenario, this loop stops at the first row. - // In the worst case scenario, it'll keep looping until it finds an array with a non-null value. - 'outer: for row in rows.iter() { - for (idx, value) in row.iter().enumerate() { - let current_type = types[idx]; - - if matches!( - current_type, - SerializedValueType::Unknown | SerializedValueType::UnknownArray - ) { - let inferred_type = SerializedValueType::infer_from(value); - - if inferred_type != SerializedValueType::Unknown && inferred_type != current_type { - types[idx] = inferred_type; - - if inferred_type != SerializedValueType::UnknownArray { - types_found += 1; - } - } - } + let mut seq = serializer.serialize_seq(Some(types.len()))?; - if types_found == row_len { - break 'outer; - } - } + for column_type in types { + seq.serialize_element(&column_type.to_string())?; } - // Client doesn't know how to handle unknown types. - assert!(!types.contains(&SerializedValueType::Unknown)); - - Self(types) + seq.end() } } @@ -200,361 +190,3 @@ impl<'a> Serialize for SerializedValue<'a> { } } } - -#[derive(Debug, Copy, Clone, PartialEq, Serialize)] -enum SerializedValueType { - #[serde(rename = "int")] - Int32, - #[serde(rename = "bigint")] - Int64, - #[serde(rename = "float")] - Float, - #[serde(rename = "double")] - Double, - #[serde(rename = "string")] - Text, - #[serde(rename = "enum")] - Enum, - #[serde(rename = "bytes")] - Bytes, - #[serde(rename = "bool")] - Boolean, - #[serde(rename = "char")] - Char, - #[serde(rename = "decimal")] - Numeric, - #[serde(rename = "json")] - Json, - #[serde(rename = "xml")] - Xml, - #[serde(rename = "uuid")] - Uuid, - #[serde(rename = "datetime")] - DateTime, - #[serde(rename = "date")] - Date, - #[serde(rename = "time")] - Time, - - #[serde(rename = "int-array")] - Int32Array, - #[serde(rename = "bigint-array")] - Int64Array, - #[serde(rename = "float-array")] - FloatArray, - #[serde(rename = "double-array")] - DoubleArray, - #[serde(rename = "string-array")] - TextArray, - #[serde(rename = "bytes-array")] - BytesArray, - #[serde(rename = "bool-array")] - BooleanArray, - #[serde(rename = "char-array")] - CharArray, - #[serde(rename = "decimal-array")] - NumericArray, - #[serde(rename = "json-array")] - JsonArray, - #[serde(rename = "xml-array")] - XmlArray, - #[serde(rename = "uuid-array")] - UuidArray, - #[serde(rename = "datetime-array")] - DateTimeArray, - #[serde(rename = "date-array")] - DateArray, - #[serde(rename = "time-array")] - TimeArray, - - #[serde(rename = "unknown-array")] - UnknownArray, - - #[serde(rename = "unknown")] - Unknown, -} - -impl SerializedValueType { - fn infer_from(value: &Value) -> SerializedValueType { - match &value.typed { - ValueType::Int32(_) => SerializedValueType::Int32, - ValueType::Int64(_) => SerializedValueType::Int64, - ValueType::Float(_) => SerializedValueType::Float, - ValueType::Double(_) => SerializedValueType::Double, - ValueType::Text(_) => SerializedValueType::Text, - ValueType::Enum(_, _) => SerializedValueType::Enum, - ValueType::EnumArray(_, _) => SerializedValueType::TextArray, - ValueType::Bytes(_) => SerializedValueType::Bytes, - ValueType::Boolean(_) => SerializedValueType::Boolean, - ValueType::Char(_) => SerializedValueType::Char, - ValueType::Numeric(_) => SerializedValueType::Numeric, - ValueType::Json(_) => SerializedValueType::Json, - ValueType::Xml(_) => SerializedValueType::Xml, - ValueType::Uuid(_) => SerializedValueType::Uuid, - ValueType::DateTime(_) => SerializedValueType::DateTime, - ValueType::Date(_) => SerializedValueType::Date, - ValueType::Time(_) => SerializedValueType::Time, - - ValueType::Array(Some(values)) => { - if values.is_empty() { - return SerializedValueType::UnknownArray; - } - - match &values[0].typed { - ValueType::Int32(_) => SerializedValueType::Int32Array, - ValueType::Int64(_) => SerializedValueType::Int64Array, - ValueType::Float(_) => SerializedValueType::FloatArray, - ValueType::Double(_) => SerializedValueType::DoubleArray, - ValueType::Text(_) => SerializedValueType::TextArray, - ValueType::Bytes(_) => SerializedValueType::BytesArray, - ValueType::Boolean(_) => SerializedValueType::BooleanArray, - ValueType::Char(_) => SerializedValueType::CharArray, - ValueType::Numeric(_) => SerializedValueType::NumericArray, - ValueType::Json(_) => SerializedValueType::JsonArray, - ValueType::Xml(_) => SerializedValueType::XmlArray, - ValueType::Uuid(_) => SerializedValueType::UuidArray, - ValueType::DateTime(_) => SerializedValueType::DateTimeArray, - ValueType::Date(_) => SerializedValueType::DateArray, - ValueType::Time(_) => SerializedValueType::TimeArray, - ValueType::Enum(_, _) => SerializedValueType::TextArray, - ValueType::Array(_) | ValueType::EnumArray(_, _) => { - unreachable!("Only PG supports scalar lists and tokio-postgres does not support 2d arrays") - } - } - } - ValueType::Array(None) => SerializedValueType::UnknownArray, - } - } -} - -#[cfg(test)] -mod tests { - use super::SerializedResultSet; - use bigdecimal::BigDecimal; - use chrono::{DateTime, Utc}; - use expect_test::expect; - use quaint::{ - ast::{EnumName, EnumVariant}, - connector::ResultSet, - Value, - }; - use std::str::FromStr; - - #[test] - fn serialize_result_set() { - let names = vec![ - "int32".to_string(), - "int64".to_string(), - "float".to_string(), - "double".to_string(), - "text".to_string(), - "enum".to_string(), - "bytes".to_string(), - "boolean".to_string(), - "char".to_string(), - "numeric".to_string(), - "json".to_string(), - "xml".to_string(), - "uuid".to_string(), - "datetime".to_string(), - "date".to_string(), - "time".to_string(), - "intArray".to_string(), - ]; - let rows = vec![vec![ - Value::int32(42), - Value::int64(42), - Value::float(42.523), - Value::double(42.523), - Value::text("heLlo"), - Value::enum_variant_with_name("Red", EnumName::new("Color", Option::::None)), - Value::bytes(b"hello".to_vec()), - Value::boolean(true), - Value::character('c'), - Value::numeric(BigDecimal::from_str("123456789.123456789").unwrap()), - Value::json(serde_json::json!({"hello": "world"})), - Value::xml("world"), - Value::uuid(uuid::Uuid::from_str("550e8400-e29b-41d4-a716-446655440000").unwrap()), - Value::datetime( - chrono::DateTime::parse_from_rfc3339("2021-01-01T02:00:00Z") - .map(DateTime::::from) - .unwrap(), - ), - Value::date(chrono::NaiveDate::from_ymd_opt(2021, 1, 1).unwrap()), - Value::time(chrono::NaiveTime::from_hms_opt(2, 0, 0).unwrap()), - Value::array(vec![Value::int32(42), Value::int32(42)]), - ]]; - let result_set = ResultSet::new(names, rows); - - let serialized = serde_json::to_string_pretty(&SerializedResultSet(result_set)).unwrap(); - - let expected = expect![[r#" - { - "columns": [ - "int32", - "int64", - "float", - "double", - "text", - "enum", - "bytes", - "boolean", - "char", - "numeric", - "json", - "xml", - "uuid", - "datetime", - "date", - "time", - "intArray" - ], - "types": [ - "int", - "bigint", - "float", - "double", - "string", - "enum", - "bytes", - "bool", - "char", - "decimal", - "json", - "xml", - "uuid", - "datetime", - "date", - "time", - "int-array" - ], - "rows": [ - [ - 42, - "42", - 42.523, - 42.523, - "heLlo", - "Red", - "aGVsbG8=", - true, - "c", - "123456789.123456789", - { - "hello": "world" - }, - "world", - "550e8400-e29b-41d4-a716-446655440000", - "2021-01-01T02:00:00+00:00", - "2021-01-01", - "02:00:00", - [ - 42, - 42 - ] - ] - ] - }"#]]; - - expected.assert_eq(&serialized); - } - - #[test] - fn serialize_empty_result_set() { - let names = vec!["hello".to_string()]; - let result_set = ResultSet::new(names, vec![]); - - let serialized = serde_json::to_string_pretty(&SerializedResultSet(result_set)).unwrap(); - - let expected = expect![[r#" - { - "columns": [ - "hello" - ], - "types": [], - "rows": [] - }"#]]; - - expected.assert_eq(&serialized) - } - - #[test] - fn serialize_arrays() { - let names = vec!["array".to_string()]; - let rows = vec![ - vec![Value::null_array()], - vec![Value::array(vec![Value::int32(42), Value::int64(42)])], - vec![Value::array(vec![Value::text("heLlo"), Value::null_text()])], - ]; - let result_set = ResultSet::new(names, rows); - - let serialized = serde_json::to_string_pretty(&SerializedResultSet(result_set)).unwrap(); - - let expected = expect![[r#" - { - "columns": [ - "array" - ], - "types": [ - "int-array" - ], - "rows": [ - [ - null - ], - [ - [ - 42, - "42" - ] - ], - [ - [ - "heLlo", - null - ] - ] - ] - }"#]]; - - expected.assert_eq(&serialized); - } - - #[test] - fn serialize_enum_array() { - let names = vec!["array".to_string()]; - let rows = vec![ - vec![Value::enum_array_with_name( - vec![EnumVariant::new("A"), EnumVariant::new("B")], - EnumName::new("Alphabet", Some("foo")), - )], - vec![Value::null_enum_array()], - ]; - let result_set = ResultSet::new(names, rows); - - let serialized = serde_json::to_string_pretty(&SerializedResultSet(result_set)).unwrap(); - - let expected = expect![[r#" - { - "columns": [ - "array" - ], - "types": [ - "string-array" - ], - "rows": [ - [ - [ - "A", - "B" - ] - ], - [ - null - ] - ] - }"#]]; - - expected.assert_eq(&serialized); - } -} diff --git a/query-engine/connectors/sql-query-connector/src/sql_trace.rs b/query-engine/connectors/sql-query-connector/src/sql_trace.rs index bffaf1174311..4fa88a64d2e6 100644 --- a/query-engine/connectors/sql-query-connector/src/sql_trace.rs +++ b/query-engine/connectors/sql-query-connector/src/sql_trace.rs @@ -1,5 +1,6 @@ use opentelemetry::trace::{SpanContext, TraceContextExt, TraceFlags}; use quaint::ast::{Delete, Insert, Select, Update}; +use telemetry::helpers::TraceParent; use tracing::Span; use tracing_opentelemetry::OpenTelemetrySpanExt; @@ -8,12 +9,12 @@ pub fn trace_parent_to_string(context: &SpanContext) -> String { let span_id = context.span_id(); // see https://www.w3.org/TR/trace-context/#traceparent-header-field-values - format!("traceparent='00-{trace_id:032x}-{span_id:032x}-01'") + format!("traceparent='00-{trace_id:032x}-{span_id:016x}-01'") } pub trait SqlTraceComment: Sized { fn append_trace(self, span: &Span) -> Self; - fn add_trace_id(self, trace_id: Option<&str>) -> Self; + fn add_traceparent(self, traceparent: Option) -> Self; } macro_rules! sql_trace { @@ -30,14 +31,15 @@ macro_rules! sql_trace { self } } + // Temporary method to pass the traceid in an operation - fn add_trace_id(self, trace_id: Option<&str>) -> Self { - if let Some(traceparent) = trace_id { - if should_sample(&traceparent) { - self.comment(format!("traceparent='{}'", traceparent)) - } else { - self - } + fn add_traceparent(self, traceparent: Option) -> Self { + let Some(traceparent) = traceparent else { + return self; + }; + + if traceparent.sampled() { + self.comment(format!("traceparent='{traceparent}'")) } else { self } @@ -46,10 +48,6 @@ macro_rules! sql_trace { }; } -fn should_sample(traceparent: &str) -> bool { - traceparent.split('-').count() == 4 && traceparent.ends_with("-01") -} - sql_trace!(Insert<'_>); sql_trace!(Update<'_>); diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index bd5df5ca166e..6005b091f55a 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -4,7 +4,7 @@ name = "query-core" version = "0.1.0" [features] -metrics = ["query-engine-metrics"] +metrics = ["prisma-metrics"] graphql-protocol = [] [dependencies] @@ -15,7 +15,7 @@ connection-string.workspace = true connector = { path = "../connectors/query-connector", package = "query-connector" } crossbeam-channel = "0.5.6" psl.workspace = true -futures = "0.3" +futures.workspace = true indexmap.workspace = true itertools.workspace = true once_cell = "1" @@ -24,13 +24,13 @@ query-structure = { path = "../query-structure", features = [ "default_generators", ] } opentelemetry = { version = "0.17.0", features = ["rt-tokio", "serialize"] } -query-engine-metrics = { path = "../metrics", optional = true } +prisma-metrics = { path = "../../libs/metrics", optional = true } serde.workspace = true serde_json.workspace = true thiserror = "1.0" -tokio = { version = "1.0", features = ["macros", "time"] } +tokio = { version = "1", features = ["macros", "time"] } tracing = { workspace = true, features = ["attributes"] } -tracing-futures = "0.2" +tracing-futures.workspace = true tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-opentelemetry = "0.17.4" user-facing-errors = { path = "../../libs/user-facing-errors" } @@ -38,5 +38,7 @@ uuid.workspace = true cuid = { git = "https://github.com/prisma/cuid-rust", branch = "wasm32-support" } schema = { path = "../schema" } crosstarget-utils = { path = "../../libs/crosstarget-utils" } +telemetry = { path = "../../libs/telemetry" } lru = "0.7.7" enumflags2.workspace = true +derive_more.workspace = true diff --git a/query-engine/core/src/error.rs b/query-engine/core/src/error.rs index 3a3803bf0d67..b067a325a4a5 100644 --- a/query-engine/core/src/error.rs +++ b/query-engine/core/src/error.rs @@ -4,6 +4,8 @@ use query_structure::DomainError; use thiserror::Error; use user_facing_errors::UnknownError; +use crate::response_ir::{Item, Map}; + #[derive(Debug, Error)] #[error( "Error converting field \"{field}\" of expected non-nullable type \"{expected_type}\", found incompatible value of \"{found}\"." @@ -62,6 +64,9 @@ pub enum CoreError { #[error("Error in batch request {request_idx}: {error}")] BatchError { request_idx: usize, error: Box }, + + #[error("Query timed out")] + QueryTimeout, } impl CoreError { @@ -227,3 +232,27 @@ impl From for user_facing_errors::Error { } } } + +#[derive(Debug, serde::Serialize, PartialEq)] +pub struct ExtendedUserFacingError { + #[serde(flatten)] + user_facing_error: user_facing_errors::Error, + + #[serde(skip_serializing_if = "indexmap::IndexMap::is_empty")] + extensions: Map, +} + +impl ExtendedUserFacingError { + pub fn set_extension(&mut self, key: String, val: serde_json::Value) { + self.extensions.entry(key).or_insert(Item::Json(val)); + } +} + +impl From for ExtendedUserFacingError { + fn from(error: CoreError) -> Self { + ExtendedUserFacingError { + user_facing_error: error.into(), + extensions: Default::default(), + } + } +} diff --git a/query-engine/core/src/executor/execute_operation.rs b/query-engine/core/src/executor/execute_operation.rs index 6ef445d8364e..986741182b93 100644 --- a/query-engine/core/src/executor/execute_operation.rs +++ b/query-engine/core/src/executor/execute_operation.rs @@ -10,13 +10,17 @@ use connector::{Connection, ConnectionLike, Connector}; use crosstarget_utils::time::ElapsedTimeCounter; use futures::future; +#[cfg(not(feature = "metrics"))] +use crate::metrics::MetricsInstrumentationStub; #[cfg(feature = "metrics")] -use query_engine_metrics::{ - histogram, increment_counter, metrics, PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, PRISMA_CLIENT_QUERIES_TOTAL, +use prisma_metrics::{ + counter, histogram, WithMetricsInstrumentation, PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, + PRISMA_CLIENT_QUERIES_TOTAL, }; use schema::{QuerySchema, QuerySchemaRef}; use std::time::Duration; +use telemetry::helpers::TraceParent; use tracing::Instrument; use tracing_futures::WithSubscriber; @@ -24,18 +28,15 @@ pub async fn execute_single_operation( query_schema: QuerySchemaRef, conn: &mut dyn ConnectionLike, operation: &Operation, - trace_id: Option, + traceparent: Option, ) -> crate::Result { let operation_timer = ElapsedTimeCounter::start(); let (graph, serializer) = build_graph(&query_schema, operation.clone())?; - let result = execute_on(conn, graph, serializer, query_schema.as_ref(), trace_id).await; + let result = execute_on(conn, graph, serializer, query_schema.as_ref(), traceparent).await; #[cfg(feature = "metrics")] - histogram!( - PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, - operation_timer.elapsed_time() - ); + histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS).record(operation_timer.elapsed_time()); result } @@ -44,7 +45,7 @@ pub async fn execute_many_operations( query_schema: QuerySchemaRef, conn: &mut dyn ConnectionLike, operations: &[Operation], - trace_id: Option, + traceparent: Option, ) -> crate::Result>> { let queries = operations .iter() @@ -55,13 +56,10 @@ pub async fn execute_many_operations( for (i, (graph, serializer)) in queries.into_iter().enumerate() { let operation_timer = ElapsedTimeCounter::start(); - let result = execute_on(conn, graph, serializer, query_schema.as_ref(), trace_id.clone()).await; + let result = execute_on(conn, graph, serializer, query_schema.as_ref(), traceparent).await; #[cfg(feature = "metrics")] - histogram!( - PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, - operation_timer.elapsed_time() - ); + histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS).record(operation_timer.elapsed_time()); match result { Ok(result) => results.push(Ok(result)), @@ -81,13 +79,13 @@ pub async fn execute_single_self_contained( connector: &C, query_schema: QuerySchemaRef, operation: Operation, - trace_id: Option, + traceparent: Option, force_transactions: bool, ) -> crate::Result { let conn_span = info_span!( "prisma:engine:connection", user_facing = true, - "db.type" = connector.name() + "db.system" = connector.name(), ); let conn = connector.get_connection().instrument(conn_span).await?; @@ -97,7 +95,7 @@ pub async fn execute_single_self_contained( operation, force_transactions, connector.should_retry_on_transient_error(), - trace_id, + traceparent, ) .await } @@ -106,21 +104,20 @@ pub async fn execute_many_self_contained( connector: &C, query_schema: QuerySchemaRef, operations: &[Operation], - trace_id: Option, + traceparent: Option, force_transactions: bool, engine_protocol: EngineProtocol, ) -> crate::Result>> { let mut futures = Vec::with_capacity(operations.len()); - let dispatcher = crate::get_current_dispatcher(); for op in operations { #[cfg(feature = "metrics")] - increment_counter!(PRISMA_CLIENT_QUERIES_TOTAL); + counter!(PRISMA_CLIENT_QUERIES_TOTAL).increment(1); let conn_span = info_span!( "prisma:engine:connection", user_facing = true, - "db.type" = connector.name(), + "db.system" = connector.name(), ); let conn = connector.get_connection().instrument(conn_span).await?; @@ -133,10 +130,11 @@ pub async fn execute_many_self_contained( op.clone(), force_transactions, connector.should_retry_on_transient_error(), - trace_id.clone(), + traceparent, ), ) - .with_subscriber(dispatcher.clone()), + .with_current_subscriber() + .with_current_recorder(), )); } @@ -156,7 +154,7 @@ async fn execute_self_contained( operation: Operation, force_transactions: bool, retry_on_transient_error: bool, - trace_id: Option, + traceparent: Option, ) -> crate::Result { let operation_timer = ElapsedTimeCounter::start(); let result = if retry_on_transient_error { @@ -166,20 +164,18 @@ async fn execute_self_contained( operation, force_transactions, ElapsedTimeCounter::start(), - trace_id, + traceparent, ) .await } else { let (graph, serializer) = build_graph(&query_schema, operation)?; - execute_self_contained_without_retry(conn, graph, serializer, force_transactions, &query_schema, trace_id).await + execute_self_contained_without_retry(conn, graph, serializer, force_transactions, &query_schema, traceparent) + .await }; #[cfg(feature = "metrics")] - histogram!( - PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, - operation_timer.elapsed_time() - ); + histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS).record(operation_timer.elapsed_time()); result } @@ -190,13 +186,13 @@ async fn execute_self_contained_without_retry<'a>( serializer: IrSerializer<'a>, force_transactions: bool, query_schema: &'a QuerySchema, - trace_id: Option, + traceparent: Option, ) -> crate::Result { if force_transactions || graph.needs_transaction() { - return execute_in_tx(&mut conn, graph, serializer, query_schema, trace_id).await; + return execute_in_tx(&mut conn, graph, serializer, query_schema, traceparent).await; } - execute_on(conn.as_connection_like(), graph, serializer, query_schema, trace_id).await + execute_on(conn.as_connection_like(), graph, serializer, query_schema, traceparent).await } // As suggested by the MongoDB documentation @@ -212,12 +208,12 @@ async fn execute_self_contained_with_retry( operation: Operation, force_transactions: bool, retry_timeout: ElapsedTimeCounter, - trace_id: Option, + traceparent: Option, ) -> crate::Result { let (graph, serializer) = build_graph(&query_schema, operation.clone())?; if force_transactions || graph.needs_transaction() { - let res = execute_in_tx(conn, graph, serializer, query_schema.as_ref(), trace_id.clone()).await; + let res = execute_in_tx(conn, graph, serializer, query_schema.as_ref(), traceparent).await; if !is_transient_error(&res) { return res; @@ -225,7 +221,7 @@ async fn execute_self_contained_with_retry( loop { let (graph, serializer) = build_graph(&query_schema, operation.clone())?; - let res = execute_in_tx(conn, graph, serializer, query_schema.as_ref(), trace_id.clone()).await; + let res = execute_in_tx(conn, graph, serializer, query_schema.as_ref(), traceparent).await; if is_transient_error(&res) && retry_timeout.elapsed_time() < MAX_TX_TIMEOUT_RETRY_LIMIT { crosstarget_utils::time::sleep(TX_RETRY_BACKOFF).await; @@ -240,7 +236,7 @@ async fn execute_self_contained_with_retry( graph, serializer, query_schema.as_ref(), - trace_id, + traceparent, ) .await } @@ -251,17 +247,10 @@ async fn execute_in_tx<'a>( graph: QueryGraph, serializer: IrSerializer<'a>, query_schema: &'a QuerySchema, - trace_id: Option, + traceparent: Option, ) -> crate::Result { let mut tx = conn.start_transaction(None).await?; - let result = execute_on( - tx.as_connection_like(), - graph, - serializer, - query_schema, - trace_id.clone(), - ) - .await; + let result = execute_on(tx.as_connection_like(), graph, serializer, query_schema, traceparent).await; if result.is_ok() { tx.commit().await?; @@ -278,14 +267,14 @@ async fn execute_on<'a>( graph: QueryGraph, serializer: IrSerializer<'a>, query_schema: &'a QuerySchema, - trace_id: Option, + traceparent: Option, ) -> crate::Result { #[cfg(feature = "metrics")] - increment_counter!(PRISMA_CLIENT_QUERIES_TOTAL); + counter!(PRISMA_CLIENT_QUERIES_TOTAL).increment(1); let interpreter = QueryInterpreter::new(conn); QueryPipeline::new(graph, interpreter, serializer) - .execute(query_schema, trace_id) + .execute(query_schema, traceparent) .await } diff --git a/query-engine/core/src/executor/interpreting_executor.rs b/query-engine/core/src/executor/interpreting_executor.rs index 0408361b766d..2e391461c718 100644 --- a/query-engine/core/src/executor/interpreting_executor.rs +++ b/query-engine/core/src/executor/interpreting_executor.rs @@ -1,13 +1,15 @@ use super::execute_operation::{execute_many_operations, execute_many_self_contained, execute_single_self_contained}; use super::request_context; +use crate::ItxManager; use crate::{ protocol::EngineProtocol, BatchDocumentTransaction, CoreError, Operation, QueryExecutor, ResponseData, - TransactionActorManager, TransactionError, TransactionManager, TransactionOptions, TxId, + TransactionError, TransactionManager, TransactionOptions, TxId, }; use async_trait::async_trait; use connector::Connector; use schema::QuerySchemaRef; +use telemetry::helpers::TraceParent; use tokio::time::Duration; use tracing_futures::Instrument; @@ -16,7 +18,7 @@ pub struct InterpretingExecutor { /// The loaded connector connector: C, - itx_manager: TransactionActorManager, + itx_manager: ItxManager, /// Flag that forces individual operations to run in a transaction. /// Does _not_ force batches to use transactions. @@ -31,7 +33,7 @@ where InterpretingExecutor { connector, force_transactions, - itx_manager: TransactionActorManager::new(), + itx_manager: ItxManager::new(), } } } @@ -48,25 +50,24 @@ where tx_id: Option, operation: Operation, query_schema: QuerySchemaRef, - trace_id: Option, + traceparent: Option, engine_protocol: EngineProtocol, ) -> crate::Result { - // If a Tx id is provided, execute on that one. Else execute normally as a single operation. - if let Some(tx_id) = tx_id { - self.itx_manager.execute(&tx_id, operation, trace_id).await - } else { - request_context::with_request_context(engine_protocol, async move { + request_context::with_request_context(engine_protocol, async move { + if let Some(tx_id) = tx_id { + self.itx_manager.execute(&tx_id, operation, traceparent).await + } else { execute_single_self_contained( &self.connector, query_schema, operation, - trace_id, + traceparent, self.force_transactions, ) .await - }) - .await - } + } + }) + .await } /// Executes a batch of operations. @@ -87,53 +88,50 @@ where operations: Vec, transaction: Option, query_schema: QuerySchemaRef, - trace_id: Option, + traceparent: Option, engine_protocol: EngineProtocol, ) -> crate::Result>> { - if let Some(tx_id) = tx_id { - let batch_isolation_level = transaction.and_then(|t| t.isolation_level()); - if batch_isolation_level.is_some() { - return Err(CoreError::UnsupportedFeatureError( - "Can not set batch isolation level within interactive transaction".into(), - )); - } - self.itx_manager.batch_execute(&tx_id, operations, trace_id).await - } else if let Some(transaction) = transaction { - let conn_span = info_span!( - "prisma:engine:connection", - user_facing = true, - "db.type" = self.connector.name(), - ); - let mut conn = self.connector.get_connection().instrument(conn_span).await?; - let mut tx = conn.start_transaction(transaction.isolation_level()).await?; - - let results = request_context::with_request_context( - engine_protocol, - execute_many_operations(query_schema, tx.as_connection_like(), &operations, trace_id), - ) - .await; - - if results.is_err() { - tx.rollback().await?; + request_context::with_request_context(engine_protocol, async move { + if let Some(tx_id) = tx_id { + let batch_isolation_level = transaction.and_then(|t| t.isolation_level()); + if batch_isolation_level.is_some() { + return Err(CoreError::UnsupportedFeatureError( + "Can not set batch isolation level within interactive transaction".into(), + )); + } + self.itx_manager.batch_execute(&tx_id, operations, traceparent).await + } else if let Some(transaction) = transaction { + let conn_span = info_span!( + "prisma:engine:connection", + user_facing = true, + "db.system" = self.connector.name(), + ); + let mut conn = self.connector.get_connection().instrument(conn_span).await?; + let mut tx = conn.start_transaction(transaction.isolation_level()).await?; + + let results = + execute_many_operations(query_schema, tx.as_connection_like(), &operations, traceparent).await; + + if results.is_err() { + tx.rollback().await?; + } else { + tx.commit().await?; + } + + results } else { - tx.commit().await?; - } - - results - } else { - request_context::with_request_context(engine_protocol, async move { execute_many_self_contained( &self.connector, query_schema, &operations, - trace_id, + traceparent, self.force_transactions, engine_protocol, ) .await - }) - .await - } + } + }) + .await } fn primary_connector(&self) -> &(dyn Connector + Send + Sync) { @@ -158,11 +156,10 @@ where let valid_for_millis = tx_opts.valid_for_millis; let id = tx_opts.new_tx_id.unwrap_or_default(); - trace!("[{}] Starting...", id); let conn_span = info_span!( "prisma:engine:connection", user_facing = true, - "db.type" = self.connector.name() + "db.system" = self.connector.name() ); let conn = crosstarget_utils::time::timeout( Duration::from_millis(tx_opts.max_acquisition_millis), @@ -180,23 +177,19 @@ where conn, isolation_level, Duration::from_millis(valid_for_millis), - engine_protocol, ) .await?; - debug!("[{}] Started.", id); Ok(id) }) .await } async fn commit_tx(&self, tx_id: TxId) -> crate::Result<()> { - trace!("[{}] Committing.", tx_id); self.itx_manager.commit_tx(&tx_id).await } async fn rollback_tx(&self, tx_id: TxId) -> crate::Result<()> { - trace!("[{}] Rolling back.", tx_id); self.itx_manager.rollback_tx(&tx_id).await } } diff --git a/query-engine/core/src/executor/mod.rs b/query-engine/core/src/executor/mod.rs index fee7bc68fe7b..c7846f7ff7cb 100644 --- a/query-engine/core/src/executor/mod.rs +++ b/query-engine/core/src/executor/mod.rs @@ -14,6 +14,7 @@ mod request_context; pub use self::{execute_operation::*, interpreting_executor::InterpretingExecutor}; pub(crate) use request_context::*; +use telemetry::helpers::TraceParent; use crate::{ protocol::EngineProtocol, query_document::Operation, response_ir::ResponseData, schema::QuerySchemaRef, @@ -22,7 +23,6 @@ use crate::{ use async_trait::async_trait; use connector::Connector; use serde::{Deserialize, Serialize}; -use tracing::Dispatch; #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] #[cfg_attr(not(target_arch = "wasm32"), async_trait)] @@ -35,7 +35,7 @@ pub trait QueryExecutor: TransactionManager { tx_id: Option, operation: Operation, query_schema: QuerySchemaRef, - trace_id: Option, + traceparent: Option, engine_protocol: EngineProtocol, ) -> crate::Result; @@ -51,7 +51,7 @@ pub trait QueryExecutor: TransactionManager { operations: Vec, transaction: Option, query_schema: QuerySchemaRef, - trace_id: Option, + traceparent: Option, engine_protocol: EngineProtocol, ) -> crate::Result>>; @@ -89,10 +89,10 @@ impl TransactionOptions { /// Generates a new transaction id before the transaction is started and returns a modified version /// of self with the new predefined_id set. - pub fn with_new_transaction_id(&mut self) -> TxId { - let tx_id: TxId = Default::default(); + pub fn with_new_transaction_id(mut self) -> Self { + let tx_id = TxId::default(); self.new_tx_id = Some(tx_id.clone()); - tx_id + self } } #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] @@ -116,20 +116,3 @@ pub trait TransactionManager { /// Rolls back a transaction. async fn rollback_tx(&self, tx_id: TxId) -> crate::Result<()>; } - -// With the node-api when a future is spawned in a new thread `tokio:spawn` it will not -// use the current dispatcher and its logs will not be captured anymore. We can use this -// method to get the current dispatcher and combine it with `with_subscriber` -// let dispatcher = get_current_dispatcher(); -// tokio::spawn(async { -// my_async_ops.await -// }.with_subscriber(dispatcher)); -// -// -// Finally, this can be replaced with with_current_collector -// https://github.com/tokio-rs/tracing/blob/master/tracing-futures/src/lib.rs#L234 -// once this is in a release - -pub fn get_current_dispatcher() -> Dispatch { - tracing::dispatcher::get_default(|current| current.clone()) -} diff --git a/query-engine/core/src/executor/pipeline.rs b/query-engine/core/src/executor/pipeline.rs index 2193410a57e1..bd1ba73d5e8b 100644 --- a/query-engine/core/src/executor/pipeline.rs +++ b/query-engine/core/src/executor/pipeline.rs @@ -1,5 +1,6 @@ use crate::{Env, Expressionista, IrSerializer, QueryGraph, QueryInterpreter, ResponseData}; use schema::QuerySchema; +use telemetry::helpers::TraceParent; use tracing::Instrument; #[derive(Debug)] @@ -25,7 +26,7 @@ impl<'conn, 'schema> QueryPipeline<'conn, 'schema> { pub(crate) async fn execute( mut self, query_schema: &'schema QuerySchema, - trace_id: Option, + traceparent: Option, ) -> crate::Result { let serializer = self.serializer; let expr = Expressionista::translate(self.graph)?; @@ -34,7 +35,7 @@ impl<'conn, 'schema> QueryPipeline<'conn, 'schema> { let result = self .interpreter - .interpret(expr, Env::default(), 0, trace_id) + .interpret(expr, Env::default(), 0, traceparent) .instrument(span) .await; diff --git a/query-engine/core/src/interactive_transactions/actor_manager.rs b/query-engine/core/src/interactive_transactions/actor_manager.rs deleted file mode 100644 index e6c1c7fbd1dc..000000000000 --- a/query-engine/core/src/interactive_transactions/actor_manager.rs +++ /dev/null @@ -1,160 +0,0 @@ -use crate::{protocol::EngineProtocol, ClosedTx, Operation, ResponseData}; -use connector::Connection; -use crosstarget_utils::task::JoinHandle; -use lru::LruCache; -use once_cell::sync::Lazy; -use schema::QuerySchemaRef; -use std::{collections::HashMap, sync::Arc}; -use tokio::{ - sync::{ - mpsc::{channel, Sender}, - RwLock, - }, - time::Duration, -}; - -use super::{spawn_client_list_clear_actor, spawn_itx_actor, ITXClient, TransactionError, TxId}; - -pub static CLOSED_TX_CACHE_SIZE: Lazy = Lazy::new(|| match std::env::var("CLOSED_TX_CACHE_SIZE") { - Ok(size) => size.parse().unwrap_or(100), - Err(_) => 100, -}); - -static CHANNEL_SIZE: usize = 100; - -pub struct TransactionActorManager { - /// Map of active ITx clients - pub(crate) clients: Arc>>, - /// Cache of closed transactions. We keep the last N closed transactions in memory to - /// return better error messages if operations are performed on closed transactions. - pub(crate) closed_txs: Arc>>>, - /// Channel used to signal an ITx is closed and can be moved to the list of closed transactions. - send_done: Sender<(TxId, Option)>, - /// Handle to the task in charge of clearing actors. - /// Used to abort the task when the TransactionActorManager is dropped. - bg_reader_clear: JoinHandle<()>, -} - -impl Drop for TransactionActorManager { - fn drop(&mut self) { - self.bg_reader_clear.abort(); - } -} - -impl Default for TransactionActorManager { - fn default() -> Self { - Self::new() - } -} - -impl TransactionActorManager { - pub fn new() -> Self { - let clients = Arc::new(RwLock::new(HashMap::new())); - let closed_txs = Arc::new(RwLock::new(LruCache::new(*CLOSED_TX_CACHE_SIZE))); - - let (send_done, rx) = channel(CHANNEL_SIZE); - let handle = spawn_client_list_clear_actor(clients.clone(), closed_txs.clone(), rx); - - Self { - clients, - closed_txs, - send_done, - bg_reader_clear: handle, - } - } - - pub(crate) async fn create_tx( - &self, - query_schema: QuerySchemaRef, - tx_id: TxId, - conn: Box, - isolation_level: Option, - timeout: Duration, - engine_protocol: EngineProtocol, - ) -> crate::Result<()> { - let client = spawn_itx_actor( - query_schema.clone(), - tx_id.clone(), - conn, - isolation_level, - timeout, - CHANNEL_SIZE, - self.send_done.clone(), - engine_protocol, - ) - .await?; - - self.clients.write().await.insert(tx_id, client); - Ok(()) - } - - async fn get_client(&self, tx_id: &TxId, from_operation: &str) -> crate::Result { - if let Some(client) = self.clients.read().await.get(tx_id) { - Ok(client.clone()) - } else if let Some(closed_tx) = self.closed_txs.read().await.peek(tx_id) { - Err(TransactionError::Closed { - reason: match closed_tx { - Some(ClosedTx::Committed) => { - format!("A {from_operation} cannot be executed on a committed transaction") - } - Some(ClosedTx::RolledBack) => { - format!("A {from_operation} cannot be executed on a transaction that was rolled back") - } - Some(ClosedTx::Expired { start_time, timeout }) => { - format!( - "A {from_operation} cannot be executed on an expired transaction. \ - The timeout for this transaction was {} ms, however {} ms passed since the start \ - of the transaction. Consider increasing the interactive transaction timeout \ - or doing less work in the transaction", - timeout.as_millis(), - start_time.elapsed_time().as_millis(), - ) - } - None => { - error!("[{tx_id}] no details about closed transaction"); - format!("A {from_operation} cannot be executed on a closed transaction") - } - }, - } - .into()) - } else { - Err(TransactionError::NotFound.into()) - } - } - - pub async fn execute( - &self, - tx_id: &TxId, - operation: Operation, - traceparent: Option, - ) -> crate::Result { - let client = self.get_client(tx_id, "query").await?; - - client.execute(operation, traceparent).await - } - - pub async fn batch_execute( - &self, - tx_id: &TxId, - operations: Vec, - traceparent: Option, - ) -> crate::Result>> { - let client = self.get_client(tx_id, "batch query").await?; - - client.batch_execute(operations, traceparent).await - } - - pub async fn commit_tx(&self, tx_id: &TxId) -> crate::Result<()> { - let client = self.get_client(tx_id, "commit").await?; - client.commit().await?; - - Ok(()) - } - - pub async fn rollback_tx(&self, tx_id: &TxId) -> crate::Result<()> { - let client = self.get_client(tx_id, "rollback").await?; - client.rollback().await?; - - Ok(()) - } -} diff --git a/query-engine/core/src/interactive_transactions/actors.rs b/query-engine/core/src/interactive_transactions/actors.rs deleted file mode 100644 index 86ebd5c13b84..000000000000 --- a/query-engine/core/src/interactive_transactions/actors.rs +++ /dev/null @@ -1,425 +0,0 @@ -use super::{CachedTx, TransactionError, TxOpRequest, TxOpRequestMsg, TxOpResponse}; -use crate::{ - execute_many_operations, execute_single_operation, protocol::EngineProtocol, ClosedTx, Operation, ResponseData, - TxId, -}; -use connector::Connection; -use crosstarget_utils::task::{spawn, spawn_controlled, JoinHandle}; -use crosstarget_utils::time::ElapsedTimeCounter; -use schema::QuerySchemaRef; -use std::{collections::HashMap, sync::Arc}; -use tokio::{ - sync::{ - mpsc::{channel, Receiver, Sender}, - oneshot, RwLock, - }, - time::Duration, -}; -use tracing::Span; -use tracing_futures::Instrument; -use tracing_futures::WithSubscriber; - -#[cfg(feature = "metrics")] -use crate::telemetry::helpers::set_span_link_from_traceparent; - -#[derive(PartialEq)] -enum RunState { - Continue, - Finished, -} - -pub struct ITXServer<'a> { - id: TxId, - pub cached_tx: CachedTx<'a>, - pub timeout: Duration, - receive: Receiver, - query_schema: QuerySchemaRef, -} - -impl<'a> ITXServer<'a> { - pub fn new( - id: TxId, - tx: CachedTx<'a>, - timeout: Duration, - receive: Receiver, - query_schema: QuerySchemaRef, - ) -> Self { - Self { - id, - cached_tx: tx, - timeout, - receive, - query_schema, - } - } - - // RunState is used to tell if the run loop should continue - async fn process_msg(&mut self, op: TxOpRequest) -> RunState { - match op.msg { - TxOpRequestMsg::Single(ref operation, traceparent) => { - let result = self.execute_single(operation, traceparent).await; - let _ = op.respond_to.send(TxOpResponse::Single(result)); - RunState::Continue - } - TxOpRequestMsg::Batch(ref operations, traceparent) => { - let result = self.execute_batch(operations, traceparent).await; - let _ = op.respond_to.send(TxOpResponse::Batch(result)); - RunState::Continue - } - TxOpRequestMsg::Commit => { - let resp = self.commit().await; - let _ = op.respond_to.send(TxOpResponse::Committed(resp)); - RunState::Finished - } - TxOpRequestMsg::Rollback => { - let resp = self.rollback(false).await; - let _ = op.respond_to.send(TxOpResponse::RolledBack(resp)); - RunState::Finished - } - } - } - - async fn execute_single( - &mut self, - operation: &Operation, - traceparent: Option, - ) -> crate::Result { - let span = info_span!("prisma:engine:itx_query_builder", user_facing = true); - - #[cfg(feature = "metrics")] - set_span_link_from_traceparent(&span, traceparent.clone()); - - let conn = self.cached_tx.as_open()?; - execute_single_operation( - self.query_schema.clone(), - conn.as_connection_like(), - operation, - traceparent, - ) - .instrument(span) - .await - } - - async fn execute_batch( - &mut self, - operations: &[Operation], - traceparent: Option, - ) -> crate::Result>> { - let span = info_span!("prisma:engine:itx_execute", user_facing = true); - - let conn = self.cached_tx.as_open()?; - execute_many_operations( - self.query_schema.clone(), - conn.as_connection_like(), - operations, - traceparent, - ) - .instrument(span) - .await - } - - pub(crate) async fn commit(&mut self) -> crate::Result<()> { - if let CachedTx::Open(_) = self.cached_tx { - let open_tx = self.cached_tx.as_open()?; - trace!("[{}] committing.", self.id.to_string()); - open_tx.commit().await?; - self.cached_tx = CachedTx::Committed; - } - - Ok(()) - } - - pub(crate) async fn rollback(&mut self, was_timeout: bool) -> crate::Result<()> { - debug!("[{}] rolling back, was timed out = {was_timeout}", self.name()); - if let CachedTx::Open(_) = self.cached_tx { - let open_tx = self.cached_tx.as_open()?; - open_tx.rollback().await?; - if was_timeout { - trace!("[{}] Expired Rolling back", self.id.to_string()); - self.cached_tx = CachedTx::Expired; - } else { - self.cached_tx = CachedTx::RolledBack; - trace!("[{}] Rolling back", self.id.to_string()); - } - } - - Ok(()) - } - - pub(crate) fn name(&self) -> String { - format!("itx-{:?}", self.id.to_string()) - } -} - -#[derive(Clone)] -pub struct ITXClient { - send: Sender, - tx_id: TxId, -} - -impl ITXClient { - pub(crate) async fn commit(&self) -> crate::Result<()> { - let msg = self.send_and_receive(TxOpRequestMsg::Commit).await?; - - if let TxOpResponse::Committed(resp) = msg { - debug!("[{}] COMMITTED {:?}", self.tx_id, resp); - resp - } else { - Err(self.handle_error(msg).into()) - } - } - - pub(crate) async fn rollback(&self) -> crate::Result<()> { - let msg = self.send_and_receive(TxOpRequestMsg::Rollback).await?; - - if let TxOpResponse::RolledBack(resp) = msg { - resp - } else { - Err(self.handle_error(msg).into()) - } - } - - pub async fn execute(&self, operation: Operation, traceparent: Option) -> crate::Result { - let msg_req = TxOpRequestMsg::Single(operation, traceparent); - let msg = self.send_and_receive(msg_req).await?; - - if let TxOpResponse::Single(resp) = msg { - resp - } else { - Err(self.handle_error(msg).into()) - } - } - - pub(crate) async fn batch_execute( - &self, - operations: Vec, - traceparent: Option, - ) -> crate::Result>> { - let msg_req = TxOpRequestMsg::Batch(operations, traceparent); - - let msg = self.send_and_receive(msg_req).await?; - - if let TxOpResponse::Batch(resp) = msg { - resp - } else { - Err(self.handle_error(msg).into()) - } - } - - async fn send_and_receive(&self, msg: TxOpRequestMsg) -> Result { - let (receiver, req) = self.create_receive_and_req(msg); - if let Err(err) = self.send.send(req).await { - debug!("channel send error {err}"); - return Err(TransactionError::Closed { - reason: "Could not perform operation".to_string(), - } - .into()); - } - - match receiver.await { - Ok(resp) => Ok(resp), - Err(_err) => Err(TransactionError::Closed { - reason: "Could not perform operation".to_string(), - } - .into()), - } - } - - fn create_receive_and_req(&self, msg: TxOpRequestMsg) -> (oneshot::Receiver, TxOpRequest) { - let (send, rx) = oneshot::channel::(); - let request = TxOpRequest { msg, respond_to: send }; - (rx, request) - } - - fn handle_error(&self, msg: TxOpResponse) -> TransactionError { - match msg { - TxOpResponse::Committed(..) => { - let reason = "Transaction is no longer valid. Last state: 'Committed'".to_string(); - TransactionError::Closed { reason } - } - TxOpResponse::RolledBack(..) => { - let reason = "Transaction is no longer valid. Last state: 'RolledBack'".to_string(); - TransactionError::Closed { reason } - } - other => { - error!("Unexpected iTx response, {}", other); - let reason = format!("response '{other}'"); - TransactionError::Closed { reason } - } - } - } -} - -#[allow(clippy::too_many_arguments)] -pub(crate) async fn spawn_itx_actor( - query_schema: QuerySchemaRef, - tx_id: TxId, - mut conn: Box, - isolation_level: Option, - timeout: Duration, - channel_size: usize, - send_done: Sender<(TxId, Option)>, - engine_protocol: EngineProtocol, -) -> crate::Result { - let span = Span::current(); - let tx_id_str = tx_id.to_string(); - span.record("itx_id", tx_id_str.as_str()); - let dispatcher = crate::get_current_dispatcher(); - - let (tx_to_server, rx_from_client) = channel::(channel_size); - let client = ITXClient { - send: tx_to_server, - tx_id: tx_id.clone(), - }; - let (open_transaction_send, open_transaction_rcv) = oneshot::channel(); - - spawn( - crate::executor::with_request_context(engine_protocol, async move { - // We match on the result in order to send the error to the parent task and abort this - // task, on error. This is a separate task (actor), not a function where we can just bubble up the - // result. - let c_tx = match conn.start_transaction(isolation_level).await { - Ok(c_tx) => { - open_transaction_send.send(Ok(())).unwrap(); - c_tx - } - Err(err) => { - open_transaction_send.send(Err(err)).unwrap(); - return; - } - }; - - let mut server = ITXServer::new( - tx_id.clone(), - CachedTx::Open(c_tx), - timeout, - rx_from_client, - query_schema, - ); - - let start_time = ElapsedTimeCounter::start(); - let sleep = crosstarget_utils::time::sleep(timeout); - tokio::pin!(sleep); - - loop { - tokio::select! { - _ = &mut sleep => { - trace!("[{}] interactive transaction timed out", server.id.to_string()); - let _ = server.rollback(true).await; - break; - } - msg = server.receive.recv() => { - if let Some(op) = msg { - let run_state = server.process_msg(op).await; - - if run_state == RunState::Finished { - break - } - } else { - break; - } - } - } - } - - trace!("[{}] completed with {}", server.id.to_string(), server.cached_tx); - - let _ = send_done - .send(( - server.id.clone(), - server.cached_tx.to_closed(start_time, server.timeout), - )) - .await; - - trace!("[{}] has stopped with {}", server.id.to_string(), server.cached_tx); - }) - .instrument(span) - .with_subscriber(dispatcher), - ); - - open_transaction_rcv.await.unwrap()?; - - Ok(client) -} - -/// Spawn the client list clear actor -/// It waits for messages from completed ITXServers and removes -/// the ITXClient from the clients hashmap - -/* A future improvement to this would be to change this to keep a queue of - clients to remove from the list and then periodically remove them. This - would be a nice optimization because we would take less write locks on the - hashmap. - - The downside to consider is that we can introduce a race condition where the - ITXServer has stopped running but the client hasn't been removed from the hashmap - yet. When the client tries to send a message to the ITXServer there will be a - send error. This isn't a huge obstacle but something to handle correctly. - And example implementation for this would be: - - ``` - let mut queue: Vec = Vec::new(); - - let sleep_duration = Duration::from_millis(100); - let clear_sleeper = time::sleep(sleep_duration); - tokio::pin!(clear_sleeper); - - loop { - tokio::select! { - _ = &mut clear_sleeper => { - let mut list = clients.write().await; - for id in queue.drain(..) { - trace!("removing {} from client list", id); - list.remove(&id); - } - clear_sleeper.as_mut().reset(Instant::now() + sleep_duration); - } - msg = rx.recv() => { - if let Some(id) = msg { - queue.push(id); - } - } - } - } - ``` -*/ -pub(crate) fn spawn_client_list_clear_actor( - clients: Arc>>, - closed_txs: Arc>>>, - mut rx: Receiver<(TxId, Option)>, -) -> JoinHandle<()> { - // Note: tasks implemented via loops cannot be cancelled implicitly, so we need to spawn them in a - // "controlled" way, via `spawn_controlled`. - // The `rx_exit` receiver is used to signal the loop to exit, and that signal is emitted whenever - // the task is aborted (likely, due to the engine shutting down and cleaning up the allocated resources). - spawn_controlled(Box::new( - |mut rx_exit: tokio::sync::broadcast::Receiver<()>| async move { - loop { - tokio::select! { - result = rx.recv() => { - match result { - Some((id, closed_tx)) => { - trace!("removing {} from client list", id); - - let mut clients_guard = clients.write().await; - - clients_guard.remove(&id); - drop(clients_guard); - - closed_txs.write().await.put(id, closed_tx); - } - None => { - // the `rx` channel is closed. - tracing::error!("rx channel is closed!"); - break; - } - } - }, - _ = rx_exit.recv() => { - break; - }, - } - } - }, - )) -} diff --git a/query-engine/core/src/interactive_transactions/error.rs b/query-engine/core/src/interactive_transactions/error.rs index 8189e2ce7420..146d69f103b5 100644 --- a/query-engine/core/src/interactive_transactions/error.rs +++ b/query-engine/core/src/interactive_transactions/error.rs @@ -1,10 +1,5 @@ use thiserror::Error; -use crate::{ - response_ir::{Item, Map}, - CoreError, -}; - #[derive(Debug, Error, PartialEq)] pub enum TransactionError { #[error("Unable to start a transaction in the given time.")] @@ -22,27 +17,3 @@ pub enum TransactionError { #[error("Unexpected response: {reason}.")] Unknown { reason: String }, } - -#[derive(Debug, serde::Serialize, PartialEq)] -pub struct ExtendedTransactionUserFacingError { - #[serde(flatten)] - user_facing_error: user_facing_errors::Error, - - #[serde(skip_serializing_if = "indexmap::IndexMap::is_empty")] - extensions: Map, -} - -impl ExtendedTransactionUserFacingError { - pub fn set_extension(&mut self, key: String, val: serde_json::Value) { - self.extensions.entry(key).or_insert(Item::Json(val)); - } -} - -impl From for ExtendedTransactionUserFacingError { - fn from(error: CoreError) -> Self { - ExtendedTransactionUserFacingError { - user_facing_error: error.into(), - extensions: Default::default(), - } - } -} diff --git a/query-engine/core/src/interactive_transactions/manager.rs b/query-engine/core/src/interactive_transactions/manager.rs new file mode 100644 index 000000000000..d9873c4383a7 --- /dev/null +++ b/query-engine/core/src/interactive_transactions/manager.rs @@ -0,0 +1,192 @@ +use crate::{ClosedTransaction, InteractiveTransaction, Operation, ResponseData}; +use connector::Connection; +use lru::LruCache; +use once_cell::sync::Lazy; +use schema::QuerySchemaRef; +use std::{collections::HashMap, sync::Arc}; +use telemetry::helpers::TraceParent; +use tokio::{ + sync::{ + mpsc::{unbounded_channel, UnboundedSender}, + Mutex, RwLock, + }, + time::Duration, +}; +use tracing_futures::WithSubscriber; + +#[cfg(not(feature = "metrics"))] +use crate::metrics::MetricsInstrumentationStub; +#[cfg(feature = "metrics")] +use prisma_metrics::WithMetricsInstrumentation; + +use super::{TransactionError, TxId}; + +pub static CLOSED_TX_CACHE_SIZE: Lazy = Lazy::new(|| match std::env::var("CLOSED_TX_CACHE_SIZE") { + Ok(size) => size.parse().unwrap_or(100), + Err(_) => 100, +}); + +pub struct ItxManager { + /// Stores all current transactions (some of them might be already committed/expired/rolled back). + /// + /// There are two tiers of locks here: + /// 1. Lock on the entire hashmap. This *must* be taken only for short periods of time - for + /// example to insert/delete transaction or to clone transaction inside. + /// 2. Lock on the individual transactions. This one can be taken for prolonged periods of time - for + /// example to perform an I/O operation. + /// + /// The rationale behind this design is to make shared path (lock on the entire hashmap) as free + /// from contention as possible. Individual transactions are not capable of concurrency, so + /// taking a lock on them to serialise operations is acceptable. + /// + /// Note that since we clone transaction from the shared hashmap to perform operations on it, it + /// is possible to end up in a situation where we cloned the transaction, but it was then + /// immediately removed by the background task from the common hashmap. In this case, either + /// our operation will be first or the background cleanup task will be first. Both cases are + /// an acceptable outcome. + transactions: Arc>>>>, + + /// Cache of closed transactions. We keep the last N closed transactions in memory to + /// return better error messages if operations are performed on closed transactions. + closed_txs: Arc>>, + + /// Sender part of the channel to which transaction id is sent when the timeout of the + /// transaction expires. + timeout_sender: UnboundedSender, +} + +impl ItxManager { + pub fn new() -> Self { + let transactions = Arc::new(RwLock::new(HashMap::<_, Arc>>::default())); + let closed_txs = Arc::new(RwLock::new(LruCache::new(*CLOSED_TX_CACHE_SIZE))); + let (timeout_sender, mut timeout_receiver) = unbounded_channel(); + + // This task rollbacks and removes any open transactions with expired timeouts from the + // `self.transactions`. It also removes any closed transactions to avoid `self.transactions` + // growing infinitely in size over time. + // Note that this task automatically exits when all transactions finish and the `ItxManager` + // is dropped, because that causes the `timeout_receiver` to become closed. + crosstarget_utils::task::spawn({ + let transactions = Arc::clone(&transactions); + let closed_txs = Arc::clone(&closed_txs); + async move { + while let Some(tx_id) = timeout_receiver.recv().await { + let transaction_entry = match transactions.write().await.remove(&tx_id) { + Some(transaction_entry) => transaction_entry, + None => { + // Transaction was committed or rolled back already. + continue; + } + }; + let mut transaction = transaction_entry.lock().await; + + // If transaction was already committed, rollback will error. + let _ = transaction.rollback(true).await; + + let closed_tx = transaction + .as_closed() + .expect("transaction must be closed after rollback"); + + closed_txs.write().await.put(tx_id, closed_tx); + } + } + .with_current_subscriber() + .with_current_recorder() + }); + + Self { + transactions, + closed_txs, + timeout_sender, + } + } + + pub async fn create_tx( + &self, + query_schema: QuerySchemaRef, + tx_id: TxId, + conn: Box, + isolation_level: Option, + timeout: Duration, + ) -> crate::Result<()> { + // This task notifies the task spawned in `new()` method that the timeout for this + // transaction has expired. + crosstarget_utils::task::spawn({ + let timeout_sender = self.timeout_sender.clone(); + let tx_id = tx_id.clone(); + async move { + crosstarget_utils::time::sleep(timeout).await; + timeout_sender.send(tx_id).expect("receiver must exist"); + } + }); + + let transaction = + InteractiveTransaction::new(tx_id.clone(), conn, timeout, query_schema, isolation_level).await?; + + self.transactions + .write() + .await + .insert(tx_id, Arc::new(Mutex::new(transaction))); + Ok(()) + } + + async fn get_transaction( + &self, + tx_id: &TxId, + from_operation: &str, + ) -> crate::Result>> { + if let Some(transaction) = self.transactions.read().await.get(tx_id) { + Ok(Arc::clone(transaction)) + } else { + Err(if let Some(closed_tx) = self.closed_txs.read().await.peek(tx_id) { + TransactionError::Closed { + reason: closed_tx.error_message_for(from_operation), + } + .into() + } else { + TransactionError::NotFound.into() + }) + } + } + + pub async fn execute( + &self, + tx_id: &TxId, + operation: Operation, + traceparent: Option, + ) -> crate::Result { + self.get_transaction(tx_id, "query") + .await? + .lock() + .await + .execute_single(&operation, traceparent) + .await + } + + pub async fn batch_execute( + &self, + tx_id: &TxId, + operations: Vec, + traceparent: Option, + ) -> crate::Result>> { + self.get_transaction(tx_id, "batch query") + .await? + .lock() + .await + .execute_batch(&operations, traceparent) + .await + } + + pub async fn commit_tx(&self, tx_id: &TxId) -> crate::Result<()> { + self.get_transaction(tx_id, "commit").await?.lock().await.commit().await + } + + pub async fn rollback_tx(&self, tx_id: &TxId) -> crate::Result<()> { + self.get_transaction(tx_id, "rollback") + .await? + .lock() + .await + .rollback(false) + .await + } +} diff --git a/query-engine/core/src/interactive_transactions/messages.rs b/query-engine/core/src/interactive_transactions/messages.rs deleted file mode 100644 index 0dba2c096a8a..000000000000 --- a/query-engine/core/src/interactive_transactions/messages.rs +++ /dev/null @@ -1,46 +0,0 @@ -use crate::{Operation, ResponseData}; -use std::fmt::Display; -use tokio::sync::oneshot; - -#[derive(Debug)] -pub enum TxOpRequestMsg { - Commit, - Rollback, - Single(Operation, Option), - Batch(Vec, Option), -} - -pub struct TxOpRequest { - pub msg: TxOpRequestMsg, - pub respond_to: oneshot::Sender, -} - -impl Display for TxOpRequest { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self.msg { - TxOpRequestMsg::Commit => write!(f, "Commit"), - TxOpRequestMsg::Rollback => write!(f, "Rollback"), - TxOpRequestMsg::Single(..) => write!(f, "Single"), - TxOpRequestMsg::Batch(..) => write!(f, "Batch"), - } - } -} - -#[derive(Debug)] -pub enum TxOpResponse { - Committed(crate::Result<()>), - RolledBack(crate::Result<()>), - Single(crate::Result), - Batch(crate::Result>>), -} - -impl Display for TxOpResponse { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Committed(..) => write!(f, "Committed"), - Self::RolledBack(..) => write!(f, "RolledBack"), - Self::Single(..) => write!(f, "Single"), - Self::Batch(..) => write!(f, "Batch"), - } - } -} diff --git a/query-engine/core/src/interactive_transactions/mod.rs b/query-engine/core/src/interactive_transactions/mod.rs index c3ee76703a06..009cab37ccfd 100644 --- a/query-engine/core/src/interactive_transactions/mod.rs +++ b/query-engine/core/src/interactive_transactions/mod.rs @@ -1,49 +1,19 @@ -use crate::CoreError; -use connector::Transaction; -use crosstarget_utils::time::ElapsedTimeCounter; +use derive_more::Display; use serde::Deserialize; -use std::fmt::Display; -use tokio::time::Duration; -mod actor_manager; -mod actors; mod error; -mod messages; +mod manager; +mod transaction; pub use error::*; -pub(crate) use actor_manager::*; -pub(crate) use actors::*; -pub(crate) use messages::*; +pub(crate) use manager::*; +pub(crate) use transaction::*; -/// How Interactive Transactions work -/// The Interactive Transactions (iTx) follow an actor model design. Where each iTx is created in its own process. -/// When a prisma client requests to start a new transaction, the Transaction Actor Manager spawns a new ITXServer. The ITXServer runs in its own -/// process and waits for messages to arrive via its receive channel to process. -/// The Transaction Actor Manager will also create an ITXClient and add it to hashmap managed by an RwLock. The ITXClient is the only way to communicate -/// with the ITXServer. - -/// Once Prisma Client receives the iTx Id it can perform database operations using that iTx id. When an operation request is received by the -/// TransactionActorManager, it looks for the client in the hashmap and passes the operation to the client. The ITXClient sends a message to the -/// ITXServer and waits for a response. The ITXServer will then perform the operation and return the result. The ITXServer will perform one -/// operation at a time. All other operations will sit in the message queue waiting to be processed. -/// -/// The ITXServer will handle all messages until: -/// - It transitions state, e.g "rollback" or "commit" -/// - It exceeds its timeout, in which case the iTx is rolledback and the connection to the database is closed. - -/// Once the ITXServer is done handling messages from the iTx Client, it sends a last message to the Background Client list Actor to say that it is completed and then shuts down. -/// The Background Client list Actor removes the client from the list of active clients and keeps in cache the iTx id of the closed transaction. - -/// We keep a list of closed transactions so that if any further messages are received for this iTx id, -/// the TransactionActorManager can reply with a helpful error message which explains that no operation can be performed on a closed transaction -/// rather than an error message stating that the transaction does not exist. - -#[derive(Debug, Clone, Hash, Eq, PartialEq, Deserialize)] +#[derive(Debug, Clone, Hash, Eq, PartialEq, Deserialize, Display)] +#[display(fmt = "{}", _0)] pub struct TxId(String); -const MINIMUM_TX_ID_LENGTH: usize = 24; - impl Default for TxId { fn default() -> Self { #[allow(deprecated)] @@ -56,9 +26,11 @@ where T: Into, { fn from(s: T) -> Self { + const MINIMUM_TX_ID_LENGTH: usize = 24; + let contents = s.into(); // This postcondition is to ensure that the TxId is long enough as to be able to derive - // a TraceId from it. + // a TraceId from it. See `TxTraceExt` trait for more details. assert!( contents.len() >= MINIMUM_TX_ID_LENGTH, "minimum length for a TxId ({}) is {}, but was {}", @@ -69,57 +41,3 @@ where Self(contents) } } - -impl Display for TxId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Display::fmt(&self.0, f) - } -} - -pub enum CachedTx<'a> { - Open(Box), - Committed, - RolledBack, - Expired, -} - -impl Display for CachedTx<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - CachedTx::Open(_) => f.write_str("Open"), - CachedTx::Committed => f.write_str("Committed"), - CachedTx::RolledBack => f.write_str("Rolled back"), - CachedTx::Expired => f.write_str("Expired"), - } - } -} - -impl<'a> CachedTx<'a> { - /// Requires this cached TX to be `Open`, else an error will be raised that it is no longer valid. - pub(crate) fn as_open(&mut self) -> crate::Result<&mut Box> { - if let Self::Open(ref mut otx) = self { - Ok(otx) - } else { - let reason = format!("Transaction is no longer valid. Last state: '{self}'"); - Err(CoreError::from(TransactionError::Closed { reason })) - } - } - - pub(crate) fn to_closed(&self, start_time: ElapsedTimeCounter, timeout: Duration) -> Option { - match self { - CachedTx::Open(_) => None, - CachedTx::Committed => Some(ClosedTx::Committed), - CachedTx::RolledBack => Some(ClosedTx::RolledBack), - CachedTx::Expired => Some(ClosedTx::Expired { start_time, timeout }), - } - } -} - -pub(crate) enum ClosedTx { - Committed, - RolledBack, - Expired { - start_time: ElapsedTimeCounter, - timeout: Duration, - }, -} diff --git a/query-engine/core/src/interactive_transactions/transaction.rs b/query-engine/core/src/interactive_transactions/transaction.rs new file mode 100644 index 000000000000..4e84155ad78e --- /dev/null +++ b/query-engine/core/src/interactive_transactions/transaction.rs @@ -0,0 +1,253 @@ +#![allow(unsafe_code)] + +use std::pin::Pin; + +use crate::{ + execute_many_operations, execute_single_operation, CoreError, Operation, ResponseData, TransactionError, TxId, +}; +use connector::{Connection, Transaction}; +use crosstarget_utils::time::ElapsedTimeCounter; +use schema::QuerySchemaRef; +use telemetry::helpers::TraceParent; +use tokio::time::Duration; +use tracing::Span; +use tracing_futures::Instrument; + +// Note: it's important to maintain the correct state of the transaction throughout execution. If +// the transaction is ever left in the `Open` state after rollback or commit operations, it means +// that the corresponding connection will never be returned to the connection pool. +enum TransactionState { + Open { + // Note: field order is important here because fields are dropped in the declaration order. + // First, we drop the `tx`, which may reference `_conn`. Only after that we drop `_conn`. + tx: Box, + _conn: Pin>, + }, + Committed, + RolledBack, + Expired { + start_time: ElapsedTimeCounter, + timeout: Duration, + }, +} + +pub enum ClosedTransaction { + Committed, + RolledBack, + Expired { + start_time: ElapsedTimeCounter, + timeout: Duration, + }, +} + +impl ClosedTransaction { + pub fn error_message_for(&self, operation: &str) -> String { + match self { + ClosedTransaction::Committed => { + format!("A {operation} cannot be executed on a committed transaction") + } + ClosedTransaction::RolledBack => { + format!("A {operation} cannot be executed on a transaction that was rolled back") + } + ClosedTransaction::Expired { start_time, timeout } => { + format!( + "A {operation} cannot be executed on an expired transaction. \ + The timeout for this transaction was {} ms, however {} ms passed since the start \ + of the transaction. Consider increasing the interactive transaction timeout \ + or doing less work in the transaction", + timeout.as_millis(), + start_time.elapsed_time().as_millis(), + ) + } + } + } +} + +impl TransactionState { + async fn start_transaction( + conn: Box, + isolation_level: Option, + ) -> crate::Result { + // Note: This method creates a self-referential struct, which is why we need unsafe. Field + // `tx` is referencing field `conn` in the `Self::Open` variant. + let mut conn = Box::into_pin(conn); + + // SAFETY: We do not move out of `conn`. + let conn_mut: &mut (dyn Connection + Send + Sync) = unsafe { conn.as_mut().get_unchecked_mut() }; + + // This creates a transaction, which borrows from the connection. + let tx_borrowed_from_conn: Box = conn_mut.start_transaction(isolation_level).await?; + + // SAFETY: This transmute only erases the lifetime from `conn_mut`. Normally, borrow checker + // guarantees that the borrowed value is not dropped. In this case, we guarantee ourselves + // through the use of `Pin` on the connection. + let tx_with_erased_lifetime: Box = + unsafe { std::mem::transmute(tx_borrowed_from_conn) }; + + Ok(Self::Open { + tx: tx_with_erased_lifetime, + _conn: conn, + }) + } + + fn as_open(&mut self, from_operation: &str) -> crate::Result<&mut Box> { + match self { + Self::Open { tx, .. } => Ok(tx), + tx => Err(CoreError::from(TransactionError::Closed { + reason: tx.as_closed().unwrap().error_message_for(from_operation), + })), + } + } + + fn as_closed(&self) -> Option { + match self { + Self::Open { .. } => None, + Self::Committed => Some(ClosedTransaction::Committed), + Self::RolledBack => Some(ClosedTransaction::RolledBack), + Self::Expired { start_time, timeout } => Some(ClosedTransaction::Expired { + start_time: *start_time, + timeout: *timeout, + }), + } + } +} + +pub struct InteractiveTransaction { + id: TxId, + state: TransactionState, + start_time: ElapsedTimeCounter, + timeout: Duration, + query_schema: QuerySchemaRef, +} + +/// This macro executes the future until it's ready or the transaction's timeout expires. +macro_rules! tx_timeout { + ($self:expr, $operation:expr, $fut:expr) => {{ + let remaining_time = $self + .timeout + .checked_sub($self.start_time.elapsed_time()) + .unwrap_or(Duration::ZERO); + tokio::select! { + biased; + _ = crosstarget_utils::time::sleep(remaining_time) => { + let _ = $self.rollback(true).await; + Err(TransactionError::Closed { + reason: $self.as_closed().unwrap().error_message_for($operation), + }.into()) + } + result = $fut => { + result + } + } + }}; +} + +impl InteractiveTransaction { + pub async fn new( + id: TxId, + conn: Box, + timeout: Duration, + query_schema: QuerySchemaRef, + isolation_level: Option, + ) -> crate::Result { + Span::current().record("itx_id", id.to_string()); + + Ok(Self { + id, + state: TransactionState::start_transaction(conn, isolation_level).await?, + start_time: ElapsedTimeCounter::start(), + timeout, + query_schema, + }) + } + + pub async fn execute_single( + &mut self, + operation: &Operation, + traceparent: Option, + ) -> crate::Result { + tx_timeout!(self, "query", async { + let conn = self.state.as_open("query")?; + execute_single_operation( + self.query_schema.clone(), + conn.as_connection_like(), + operation, + traceparent, + ) + .instrument(info_span!("prisma:engine:itx_execute_single", user_facing = true)) + .await + }) + } + + pub async fn execute_batch( + &mut self, + operations: &[Operation], + traceparent: Option, + ) -> crate::Result>> { + tx_timeout!(self, "batch query", async { + let conn = self.state.as_open("batch query")?; + execute_many_operations( + self.query_schema.clone(), + conn.as_connection_like(), + operations, + traceparent, + ) + .instrument(info_span!("prisma:engine:itx_execute_batch", user_facing = true)) + .await + }) + } + + pub async fn commit(&mut self) -> crate::Result<()> { + tx_timeout!(self, "commit", async { + let name = self.name(); + let conn = self.state.as_open("commit")?; + let span = info_span!("prisma:engine:itx_commit", user_facing = true); + + if let Err(err) = conn.commit().instrument(span).await { + error!(?err, ?name, "transaction failed to commit"); + // We don't know if the transaction was committed or not. Because of that, we cannot + // leave it in "open" state. We attempt to rollback to get the transaction into a + // known state. + let _ = self.rollback(false).await; + Err(err.into()) + } else { + debug!(?name, "transaction committed"); + self.state = TransactionState::Committed; + Ok(()) + } + }) + } + + pub async fn rollback(&mut self, was_timeout: bool) -> crate::Result<()> { + let name = self.name(); + let conn = self.state.as_open("rollback")?; + let span = info_span!("prisma:engine:itx_rollback", user_facing = true); + + let result = conn.rollback().instrument(span).await; + if let Err(err) = &result { + error!(?err, ?was_timeout, ?name, "transaction failed to roll back"); + } else { + debug!(?was_timeout, ?name, "transaction rolled back"); + } + + // Ensure that the transaction isn't left in the "open" state after the rollback. + if was_timeout { + self.state = TransactionState::Expired { + start_time: self.start_time, + timeout: self.timeout, + }; + } else { + self.state = TransactionState::RolledBack; + } + + result.map_err(<_>::into) + } + + pub fn as_closed(&self) -> Option { + self.state.as_closed() + } + + pub fn name(&self) -> String { + format!("itx-{}", self.id) + } +} diff --git a/query-engine/core/src/interpreter/interpreter_impl.rs b/query-engine/core/src/interpreter/interpreter_impl.rs index 012bbc953b1a..e25d157b7ada 100644 --- a/query-engine/core/src/interpreter/interpreter_impl.rs +++ b/query-engine/core/src/interpreter/interpreter_impl.rs @@ -8,6 +8,7 @@ use connector::ConnectionLike; use futures::future::BoxFuture; use query_structure::prelude::*; use std::{collections::HashMap, fmt}; +use telemetry::helpers::TraceParent; use tracing::Instrument; #[derive(Debug, Clone)] @@ -178,7 +179,7 @@ impl<'conn> QueryInterpreter<'conn> { exp: Expression, env: Env, level: usize, - trace_id: Option, + traceparent: Option, ) -> BoxFuture<'_, InterpretationResult> { match exp { Expression::Func { func } => { @@ -186,7 +187,7 @@ impl<'conn> QueryInterpreter<'conn> { Box::pin(async move { self.log_line(level, || "execute {"); - let result = self.interpret(expr?, env, level + 1, trace_id).await; + let result = self.interpret(expr?, env, level + 1, traceparent).await; self.log_line(level, || "}"); result }) @@ -204,7 +205,7 @@ impl<'conn> QueryInterpreter<'conn> { let mut results = Vec::with_capacity(seq.len()); for expr in seq { - results.push(self.interpret(expr, env.clone(), level + 1, trace_id.clone()).await?); + results.push(self.interpret(expr, env.clone(), level + 1, traceparent).await?); self.log_line(level + 1, || ","); } @@ -227,7 +228,7 @@ impl<'conn> QueryInterpreter<'conn> { self.log_line(level + 1, || format!("{} = {{", &binding.name)); let result = self - .interpret(binding.expr, env.clone(), level + 2, trace_id.clone()) + .interpret(binding.expr, env.clone(), level + 2, traceparent) .await?; inner_env.insert(binding.name, result); @@ -242,7 +243,7 @@ impl<'conn> QueryInterpreter<'conn> { }; self.log_line(level, || "in {"); - let result = self.interpret(next_expression, inner_env, level + 1, trace_id).await; + let result = self.interpret(next_expression, inner_env, level + 1, traceparent).await; self.log_line(level, || "}"); result }) @@ -253,7 +254,7 @@ impl<'conn> QueryInterpreter<'conn> { Query::Read(read) => { self.log_line(level, || format!("readExecute {read}")); let span = info_span!("prisma:engine:read-execute"); - Ok(read::execute(self.conn, read, None, trace_id) + Ok(read::execute(self.conn, read, None, traceparent) .instrument(span) .await .map(ExpressionResult::Query)?) @@ -262,7 +263,7 @@ impl<'conn> QueryInterpreter<'conn> { Query::Write(write) => { self.log_line(level, || format!("writeExecute {write}")); let span = info_span!("prisma:engine:write-execute"); - Ok(write::execute(self.conn, write, trace_id) + Ok(write::execute(self.conn, write, traceparent) .instrument(span) .await .map(ExpressionResult::Query)?) @@ -297,10 +298,10 @@ impl<'conn> QueryInterpreter<'conn> { self.log_line(level, || format!("if = {predicate} {{")); let result = if predicate { - self.interpret(Expression::Sequence { seq: then }, env, level + 1, trace_id) + self.interpret(Expression::Sequence { seq: then }, env, level + 1, traceparent) .await } else { - self.interpret(Expression::Sequence { seq: elze }, env, level + 1, trace_id) + self.interpret(Expression::Sequence { seq: elze }, env, level + 1, traceparent) .await }; self.log_line(level, || "}"); diff --git a/query-engine/core/src/interpreter/query_interpreters/nested_read.rs b/query-engine/core/src/interpreter/query_interpreters/nested_read.rs index 790728104fd3..95e5945c18dc 100644 --- a/query-engine/core/src/interpreter/query_interpreters/nested_read.rs +++ b/query-engine/core/src/interpreter/query_interpreters/nested_read.rs @@ -3,12 +3,13 @@ use crate::{interpreter::InterpretationResult, query_ast::*}; use connector::ConnectionLike; use query_structure::*; use std::collections::HashMap; +use telemetry::helpers::TraceParent; pub(crate) async fn m2m( tx: &mut dyn ConnectionLike, query: &mut RelatedRecordsQuery, parent_result: Option<&ManyRecords>, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { let processor = InMemoryRecordProcessor::new_from_query_args(&mut query.args); @@ -31,7 +32,7 @@ pub(crate) async fn m2m( } let ids = tx - .get_related_m2m_record_ids(&query.parent_field, &parent_ids, trace_id.clone()) + .get_related_m2m_record_ids(&query.parent_field, &parent_ids, traceparent) .await?; if ids.is_empty() { return Ok(ManyRecords::empty(&query.selected_fields)); @@ -70,7 +71,7 @@ pub(crate) async fn m2m( args, &query.selected_fields, RelationLoadStrategy::Query, - trace_id.clone(), + traceparent, ) .await? }; @@ -137,7 +138,7 @@ pub async fn one2m( parent_result: Option<&ManyRecords>, mut query_args: QueryArguments, selected_fields: &FieldSelection, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { let parent_model_id = parent_field.model().primary_identifier(); let parent_link_id = parent_field.linking_fields(); @@ -208,7 +209,7 @@ pub async fn one2m( args, selected_fields, RelationLoadStrategy::Query, - trace_id, + traceparent, ) .await? }; diff --git a/query-engine/core/src/interpreter/query_interpreters/read.rs b/query-engine/core/src/interpreter/query_interpreters/read.rs index 7e194993b75c..d79f4fd5c998 100644 --- a/query-engine/core/src/interpreter/query_interpreters/read.rs +++ b/query-engine/core/src/interpreter/query_interpreters/read.rs @@ -4,20 +4,21 @@ use connector::{error::ConnectorError, ConnectionLike}; use futures::future::{BoxFuture, FutureExt}; use psl::can_support_relation_load_strategy; use query_structure::{ManyRecords, RelationLoadStrategy, RelationSelection}; +use telemetry::helpers::TraceParent; use user_facing_errors::KnownError; pub(crate) fn execute<'conn>( tx: &'conn mut dyn ConnectionLike, query: ReadQuery, parent_result: Option<&'conn ManyRecords>, - trace_id: Option, + traceparent: Option, ) -> BoxFuture<'conn, InterpretationResult> { let fut = async move { match query { - ReadQuery::RecordQuery(q) => read_one(tx, q, trace_id).await, - ReadQuery::ManyRecordsQuery(q) => read_many(tx, q, trace_id).await, - ReadQuery::RelatedRecordsQuery(q) => read_related(tx, q, parent_result, trace_id).await, - ReadQuery::AggregateRecordsQuery(q) => aggregate(tx, q, trace_id).await, + ReadQuery::RecordQuery(q) => read_one(tx, q, traceparent).await, + ReadQuery::ManyRecordsQuery(q) => read_many(tx, q, traceparent).await, + ReadQuery::RelatedRecordsQuery(q) => read_related(tx, q, parent_result, traceparent).await, + ReadQuery::AggregateRecordsQuery(q) => aggregate(tx, q, traceparent).await, } }; @@ -28,7 +29,7 @@ pub(crate) fn execute<'conn>( fn read_one( tx: &mut dyn ConnectionLike, query: RecordQuery, - trace_id: Option, + traceparent: Option, ) -> BoxFuture<'_, InterpretationResult> { let fut = async move { let model = query.model; @@ -39,7 +40,7 @@ fn read_one( &filter, &query.selected_fields, query.relation_load_strategy, - trace_id, + traceparent, ) .await?; @@ -97,18 +98,18 @@ fn read_one( fn read_many( tx: &mut dyn ConnectionLike, query: ManyRecordsQuery, - trace_id: Option, + traceparent: Option, ) -> BoxFuture<'_, InterpretationResult> { match query.relation_load_strategy { - RelationLoadStrategy::Join => read_many_by_joins(tx, query, trace_id), - RelationLoadStrategy::Query => read_many_by_queries(tx, query, trace_id), + RelationLoadStrategy::Join => read_many_by_joins(tx, query, traceparent), + RelationLoadStrategy::Query => read_many_by_queries(tx, query, traceparent), } } fn read_many_by_queries( tx: &mut dyn ConnectionLike, mut query: ManyRecordsQuery, - trace_id: Option, + traceparent: Option, ) -> BoxFuture<'_, InterpretationResult> { let processor = if query.args.requires_inmemory_processing() { Some(InMemoryRecordProcessor::new_from_query_args(&mut query.args)) @@ -123,7 +124,7 @@ fn read_many_by_queries( query.args.clone(), &query.selected_fields, query.relation_load_strategy, - trace_id, + traceparent, ) .await?; @@ -156,7 +157,7 @@ fn read_many_by_queries( fn read_many_by_joins( tx: &mut dyn ConnectionLike, query: ManyRecordsQuery, - trace_id: Option, + traceparent: Option, ) -> BoxFuture<'_, InterpretationResult> { if !can_support_relation_load_strategy() { unreachable!() @@ -168,7 +169,7 @@ fn read_many_by_joins( query.args.clone(), &query.selected_fields, query.relation_load_strategy, - trace_id, + traceparent, ) .await?; @@ -209,13 +210,13 @@ fn read_related<'conn>( tx: &'conn mut dyn ConnectionLike, mut query: RelatedRecordsQuery, parent_result: Option<&'conn ManyRecords>, - trace_id: Option, + traceparent: Option, ) -> BoxFuture<'conn, InterpretationResult> { let fut = async move { let relation = query.parent_field.relation(); let records = if relation.is_many_to_many() { - nested_read::m2m(tx, &mut query, parent_result, trace_id).await? + nested_read::m2m(tx, &mut query, parent_result, traceparent).await? } else { nested_read::one2m( tx, @@ -224,7 +225,7 @@ fn read_related<'conn>( parent_result, query.args.clone(), &query.selected_fields, - trace_id, + traceparent, ) .await? }; @@ -248,7 +249,7 @@ fn read_related<'conn>( async fn aggregate( tx: &mut dyn ConnectionLike, query: AggregateRecordsQuery, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { let selection_order = query.selection_order; @@ -259,7 +260,7 @@ async fn aggregate( query.selectors, query.group_by, query.having, - trace_id, + traceparent, ) .await?; diff --git a/query-engine/core/src/interpreter/query_interpreters/write.rs b/query-engine/core/src/interpreter/query_interpreters/write.rs index d3146c383639..453964369801 100644 --- a/query-engine/core/src/interpreter/query_interpreters/write.rs +++ b/query-engine/core/src/interpreter/query_interpreters/write.rs @@ -7,24 +7,25 @@ use crate::{ }; use connector::{ConnectionLike, DatasourceFieldName, NativeUpsert, WriteArgs}; use query_structure::{ManyRecords, Model, RawJson}; +use telemetry::helpers::TraceParent; pub(crate) async fn execute( tx: &mut dyn ConnectionLike, write_query: WriteQuery, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { match write_query { - WriteQuery::CreateRecord(q) => create_one(tx, q, trace_id).await, - WriteQuery::CreateManyRecords(q) => create_many(tx, q, trace_id).await, - WriteQuery::UpdateRecord(q) => update_one(tx, q, trace_id).await, - WriteQuery::DeleteRecord(q) => delete_one(tx, q, trace_id).await, - WriteQuery::UpdateManyRecords(q) => update_many(tx, q, trace_id).await, - WriteQuery::DeleteManyRecords(q) => delete_many(tx, q, trace_id).await, - WriteQuery::ConnectRecords(q) => connect(tx, q, trace_id).await, - WriteQuery::DisconnectRecords(q) => disconnect(tx, q, trace_id).await, + WriteQuery::CreateRecord(q) => create_one(tx, q, traceparent).await, + WriteQuery::CreateManyRecords(q) => create_many(tx, q, traceparent).await, + WriteQuery::UpdateRecord(q) => update_one(tx, q, traceparent).await, + WriteQuery::DeleteRecord(q) => delete_one(tx, q, traceparent).await, + WriteQuery::UpdateManyRecords(q) => update_many(tx, q, traceparent).await, + WriteQuery::DeleteManyRecords(q) => delete_many(tx, q, traceparent).await, + WriteQuery::ConnectRecords(q) => connect(tx, q, traceparent).await, + WriteQuery::DisconnectRecords(q) => disconnect(tx, q, traceparent).await, WriteQuery::ExecuteRaw(q) => execute_raw(tx, q).await, WriteQuery::QueryRaw(q) => query_raw(tx, q).await, - WriteQuery::Upsert(q) => native_upsert(tx, q, trace_id).await, + WriteQuery::Upsert(q) => native_upsert(tx, q, traceparent).await, } } @@ -46,9 +47,11 @@ async fn execute_raw(tx: &mut dyn ConnectionLike, q: RawQuery) -> Interpretation async fn create_one( tx: &mut dyn ConnectionLike, q: CreateRecord, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { - let res = tx.create_record(&q.model, q.args, q.selected_fields, trace_id).await?; + let res = tx + .create_record(&q.model, q.args, q.selected_fields, traceparent) + .await?; Ok(QueryResult::RecordSelection(Some(Box::new(RecordSelection { name: q.name, @@ -63,15 +66,15 @@ async fn create_one( async fn create_many( tx: &mut dyn ConnectionLike, q: CreateManyRecords, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { if q.split_by_shape { - return create_many_split_by_shape(tx, q, trace_id).await; + return create_many_split_by_shape(tx, q, traceparent).await; } if let Some(selected_fields) = q.selected_fields { let records = tx - .create_records_returning(&q.model, q.args, q.skip_duplicates, selected_fields.fields, trace_id) + .create_records_returning(&q.model, q.args, q.skip_duplicates, selected_fields.fields, traceparent) .await?; let nested: Vec = super::read::process_nested(tx, selected_fields.nested, Some(&records)).await?; @@ -87,7 +90,9 @@ async fn create_many( Ok(QueryResult::RecordSelection(Some(Box::new(selection)))) } else { - let affected_records = tx.create_records(&q.model, q.args, q.skip_duplicates, trace_id).await?; + let affected_records = tx + .create_records(&q.model, q.args, q.skip_duplicates, traceparent) + .await?; Ok(QueryResult::Count(affected_records)) } @@ -100,7 +105,7 @@ async fn create_many( async fn create_many_split_by_shape( tx: &mut dyn ConnectionLike, q: CreateManyRecords, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { let mut args_by_shape: HashMap> = Default::default(); let model = &q.model; @@ -121,7 +126,7 @@ async fn create_many_split_by_shape( args, q.skip_duplicates, selected_fields.fields.clone(), - trace_id.clone(), + traceparent, ) .await?; @@ -139,7 +144,7 @@ async fn create_many_split_by_shape( result } else { // Empty result means that the list of arguments was empty as well. - tx.create_records_returning(&q.model, vec![], q.skip_duplicates, selected_fields.fields, trace_id) + tx.create_records_returning(&q.model, vec![], q.skip_duplicates, selected_fields.fields, traceparent) .await? }; @@ -161,7 +166,7 @@ async fn create_many_split_by_shape( for args in args_by_shape.into_values() { let affected_records = tx - .create_records(&q.model, args, q.skip_duplicates, trace_id.clone()) + .create_records(&q.model, args, q.skip_duplicates, traceparent) .await?; result += affected_records; } @@ -205,7 +210,7 @@ fn create_many_shape(write_args: &WriteArgs, model: &Model) -> CreateManyShape { async fn update_one( tx: &mut dyn ConnectionLike, q: UpdateRecord, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { let res = tx .update_record( @@ -213,7 +218,7 @@ async fn update_one( q.record_filter().clone(), q.args().clone(), q.selected_fields(), - trace_id, + traceparent, ) .await?; @@ -245,9 +250,9 @@ async fn update_one( async fn native_upsert( tx: &mut dyn ConnectionLike, query: NativeUpsert, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { - let scalars = tx.native_upsert_record(query.clone(), trace_id).await?; + let scalars = tx.native_upsert_record(query.clone(), traceparent).await?; Ok(RecordSelection { name: query.name().to_string(), @@ -263,7 +268,7 @@ async fn native_upsert( async fn delete_one( tx: &mut dyn ConnectionLike, q: DeleteRecord, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { // We need to ensure that we have a record finder, else we delete everything (conversion to empty filter). let filter = match q.record_filter { @@ -276,7 +281,7 @@ async fn delete_one( if let Some(selected_fields) = q.selected_fields { let record = tx - .delete_record(&q.model, filter, selected_fields.fields, trace_id) + .delete_record(&q.model, filter, selected_fields.fields, traceparent) .await?; let selection = RecordSelection { name: q.name, @@ -289,7 +294,7 @@ async fn delete_one( Ok(QueryResult::RecordSelection(Some(Box::new(selection)))) } else { - let result = tx.delete_records(&q.model, filter, trace_id).await?; + let result = tx.delete_records(&q.model, filter, traceparent).await?; Ok(QueryResult::Count(result)) } } @@ -297,9 +302,11 @@ async fn delete_one( async fn update_many( tx: &mut dyn ConnectionLike, q: UpdateManyRecords, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { - let res = tx.update_records(&q.model, q.record_filter, q.args, trace_id).await?; + let res = tx + .update_records(&q.model, q.record_filter, q.args, traceparent) + .await?; Ok(QueryResult::Count(res)) } @@ -307,9 +314,9 @@ async fn update_many( async fn delete_many( tx: &mut dyn ConnectionLike, q: DeleteManyRecords, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { - let res = tx.delete_records(&q.model, q.record_filter, trace_id).await?; + let res = tx.delete_records(&q.model, q.record_filter, traceparent).await?; Ok(QueryResult::Count(res)) } @@ -317,13 +324,13 @@ async fn delete_many( async fn connect( tx: &mut dyn ConnectionLike, q: ConnectRecords, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { tx.m2m_connect( &q.relation_field, &q.parent_id.expect("Expected parent record ID to be set for connect"), &q.child_ids, - trace_id, + traceparent, ) .await?; @@ -333,13 +340,13 @@ async fn connect( async fn disconnect( tx: &mut dyn ConnectionLike, q: DisconnectRecords, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { tx.m2m_disconnect( &q.relation_field, &q.parent_id.expect("Expected parent record ID to be set for disconnect"), &q.child_ids, - trace_id, + traceparent, ) .await?; diff --git a/query-engine/core/src/lib.rs b/query-engine/core/src/lib.rs index bf993d6bce18..7e1868cc017f 100644 --- a/query-engine/core/src/lib.rs +++ b/query-engine/core/src/lib.rs @@ -10,13 +10,11 @@ pub mod query_document; pub mod query_graph_builder; pub mod relation_load_strategy; pub mod response_ir; -pub mod telemetry; -pub use self::telemetry::*; pub use self::{ - error::{CoreError, FieldConversionError}, + error::{CoreError, ExtendedUserFacingError, FieldConversionError}, executor::{QueryExecutor, TransactionOptions}, - interactive_transactions::{ExtendedTransactionUserFacingError, TransactionError, TxId}, + interactive_transactions::{TransactionError, TxId}, query_document::*, }; @@ -28,6 +26,7 @@ pub use connector::{ mod error; mod interactive_transactions; mod interpreter; +mod metrics; mod query_ast; mod query_graph; mod result_ast; diff --git a/query-engine/core/src/metrics.rs b/query-engine/core/src/metrics.rs new file mode 100644 index 000000000000..736096634ad9 --- /dev/null +++ b/query-engine/core/src/metrics.rs @@ -0,0 +1,13 @@ +/// When the `metrics` feature is disabled, we don't compile the `prisma-metrics` crate and +/// thus can't use the metrics instrumentation. To avoid the boilerplate of putting every +/// `with_current_recorder` call behind `#[cfg]`, we use this stub trait that does nothing but +/// allows the code that relies on `WithMetricsInstrumentation` trait to be in scope compile. +#[cfg(not(feature = "metrics"))] +pub(crate) trait MetricsInstrumentationStub: Sized { + fn with_current_recorder(self) -> Self { + self + } +} + +#[cfg(not(feature = "metrics"))] +impl MetricsInstrumentationStub for T {} diff --git a/query-engine/core/src/query_document/mod.rs b/query-engine/core/src/query_document/mod.rs index b379f74d3129..465109bf797e 100644 --- a/query-engine/core/src/query_document/mod.rs +++ b/query-engine/core/src/query_document/mod.rs @@ -32,6 +32,7 @@ use crate::{ query_ast::{QueryOption, QueryOptions}, query_graph_builder::resolve_compound_field, }; +use itertools::Itertools; use query_structure::Model; use schema::{constants::*, QuerySchema}; use std::collections::HashMap; @@ -285,12 +286,15 @@ impl CompactedDocument { .collect(); // Gets the argument keys for later mapping. - let keys: Vec<_> = arguments[0] + let keys: Vec<_> = arguments .iter() - .flat_map(|pair| match pair { - (_, ArgumentValue::Object(obj)) => obj.keys().map(ToOwned::to_owned).collect(), - (key, _) => vec![key.to_owned()], + .flat_map(|map| { + map.iter().flat_map(|(key, value)| match value { + ArgumentValue::Object(obj) => obj.keys().map(ToOwned::to_owned).collect::>(), + _ => vec![key.to_owned()], + }) }) + .unique() .collect(); Self { diff --git a/query-engine/core/src/query_graph/mod.rs b/query-engine/core/src/query_graph/mod.rs index 458be8280a3a..8459584a0c42 100644 --- a/query-engine/core/src/query_graph/mod.rs +++ b/query-engine/core/src/query_graph/mod.rs @@ -594,6 +594,7 @@ impl QueryGraph { /// - ... not an `if`-flow node themself /// - ... not already connected to the current `if`-flow node in any form (to prevent double edges) /// - ... not connected to another `if`-flow node with control flow edges (indirect sibling) + /// /// will be ordered below the currently processed `if`-flow node in execution predence. /// /// ```text diff --git a/query-engine/core/src/query_graph_builder/extractors/filters/mod.rs b/query-engine/core/src/query_graph_builder/extractors/filters/mod.rs index c87da451ff2b..803dd6100c43 100644 --- a/query-engine/core/src/query_graph_builder/extractors/filters/mod.rs +++ b/query-engine/core/src/query_graph_builder/extractors/filters/mod.rs @@ -193,11 +193,11 @@ where /// are merged together to optimize the generated SQL statements. /// This is done in three steps (below transformations are using pseudo-code): /// 1. We flatten the filter tree. -/// eg: `Filter(And([ScalarFilter, ScalarFilter], And([ScalarFilter])))` -> `Filter(And([ScalarFilter, ScalarFilter, ScalarFilter]))` +/// eg: `Filter(And([ScalarFilter, ScalarFilter], And([ScalarFilter])))` -> `Filter(And([ScalarFilter, ScalarFilter, ScalarFilter]))` /// 2. We index search filters by their query. -/// eg: `Filter(And([SearchFilter("query", [FieldA]), SearchFilter("query", [FieldB])]))` -> `{ "query": [FieldA, FieldB] }` +/// eg: `Filter(And([SearchFilter("query", [FieldA]), SearchFilter("query", [FieldB])]))` -> `{ "query": [FieldA, FieldB] }` /// 3. We reconstruct the filter tree and merge the search filters that have the same query along the way -/// eg: `Filter(And([SearchFilter("query", [FieldA]), SearchFilter("query", [FieldB])]))` -> `Filter(And([SearchFilter("query", [FieldA, FieldB])]))` +/// eg: `Filter(And([SearchFilter("query", [FieldA]), SearchFilter("query", [FieldB])]))` -> `Filter(And([SearchFilter("query", [FieldA, FieldB])]))` fn merge_search_filters(filter: Filter) -> Filter { // The filter tree _needs_ to be flattened for the merge to work properly let flattened = fold_filter(filter); @@ -309,7 +309,7 @@ fn extract_relation_filters( // Implicit is ParsedInputValue::Map(filter_map) => { - extract_filter(filter_map, &field.related_model()).map(|filter| vec![field.to_one_related(filter)]) + extract_filter(filter_map, field.related_model()).map(|filter| vec![field.to_one_related(filter)]) } x => Err(QueryGraphBuilderError::InputError(format!( diff --git a/query-engine/core/src/query_graph_builder/extractors/filters/relation.rs b/query-engine/core/src/query_graph_builder/extractors/filters/relation.rs index 47ec7ab9d193..3e497fc4d3a8 100644 --- a/query-engine/core/src/query_graph_builder/extractors/filters/relation.rs +++ b/query-engine/core/src/query_graph_builder/extractors/filters/relation.rs @@ -13,14 +13,14 @@ pub fn parse( match (filter_key, value) { // Relation list filters - (filters::SOME, Some(value)) => Ok(field.at_least_one_related(extract_filter(value, &field.related_model())?)), - (filters::NONE, Some(value)) => Ok(field.no_related(extract_filter(value, &field.related_model())?)), - (filters::EVERY, Some(value)) => Ok(field.every_related(extract_filter(value, &field.related_model())?)), + (filters::SOME, Some(value)) => Ok(field.at_least_one_related(extract_filter(value, field.related_model())?)), + (filters::NONE, Some(value)) => Ok(field.no_related(extract_filter(value, field.related_model())?)), + (filters::EVERY, Some(value)) => Ok(field.every_related(extract_filter(value, field.related_model())?)), // One-relation filters - (filters::IS, Some(value)) => Ok(field.to_one_related(extract_filter(value, &field.related_model())?)), + (filters::IS, Some(value)) => Ok(field.to_one_related(extract_filter(value, field.related_model())?)), (filters::IS, None) => Ok(field.one_relation_is_null()), - (filters::IS_NOT, Some(value)) => Ok(field.no_related(extract_filter(value, &field.related_model())?)), + (filters::IS_NOT, Some(value)) => Ok(field.no_related(extract_filter(value, field.related_model())?)), (filters::IS_NOT, None) => Ok(Filter::not(vec![field.one_relation_is_null()])), _ => Err(QueryGraphBuilderError::InputError(format!( diff --git a/query-engine/core/src/query_graph_builder/write/nested/connect_nested.rs b/query-engine/core/src/query_graph_builder/write/nested/connect_nested.rs index 81038c18a57e..a83bc6adfa85 100644 --- a/query-engine/core/src/query_graph_builder/write/nested/connect_nested.rs +++ b/query-engine/core/src/query_graph_builder/write/nested/connect_nested.rs @@ -327,7 +327,7 @@ fn handle_one_to_many( /// - Parent gets injected with a child on x, because that's what the connect is supposed to do. /// - The update runs, the relation is updated. /// - Now the check runs, because it's dependent on the parent's ID... but the check finds an existing child and fails... -/// ... because we just updated the relation. +/// ... because we just updated the relation. /// /// This is why we need to have an extra update at the end if it's inlined on the parent and a non-create. fn handle_one_to_one( diff --git a/query-engine/core/src/query_graph_builder/write/nested/create_nested.rs b/query-engine/core/src/query_graph_builder/write/nested/create_nested.rs index 72a299c472f6..7414018d818e 100644 --- a/query-engine/core/src/query_graph_builder/write/nested/create_nested.rs +++ b/query-engine/core/src/query_graph_builder/write/nested/create_nested.rs @@ -415,7 +415,7 @@ fn handle_one_to_many( /// - Parent gets injected with a child on x, because that's what the nested create is supposed to do. /// - The update runs, the relation is updated. /// - Now the check runs, because it's dependent on the parent's ID... but the check finds an existing child and fails... -/// ... because we just updated the relation. +/// ... because we just updated the relation. /// /// For these reasons, we need to have an extra update at the end if it's inlined on the parent and a non-create. fn handle_one_to_one( diff --git a/query-engine/core/src/telemetry/capturing/tx_ext.rs b/query-engine/core/src/telemetry/capturing/tx_ext.rs deleted file mode 100644 index 6b1b4905ab57..000000000000 --- a/query-engine/core/src/telemetry/capturing/tx_ext.rs +++ /dev/null @@ -1,88 +0,0 @@ -use std::collections::HashMap; - -pub trait TxTraceExt { - fn into_trace_id(self) -> opentelemetry::trace::TraceId; - fn into_trace_context(self) -> opentelemetry::Context; - fn as_traceparent(&self) -> String; -} - -impl TxTraceExt for crate::TxId { - // in order to convert a TxId (a 48 bytes cuid) into a TraceId (16 bytes), we remove the first byte, - // (always 'c') and get the next 16 bytes, which are random enough to be used as a trace id. - // this is a typical cuid: "c-lct0q6ma-0004-rb04-h6en1roa" - // - // - first letter is always the same - // - next 7-8 byte are random a timestamp. There's more entropy in the least significative bytes - // - next 4 bytes are a counter since the server started - // - next 4 bytes are a system fingerprint, invariant for the same server instance - // - least significative 8 bytes. Totally random. - // - // We want the most entropic slice of 16 bytes that's deterministicly determined - fn into_trace_id(self) -> opentelemetry::trace::TraceId { - let mut buffer = [0; 16]; - let str = self.to_string(); - let tx_id_bytes = str.as_bytes(); - let len = tx_id_bytes.len(); - - // bytes [len-20 to len-12): least significative 4 bytes of the timestamp + 4 bytes counter - for (i, source_idx) in (len - 20..len - 12).enumerate() { - buffer[i] = tx_id_bytes[source_idx]; - } - // bytes [len-8 to len): the random blocks - for (i, source_idx) in (len - 8..len).enumerate() { - buffer[i + 8] = tx_id_bytes[source_idx]; - } - - opentelemetry::trace::TraceId::from_bytes(buffer) - } - // This is a bit of a hack, but it's the only way to have a default trace span for a whole - // transaction when no traceparent is propagated from the client. - // - // This is done so we can capture traces happening accross the different queries in a - // transaction. Otherwise, if a traceparent is not propagated from the client, each query in - // the transaction will run within a span that has already been generated at the begining of the - // transaction, and held active in the actor in charge of running the queries. Thus, making - // impossible to capture traces happening in the individual queries, as they won't be aware of - // the transaction they are part of. - // - // By generating this "fake" traceparent based on the transaction id, we can have a common - // trace_id for all transaction operations. - fn into_trace_context(self) -> opentelemetry::Context { - let extractor: HashMap = - HashMap::from_iter(vec![("traceparent".to_string(), self.as_traceparent())]); - opentelemetry::global::get_text_map_propagator(|propagator| propagator.extract(&extractor)) - } - - fn as_traceparent(&self) -> String { - let trace_id = self.clone().into_trace_id(); - format!("00-{trace_id}-0000000000000001-01") - } -} - -// tests for txid into traits -#[cfg(test)] -mod test { - use super::*; - use crate::TxId; - - #[test] - fn test_txid_into_traceid() { - let fixture = vec![ - ("clct0q6ma0000rb04768tiqbj", "71366d6130303030373638746971626a"), - // counter changed, trace id changed: - ("clct0q6ma0002rb04cpa6zkmx", "71366d6130303032637061367a6b6d78"), - // fingerprint changed, trace id did not change, as that chunk is ignored: - ("clct0q6ma00020000cpa6zkmx", "71366d6130303032637061367a6b6d78"), - // first 5 bytes changed, trace id did not change, as that chunk is ignored: - ("00000q6ma00020000cpa6zkmx", "71366d6130303032637061367a6b6d78"), - // 6 th byte changed, trace id changed, as that chunk is part of the lsb of the timestamp - ("0000006ma00020000cpa6zkmx", "30366d6130303032637061367a6b6d78"), - ]; - - for (txid, expected_trace_id) in fixture { - let txid: TxId = txid.into(); - let trace_id: opentelemetry::trace::TraceId = txid.into_trace_id(); - assert_eq!(trace_id.to_string(), expected_trace_id); - } - } -} diff --git a/query-engine/core/src/telemetry/helpers.rs b/query-engine/core/src/telemetry/helpers.rs deleted file mode 100644 index 30c63ed6693f..000000000000 --- a/query-engine/core/src/telemetry/helpers.rs +++ /dev/null @@ -1,128 +0,0 @@ -use super::models::TraceSpan; -use once_cell::sync::Lazy; -use opentelemetry::sdk::export::trace::SpanData; -use opentelemetry::trace::{TraceContextExt, TraceId}; -use opentelemetry::Context; -use serde_json::{json, Value}; -use std::collections::HashMap; -use tracing::{Metadata, Span}; -use tracing_opentelemetry::OpenTelemetrySpanExt; -use tracing_subscriber::EnvFilter; - -pub static SHOW_ALL_TRACES: Lazy = Lazy::new(|| match std::env::var("PRISMA_SHOW_ALL_TRACES") { - Ok(enabled) => enabled.eq_ignore_ascii_case("true"), - Err(_) => false, -}); - -pub fn spans_to_json(spans: Vec) -> String { - let json_spans: Vec = spans.into_iter().map(|span| json!(TraceSpan::from(span))).collect(); - let span_result = json!({ - "span": true, - "spans": json_spans - }); - serde_json::to_string(&span_result).unwrap_or_default() -} - -// set the parent context and return the traceparent -pub fn set_parent_context_from_json_str(span: &Span, trace: &str) -> Option { - let trace: HashMap = serde_json::from_str(trace).unwrap_or_default(); - let trace_id = trace.get("traceparent").map(String::from); - let cx = opentelemetry::global::get_text_map_propagator(|propagator| propagator.extract(&trace)); - span.set_parent(cx); - trace_id -} - -pub fn set_span_link_from_traceparent(span: &Span, traceparent: Option) { - if let Some(traceparent) = traceparent { - let trace: HashMap = HashMap::from([("traceparent".to_string(), traceparent)]); - let cx = opentelemetry::global::get_text_map_propagator(|propagator| propagator.extract(&trace)); - let context_span = cx.span(); - span.add_link(context_span.span_context().clone()); - } -} - -pub fn get_trace_parent_from_span(span: &Span) -> String { - let cx = span.context(); - let binding = cx.span(); - let span_context = binding.span_context(); - - format!("00-{}-{}-01", span_context.trace_id(), span_context.span_id()) -} - -pub fn get_trace_id_from_span(span: &Span) -> TraceId { - let cx = span.context(); - get_trace_id_from_context(&cx) -} - -pub fn get_trace_id_from_context(context: &Context) -> TraceId { - let context_span = context.span(); - context_span.span_context().trace_id() -} - -pub fn get_trace_id_from_traceparent(traceparent: Option<&str>) -> TraceId { - traceparent - .unwrap_or("0-0-0-0") - .split('-') - .nth(1) - .map(|id| TraceId::from_hex(id).unwrap_or(TraceId::INVALID)) - .unwrap() -} - -pub enum QueryEngineLogLevel { - FromEnv, - Override(String), -} - -impl QueryEngineLogLevel { - fn level(self) -> Option { - match self { - Self::FromEnv => std::env::var("QE_LOG_LEVEL").ok(), - Self::Override(l) => Some(l), - } - } -} - -#[rustfmt::skip] -pub fn env_filter(log_queries: bool, qe_log_level: QueryEngineLogLevel) -> EnvFilter { - let mut filter = EnvFilter::from_default_env() - .add_directive("tide=error".parse().unwrap()) - .add_directive("tonic=error".parse().unwrap()) - .add_directive("h2=error".parse().unwrap()) - .add_directive("hyper=error".parse().unwrap()) - .add_directive("tower=error".parse().unwrap()); - - if let Some(level) = qe_log_level.level() { - filter = filter - .add_directive(format!("query_engine={}", &level).parse().unwrap()) - .add_directive(format!("query_core={}", &level).parse().unwrap()) - .add_directive(format!("query_connector={}", &level).parse().unwrap()) - .add_directive(format!("sql_query_connector={}", &level).parse().unwrap()) - .add_directive(format!("mongodb_query_connector={}", &level).parse().unwrap()); - } - - if log_queries { - filter = filter - .add_directive("quaint[{is_query}]=trace".parse().unwrap()) - .add_directive("mongodb_query_connector=debug".parse().unwrap()); - } - - filter -} - -pub fn user_facing_span_only_filter(meta: &Metadata<'_>) -> bool { - if !meta.is_span() { - return false; - } - - if *SHOW_ALL_TRACES { - return true; - } - - if meta.fields().iter().any(|f| f.name() == "user_facing") { - return true; - } - - // spans describing a quaint query. - // TODO: should this span be made user_facing in quaint? - meta.target() == "quaint::connector::metrics" && meta.name() == "quaint:query" -} diff --git a/query-engine/dmmf/src/tests/test-schemas/snapshots/odoo.snapshot.json.gz b/query-engine/dmmf/src/tests/test-schemas/snapshots/odoo.snapshot.json.gz index f65efcc2c5a7..9190b3618205 100644 Binary files a/query-engine/dmmf/src/tests/test-schemas/snapshots/odoo.snapshot.json.gz and b/query-engine/dmmf/src/tests/test-schemas/snapshots/odoo.snapshot.json.gz differ diff --git a/query-engine/driver-adapters/Cargo.toml b/query-engine/driver-adapters/Cargo.toml index 5bda20fc10a2..606b33e96422 100644 --- a/query-engine/driver-adapters/Cargo.toml +++ b/query-engine/driver-adapters/Cargo.toml @@ -10,21 +10,20 @@ postgresql = ["quaint/postgresql"] [dependencies] async-trait.workspace = true +futures.workspace = true once_cell = "1.15" +prisma-metrics.path = "../../libs/metrics" serde.workspace = true serde_json.workspace = true tracing.workspace = true tracing-core = "0.1" -metrics = "0.18" uuid.workspace = true -pin-project = "1" +pin-project.workspace = true serde_repr.workspace = true -futures = "0.3" - [dev-dependencies] expect-test = "1" -tokio = { version = "1.0", features = ["macros", "time", "sync"] } +tokio = { version = "1", features = ["macros", "time", "sync"] } wasm-rs-dbg.workspace = true [target.'cfg(not(target_arch = "wasm32"))'.dependencies] diff --git a/query-engine/driver-adapters/executor/src/recording.ts b/query-engine/driver-adapters/executor/src/recording.ts index 88b9d369bc23..5ac0f52b4cb7 100644 --- a/query-engine/driver-adapters/executor/src/recording.ts +++ b/query-engine/driver-adapters/executor/src/recording.ts @@ -21,7 +21,7 @@ function recorder(adapter: DriverAdapter, recordings: Recordings) { return { provider: adapter.provider, adapterName: adapter.adapterName, - startTransaction: () => { + transactionContext: () => { throw new Error("Not implemented"); }, getConnectionInfo: () => { @@ -43,7 +43,7 @@ function replayer(adapter: DriverAdapter, recordings: Recordings) { provider: adapter.provider, adapterName: adapter.adapterName, recordings: recordings, - startTransaction: () => { + transactionContext: () => { throw new Error("Not implemented"); }, getConnectionInfo: () => { diff --git a/query-engine/driver-adapters/src/conversion/js_arg_type.rs b/query-engine/driver-adapters/src/conversion/js_arg_type.rs new file mode 100644 index 000000000000..e1ea7c1c5754 --- /dev/null +++ b/query-engine/driver-adapters/src/conversion/js_arg_type.rs @@ -0,0 +1,93 @@ +/// `JSArgType` is a 1:1 mapping of [`quaint::ValueType`] that: +/// - only includes the type tag (e.g. `Int32`, `Text`, `Enum`, etc.) +/// - doesn't care for the optionality of the actual value (e.g., `quaint::Value::Int32(None)` -> `JSArgType::Int32`) +/// - is used to guide the JS side on how to serialize the query argument value before sending it to the JS driver. +#[derive(Debug, PartialEq)] +pub enum JSArgType { + /// 32-bit signed integer. + Int32, + /// 64-bit signed integer. + Int64, + /// 32-bit floating point. + Float, + /// 64-bit floating point. + Double, + /// String value. + Text, + /// Database enum value. + Enum, + /// Database enum array (PostgreSQL specific). + EnumArray, + /// Bytes value. + Bytes, + /// Boolean value. + Boolean, + /// A single character. + Char, + /// An array value (PostgreSQL). + Array, + /// A numeric value. + Numeric, + /// A JSON value. + Json, + /// A XML value. + Xml, + /// An UUID value. + Uuid, + /// A datetime value. + DateTime, + /// A date value. + Date, + /// A time value. + Time, +} + +impl core::fmt::Display for JSArgType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + JSArgType::Int32 => "Int32", + JSArgType::Int64 => "Int64", + JSArgType::Float => "Float", + JSArgType::Double => "Double", + JSArgType::Text => "Text", + JSArgType::Enum => "Enum", + JSArgType::EnumArray => "EnumArray", + JSArgType::Bytes => "Bytes", + JSArgType::Boolean => "Boolean", + JSArgType::Char => "Char", + JSArgType::Array => "Array", + JSArgType::Numeric => "Numeric", + JSArgType::Json => "Json", + JSArgType::Xml => "Xml", + JSArgType::Uuid => "Uuid", + JSArgType::DateTime => "DateTime", + JSArgType::Date => "Date", + JSArgType::Time => "Time", + }; + + write!(f, "{}", s) + } +} + +pub fn value_to_js_arg_type(value: &quaint::Value) -> JSArgType { + match &value.typed { + quaint::ValueType::Int32(_) => JSArgType::Int32, + quaint::ValueType::Int64(_) => JSArgType::Int64, + quaint::ValueType::Float(_) => JSArgType::Float, + quaint::ValueType::Double(_) => JSArgType::Double, + quaint::ValueType::Text(_) => JSArgType::Text, + quaint::ValueType::Enum(_, _) => JSArgType::Enum, + quaint::ValueType::EnumArray(_, _) => JSArgType::EnumArray, + quaint::ValueType::Bytes(_) => JSArgType::Bytes, + quaint::ValueType::Boolean(_) => JSArgType::Boolean, + quaint::ValueType::Char(_) => JSArgType::Char, + quaint::ValueType::Array(_) => JSArgType::Array, + quaint::ValueType::Numeric(_) => JSArgType::Numeric, + quaint::ValueType::Json(_) => JSArgType::Json, + quaint::ValueType::Xml(_) => JSArgType::Xml, + quaint::ValueType::Uuid(_) => JSArgType::Uuid, + quaint::ValueType::DateTime(_) => JSArgType::DateTime, + quaint::ValueType::Date(_) => JSArgType::Date, + quaint::ValueType::Time(_) => JSArgType::Time, + } +} diff --git a/query-engine/driver-adapters/src/conversion/js_to_quaint.rs b/query-engine/driver-adapters/src/conversion/js_to_quaint.rs index 8fb07d6f6230..0282248701ad 100644 --- a/query-engine/driver-adapters/src/conversion/js_to_quaint.rs +++ b/query-engine/driver-adapters/src/conversion/js_to_quaint.rs @@ -5,7 +5,7 @@ pub use crate::types::{ColumnType, JSResultSet}; use quaint::bigdecimal::BigDecimal; use quaint::chrono::{DateTime, NaiveDate, NaiveTime, Utc}; use quaint::{ - connector::ResultSet as QuaintResultSet, + connector::{ColumnType as QuaintColumnType, ResultSet as QuaintResultSet}, error::{Error as QuaintError, ErrorKind}, Value as QuaintValue, }; @@ -22,6 +22,7 @@ impl TryFrom for QuaintResultSet { } = js_result_set; let mut quaint_rows = Vec::with_capacity(rows.len()); + let quaint_column_types = column_types.iter().map(QuaintColumnType::from).collect::>(); for row in rows { let mut quaint_row = Vec::with_capacity(column_types.len()); @@ -37,7 +38,7 @@ impl TryFrom for QuaintResultSet { } let last_insert_id = last_insert_id.and_then(|id| id.parse::().ok()); - let mut quaint_result_set = QuaintResultSet::new(column_names, quaint_rows); + let mut quaint_result_set = QuaintResultSet::new(column_names, quaint_column_types, quaint_rows); // Not a fan of this (extracting the `Some` value from an `Option` and pass it to a method that creates a new `Some` value), // but that's Quaint's ResultSet API and that's how the MySQL connector does it. diff --git a/query-engine/driver-adapters/src/conversion/mod.rs b/query-engine/driver-adapters/src/conversion/mod.rs index 8b7808517ba4..5d07fd33dc9b 100644 --- a/query-engine/driver-adapters/src/conversion/mod.rs +++ b/query-engine/driver-adapters/src/conversion/mod.rs @@ -1,4 +1,5 @@ pub(crate) mod js_arg; +pub(crate) mod js_arg_type; pub(crate) mod js_to_quaint; #[cfg(feature = "mysql")] @@ -9,3 +10,4 @@ pub(crate) mod postgres; pub(crate) mod sqlite; pub use js_arg::JSArg; +pub use js_arg_type::{value_to_js_arg_type, JSArgType}; diff --git a/query-engine/driver-adapters/src/conversion/sqlite.rs b/query-engine/driver-adapters/src/conversion/sqlite.rs index af070ec0b2cd..0648f713da89 100644 --- a/query-engine/driver-adapters/src/conversion/sqlite.rs +++ b/query-engine/driver-adapters/src/conversion/sqlite.rs @@ -36,7 +36,7 @@ mod test { #[rustfmt::skip] fn test_value_to_js_arg() { let test_cases = vec![ - ( + ( // This is different than how mysql or postgres processes integral BigInt values. ValueType::Numeric(Some(1.into())), JSArg::Value(Value::Number("1.0".parse().unwrap())) diff --git a/query-engine/driver-adapters/src/lib.rs b/query-engine/driver-adapters/src/lib.rs index 55c7de41eb8d..137df06d7315 100644 --- a/query-engine/driver-adapters/src/lib.rs +++ b/query-engine/driver-adapters/src/lib.rs @@ -117,14 +117,14 @@ mod arch { pub(crate) fn get_named_property(object: &::napi::JsObject, name: &str) -> JsResult where - T: ::napi::bindgen_prelude::FromNapiValue, + T: ::napi::bindgen_prelude::FromNapiValue + ::napi::bindgen_prelude::ValidateNapiValue, { object.get_named_property(name) } pub(crate) fn get_optional_named_property(object: &::napi::JsObject, name: &str) -> JsResult> where - T: ::napi::bindgen_prelude::FromNapiValue, + T: ::napi::bindgen_prelude::FromNapiValue + ::napi::bindgen_prelude::ValidateNapiValue, { if has_named_property(object, name)? { Ok(Some(get_named_property(object, name)?)) diff --git a/query-engine/driver-adapters/src/napi/adapter_method.rs b/query-engine/driver-adapters/src/napi/adapter_method.rs index ae92117c2713..658c2003e9e9 100644 --- a/query-engine/driver-adapters/src/napi/adapter_method.rs +++ b/query-engine/driver-adapters/src/napi/adapter_method.rs @@ -13,7 +13,7 @@ use crate::AdapterResult; /// - Automatically unrefs the function so it won't hold off event loop /// - Awaits for returned Promise /// - Unpacks JS `Result` type into Rust `Result` type and converts the error -/// into `quaint::Error`. +/// into `quaint::Error`. /// - Catches panics and converts them to `quaint:Error` pub(crate) struct AdapterMethod where @@ -79,3 +79,24 @@ where Self::from_threadsafe_function(threadsafe_fn, env) } } + +impl ValidateNapiValue for AdapterMethod +where + ArgType: ToNapiValue + 'static, + ReturnType: FromNapiValue + 'static, +{ +} + +impl TypeName for AdapterMethod +where + ArgType: ToNapiValue + 'static, + ReturnType: FromNapiValue + 'static, +{ + fn type_name() -> &'static str { + "AdapterMethod" + } + + fn value_type() -> ValueType { + ValueType::Function + } +} diff --git a/query-engine/driver-adapters/src/napi/conversion.rs b/query-engine/driver-adapters/src/napi/conversion.rs index 2fab5a28bb73..e1ff0ec4b99c 100644 --- a/query-engine/driver-adapters/src/napi/conversion.rs +++ b/query-engine/driver-adapters/src/napi/conversion.rs @@ -1,4 +1,4 @@ -pub(crate) use crate::conversion::JSArg; +pub(crate) use crate::conversion::{JSArg, JSArgType}; use napi::bindgen_prelude::{FromNapiValue, ToNapiValue}; use napi::NapiValue; @@ -12,6 +12,12 @@ impl FromNapiValue for JSArg { } } +impl FromNapiValue for JSArgType { + unsafe fn from_napi_value(_env: napi::sys::napi_env, _napi_value: napi::sys::napi_value) -> napi::Result { + unreachable!() + } +} + // ToNapiValue is the napi equivalent to serde::Serialize. impl ToNapiValue for JSArg { unsafe fn to_napi_value(env: napi::sys::napi_env, value: Self) -> napi::Result { @@ -46,3 +52,9 @@ impl ToNapiValue for JSArg { } } } + +impl ToNapiValue for JSArgType { + unsafe fn to_napi_value(env: napi::sys::napi_env, value: Self) -> napi::Result { + ToNapiValue::to_napi_value(env, value.to_string()) + } +} diff --git a/query-engine/driver-adapters/src/napi/result.rs b/query-engine/driver-adapters/src/napi/result.rs index 529455bf9a0b..466658df3295 100644 --- a/query-engine/driver-adapters/src/napi/result.rs +++ b/query-engine/driver-adapters/src/napi/result.rs @@ -1,5 +1,8 @@ use crate::error::DriverAdapterError; -use napi::{bindgen_prelude::FromNapiValue, Env, JsUnknown, NapiValue}; +use napi::{ + bindgen_prelude::{FromNapiValue, TypeName, ValidateNapiValue}, + Env, JsUnknown, NapiValue, +}; impl FromNapiValue for DriverAdapterError { unsafe fn from_napi_value(napi_env: napi::sys::napi_env, napi_val: napi::sys::napi_value) -> napi::Result { @@ -9,6 +12,18 @@ impl FromNapiValue for DriverAdapterError { } } +impl ValidateNapiValue for DriverAdapterError {} + +impl TypeName for DriverAdapterError { + fn type_name() -> &'static str { + "DriverAdapterError" + } + + fn value_type() -> napi::ValueType { + napi::ValueType::Object + } +} + /// Wrapper for JS-side result type. /// This Napi-specific implementation has the same shape and API as the Wasm implementation, /// but it asks for a `FromNapiValue` bound on the generic type. diff --git a/query-engine/driver-adapters/src/proxy.rs b/query-engine/driver-adapters/src/proxy.rs index 8e1d39138cb6..cf78a4cbb88d 100644 --- a/query-engine/driver-adapters/src/proxy.rs +++ b/query-engine/driver-adapters/src/proxy.rs @@ -1,13 +1,13 @@ -use crate::send_future::UnsafeFuture; use crate::types::JsConnectionInfo; pub use crate::types::{JSResultSet, Query, TransactionOptions}; use crate::{ from_js_value, get_named_property, get_optional_named_property, to_rust_str, AdapterMethod, JsObject, JsResult, JsString, JsTransaction, }; +use crate::{send_future::UnsafeFuture, transaction::JsTransactionContext}; use futures::Future; -use metrics::increment_gauge; +use prisma_metrics::gauge; use std::sync::atomic::{AtomicBool, Ordering}; /// Proxy is a struct wrapping a javascript object that exhibits basic primitives for @@ -28,8 +28,19 @@ pub(crate) struct CommonProxy { /// This is a JS proxy for accessing the methods specific to top level /// JS driver objects pub(crate) struct DriverProxy { - start_transaction: AdapterMethod<(), JsTransaction>, + /// Retrieve driver-specific info, such as the maximum number of query parameters get_connection_info: Option>, + + /// Provide a transaction context, in which raw commands are guaranteed to be executed in + /// the same scope as a future transaction, which can be spawned by via + /// [`driver_adapters::transaction::JsTransactionContext::start_transaction`]. + /// This was first introduced for supporting Isolation Levels in PlanetScale. + transaction_context: AdapterMethod<(), JsTransactionContext>, +} + +/// This is a JS proxy for accessing the methods specific to JS transaction contexts. +pub(crate) struct TransactionContextProxy { + start_transaction: AdapterMethod<(), JsTransaction>, } /// This a JS proxy for accessing the methods, specific @@ -48,6 +59,7 @@ pub(crate) struct TransactionProxy { closed: AtomicBool, } +// TypeScript: Queryable impl CommonProxy { pub fn new(object: &JsObject) -> JsResult { let provider: JsString = get_named_property(object, "provider")?; @@ -68,11 +80,12 @@ impl CommonProxy { } } +// TypeScript: DriverAdapter impl DriverProxy { pub fn new(object: &JsObject) -> JsResult { Ok(Self { - start_transaction: get_named_property(object, "startTransaction")?, get_connection_info: get_optional_named_property(object, "getConnectionInfo")?, + transaction_context: get_named_property(object, "transactionContext")?, }) } @@ -87,6 +100,20 @@ impl DriverProxy { .await } + pub async fn transaction_context(&self) -> quaint::Result { + let ctx = self.transaction_context.call_as_async(()).await?; + + Ok(ctx) + } +} + +impl TransactionContextProxy { + pub fn new(object: &JsObject) -> JsResult { + let start_transaction = get_named_property(object, "startTransaction")?; + + Ok(Self { start_transaction }) + } + async fn start_transaction_inner(&self) -> quaint::Result> { let tx = self.start_transaction.call_as_async(()).await?; @@ -94,11 +121,11 @@ impl DriverProxy { // Previously, it was done in JsTransaction::new, similar to the native Transaction. // However, correct Dispatcher is lost there and increment does not register, so we moved // it here instead. - increment_gauge!("prisma_client_queries_active", 1.0); + gauge!("prisma_client_queries_active").increment(1.0); Ok(Box::new(tx)) } - pub fn start_transaction(&self) -> UnsafeFuture>> + '_> { + pub fn start_transaction(&self) -> impl Future>> + '_ { UnsafeFuture(self.start_transaction_inner()) } } @@ -184,6 +211,8 @@ macro_rules! impl_send_sync_on_wasm { // Assume the proxy object will not be sent to service workers, we can unsafe impl Send + Sync. impl_send_sync_on_wasm!(TransactionProxy); +impl_send_sync_on_wasm!(JsTransaction); +impl_send_sync_on_wasm!(TransactionContextProxy); +impl_send_sync_on_wasm!(JsTransactionContext); impl_send_sync_on_wasm!(DriverProxy); impl_send_sync_on_wasm!(CommonProxy); -impl_send_sync_on_wasm!(JsTransaction); diff --git a/query-engine/driver-adapters/src/queryable.rs b/query-engine/driver-adapters/src/queryable.rs index 278c4e4f514e..8aa7579762f6 100644 --- a/query-engine/driver-adapters/src/queryable.rs +++ b/query-engine/driver-adapters/src/queryable.rs @@ -6,7 +6,7 @@ use super::conversion; use crate::send_future::UnsafeFuture; use async_trait::async_trait; use futures::Future; -use quaint::connector::{ExternalConnectionInfo, ExternalConnector}; +use quaint::connector::{DescribedQuery, ExternalConnectionInfo, ExternalConnector}; use quaint::{ connector::{metrics, IsolationLevel, Transaction}, error::{Error, ErrorKind}, @@ -30,12 +30,18 @@ use tracing::{info_span, Instrument}; pub(crate) struct JsBaseQueryable { pub(crate) proxy: CommonProxy, pub provider: AdapterFlavour, + pub(crate) db_system_name: &'static str, } impl JsBaseQueryable { pub(crate) fn new(proxy: CommonProxy) -> Self { let provider: AdapterFlavour = proxy.provider.parse().unwrap(); - Self { proxy, provider } + let db_system_name = provider.db_system_name(); + Self { + proxy, + provider, + db_system_name, + } } /// visit a quaint query AST according to the provider of the JS connector @@ -53,7 +59,7 @@ impl JsBaseQueryable { async fn build_query(&self, sql: &str, values: &[quaint::Value<'_>]) -> quaint::Result { let sql: String = sql.to_string(); - let converter = match self.provider { + let args_converter = match self.provider { #[cfg(feature = "postgresql")] AdapterFlavour::Postgres => conversion::postgres::value_to_js_arg, #[cfg(feature = "sqlite")] @@ -64,10 +70,15 @@ impl JsBaseQueryable { let args = values .iter() - .map(converter) + .map(args_converter) .collect::>>()?; - Ok(Query { sql, args }) + let arg_types = values + .iter() + .map(conversion::value_to_js_arg_type) + .collect::>(); + + Ok(Query { sql, args, arg_types }) } } @@ -79,7 +90,7 @@ impl QuaintQueryable for JsBaseQueryable { } async fn query_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - metrics::query("js.query_raw", sql, params, move || async move { + metrics::query("js.query_raw", self.db_system_name, sql, params, move || async move { self.do_query_raw(sql, params).await }) .await @@ -89,13 +100,17 @@ impl QuaintQueryable for JsBaseQueryable { self.query_raw(sql, params).await } + async fn describe_query(&self, sql: &str) -> quaint::Result { + self.describe_query(sql).await + } + async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { let (sql, params) = self.visit_quaint_query(q)?; self.execute_raw(&sql, ¶ms).await } async fn execute_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - metrics::query("js.execute_raw", sql, params, move || async move { + metrics::query("js.execute_raw", self.db_system_name, sql, params, move || async move { self.do_execute_raw(sql, params).await }) .await @@ -107,7 +122,7 @@ impl QuaintQueryable for JsBaseQueryable { async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { let params = &[]; - metrics::query("js.raw_cmd", cmd, params, move || async move { + metrics::query("js.raw_cmd", self.db_system_name, cmd, params, move || async move { self.do_execute_raw(cmd, params).await?; Ok(()) }) @@ -165,7 +180,7 @@ impl JsBaseQueryable { let serialization_span = info_span!("js:query:args", user_facing = true, "length" = %len); let query = self.build_query(sql, params).instrument(serialization_span).await?; - let sql_span = info_span!("js:query:sql", user_facing = true, "db.statement" = %sql); + let sql_span = info_span!("js:query:sql", user_facing = true, "db.system" = %self.db_system_name, "db.statement" = %sql, "otel.kind" = "client"); let result_set = self.proxy.query_raw(query).instrument(sql_span).await?; let len = result_set.len(); @@ -187,7 +202,7 @@ impl JsBaseQueryable { let serialization_span = info_span!("js:query:args", user_facing = true, "length" = %len); let query = self.build_query(sql, params).instrument(serialization_span).await?; - let sql_span = info_span!("js:query:sql", user_facing = true, "db.statement" = %sql); + let sql_span = info_span!("js:query:sql", user_facing = true, "db.system" = %self.db_system_name, "db.statement" = %sql, "otel.kind" = "client"); let affected_rows = self.proxy.execute_raw(query).instrument(sql_span).await?; Ok(affected_rows as u64) @@ -255,6 +270,10 @@ impl QuaintQueryable for JsQueryable { self.inner.query_raw_typed(sql, params).await } + async fn describe_query(&self, sql: &str) -> quaint::Result { + self.inner.describe_query(sql).await + } + async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { self.inner.execute(q).await } @@ -288,25 +307,32 @@ impl QuaintQueryable for JsQueryable { } } -#[async_trait] -impl TransactionCapable for JsQueryable { - async fn start_transaction<'a>( +impl JsQueryable { + async fn start_transaction_inner<'a>( &'a self, isolation: Option, ) -> quaint::Result> { - let tx = self.driver_proxy.start_transaction().await?; + // 1. Obtain a transaction context from the driver. + // Any command run on this context is guaranteed to be part of the same session + // as the transaction spawned from it. + let tx_ctx = self.driver_proxy.transaction_context().await?; - let isolation_first = tx.requires_isolation_first(); + let requires_isolation_first = tx_ctx.requires_isolation_first(); - if isolation_first { + // 2. Set the isolation level (if specified) if the provider requires it to be set before + // creating the transaction. + if requires_isolation_first { if let Some(isolation) = isolation { - tx.set_tx_isolation_level(isolation).await?; + tx_ctx.set_tx_isolation_level(isolation).await?; } } - let begin_stmt = tx.begin_statement(); + // 3. Spawn a transaction from the context. + let tx = tx_ctx.start_transaction().await?; + let begin_stmt = tx.begin_statement(); let tx_opts = tx.options(); + if tx_opts.use_phantom_query { let begin_stmt = JsBaseQueryable::phantom_query_message(begin_stmt); tx.raw_phantom_cmd(begin_stmt.as_str()).await?; @@ -314,7 +340,8 @@ impl TransactionCapable for JsQueryable { tx.raw_cmd(begin_stmt).await?; } - if !isolation_first { + // 4. Set the isolation level (if specified) if we didn't do it before. + if !requires_isolation_first { if let Some(isolation) = isolation { tx.set_tx_isolation_level(isolation).await?; } @@ -326,6 +353,16 @@ impl TransactionCapable for JsQueryable { } } +#[async_trait] +impl TransactionCapable for JsQueryable { + async fn start_transaction<'a>( + &'a self, + isolation: Option, + ) -> quaint::Result> { + UnsafeFuture(self.start_transaction_inner(isolation)).await + } +} + pub fn from_js(driver: JsObject) -> JsQueryable { let common = CommonProxy::new(&driver).unwrap(); let driver_proxy = DriverProxy::new(&driver).unwrap(); diff --git a/query-engine/driver-adapters/src/transaction.rs b/query-engine/driver-adapters/src/transaction.rs index 264c363ea608..3a1167159ae5 100644 --- a/query-engine/driver-adapters/src/transaction.rs +++ b/query-engine/driver-adapters/src/transaction.rs @@ -1,15 +1,85 @@ +use std::future::Future; + use async_trait::async_trait; -use metrics::decrement_gauge; +use prisma_metrics::gauge; use quaint::{ - connector::{IsolationLevel, Transaction as QuaintTransaction}, + connector::{DescribedQuery, IsolationLevel, Transaction as QuaintTransaction}, prelude::{Query as QuaintQuery, Queryable, ResultSet}, Value, }; -use crate::proxy::{TransactionOptions, TransactionProxy}; +use crate::proxy::{TransactionContextProxy, TransactionOptions, TransactionProxy}; use crate::{proxy::CommonProxy, queryable::JsBaseQueryable, send_future::UnsafeFuture}; use crate::{JsObject, JsResult}; +pub(crate) struct JsTransactionContext { + tx_ctx_proxy: TransactionContextProxy, + inner: JsBaseQueryable, +} + +// Wrapper around JS transaction context objects that implements Queryable. Can be used in place of quaint transaction, +// context, but delegates most operations to JS +impl JsTransactionContext { + pub(crate) fn new(inner: JsBaseQueryable, tx_ctx_proxy: TransactionContextProxy) -> Self { + Self { inner, tx_ctx_proxy } + } + + pub fn start_transaction(&self) -> impl Future>> + '_ { + UnsafeFuture(self.tx_ctx_proxy.start_transaction()) + } +} + +#[async_trait] +impl Queryable for JsTransactionContext { + async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { + self.inner.query(q).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.query_raw(sql, params).await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.query_raw_typed(sql, params).await + } + + async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { + self.inner.execute(q).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.execute_raw(sql, params).await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.execute_raw_typed(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { + self.inner.raw_cmd(cmd).await + } + + async fn version(&self) -> quaint::Result> { + self.inner.version().await + } + + async fn describe_query(&self, sql: &str) -> quaint::Result { + self.inner.describe_query(sql).await + } + + fn is_healthy(&self) -> bool { + self.inner.is_healthy() + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { + self.inner.set_tx_isolation_level(isolation_level).await + } + + fn requires_isolation_first(&self) -> bool { + self.inner.requires_isolation_first() + } +} + // Wrapper around JS transaction objects that implements Queryable // and quaint::Transaction. Can be used in place of quaint transaction, // but delegates most operations to JS @@ -29,7 +99,14 @@ impl JsTransaction { pub async fn raw_phantom_cmd(&self, cmd: &str) -> quaint::Result<()> { let params = &[]; - quaint::connector::metrics::query("js.raw_phantom_cmd", cmd, params, move || async move { Ok(()) }).await + quaint::connector::metrics::query( + "js.raw_phantom_cmd", + self.inner.db_system_name, + cmd, + params, + move || async move { Ok(()) }, + ) + .await } } @@ -37,7 +114,7 @@ impl JsTransaction { impl QuaintTransaction for JsTransaction { async fn commit(&self) -> quaint::Result<()> { // increment of this gauge is done in DriverProxy::startTransaction - decrement_gauge!("prisma_client_queries_active", 1.0); + gauge!("prisma_client_queries_active").decrement(1.0); let commit_stmt = "COMMIT"; @@ -53,7 +130,7 @@ impl QuaintTransaction for JsTransaction { async fn rollback(&self) -> quaint::Result<()> { // increment of this gauge is done in DriverProxy::startTransaction - decrement_gauge!("prisma_client_queries_active", 1.0); + gauge!("prisma_client_queries_active").decrement(1.0); let rollback_stmt = "ROLLBACK"; @@ -86,6 +163,10 @@ impl Queryable for JsTransaction { self.inner.query_raw_typed(sql, params).await } + async fn describe_query(&self, sql: &str) -> quaint::Result { + self.inner.describe_query(sql).await + } + async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { self.inner.execute(q).await } @@ -145,3 +226,30 @@ impl ::napi::bindgen_prelude::FromNapiValue for JsTransaction { Ok(Self::new(JsBaseQueryable::new(common_proxy), tx_proxy)) } } + +#[cfg(target_arch = "wasm32")] +impl super::wasm::FromJsValue for JsTransactionContext { + fn from_js_value(value: wasm_bindgen::prelude::JsValue) -> JsResult { + use wasm_bindgen::JsCast; + + let object = value.dyn_into::()?; + let common_proxy = CommonProxy::new(&object)?; + let base = JsBaseQueryable::new(common_proxy); + let tx_ctx_proxy = TransactionContextProxy::new(&object)?; + + Ok(Self::new(base, tx_ctx_proxy)) + } +} + +/// Implementing unsafe `from_napi_value` allows retrieving a threadsafe `JsTransactionContext` in `DriverProxy` +/// while keeping derived futures `Send`. +#[cfg(not(target_arch = "wasm32"))] +impl ::napi::bindgen_prelude::FromNapiValue for JsTransactionContext { + unsafe fn from_napi_value(env: napi::sys::napi_env, napi_val: napi::sys::napi_value) -> JsResult { + let object = JsObject::from_napi_value(env, napi_val)?; + let common_proxy = CommonProxy::new(&object)?; + let tx_ctx_proxy = TransactionContextProxy::new(&object)?; + + Ok(Self::new(JsBaseQueryable::new(common_proxy), tx_ctx_proxy)) + } +} diff --git a/query-engine/driver-adapters/src/types.rs b/query-engine/driver-adapters/src/types.rs index 1b4cbe531359..03f9c5d6325a 100644 --- a/query-engine/driver-adapters/src/types.rs +++ b/query-engine/driver-adapters/src/types.rs @@ -6,11 +6,11 @@ use std::str::FromStr; #[cfg(not(target_arch = "wasm32"))] use napi::bindgen_prelude::{FromNapiValue, ToNapiValue}; -use quaint::connector::{ExternalConnectionInfo, SqlFamily}; +use quaint::connector::{ColumnType as QuaintColumnType, ExternalConnectionInfo, SqlFamily}; #[cfg(target_arch = "wasm32")] use tsify::Tsify; -use crate::conversion::JSArg; +use crate::conversion::{JSArg, JSArgType}; use serde::{Deserialize, Serialize}; use serde_repr::{Deserialize_repr, Serialize_repr}; @@ -25,6 +25,19 @@ pub enum AdapterFlavour { Sqlite, } +impl AdapterFlavour { + pub fn db_system_name(&self) -> &'static str { + match self { + #[cfg(feature = "mysql")] + Self::Mysql => "mysql", + #[cfg(feature = "postgresql")] + Self::Postgres => "postgresql", + #[cfg(feature = "sqlite")] + Self::Sqlite => "sqlite", + } + } +} + impl FromStr for AdapterFlavour { type Err = String; @@ -126,129 +139,150 @@ impl JSResultSet { } } -#[cfg_attr(not(target_arch = "wasm32"), napi_derive::napi)] -#[cfg_attr(target_arch = "wasm32", derive(Clone, Copy, Deserialize_repr))] -#[repr(u8)] -#[derive(Debug)] -pub enum ColumnType { - // [PLANETSCALE_TYPE] (MYSQL_TYPE) -> [TypeScript example] +macro_rules! js_column_type { + ($($(#[$($attrss:tt)*])*$name:ident($val:expr) => $quaint_name:ident,)*) => { + #[cfg_attr(not(target_arch = "wasm32"), napi_derive::napi)] + #[cfg_attr(target_arch = "wasm32", derive(Clone, Copy, Deserialize_repr))] + #[repr(u8)] + #[derive(Debug)] + pub enum ColumnType { + $( + $(#[$($attrss)*])* + $name = $val, + )* + } + + impl From<&ColumnType> for QuaintColumnType { + fn from(value: &ColumnType) -> Self { + match value { + $(ColumnType::$name => QuaintColumnType::$quaint_name,)* + } + } + } + }; +} + +// JsColumnType(discriminant) => quaint::ColumnType +js_column_type! { + /// [PLANETSCALE_TYPE] (MYSQL_TYPE) -> [TypeScript example] /// The following PlanetScale type IDs are mapped into Int32: /// - INT8 (TINYINT) -> e.g. `127` /// - INT16 (SMALLINT) -> e.g. `32767` /// - INT24 (MEDIUMINT) -> e.g. `8388607` /// - INT32 (INT) -> e.g. `2147483647` - Int32 = 0, + Int32(0) => Int32, /// The following PlanetScale type IDs are mapped into Int64: /// - INT64 (BIGINT) -> e.g. `"9223372036854775807"` (String-encoded) - Int64 = 1, + Int64(1) => Int64, /// The following PlanetScale type IDs are mapped into Float: /// - FLOAT32 (FLOAT) -> e.g. `3.402823466` - Float = 2, + Float(2) => Float, /// The following PlanetScale type IDs are mapped into Double: /// - FLOAT64 (DOUBLE) -> e.g. `1.7976931348623157` - Double = 3, + Double(3) => Double, /// The following PlanetScale type IDs are mapped into Numeric: /// - DECIMAL (DECIMAL) -> e.g. `"99999999.99"` (String-encoded) - Numeric = 4, + Numeric(4) => Numeric, /// The following PlanetScale type IDs are mapped into Boolean: /// - BOOLEAN (BOOLEAN) -> e.g. `1` - Boolean = 5, + Boolean(5) => Boolean, - Character = 6, + + Character(6) => Char, /// The following PlanetScale type IDs are mapped into Text: /// - TEXT (TEXT) -> e.g. `"foo"` (String-encoded) /// - VARCHAR (VARCHAR) -> e.g. `"foo"` (String-encoded) - Text = 7, + Text(7) => Text, /// The following PlanetScale type IDs are mapped into Date: /// - DATE (DATE) -> e.g. `"2023-01-01"` (String-encoded, yyyy-MM-dd) - Date = 8, + Date(8) => Date, /// The following PlanetScale type IDs are mapped into Time: /// - TIME (TIME) -> e.g. `"23:59:59"` (String-encoded, HH:mm:ss) - Time = 9, + Time(9) => Time, + /// The following PlanetScale type IDs are mapped into DateTime: /// - DATETIME (DATETIME) -> e.g. `"2023-01-01 23:59:59"` (String-encoded, yyyy-MM-dd HH:mm:ss) /// - TIMESTAMP (TIMESTAMP) -> e.g. `"2023-01-01 23:59:59"` (String-encoded, yyyy-MM-dd HH:mm:ss) - DateTime = 10, + DateTime(10) => DateTime, /// The following PlanetScale type IDs are mapped into Json: /// - JSON (JSON) -> e.g. `"{\"key\": \"value\"}"` (String-encoded) - Json = 11, + Json(11) => Json, /// The following PlanetScale type IDs are mapped into Enum: /// - ENUM (ENUM) -> e.g. `"foo"` (String-encoded) - Enum = 12, + Enum(12) => Enum, + /// The following PlanetScale type IDs are mapped into Bytes: /// - BLOB (BLOB) -> e.g. `"\u0012"` (String-encoded) /// - VARBINARY (VARBINARY) -> e.g. `"\u0012"` (String-encoded) /// - BINARY (BINARY) -> e.g. `"\u0012"` (String-encoded) /// - GEOMETRY (GEOMETRY) -> e.g. `"\u0012"` (String-encoded) - Bytes = 13, + Bytes(13) => Bytes, + /// The following PlanetScale type IDs are mapped into Set: /// - SET (SET) -> e.g. `"foo,bar"` (String-encoded, comma-separated) /// This is currently unhandled, and will panic if encountered. - Set = 14, + Set(14) => Text, /// UUID from postgres-flavored driver adapters is mapped to this type. - Uuid = 15, + Uuid(15) => Uuid, - /* - * Scalar arrays - */ /// Int32 array (INT2_ARRAY and INT4_ARRAY in PostgreSQL) - Int32Array = 64, + Int32Array(64) => Int32Array, /// Int64 array (INT8_ARRAY in PostgreSQL) - Int64Array = 65, + Int64Array(65) => Int64Array, /// Float array (FLOAT4_ARRAY in PostgreSQL) - FloatArray = 66, + FloatArray(66) => FloatArray, /// Double array (FLOAT8_ARRAY in PostgreSQL) - DoubleArray = 67, + DoubleArray(67) => DoubleArray, /// Numeric array (NUMERIC_ARRAY, MONEY_ARRAY etc in PostgreSQL) - NumericArray = 68, + NumericArray(68) => NumericArray, /// Boolean array (BOOL_ARRAY in PostgreSQL) - BooleanArray = 69, + BooleanArray(69) => BooleanArray, /// Char array (CHAR_ARRAY in PostgreSQL) - CharacterArray = 70, + CharacterArray(70) => CharArray, /// Text array (TEXT_ARRAY in PostgreSQL) - TextArray = 71, + TextArray(71) => TextArray, /// Date array (DATE_ARRAY in PostgreSQL) - DateArray = 72, + DateArray(72) => DateArray, /// Time array (TIME_ARRAY in PostgreSQL) - TimeArray = 73, + TimeArray(73) => TimeArray, /// DateTime array (TIMESTAMP_ARRAY in PostgreSQL) - DateTimeArray = 74, + DateTimeArray(74) => DateTimeArray, /// Json array (JSON_ARRAY in PostgreSQL) - JsonArray = 75, + JsonArray(75) => JsonArray, - /// Enum array - EnumArray = 76, + /// Enum array (ENUM_ARRAY in PostgreSQL) + EnumArray(76) => TextArray, /// Bytes array (BYTEA_ARRAY in PostgreSQL) - BytesArray = 77, + BytesArray(77) => BytesArray, /// Uuid array (UUID_ARRAY in PostgreSQL) - UuidArray = 78, + UuidArray(78) => UuidArray, /* * Below there are custom types that don't have a 1:1 translation with a quaint::Value. @@ -259,7 +293,7 @@ pub enum ColumnType { /// /// It's used by some driver adapters, like libsql to return aggregation values like AVG, or /// COUNT, and it can be mapped to either Int64, or Double - UnknownNumber = 128, + UnknownNumber(128) => Unknown, } #[cfg_attr(not(target_arch = "wasm32"), napi_derive::napi(object))] @@ -267,6 +301,7 @@ pub enum ColumnType { pub struct Query { pub sql: String, pub args: Vec, + pub arg_types: Vec, } #[cfg_attr(not(target_arch = "wasm32"), napi_derive::napi(object))] diff --git a/query-engine/driver-adapters/src/wasm/conversion.rs b/query-engine/driver-adapters/src/wasm/conversion.rs index d2039210a626..b697a49577c5 100644 --- a/query-engine/driver-adapters/src/wasm/conversion.rs +++ b/query-engine/driver-adapters/src/wasm/conversion.rs @@ -1,4 +1,4 @@ -use crate::conversion::JSArg; +use crate::conversion::{JSArg, JSArgType}; use super::to_js::{serde_serialize, ToJsValue}; use crate::types::Query; @@ -8,8 +8,10 @@ use wasm_bindgen::JsValue; impl ToJsValue for Query { fn to_js_value(&self) -> Result { let object = Object::new(); + let sql = self.sql.to_js_value()?; Reflect::set(&object, &JsValue::from(JsString::from("sql")), &sql)?; + let args = Array::new(); for arg in &self.args { let value = arg.to_js_value()?; @@ -17,6 +19,12 @@ impl ToJsValue for Query { } Reflect::set(&object, &JsValue::from(JsString::from("args")), &args)?; + let arg_types = Array::new(); + for arg_type in &self.arg_types { + arg_types.push(&arg_type.to_js_value()?); + } + Reflect::set(&object, &JsValue::from(JsString::from("argTypes")), &arg_types)?; + Ok(JsValue::from(object)) } } @@ -42,3 +50,9 @@ impl ToJsValue for JSArg { } } } + +impl ToJsValue for JSArgType { + fn to_js_value(&self) -> Result { + Ok(JsValue::from(self.to_string())) + } +} diff --git a/query-engine/metrics/src/recorder.rs b/query-engine/metrics/src/recorder.rs deleted file mode 100644 index 94d2c050f60a..000000000000 --- a/query-engine/metrics/src/recorder.rs +++ /dev/null @@ -1,122 +0,0 @@ -use std::sync::Arc; - -use metrics::KeyName; -use metrics::{Counter, CounterFn, Gauge, GaugeFn, Histogram, HistogramFn, Key, Recorder, Unit}; -use tracing::trace; - -use super::common::KeyLabels; -use super::{METRIC_COUNTER, METRIC_DESCRIPTION, METRIC_GAUGE, METRIC_HISTOGRAM, METRIC_TARGET}; - -#[derive(Default)] -pub(crate) struct MetricRecorder; - -impl MetricRecorder { - fn register_description(&self, name: &str, description: &str) { - trace!( - target: METRIC_TARGET, - name = name, - metric_type = METRIC_DESCRIPTION, - description = description - ); - } -} - -impl Recorder for MetricRecorder { - fn describe_counter(&self, key_name: KeyName, _unit: Option, description: &'static str) { - self.register_description(key_name.as_str(), description); - } - - fn describe_gauge(&self, key_name: KeyName, _unit: Option, description: &'static str) { - self.register_description(key_name.as_str(), description); - } - - fn describe_histogram(&self, key_name: KeyName, _unit: Option, description: &'static str) { - self.register_description(key_name.as_str(), description); - } - - fn register_counter(&self, key: &Key) -> Counter { - Counter::from_arc(Arc::new(MetricHandle(key.clone()))) - } - - fn register_gauge(&self, key: &Key) -> Gauge { - Gauge::from_arc(Arc::new(MetricHandle(key.clone()))) - } - - fn register_histogram(&self, key: &Key) -> Histogram { - Histogram::from_arc(Arc::new(MetricHandle(key.clone()))) - } -} - -pub(crate) struct MetricHandle(Key); - -impl CounterFn for MetricHandle { - fn increment(&self, value: u64) { - let keylabels: KeyLabels = self.0.clone().into(); - let json_string = serde_json::to_string(&keylabels).unwrap(); - trace!( - target: METRIC_TARGET, - key_labels = json_string.as_str(), - metric_type = METRIC_COUNTER, - increment = value, - ); - } - - fn absolute(&self, value: u64) { - let keylabels: KeyLabels = self.0.clone().into(); - let json_string = serde_json::to_string(&keylabels).unwrap(); - trace!( - target: METRIC_TARGET, - key_labels = json_string.as_str(), - metric_type = METRIC_COUNTER, - absolute = value, - ); - } -} - -impl GaugeFn for MetricHandle { - fn increment(&self, value: f64) { - let keylabels: KeyLabels = self.0.clone().into(); - let json_string = serde_json::to_string(&keylabels).unwrap(); - trace!( - target: METRIC_TARGET, - key_labels = json_string.as_str(), - metric_type = METRIC_GAUGE, - gauge_inc = value, - ); - } - - fn decrement(&self, value: f64) { - let keylabels: KeyLabels = self.0.clone().into(); - let json_string = serde_json::to_string(&keylabels).unwrap(); - trace!( - target: METRIC_TARGET, - key_labels = json_string.as_str(), - metric_type = METRIC_GAUGE, - gauge_dec = value, - ); - } - - fn set(&self, value: f64) { - let keylabels: KeyLabels = self.0.clone().into(); - let json_string = serde_json::to_string(&keylabels).unwrap(); - trace!( - target: METRIC_TARGET, - key_labels = json_string.as_str(), - metric_type = METRIC_GAUGE, - gauge_set = value, - ); - } -} - -impl HistogramFn for MetricHandle { - fn record(&self, value: f64) { - let keylabels: KeyLabels = self.0.clone().into(); - let json_string = serde_json::to_string(&keylabels).unwrap(); - trace!( - target: METRIC_TARGET, - key_labels = json_string.as_str(), - metric_type = METRIC_HISTOGRAM, - hist_record = value, - ); - } -} diff --git a/query-engine/query-engine-c-abi/Cargo.toml b/query-engine/query-engine-c-abi/Cargo.toml index 21e127519ddb..5542ba38a434 100644 --- a/query-engine/query-engine-c-abi/Cargo.toml +++ b/query-engine/query-engine-c-abi/Cargo.toml @@ -17,6 +17,7 @@ request-handlers = { path = "../request-handlers", features = [ ] } query-connector = { path = "../connectors/query-connector" } query-engine-common = { path = "../../libs/query-engine-common" } +telemetry = { path = "../../libs/telemetry" } user-facing-errors = { path = "../../libs/user-facing-errors" } psl = { workspace = true, features = ["sqlite"] } sql-connector = { path = "../connectors/sql-query-connector", package = "sql-query-connector" } @@ -25,7 +26,7 @@ chrono.workspace = true quaint = { path = "../../quaint", default-features = false, features = [ "sqlite", ] } -rusqlite = "0.29" +rusqlite = "0.31" uuid.workspace = true thiserror = "1" connection-string.workspace = true @@ -36,13 +37,14 @@ indoc.workspace = true tracing = "0.1" tracing-subscriber = { version = "0.3" } -tracing-futures = "0.2" +tracing-futures.workspace = true tracing-opentelemetry = "0.17.3" opentelemetry = { version = "0.17" } tokio.workspace = true -futures = "0.3" +futures.workspace = true once_cell = "1.19.0" [build-dependencies] cbindgen = "0.24.0" +build-utils.path = "../../libs/build-utils" diff --git a/query-engine/query-engine-c-abi/build.rs b/query-engine/query-engine-c-abi/build.rs index 0739d31bf255..b8f3fcdbaff7 100644 --- a/query-engine/query-engine-c-abi/build.rs +++ b/query-engine/query-engine-c-abi/build.rs @@ -1,13 +1,4 @@ -extern crate cbindgen; - use std::env; -use std::process::Command; - -fn store_git_commit_hash() { - let output = Command::new("git").args(["rev-parse", "HEAD"]).output().unwrap(); - let git_hash = String::from_utf8(output.stdout).unwrap(); - println!("cargo:rustc-env=GIT_HASH={git_hash}"); -} fn generate_c_headers() { let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); @@ -28,6 +19,6 @@ fn main() { // Tell Cargo that if the given file changes, to rerun this build script. println!("cargo:rerun-if-changed=src/engine.rs"); // println!("✅ Running build.rs"); - store_git_commit_hash(); + build_utils::store_git_commit_hash_in_env(); generate_c_headers(); } diff --git a/query-engine/query-engine-c-abi/src/engine.rs b/query-engine/query-engine-c-abi/src/engine.rs index 4d3d81636755..e0dff1ca1da6 100644 --- a/query-engine/query-engine-c-abi/src/engine.rs +++ b/query-engine/query-engine-c-abi/src/engine.rs @@ -9,7 +9,7 @@ use once_cell::sync::Lazy; use query_core::{ protocol::EngineProtocol, schema::{self}, - telemetry, TransactionOptions, TxId, + TransactionOptions, TxId, }; use request_handlers::{load_executor, RequestBody, RequestHandler}; use serde_json::json; @@ -20,11 +20,13 @@ use std::{ ptr::null_mut, sync::Arc, }; +use telemetry::helpers::TraceParent; use tokio::{ runtime::{self, Runtime}, sync::RwLock, }; -use tracing::{field, instrument::WithSubscriber, level_filters::LevelFilter, Instrument}; +use tracing::{instrument::WithSubscriber, level_filters::LevelFilter, Instrument}; +use tracing_opentelemetry::OpenTelemetrySpanExt; use query_engine_common::Result; use query_engine_common::{ @@ -201,8 +203,9 @@ impl QueryEngine { let trace_string = get_cstr_safe(trace).expect("Connect trace is missing"); - let span = tracing::info_span!("prisma:engine:connect"); - let _ = telemetry::helpers::set_parent_context_from_json_str(&span, &trace_string); + let span = tracing::info_span!("prisma:engine:connect", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace_string); + span.set_parent(parent_context); let mut inner = self.inner.write().await; let builder = inner.as_builder()?; @@ -238,7 +241,7 @@ impl QueryEngine { let conn_span = tracing::info_span!( "prisma:engine:connection", user_facing = true, - "db.type" = connector.name(), + "db.system" = connector.name(), ); connector.get_connection().instrument(conn_span).await?; @@ -293,12 +296,14 @@ impl QueryEngine { let query = RequestBody::try_from_str(&body, engine.engine_protocol())?; - let span = tracing::info_span!("prisma:engine", user_facing = true); - let trace_id = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); + let span = tracing::info_span!("prisma:engine:query", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + let traceparent = TraceParent::from_remote_context(&parent_context); + span.set_parent(parent_context); async move { let handler = RequestHandler::new(engine.executor(), engine.query_schema(), engine.engine_protocol()); - let response = handler.handle(query, tx_id.map(TxId::from), trace_id).await; + let response = handler.handle(query, tx_id.map(TxId::from), traceparent).await; let serde_span = tracing::info_span!("prisma:engine:response_json_serialization", user_facing = true); Ok(serde_span.in_scope(|| serde_json::to_string(&response))?) @@ -315,8 +320,9 @@ impl QueryEngine { let trace = get_cstr_safe(trace_str).expect("Trace is needed"); let dispatcher = self.logger.dispatcher(); async { - let span = tracing::info_span!("prisma:engine:disconnect"); - let _ = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); + let span = tracing::info_span!("prisma:engine:disconnect", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); async { let mut inner = self.inner.write().await; @@ -393,8 +399,9 @@ impl QueryEngine { let dispatcher = self.logger.dispatcher(); async move { - let span = tracing::info_span!("prisma:engine:itx_runner", user_facing = true, itx_id = field::Empty); - telemetry::helpers::set_parent_context_from_json_str(&span, &trace); + let span = tracing::info_span!("prisma:engine:start_transaction", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); let tx_opts: TransactionOptions = serde_json::from_str(&input)?; match engine @@ -412,15 +419,20 @@ impl QueryEngine { } // If connected, attempts to commit a transaction with id `tx_id` in the core. - pub async fn commit_transaction(&self, tx_id_str: *const c_char, _trace: *const c_char) -> Result { + pub async fn commit_transaction(&self, tx_id_str: *const c_char, trace: *const c_char) -> Result { let tx_id = get_cstr_safe(tx_id_str).expect("Input string missing"); + let trace = get_cstr_safe(trace).expect("trace is required in transactions"); let inner = self.inner.read().await; let engine = inner.as_engine()?; let dispatcher = self.logger.dispatcher(); async move { - match engine.executor().commit_tx(TxId::from(tx_id)).await { + let span = tracing::info_span!("prisma:engine:commit_transaction", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); + + match engine.executor().commit_tx(TxId::from(tx_id)).instrument(span).await { Ok(_) => Ok("{}".to_string()), Err(err) => Ok(map_known_error(err)?), } @@ -430,15 +442,19 @@ impl QueryEngine { } // If connected, attempts to roll back a transaction with id `tx_id` in the core. - pub async fn rollback_transaction(&self, tx_id_str: *const c_char, _trace: *const c_char) -> Result { + pub async fn rollback_transaction(&self, tx_id_str: *const c_char, trace: *const c_char) -> Result { let tx_id = get_cstr_safe(tx_id_str).expect("Input string missing"); - // let trace = get_cstr_safe(trace_str).expect("trace is required in transactions"); + let trace = get_cstr_safe(trace).expect("trace is required in transactions"); let inner = self.inner.read().await; let engine = inner.as_engine()?; let dispatcher = self.logger.dispatcher(); async move { + let span = tracing::info_span!("prisma:engine:rollback_transaction", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); + match engine.executor().rollback_tx(TxId::from(tx_id)).await { Ok(_) => Ok("{}".to_string()), Err(err) => Ok(map_known_error(err)?), diff --git a/query-engine/query-engine-c-abi/src/logger.rs b/query-engine/query-engine-c-abi/src/logger.rs index 3585b94e14a1..fbb69c38e679 100644 --- a/query-engine/query-engine-c-abi/src/logger.rs +++ b/query-engine/query-engine-c-abi/src/logger.rs @@ -1,7 +1,6 @@ use core::fmt; -use query_core::telemetry; + use query_engine_common::logger::StringCallback; -// use query_engine_metrics::MetricRegistry; use serde_json::Value; use std::sync::Arc; use std::{collections::BTreeMap, fmt::Display}; @@ -20,7 +19,6 @@ pub(crate) type LogCallback = Box; pub(crate) struct Logger { dispatcher: Dispatch, - // metrics: Option, } impl Logger { @@ -58,26 +56,14 @@ impl Logger { let layer = CallbackLayer::new(log_callback).with_filter(filters); - // let metrics = if enable_metrics { - // query_engine_metrics::setup(); - // Some(MetricRegistry::new()) - // } else { - // None - // }; - Self { dispatcher: Dispatch::new(Registry::default().with(telemetry).with(layer)), - // metrics, } } pub fn dispatcher(&self) -> Dispatch { self.dispatcher.clone() } - - // pub fn metrics(&self) -> Option { - // self.metrics.clone() - // } } pub struct JsonVisitor<'a> { diff --git a/query-engine/query-engine-c-abi/src/migrations.rs b/query-engine/query-engine-c-abi/src/migrations.rs index 4cd374705ba5..93a80cde5fb7 100644 --- a/query-engine/query-engine-c-abi/src/migrations.rs +++ b/query-engine/query-engine-c-abi/src/migrations.rs @@ -44,14 +44,14 @@ impl From for MigrationDirectory { #[derive(Debug, Clone)] pub struct MigrationRecord { /// A unique, randomly generated identifier. - pub id: String, + pub _id: String, /// The timestamp at which the migration completed *successfully*. pub finished_at: Option, /// The name of the migration, i.e. the name of migration directory /// containing the migration script. pub migration_name: String, /// The time the migration started being applied. - pub started_at: Timestamp, + pub _started_at: Timestamp, /// The time the migration failed pub failed_at: Option, } @@ -142,9 +142,9 @@ pub fn list_migrations(database_filename: &Path) -> Result> let failed_at: Option = row.get(4).unwrap(); entries.push(MigrationRecord { - id, + _id: id, migration_name, - started_at, + _started_at: started_at, finished_at, failed_at, }); diff --git a/query-engine/query-engine-node-api/Cargo.toml b/query-engine/query-engine-node-api/Cargo.toml index cbe4f455b58b..b4ec9eb5f36c 100644 --- a/query-engine/query-engine-node-api/Cargo.toml +++ b/query-engine/query-engine-node-api/Cargo.toml @@ -24,6 +24,7 @@ request-handlers = { path = "../request-handlers", features = ["all"] } query-connector = { path = "../connectors/query-connector" } query-engine-common = { path = "../../libs/query-engine-common" } user-facing-errors = { path = "../../libs/user-facing-errors" } +telemetry = { path = "../../libs/telemetry" } psl = { workspace = true, features = ["all"] } sql-connector = { path = "../connectors/sql-query-connector", package = "sql-query-connector", features = [ "all-native", @@ -45,14 +46,15 @@ serde.workspace = true tracing.workspace = true tracing-subscriber = { version = "0.3" } -tracing-futures = "0.2" +tracing-futures.workspace = true tracing-opentelemetry = "0.17.3" opentelemetry = { version = "0.17" } quaint.workspace = true tokio.workspace = true -futures = "0.3" -query-engine-metrics = { path = "../metrics" } +futures.workspace = true +prisma-metrics.path = "../../libs/metrics" [build-dependencies] napi-build = "1" +build-utils.path = "../../libs/build-utils" diff --git a/query-engine/query-engine-node-api/build.rs b/query-engine/query-engine-node-api/build.rs index 2ed42a66137c..eb0c9b2fe749 100644 --- a/query-engine/query-engine-node-api/build.rs +++ b/query-engine/query-engine-node-api/build.rs @@ -1,14 +1,4 @@ -extern crate napi_build; - -use std::process::Command; - -fn store_git_commit_hash() { - let output = Command::new("git").args(["rev-parse", "HEAD"]).output().unwrap(); - let git_hash = String::from_utf8(output.stdout).unwrap(); - println!("cargo:rustc-env=GIT_HASH={git_hash}"); -} - fn main() { - store_git_commit_hash(); + build_utils::store_git_commit_hash_in_env(); napi_build::setup() } diff --git a/query-engine/query-engine-node-api/src/engine.rs b/query-engine/query-engine-node-api/src/engine.rs index 01c78b6e2c17..7e9515e4d220 100644 --- a/query-engine/query-engine-node-api/src/engine.rs +++ b/query-engine/query-engine-node-api/src/engine.rs @@ -2,20 +2,22 @@ use crate::{error::ApiError, logger::Logger}; use futures::FutureExt; use napi::{threadsafe_function::ThreadSafeCallContext, Env, JsFunction, JsObject, JsUnknown}; use napi_derive::napi; +use prisma_metrics::{MetricFormat, WithMetricsInstrumentation}; use psl::PreviewFeature; use quaint::connector::ExternalConnector; -use query_core::{protocol::EngineProtocol, relation_load_strategy, schema, telemetry, TransactionOptions, TxId}; +use query_core::{protocol::EngineProtocol, relation_load_strategy, schema, TransactionOptions, TxId}; use query_engine_common::engine::{ map_known_error, stringify_env_values, ConnectedEngine, ConnectedEngineNative, ConstructorOptions, ConstructorOptionsNative, EngineBuilder, EngineBuilderNative, Inner, }; -use query_engine_metrics::MetricFormat; use request_handlers::{load_executor, render_graphql_schema, ConnectorKind, RequestBody, RequestHandler}; use serde::Deserialize; use serde_json::json; use std::{collections::HashMap, future::Future, marker::PhantomData, panic::AssertUnwindSafe, sync::Arc}; +use telemetry::helpers::TraceParent; use tokio::sync::RwLock; -use tracing::{field, instrument::WithSubscriber, Instrument, Span}; +use tracing::{instrument::WithSubscriber, Instrument}; +use tracing_opentelemetry::OpenTelemetrySpanExt; use tracing_subscriber::filter::LevelFilter; use user_facing_errors::Error; @@ -71,12 +73,12 @@ impl QueryEngine { native, } = napi_env.from_js_value(options).expect( r###" - Failed to deserialize constructor options. - - This usually happens when the javascript object passed to the constructor is missing + Failed to deserialize constructor options. + + This usually happens when the javascript object passed to the constructor is missing properties for the ConstructorOptions fields that must have some value. - - If you set some of these in javascript trough environment variables, make sure there are + + If you set some of these in javascript through environment variables, make sure there are values for data_model, log_level, and any field that is not Option "###, ); @@ -149,21 +151,6 @@ impl QueryEngine { let log_level = log_level.parse::().unwrap(); let logger = Logger::new(log_queries, log_level, log_callback, enable_metrics, enable_tracing); - // Describe metrics adds all the descriptions and default values for our metrics - // this needs to run once our metrics pipeline has been configured and it needs to - // use the correct logging subscriber(our dispatch) so that the metrics recorder recieves - // it - if enable_metrics { - napi_env.execute_tokio_future( - async { - query_engine_metrics::initialize_metrics(); - Ok(()) - } - .with_subscriber(logger.dispatcher()), - |&mut _env, _data| Ok(()), - )?; - } - Ok(Self { connector_mode, inner: RwLock::new(Inner::Builder(builder)), @@ -175,10 +162,12 @@ impl QueryEngine { #[napi] pub async fn connect(&self, trace: String) -> napi::Result<()> { let dispatcher = self.logger.dispatcher(); + let recorder = self.logger.recorder(); async_panic_to_js_error(async { - let span = tracing::info_span!("prisma:engine:connect"); - let _ = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); + let span = tracing::info_span!("prisma:engine:connect", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); let mut inner = self.inner.write().await; let builder = inner.as_builder()?; @@ -224,7 +213,7 @@ impl QueryEngine { let conn_span = tracing::info_span!( "prisma:engine:connection", user_facing = true, - "db.type" = connector.name(), + "db.system" = connector.name(), ); let conn = connector.get_connection().instrument(conn_span).await?; @@ -268,6 +257,7 @@ impl QueryEngine { Ok(()) }) .with_subscriber(dispatcher) + .with_optional_recorder(recorder) .await?; Ok(()) @@ -277,10 +267,12 @@ impl QueryEngine { #[napi] pub async fn disconnect(&self, trace: String) -> napi::Result<()> { let dispatcher = self.logger.dispatcher(); + let recorder = self.logger.recorder(); async_panic_to_js_error(async { - let span = tracing::info_span!("prisma:engine:disconnect"); - let _ = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); + let span = tracing::info_span!("prisma:engine:disconnect", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); // TODO: when using Node Drivers, we need to call Driver::close() here. @@ -305,6 +297,7 @@ impl QueryEngine { .await }) .with_subscriber(dispatcher) + .with_optional_recorder(recorder) .await } @@ -312,6 +305,7 @@ impl QueryEngine { #[napi] pub async fn query(&self, body: String, trace: String, tx_id: Option) -> napi::Result { let dispatcher = self.logger.dispatcher(); + let recorder = self.logger.recorder(); async_panic_to_js_error(async { let inner = self.inner.read().await; @@ -319,17 +313,14 @@ impl QueryEngine { let query = RequestBody::try_from_str(&body, engine.engine_protocol())?; - let span = if tx_id.is_none() { - tracing::info_span!("prisma:engine", user_facing = true) - } else { - Span::none() - }; - - let trace_id = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); + let span = tracing::info_span!("prisma:engine:query", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + let traceparent = TraceParent::from_remote_context(&parent_context); + span.set_parent(parent_context); async move { let handler = RequestHandler::new(engine.executor(), engine.query_schema(), engine.engine_protocol()); - let response = handler.handle(query, tx_id.map(TxId::from), trace_id).await; + let response = handler.handle(query, tx_id.map(TxId::from), traceparent).await; let serde_span = tracing::info_span!("prisma:engine:response_json_serialization", user_facing = true); Ok(serde_span.in_scope(|| serde_json::to_string(&response))?) @@ -338,55 +329,65 @@ impl QueryEngine { .await }) .with_subscriber(dispatcher) + .with_optional_recorder(recorder) .await } /// If connected, attempts to start a transaction in the core and returns its ID. #[napi] pub async fn start_transaction(&self, input: String, trace: String) -> napi::Result { + let dispatcher = self.logger.dispatcher(); + let recorder = self.logger.recorder(); + async_panic_to_js_error(async { let inner = self.inner.read().await; let engine = inner.as_engine()?; + let tx_opts: TransactionOptions = serde_json::from_str(&input)?; - let dispatcher = self.logger.dispatcher(); + let span = tracing::info_span!("prisma:engine:start_transaction", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); async move { - let span = tracing::info_span!("prisma:engine:itx_runner", user_facing = true, itx_id = field::Empty); - telemetry::helpers::set_parent_context_from_json_str(&span, &trace); - - let tx_opts: TransactionOptions = serde_json::from_str(&input)?; match engine .executor() .start_tx(engine.query_schema().clone(), engine.engine_protocol(), tx_opts) - .instrument(span) .await { Ok(tx_id) => Ok(json!({ "id": tx_id.to_string() }).to_string()), Err(err) => Ok(map_known_error(err)?), } } - .with_subscriber(dispatcher) + .instrument(span) .await }) + .with_subscriber(dispatcher) + .with_optional_recorder(recorder) .await } /// If connected, attempts to commit a transaction with id `tx_id` in the core. #[napi] - pub async fn commit_transaction(&self, tx_id: String, _trace: String) -> napi::Result { + pub async fn commit_transaction(&self, tx_id: String, trace: String) -> napi::Result { async_panic_to_js_error(async { let inner = self.inner.read().await; let engine = inner.as_engine()?; let dispatcher = self.logger.dispatcher(); + let recorder = self.logger.recorder(); async move { - match engine.executor().commit_tx(TxId::from(tx_id)).await { + let span = tracing::info_span!("prisma:engine:commit_transaction", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); + + match engine.executor().commit_tx(TxId::from(tx_id)).instrument(span).await { Ok(_) => Ok("{}".to_string()), Err(err) => Ok(map_known_error(err)?), } } .with_subscriber(dispatcher) + .with_optional_recorder(recorder) .await }) .await @@ -394,20 +395,26 @@ impl QueryEngine { /// If connected, attempts to roll back a transaction with id `tx_id` in the core. #[napi] - pub async fn rollback_transaction(&self, tx_id: String, _trace: String) -> napi::Result { + pub async fn rollback_transaction(&self, tx_id: String, trace: String) -> napi::Result { async_panic_to_js_error(async { let inner = self.inner.read().await; let engine = inner.as_engine()?; let dispatcher = self.logger.dispatcher(); + let recorder = self.logger.recorder(); async move { - match engine.executor().rollback_tx(TxId::from(tx_id)).await { + let span = tracing::info_span!("prisma:engine:rollback_transaction", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); + + match engine.executor().rollback_tx(TxId::from(tx_id)).instrument(span).await { Ok(_) => Ok("{}".to_string()), Err(err) => Ok(map_known_error(err)?), } } .with_subscriber(dispatcher) + .with_optional_recorder(recorder) .await }) .await @@ -416,17 +423,25 @@ impl QueryEngine { /// Loads the query schema. Only available when connected. #[napi] pub async fn sdl_schema(&self) -> napi::Result { + let dispatcher = self.logger.dispatcher(); + let recorder = self.logger.recorder(); + async_panic_to_js_error(async move { let inner = self.inner.read().await; let engine = inner.as_engine()?; Ok(render_graphql_schema(engine.query_schema())) }) + .with_subscriber(dispatcher) + .with_optional_recorder(recorder) .await } #[napi] pub async fn metrics(&self, json_options: String) -> napi::Result { + let dispatcher = self.logger.dispatcher(); + let recorder = self.logger.recorder(); + async_panic_to_js_error(async move { let inner = self.inner.read().await; let engine = inner.as_engine()?; @@ -447,6 +462,8 @@ impl QueryEngine { .into()) } }) + .with_subscriber(dispatcher) + .with_optional_recorder(recorder) .await } } diff --git a/query-engine/query-engine-node-api/src/logger.rs b/query-engine/query-engine-node-api/src/logger.rs index b86343bb4a94..bd0fd6dd8b36 100644 --- a/query-engine/query-engine-node-api/src/logger.rs +++ b/query-engine/query-engine-node-api/src/logger.rs @@ -1,8 +1,7 @@ use core::fmt; use napi::threadsafe_function::{ErrorStrategy, ThreadsafeFunction, ThreadsafeFunctionCallMode}; -use query_core::telemetry; +use prisma_metrics::{MetricRecorder, MetricRegistry}; use query_engine_common::logger::StringCallback; -use query_engine_metrics::MetricRegistry; use serde_json::Value; use std::{collections::BTreeMap, fmt::Display}; use tracing::{ @@ -21,6 +20,7 @@ pub(crate) type LogCallback = ThreadsafeFunction; pub(crate) struct Logger { dispatcher: Dispatch, metrics: Option, + recorder: Option, } impl Logger { @@ -63,16 +63,18 @@ impl Logger { let layer = log_callback.with_filter(filters); - let metrics = if enable_metrics { - query_engine_metrics::setup(); - Some(MetricRegistry::new()) + let (metrics, recorder) = if enable_metrics { + let registry = MetricRegistry::new(); + let recorder = MetricRecorder::new(registry.clone()).with_initialized_prisma_metrics(); + (Some(registry), Some(recorder)) } else { - None + (None, None) }; Self { - dispatcher: Dispatch::new(Registry::default().with(telemetry).with(layer).with(metrics.clone())), + dispatcher: Dispatch::new(Registry::default().with(telemetry).with(layer)), metrics, + recorder, } } @@ -83,6 +85,10 @@ impl Logger { pub fn metrics(&self) -> Option { self.metrics.clone() } + + pub fn recorder(&self) -> Option { + self.recorder.clone() + } } pub struct JsonVisitor<'a> { diff --git a/query-engine/query-engine-wasm/Cargo.toml b/query-engine/query-engine-wasm/Cargo.toml index c8aaf3205e24..40017c0270b3 100644 --- a/query-engine/query-engine-wasm/Cargo.toml +++ b/query-engine/query-engine-wasm/Cargo.toml @@ -44,6 +44,7 @@ request-handlers = { path = "../request-handlers", default-features = false, fea ] } query-core = { path = "../core" } driver-adapters = { path = "../driver-adapters" } +telemetry = { path = "../../libs/telemetry" } quaint.workspace = true connection-string.workspace = true js-sys.workspace = true @@ -57,15 +58,18 @@ wasm-rs-dbg.workspace = true thiserror = "1" url.workspace = true serde.workspace = true -tokio = { version = "1.25", features = ["macros", "sync", "io-util", "time"] } -futures = "0.3" +tokio = { version = "1", features = ["macros", "sync", "io-util", "time"] } +futures.workspace = true tracing.workspace = true tracing-subscriber = { version = "0.3" } -tracing-futures = "0.2" +tracing-futures.workspace = true tracing-opentelemetry = "0.17.3" opentelemetry = { version = "0.17" } +[build-dependencies] +build-utils.path = "../../libs/build-utils" + [package.metadata.wasm-pack.profile.release] wasm-opt = false # use wasm-opt explicitly in `./build.sh` diff --git a/query-engine/query-engine-wasm/build.rs b/query-engine/query-engine-wasm/build.rs index 2e8fe20c0503..33aded23a4a5 100644 --- a/query-engine/query-engine-wasm/build.rs +++ b/query-engine/query-engine-wasm/build.rs @@ -1,11 +1,3 @@ -use std::process::Command; - -fn store_git_commit_hash() { - let output = Command::new("git").args(["rev-parse", "HEAD"]).output().unwrap(); - let git_hash = String::from_utf8(output.stdout).unwrap(); - println!("cargo:rustc-env=GIT_HASH={git_hash}"); -} - fn main() { - store_git_commit_hash(); + build_utils::store_git_commit_hash_in_env(); } diff --git a/query-engine/query-engine-wasm/pnpm-lock.yaml b/query-engine/query-engine-wasm/pnpm-lock.yaml deleted file mode 100644 index 89591aef9869..000000000000 --- a/query-engine/query-engine-wasm/pnpm-lock.yaml +++ /dev/null @@ -1,130 +0,0 @@ -lockfileVersion: '6.0' - -settings: - autoInstallPeers: true - excludeLinksFromLockfile: false - -dependencies: - '@neondatabase/serverless': - specifier: 0.6.0 - version: 0.6.0 - '@prisma/adapter-neon': - specifier: 5.6.0 - version: 5.6.0(@neondatabase/serverless@0.6.0) - '@prisma/driver-adapter-utils': - specifier: 5.6.0 - version: 5.6.0 - -packages: - - /@neondatabase/serverless@0.6.0: - resolution: {integrity: sha512-qXxBRYN0m2v8kVQBfMxbzNGn2xFAhTXFibzQlE++NfJ56Shz3m7+MyBBtXDlEH+3Wfa6lToDXf1MElocY4sJ3w==} - dependencies: - '@types/pg': 8.6.6 - dev: false - - /@prisma/adapter-neon@5.6.0(@neondatabase/serverless@0.6.0): - resolution: {integrity: sha512-IUkIE5NKyP2wCXMMAByM78fizfaJl7YeWDEajvyqQafXgRwmxl+2HhxsevvHly8jT4RlELdhjK6IP1eciGvXVA==} - peerDependencies: - '@neondatabase/serverless': ^0.6.0 - dependencies: - '@neondatabase/serverless': 0.6.0 - '@prisma/driver-adapter-utils': 5.6.0 - postgres-array: 3.0.2 - transitivePeerDependencies: - - supports-color - dev: false - - /@prisma/driver-adapter-utils@5.6.0: - resolution: {integrity: sha512-/TSrfCGLAQghNf+bwg5/e8iKAgecCYU/gMN0IyNra3183/VTQJneLFgbacuSK9bBXiIRUmpbuUIrJ6dhENzfjA==} - dependencies: - debug: 4.3.4 - transitivePeerDependencies: - - supports-color - dev: false - - /@types/node@20.9.1: - resolution: {integrity: sha512-HhmzZh5LSJNS5O8jQKpJ/3ZcrrlG6L70hpGqMIAoM9YVD0YBRNWYsfwcXq8VnSjlNpCpgLzMXdiPo+dxcvSmiA==} - dependencies: - undici-types: 5.26.5 - dev: false - - /@types/pg@8.6.6: - resolution: {integrity: sha512-O2xNmXebtwVekJDD+02udOncjVcMZQuTEQEMpKJ0ZRf5E7/9JJX3izhKUcUifBkyKpljyUM6BTgy2trmviKlpw==} - dependencies: - '@types/node': 20.9.1 - pg-protocol: 1.6.0 - pg-types: 2.2.0 - dev: false - - /debug@4.3.4: - resolution: {integrity: sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==} - engines: {node: '>=6.0'} - peerDependencies: - supports-color: '*' - peerDependenciesMeta: - supports-color: - optional: true - dependencies: - ms: 2.1.2 - dev: false - - /ms@2.1.2: - resolution: {integrity: sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==} - dev: false - - /pg-int8@1.0.1: - resolution: {integrity: sha512-WCtabS6t3c8SkpDBUlb1kjOs7l66xsGdKpIPZsg4wR+B3+u9UAum2odSsF9tnvxg80h4ZxLWMy4pRjOsFIqQpw==} - engines: {node: '>=4.0.0'} - dev: false - - /pg-protocol@1.6.0: - resolution: {integrity: sha512-M+PDm637OY5WM307051+bsDia5Xej6d9IR4GwJse1qA1DIhiKlksvrneZOYQq42OM+spubpcNYEo2FcKQrDk+Q==} - dev: false - - /pg-types@2.2.0: - resolution: {integrity: sha512-qTAAlrEsl8s4OiEQY69wDvcMIdQN6wdz5ojQiOy6YRMuynxenON0O5oCpJI6lshc6scgAY8qvJ2On/p+CXY0GA==} - engines: {node: '>=4'} - dependencies: - pg-int8: 1.0.1 - postgres-array: 2.0.0 - postgres-bytea: 1.0.0 - postgres-date: 1.0.7 - postgres-interval: 1.2.0 - dev: false - - /postgres-array@2.0.0: - resolution: {integrity: sha512-VpZrUqU5A69eQyW2c5CA1jtLecCsN2U/bD6VilrFDWq5+5UIEVO7nazS3TEcHf1zuPYO/sqGvUvW62g86RXZuA==} - engines: {node: '>=4'} - dev: false - - /postgres-array@3.0.2: - resolution: {integrity: sha512-6faShkdFugNQCLwucjPcY5ARoW1SlbnrZjmGl0IrrqewpvxvhSLHimCVzqeuULCbG0fQv7Dtk1yDbG3xv7Veog==} - engines: {node: '>=12'} - dev: false - - /postgres-bytea@1.0.0: - resolution: {integrity: sha512-xy3pmLuQqRBZBXDULy7KbaitYqLcmxigw14Q5sj8QBVLqEwXfeybIKVWiqAXTlcvdvb0+xkOtDbfQMOf4lST1w==} - engines: {node: '>=0.10.0'} - dev: false - - /postgres-date@1.0.7: - resolution: {integrity: sha512-suDmjLVQg78nMK2UZ454hAG+OAW+HQPZ6n++TNDUX+L0+uUlLywnoxJKDou51Zm+zTCjrCl0Nq6J9C5hP9vK/Q==} - engines: {node: '>=0.10.0'} - dev: false - - /postgres-interval@1.2.0: - resolution: {integrity: sha512-9ZhXKM/rw350N1ovuWHbGxnGh/SNJ4cnxHiM0rxE4VN41wsg8P8zWn9hv/buK00RP4WvlOyr/RBDiptyxVbkZQ==} - engines: {node: '>=0.10.0'} - dependencies: - xtend: 4.0.2 - dev: false - - /undici-types@5.26.5: - resolution: {integrity: sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==} - dev: false - - /xtend@4.0.2: - resolution: {integrity: sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ==} - engines: {node: '>=0.4'} - dev: false diff --git a/query-engine/query-engine-wasm/rust-toolchain.toml b/query-engine/query-engine-wasm/rust-toolchain.toml index 5048fd2e74a6..44e38c0b8707 100644 --- a/query-engine/query-engine-wasm/rust-toolchain.toml +++ b/query-engine/query-engine-wasm/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "nightly-2024-05-25" +channel = "nightly-2024-09-01" components = ["clippy", "rustfmt", "rust-src"] targets = [ "wasm32-unknown-unknown", diff --git a/query-engine/query-engine-wasm/src/wasm/engine.rs b/query-engine/query-engine-wasm/src/wasm/engine.rs index ae6fe40f8728..837160e1bb04 100644 --- a/query-engine/query-engine-wasm/src/wasm/engine.rs +++ b/query-engine/query-engine-wasm/src/wasm/engine.rs @@ -13,15 +13,17 @@ use query_core::{ protocol::EngineProtocol, relation_load_strategy, schema::{self}, - telemetry, TransactionOptions, TxId, + TransactionOptions, TxId, }; use query_engine_common::engine::{map_known_error, ConnectedEngine, ConstructorOptions, EngineBuilder, Inner}; use request_handlers::ConnectorKind; use request_handlers::{load_executor, RequestBody, RequestHandler}; use serde_json::json; use std::{marker::PhantomData, sync::Arc}; +use telemetry::helpers::TraceParent; use tokio::sync::RwLock; -use tracing::{field, instrument::WithSubscriber, Instrument, Level, Span}; +use tracing::{instrument::WithSubscriber, Instrument, Level}; +use tracing_opentelemetry::OpenTelemetrySpanExt; use tracing_subscriber::filter::LevelFilter; use wasm_bindgen::prelude::wasm_bindgen; @@ -89,7 +91,8 @@ impl QueryEngine { async { let span = tracing::info_span!("prisma:engine:connect"); - let _ = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); let mut inner = self.inner.write().await; let builder = inner.as_builder()?; @@ -111,7 +114,7 @@ impl QueryEngine { let conn_span = tracing::info_span!( "prisma:engine:connection", user_facing = true, - "db.type" = connector.name(), + "db.system" = connector.name(), ); let conn = connector.get_connection().instrument(conn_span).await?; @@ -149,8 +152,9 @@ impl QueryEngine { let dispatcher = self.logger.dispatcher(); async { - let span = tracing::info_span!("prisma:engine:disconnect"); - let _ = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); + let span = tracing::info_span!("prisma:engine:disconnect", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); async { let mut inner = self.inner.write().await; @@ -189,17 +193,14 @@ impl QueryEngine { let query = RequestBody::try_from_str(&body, engine.engine_protocol())?; async move { - let span = if tx_id.is_none() { - tracing::info_span!("prisma:engine", user_facing = true) - } else { - Span::none() - }; - - let trace_id = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); + let span = tracing::info_span!("prisma:engine:query", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + let traceparent = TraceParent::from_remote_context(&parent_context); + span.set_parent(parent_context); let handler = RequestHandler::new(engine.executor(), engine.query_schema(), engine.engine_protocol()); let response = handler - .handle(query, tx_id.map(TxId::from), trace_id) + .handle(query, tx_id.map(TxId::from), traceparent) .instrument(span) .await; @@ -219,13 +220,15 @@ impl QueryEngine { let dispatcher = self.logger.dispatcher(); async move { - let span = tracing::info_span!("prisma:engine:itx_runner", user_facing = true, itx_id = field::Empty); + let span = tracing::info_span!("prisma:engine:start_transaction", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + let traceparent = TraceParent::from_remote_context(&parent_context); + span.set_parent(parent_context); let tx_opts: TransactionOptions = serde_json::from_str(&input)?; match engine .executor() .start_tx(engine.query_schema().clone(), engine.engine_protocol(), tx_opts) - .instrument(span) .await { Ok(tx_id) => Ok(json!({ "id": tx_id.to_string() }).to_string()), @@ -245,6 +248,11 @@ impl QueryEngine { let dispatcher = self.logger.dispatcher(); async move { + let span = tracing::info_span!("prisma:engine:commit_transaction", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + let traceparent = TraceParent::from_remote_context(&parent_context); + span.set_parent(parent_context); + match engine.executor().commit_tx(TxId::from(tx_id)).await { Ok(_) => Ok("{}".to_string()), Err(err) => Ok(map_known_error(err)?), @@ -263,6 +271,11 @@ impl QueryEngine { let dispatcher = self.logger.dispatcher(); async move { + let span = tracing::info_span!("prisma:engine:rollback_transaction", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + let traceparent = TraceParent::from_remote_context(&parent_context); + span.set_parent(parent_context); + match engine.executor().rollback_tx(TxId::from(tx_id)).await { Ok(_) => Ok("{}".to_string()), Err(err) => Ok(map_known_error(err)?), diff --git a/query-engine/query-engine/Cargo.toml b/query-engine/query-engine/Cargo.toml index e3fd4768ed76..7d56c891972a 100644 --- a/query-engine/query-engine/Cargo.toml +++ b/query-engine/query-engine/Cargo.toml @@ -32,11 +32,15 @@ tracing-opentelemetry = "0.17.3" tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } opentelemetry = { version = "0.17.0", features = ["rt-tokio"] } opentelemetry-otlp = { version = "0.10", features = ["tls", "tls-roots"] } -query-engine-metrics = { path = "../metrics" } +prisma-metrics.path = "../../libs/metrics" user-facing-errors = { path = "../../libs/user-facing-errors" } +telemetry = { path = "../../libs/telemetry", features = ["metrics"] } [dev-dependencies] serial_test = "*" quaint.workspace = true indoc.workspace = true + +[build-dependencies] +build-utils.path = "../../libs/build-utils" diff --git a/query-engine/query-engine/build.rs b/query-engine/query-engine/build.rs index 2e8fe20c0503..33aded23a4a5 100644 --- a/query-engine/query-engine/build.rs +++ b/query-engine/query-engine/build.rs @@ -1,11 +1,3 @@ -use std::process::Command; - -fn store_git_commit_hash() { - let output = Command::new("git").args(["rev-parse", "HEAD"]).output().unwrap(); - let git_hash = String::from_utf8(output.stdout).unwrap(); - println!("cargo:rustc-env=GIT_HASH={git_hash}"); -} - fn main() { - store_git_commit_hash(); + build_utils::store_git_commit_hash_in_env(); } diff --git a/query-engine/query-engine/src/context.rs b/query-engine/query-engine/src/context.rs index 7a1138c411e5..f6e3896c17ab 100644 --- a/query-engine/query-engine/src/context.rs +++ b/query-engine/query-engine/src/context.rs @@ -1,6 +1,7 @@ use crate::features::{EnabledFeatures, Feature}; use crate::{logger::Logger, opt::PrismaOpt}; use crate::{PrismaError, PrismaResult}; +use prisma_metrics::{MetricRecorder, MetricRegistry}; use psl::PreviewFeature; use query_core::{ protocol::EngineProtocol, @@ -8,11 +9,10 @@ use query_core::{ schema::{self, QuerySchemaRef}, QueryExecutor, }; -use query_engine_metrics::setup as metric_setup; -use query_engine_metrics::MetricRegistry; use request_handlers::{load_executor, ConnectorKind}; use std::{env, fmt, sync::Arc}; use tracing::Instrument; +use tracing_opentelemetry::OpenTelemetrySpanExt; /// Prisma request context containing all immutable state of the process. /// There is usually only one context initialized per process. @@ -49,7 +49,8 @@ impl PrismaContext { // Construct query schema schema::build(arced_schema, enabled_features.contains(Feature::RawQueries)) }); - let executor_fut = tokio::spawn(async move { + + let executor_fut = async move { let config = &arced_schema_2.configuration; let preview_features = config.preview_features(); @@ -62,14 +63,22 @@ impl PrismaContext { let url = datasource.load_url(|key| env::var(key).ok())?; // Load executor let executor = load_executor(ConnectorKind::Rust { url, datasource }, preview_features).await?; - let conn = executor.primary_connector().get_connection().await?; + let connector = executor.primary_connector(); + + let conn_span = tracing::info_span!( + "prisma:engine:connection", + user_facing = true, + "db.system" = connector.name(), + ); + + let conn = connector.get_connection().instrument(conn_span).await?; let db_version = conn.version().await; PrismaResult::<_>::Ok((executor, db_version)) - }); + }; let (query_schema, executor_with_db_version) = tokio::join!(query_schema_fut, executor_fut); - let (executor, db_version) = executor_with_db_version.unwrap()?; + let (executor, db_version) = executor_with_db_version?; let query_schema = query_schema.unwrap().with_db_version_supports_join_strategy( relation_load_strategy::db_version_supports_joins_strategy(db_version)?, @@ -103,29 +112,29 @@ impl PrismaContext { } } -pub async fn setup( - opts: &PrismaOpt, - install_logger: bool, - metrics: Option, -) -> PrismaResult> { - let metrics = metrics.unwrap_or_default(); - - if install_logger { - Logger::new("prisma-engine-http", Some(metrics.clone()), opts) - .install() - .unwrap(); - } +pub async fn setup(opts: &PrismaOpt) -> PrismaResult> { + Logger::new("prisma-engine-http", opts).install().unwrap(); - if opts.enable_metrics || opts.dataproxy_metric_override { - metric_setup(); - } + let metrics = if opts.enable_metrics || opts.dataproxy_metric_override { + let metrics = MetricRegistry::new(); + let recorder = MetricRecorder::new(metrics.clone()); + recorder.install_globally().expect("setup must be called only once"); + recorder.init_prisma_metrics(); + Some(metrics) + } else { + None + }; let datamodel = opts.schema(false)?; let config = &datamodel.configuration; let protocol = opts.engine_protocol(); config.validate_that_one_datasource_is_provided()?; - let span = tracing::info_span!("prisma:engine:connect"); + let span = tracing::info_span!("prisma:engine:connect", user_facing = true); + if let Some(trace_context) = opts.trace_context.as_ref() { + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(trace_context); + span.set_parent(parent_context); + } let mut features = EnabledFeatures::from(opts); @@ -133,7 +142,7 @@ pub async fn setup( features |= Feature::Metrics } - let cx = PrismaContext::new(datamodel, protocol, features, Some(metrics)) + let cx = PrismaContext::new(datamodel, protocol, features, metrics) .instrument(span) .await?; diff --git a/query-engine/query-engine/src/logger.rs b/query-engine/query-engine/src/logger.rs index 10f6ced58b86..2bf2566fe963 100644 --- a/query-engine/query-engine/src/logger.rs +++ b/query-engine/query-engine/src/logger.rs @@ -3,8 +3,6 @@ use opentelemetry::{ KeyValue, }; use opentelemetry_otlp::WithExportConfig; -use query_core::telemetry; -use query_engine_metrics::MetricRegistry; use tracing::{dispatcher::SetGlobalDefaultError, subscriber}; use tracing_subscriber::{filter::filter_fn, layer::SubscriberExt, Layer}; @@ -19,7 +17,6 @@ pub(crate) struct Logger { log_format: LogFormat, log_queries: bool, tracing_config: TracingConfig, - metrics: Option, } // TracingConfig specifies how tracing will be exposed by the logger facility @@ -38,7 +35,7 @@ enum TracingConfig { impl Logger { /// Initialize a new global logger installer. - pub fn new(service_name: &'static str, metrics: Option, opts: &PrismaOpt) -> Self { + pub fn new(service_name: &'static str, opts: &PrismaOpt) -> Self { let enable_telemetry = opts.enable_open_telemetry; let enable_capturing = opts.enable_telemetry_in_response; let endpoint = if opts.open_telemetry_endpoint.is_empty() { @@ -58,7 +55,6 @@ impl Logger { service_name, log_format: opts.log_format(), log_queries: opts.log_queries(), - metrics, tracing_config, } } @@ -81,9 +77,7 @@ impl Logger { } }; - let subscriber = tracing_subscriber::registry() - .with(fmt_layer) - .with(self.metrics.clone()); + let subscriber = tracing_subscriber::registry().with(fmt_layer); match self.tracing_config { TracingConfig::Captured => { diff --git a/query-engine/query-engine/src/main.rs b/query-engine/query-engine/src/main.rs index 7c3a6f7a1db5..17900c4ad74e 100644 --- a/query-engine/query-engine/src/main.rs +++ b/query-engine/query-engine/src/main.rs @@ -1,8 +1,5 @@ #![allow(clippy::upper_case_acronyms)] -#[macro_use] -extern crate tracing; - use query_engine::cli::CliCommand; use query_engine::context; use query_engine::error::PrismaError; @@ -11,14 +8,13 @@ use query_engine::server; use query_engine::LogFormat; use std::{error::Error, process}; use structopt::StructOpt; -use tracing::Instrument; type AnyError = Box; #[tokio::main] async fn main() -> Result<(), AnyError> { return main().await.map_err(|err| { - info!("Encountered error during initialization:"); + tracing::info!("Encountered error during initialization:"); err.render_as_json().expect("error rendering"); process::exit(1) }); @@ -29,8 +25,7 @@ async fn main() -> Result<(), AnyError> { match CliCommand::from_opt(&opts)? { Some(cmd) => cmd.execute().await?, None => { - let span = tracing::info_span!("prisma:engine:connect"); - let cx = context::setup(&opts, true, None).instrument(span).await?; + let cx = context::setup(&opts).await?; set_panic_hook(opts.log_format()); server::listen(cx, &opts).await?; } diff --git a/query-engine/query-engine/src/opt.rs b/query-engine/query-engine/src/opt.rs index 83ee4bb7fdce..d2d9441f87f3 100644 --- a/query-engine/query-engine/src/opt.rs +++ b/query-engine/query-engine/src/opt.rs @@ -119,6 +119,11 @@ pub struct PrismaOpt { #[structopt(long, env = "PRISMA_ENGINE_PROTOCOL")] pub engine_protocol: Option, + /// The trace context (https://www.w3.org/TR/trace-context) for the engine initialization + /// as a JSON object with properties corresponding to the headers (e.g. `traceparent`). + #[structopt(long, env)] + pub trace_context: Option, + #[structopt(subcommand)] pub subcommand: Option, } diff --git a/query-engine/query-engine/src/server/mod.rs b/query-engine/query-engine/src/server/mod.rs index 01b61a07b6b4..e39509962336 100644 --- a/query-engine/query-engine/src/server/mod.rs +++ b/query-engine/query-engine/src/server/mod.rs @@ -3,19 +3,19 @@ use crate::features::Feature; use crate::{opt::PrismaOpt, PrismaResult}; use hyper::service::{make_service_fn, service_fn}; use hyper::{header::CONTENT_TYPE, Body, HeaderMap, Method, Request, Response, Server, StatusCode}; -use opentelemetry::trace::TraceContextExt; +use opentelemetry::trace::{TraceContextExt, TraceId}; use opentelemetry::{global, propagation::Extractor}; -use query_core::helpers::*; -use query_core::telemetry::capturing::TxTraceExt; -use query_core::{telemetry, ExtendedTransactionUserFacingError, TransactionOptions, TxId}; +use query_core::{ExtendedUserFacingError, TransactionOptions, TxId}; use request_handlers::{dmmf, render_graphql_schema, RequestBody, RequestHandler}; use serde::Serialize; use serde_json::json; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; -use std::time::Instant; -use tracing::{field, Instrument, Span}; +use std::time::{Duration, Instant}; +use telemetry::capturing::Capturer; +use telemetry::helpers::TraceParent; +use tracing::{Instrument, Span}; use tracing_opentelemetry::OpenTelemetrySpanExt; /// Starts up the graphql query engine server @@ -111,71 +111,22 @@ async fn request_handler(cx: Arc, req: Request) -> Result { let handler = RequestHandler::new(cx.executor(), cx.query_schema(), cx.engine_protocol()); let mut result = handler.handle(body, tx_id, traceparent).instrument(span).await; - if let telemetry::capturing::Capturer::Enabled(capturer) = &capture_config { + if let telemetry::capturing::Capturer::Enabled(capturer) = capturer { let telemetry = capturer.fetch_captures().await; if let Some(telemetry) = telemetry { result.set_extension("traces".to_owned(), json!(telemetry.traces)); @@ -202,7 +153,32 @@ async fn request_handler(cx: Arc, req: Request) -> Result tokio::time::sleep(timeout).await, + // Never return if timeout isn't set. + None => std::future::pending().await, + } + }; + + tokio::select! { + _ = query_timeout_fut => { + let captured_telemetry = if let telemetry::capturing::Capturer::Enabled(capturer) = capturer { + capturer.fetch_captures().await + } else { + None + }; + + // Note: this relies on the fact that client will rollback the transaction after the + // error. If the client continues using this transaction (and later commits it), data + // corruption might happen because some write queries (but not all of them) might be + // already executed by the database before the timeout is fired. + Ok(err_to_http_resp(query_core::CoreError::QueryTimeout, captured_telemetry)) + } + result = work => { + result + } + } } /// Expose the GraphQL playground if enabled. @@ -227,10 +203,7 @@ async fn metrics_handler(cx: Arc, req: Request) -> Result = match serde_json::from_slice(full_body.as_ref()) { - Ok(map) => map, - Err(_e) => HashMap::new(), - }; + let global_labels: HashMap = serde_json::from_slice(full_body.as_ref()).unwrap_or_default(); let response = if requested_json { let metrics = cx.metrics.to_json(global_labels); @@ -281,46 +254,22 @@ async fn transaction_start_handler(cx: Arc, req: Request) - let body_start = req.into_body(); let full_body = hyper::body::to_bytes(body_start).await?; - let mut tx_opts: TransactionOptions = serde_json::from_slice(full_body.as_ref()).unwrap(); - let tx_id = tx_opts.with_new_transaction_id(); - - // This is the span we use to instrument the execution of a transaction. This span will be open - // during the tx execution, and held in the ITXServer for that transaction (see ITXServer]) - let span = info_span!("prisma:engine:itx_runner", user_facing = true, itx_id = field::Empty); - - // If telemetry needs to be captured, we use the span trace_id to correlate the logs happening - // during the different operations within a transaction. The trace_id is propagated in the - // traceparent header, but if it's not present, we need to synthetically create one for the - // transaction. This is needed, in case the client is interested in capturing logs and not - // traces, because: - // - The client won't send a traceparent header - // - A transaction initial span is created here (prisma:engine:itx_runner) and stored in the - // ITXServer for that transaction - // - When a query comes in, the graphql handler process it, but we need to tell the capturer - // to start capturing logs, and for that we need a trace_id. There are two places were we - // could get that information from: - // - First, it's the traceparent, but the client didn't send it, because they are only - // interested in logs. - // - Second, it's the root span for the transaction, but it's not in scope but instead - // stored in the ITXServer, in a different tokio task. - // - // For the above reasons, we need to create a trace_id that we can predict and use accross the - // different operations happening within a transaction. So we do it by converting the tx_id - // into a trace_id, leaning on the fact that the tx_id has more entropy, and there's no - // information loss. - let capture_settings = capture_settings(&headers); - let traceparent = traceparent(&headers); - if traceparent.is_none() && capture_settings.logs_enabled() { - span.set_parent(tx_id.into_trace_context()) - } else { - span.set_parent(get_parent_span_context(&headers)) - } - let trace_id = span.context().span().span_context().trace_id(); - let capture_config = telemetry::capturing::capturer(trace_id, capture_settings); - if let telemetry::capturing::Capturer::Enabled(capturer) = &capture_config { - capturer.start_capturing().await; - } + let tx_opts = match serde_json::from_slice::(full_body.as_ref()) { + Ok(opts) => opts.with_new_transaction_id(), + Err(_) => { + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::from("Invalid transaction options")) + .unwrap()) + } + }; + + let (span, _traceparent, capturer) = setup_telemetry( + info_span!("prisma:engine:start_transaction", user_facing = true), + &headers, + ) + .await; let result = cx .executor @@ -328,12 +277,7 @@ async fn transaction_start_handler(cx: Arc, req: Request) - .instrument(span) .await; - let telemetry = if let telemetry::capturing::Capturer::Enabled(capturer) = &capture_config { - capturer.fetch_captures().await - } else { - None - }; - + let telemetry = capturer.try_fetch_captures().await; match result { Ok(tx_id) => { let result = if let Some(telemetry) = telemetry { @@ -355,20 +299,15 @@ async fn transaction_commit_handler( req: Request, tx_id: TxId, ) -> Result, hyper::Error> { - let capture_config = capture_config(req.headers(), tx_id.clone()); - - if let telemetry::capturing::Capturer::Enabled(capturer) = &capture_config { - capturer.start_capturing().await; - } + let (span, _traceparent, capturer) = setup_telemetry( + info_span!("prisma:engine:commit_transaction", user_facing = true), + req.headers(), + ) + .await; - let result = cx.executor.commit_tx(tx_id).await; - - let telemetry = if let telemetry::capturing::Capturer::Enabled(capturer) = &capture_config { - capturer.fetch_captures().await - } else { - None - }; + let result = cx.executor.commit_tx(tx_id).instrument(span).await; + let telemetry = capturer.try_fetch_captures().await; match result { Ok(_) => Ok(empty_json_to_http_resp(telemetry)), Err(err) => Ok(err_to_http_resp(err, telemetry)), @@ -380,20 +319,15 @@ async fn transaction_rollback_handler( req: Request, tx_id: TxId, ) -> Result, hyper::Error> { - let capture_config = capture_config(req.headers(), tx_id.clone()); + let (span, _traceparent, capturer) = setup_telemetry( + info_span!("prisma:engine:rollback_transaction", user_facing = true), + req.headers(), + ) + .await; - if let telemetry::capturing::Capturer::Enabled(capturer) = &capture_config { - capturer.start_capturing().await; - } - - let result = cx.executor.rollback_tx(tx_id).await; - - let telemetry = if let telemetry::capturing::Capturer::Enabled(capturer) = &capture_config { - capturer.fetch_captures().await - } else { - None - }; + let result = cx.executor.rollback_tx(tx_id).instrument(span).await; + let telemetry = capturer.try_fetch_captures().await; match result { Ok(_) => Ok(empty_json_to_http_resp(telemetry)), Err(err) => Ok(err_to_http_resp(err, telemetry)), @@ -457,11 +391,13 @@ fn err_to_http_resp( query_core::TransactionError::Unknown { reason: _ } => StatusCode::INTERNAL_SERVER_ERROR, }, + query_core::CoreError::QueryTimeout => StatusCode::REQUEST_TIMEOUT, + // All other errors are treated as 500s, most of these paths should never be hit, only connector errors may occur. _ => StatusCode::INTERNAL_SERVER_ERROR, }; - let mut err: ExtendedTransactionUserFacingError = err.into(); + let mut err: ExtendedUserFacingError = err.into(); if let Some(telemetry) = captured_telemetry { err.set_extension("traces".to_owned(), json!(telemetry.traces)); err.set_extension("logs".to_owned(), json!(telemetry.logs)); @@ -470,57 +406,86 @@ fn err_to_http_resp( build_json_response(status, &err) } -fn capture_config(headers: &HeaderMap, tx_id: TxId) -> telemetry::capturing::Capturer { - let capture_settings = capture_settings(headers); - let mut traceparent = traceparent(headers); - - if traceparent.is_none() && capture_settings.is_enabled() { - traceparent = Some(tx_id.as_traceparent()) - } - - let trace_id = get_trace_id_from_traceparent(traceparent.as_deref()); +async fn setup_telemetry(span: Span, headers: &HeaderMap) -> (Span, Option, Capturer) { + let capture_settings = { + let settings = headers + .get("X-capture-telemetry") + .and_then(|value| value.to_str().ok()) + .unwrap_or_default(); + telemetry::capturing::Settings::from(settings) + }; - telemetry::capturing::capturer(trace_id, capture_settings) -} + // Parse parent trace_id and span_id from `traceparent` header and attach them to the current + // context. Internally, this relies on the fact that global text map propagator was installed that + // can handle `traceparent` header (for example, `TraceContextPropagator`). + let parent_context = { + let extractor = HeaderExtractor(headers); + let context = global::get_text_map_propagator(|propagator| propagator.extract(&extractor)); + if context.span().span_context().is_valid() { + Some(context) + } else { + None + } + }; -#[allow(clippy::bind_instead_of_map)] -fn capture_settings(headers: &HeaderMap) -> telemetry::capturing::Settings { - const CAPTURE_TELEMETRY_HEADER: &str = "X-capture-telemetry"; - let s = if let Some(hv) = headers.get(CAPTURE_TELEMETRY_HEADER) { - hv.to_str().unwrap_or("") + let traceparent = if let Some(parent_context) = parent_context { + let requester_traceparent = TraceParent::from_remote_context(&parent_context); + span.set_parent(parent_context); + requester_traceparent + } else if capture_settings.is_enabled() { + // If tracing is disabled on the client but capturing the logs is enabled, we construct an + // artificial traceparent. Although the span corresponding to this traceparent doesn't + // actually exist, it is okay because no spans will be returned to the client in this case + // anyway, so they don't have to be valid. The reason we need this is because capturing the + // logs requires a trace ID to correlate the events with. This is not the right design: it + // seems to be based on the wrong idea that trace ID uniquely identifies a request (which + // is not the case in reality), and it is prone to race conditions and losing spans and + // logs when there are multiple concurrent Prisma operations in a single trace. Ironically, + // this makes log capturing more reliable in the code path with the fake traceparent hack + // than when a real traceparent is present. The drawback is that the fake traceparent leaks + // to the query logs. We could of course add a custom flag to `TraceParent` to indicate + // that it is synthetic (we can't use the sampled trace flag for it as it would prevent it + // from being processed by the `SpanProcessor`) and check it when adding the traceparent + // comment if we wanted a quick fix for that, but this problem existed for as long as + // capturing was implemented, and the `DataProxyEngine` works around it by stripping the + // phony traceparent comments before emitting the logs on the `PrismaClient` instance. So + // instead, we will fix the root cause of the problem by reworking the capturer to collect + // all spans and events which have the `span` created above as an ancestor and not rely on + // trace IDs at all. This will happen in a follow-up PR as part of Tracing GA work. + let traceparent = { + #[allow(deprecated)] + TraceParent::new_random() + }; + span.set_parent(traceparent.to_remote_context()); + Some(traceparent) } else { - "" + None }; - telemetry::capturing::Settings::from(s) -} - -fn traceparent(headers: &HeaderMap) -> Option { - const TRACEPARENT_HEADER: &str = "traceparent"; + let trace_id = traceparent + .as_ref() + .map(TraceParent::trace_id) + .unwrap_or(TraceId::INVALID); - let value = headers - .get(TRACEPARENT_HEADER) - .and_then(|h| h.to_str().ok()) - .map(|s| s.to_owned()); + let capturer = telemetry::capturing::capturer(trace_id, capture_settings); + capturer.try_start_capturing().await; - let is_valid_traceparent = |s: &String| s.split_terminator('-').count() >= 4; - - value.filter(is_valid_traceparent) + (span, traceparent, capturer) } -fn transaction_id(headers: &HeaderMap) -> Option { - const TRANSACTION_ID_HEADER: &str = "X-transaction-id"; +fn try_get_transaction_id(headers: &HeaderMap) -> Option { headers - .get(TRANSACTION_ID_HEADER) + .get("X-transaction-id") .and_then(|h| h.to_str().ok()) .map(TxId::from) } -/// If the client sends us a trace and span id, extracting a new context if the -/// headers are set. If not, returns current context. -fn get_parent_span_context(headers: &HeaderMap) -> opentelemetry::Context { - let extractor = HeaderExtractor(headers); - global::get_text_map_propagator(|propagator| propagator.extract(&extractor)) +fn query_timeout(headers: &HeaderMap) -> Option { + headers + .get("X-query-timeout") + .and_then(|h| h.to_str().ok()) + .and_then(|value| value.parse::().ok()) + .map(Duration::from_millis) } fn build_json_response(status_code: StatusCode, value: &T) -> Response diff --git a/query-engine/query-engine/src/tests/dmmf.rs b/query-engine/query-engine/src/tests/dmmf.rs index 8151f25bf17b..443b2c81a194 100644 --- a/query-engine/query-engine/src/tests/dmmf.rs +++ b/query-engine/query-engine/src/tests/dmmf.rs @@ -96,6 +96,7 @@ fn test_dmmf_cli_command(schema: &str) -> PrismaResult<()> { enable_telemetry_in_response: false, dataproxy_metric_override: false, engine_protocol: None, + trace_context: None, }; let cli_cmd = CliCommand::from_opt(&prisma_opt)?.unwrap(); diff --git a/query-engine/query-engine/src/tracer.rs b/query-engine/query-engine/src/tracer.rs index 8763ba892f4f..75f6630931b5 100644 --- a/query-engine/query-engine/src/tracer.rs +++ b/query-engine/query-engine/src/tracer.rs @@ -8,7 +8,7 @@ use opentelemetry::{ }, trace::TracerProvider, }; -use query_core::telemetry; + use std::io::{stdout, Stdout}; use std::{fmt::Debug, io::Write}; diff --git a/query-engine/query-structure/Cargo.toml b/query-engine/query-structure/Cargo.toml index f990c48ffd4b..847d97bee2ff 100644 --- a/query-engine/query-structure/Cargo.toml +++ b/query-engine/query-structure/Cargo.toml @@ -22,4 +22,4 @@ features = ["js"] [features] # Support for generating default UUID, CUID, nanoid and datetime values. -default_generators = ["uuid/v4", "cuid", "nanoid"] +default_generators = ["uuid/v4", "uuid/v7", "cuid", "nanoid"] diff --git a/query-engine/query-structure/src/default_value.rs b/query-engine/query-structure/src/default_value.rs index 9eaf6828d1c8..605224909e31 100644 --- a/query-engine/query-structure/src/default_value.rs +++ b/query-engine/query-structure/src/default_value.rs @@ -45,7 +45,7 @@ impl DefaultKind { /// Does this match @default(uuid(_))? pub fn is_uuid(&self) -> bool { - matches!(self, DefaultKind::Expression(generator) if generator.name == "uuid") + matches!(self, DefaultKind::Expression(generator) if generator.name.starts_with("uuid")) } pub fn unwrap_single(self) -> PrismaValue { @@ -186,8 +186,8 @@ impl ValueGenerator { ValueGenerator::new("cuid".to_owned(), vec![]).unwrap() } - pub fn new_uuid() -> Self { - ValueGenerator::new("uuid".to_owned(), vec![]).unwrap() + pub fn new_uuid(version: u8) -> Self { + ValueGenerator::new(format!("uuid({version})"), vec![]).unwrap() } pub fn new_nanoid(length: Option) -> Self { @@ -238,7 +238,7 @@ impl ValueGenerator { #[derive(Clone, Copy, PartialEq)] pub enum ValueGeneratorFn { - Uuid, + Uuid(u8), Cuid, Nanoid(Option), Now, @@ -251,7 +251,8 @@ impl ValueGeneratorFn { fn new(name: &str) -> std::result::Result { match name { "cuid" => Ok(Self::Cuid), - "uuid" => Ok(Self::Uuid), + "uuid" | "uuid(4)" => Ok(Self::Uuid(4)), + "uuid(7)" => Ok(Self::Uuid(7)), "now" => Ok(Self::Now), "autoincrement" => Ok(Self::Autoincrement), "sequence" => Ok(Self::Autoincrement), @@ -265,7 +266,7 @@ impl ValueGeneratorFn { #[cfg(feature = "default_generators")] fn invoke(&self) -> Option { match self { - Self::Uuid => Some(Self::generate_uuid()), + Self::Uuid(version) => Some(Self::generate_uuid(*version)), Self::Cuid => Some(Self::generate_cuid()), Self::Nanoid(length) => Some(Self::generate_nanoid(length)), Self::Now => Some(Self::generate_now()), @@ -282,8 +283,12 @@ impl ValueGeneratorFn { } #[cfg(feature = "default_generators")] - fn generate_uuid() -> PrismaValue { - PrismaValue::Uuid(uuid::Uuid::new_v4()) + fn generate_uuid(version: u8) -> PrismaValue { + PrismaValue::Uuid(match version { + 4 => uuid::Uuid::new_v4(), + 7 => uuid::Uuid::now_v7(), + _ => panic!("Unknown UUID version: {}", version), + }) } #[cfg(feature = "default_generators")] @@ -337,8 +342,16 @@ mod tests { } #[test] - fn default_value_is_uuid() { - let uuid_default = DefaultValue::new_expression(ValueGenerator::new_uuid()); + fn default_value_is_uuidv4() { + let uuid_default = DefaultValue::new_expression(ValueGenerator::new_uuid(4)); + + assert!(uuid_default.is_uuid()); + assert!(!uuid_default.is_autoincrement()); + } + + #[test] + fn default_value_is_uuidv7() { + let uuid_default = DefaultValue::new_expression(ValueGenerator::new_uuid(7)); assert!(uuid_default.is_uuid()); assert!(!uuid_default.is_autoincrement()); diff --git a/query-engine/query-structure/src/field/scalar.rs b/query-engine/query-structure/src/field/scalar.rs index 8aeb9c7f47aa..5fc10acddd13 100644 --- a/query-engine/query-structure/src/field/scalar.rs +++ b/query-engine/query-structure/src/field/scalar.rs @@ -244,8 +244,15 @@ pub fn dml_default_kind(default_value: &ast::Expression, scalar_type: Option { DefaultKind::Expression(ValueGenerator::new_sequence(Vec::new())) } - ast::Expression::Function(funcname, _args, _) if funcname == "uuid" => { - DefaultKind::Expression(ValueGenerator::new_uuid()) + ast::Expression::Function(funcname, args, _) if funcname == "uuid" => { + let version = args + .arguments + .first() + .and_then(|arg| arg.value.as_numeric_value()) + .map(|(val, _)| val.parse::().unwrap()) + .unwrap_or(4); + + DefaultKind::Expression(ValueGenerator::new_uuid(version)) } ast::Expression::Function(funcname, _args, _) if funcname == "cuid" => { DefaultKind::Expression(ValueGenerator::new_cuid()) diff --git a/query-engine/query-structure/src/selection_result.rs b/query-engine/query-structure/src/selection_result.rs index 74b097506fb6..b31a77aae3df 100644 --- a/query-engine/query-structure/src/selection_result.rs +++ b/query-engine/query-structure/src/selection_result.rs @@ -12,8 +12,7 @@ impl std::fmt::Debug for SelectionResult { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_list() .entries( - &self - .pairs + self.pairs .iter() .map(|pair| (format!("{}", pair.0), pair.1.clone())) .collect_vec(), diff --git a/query-engine/request-handlers/Cargo.toml b/query-engine/request-handlers/Cargo.toml index fe9a66b449a5..8d7e8b4e2222 100644 --- a/query-engine/request-handlers/Cargo.toml +++ b/query-engine/request-handlers/Cargo.toml @@ -8,13 +8,14 @@ psl.workspace = true query-structure = { path = "../query-structure" } query-core = { path = "../core" } user-facing-errors = { path = "../../libs/user-facing-errors" } +telemetry = { path = "../../libs/telemetry" } quaint.workspace = true dmmf_crate = { path = "../dmmf", package = "dmmf" } itertools.workspace = true graphql-parser = { git = "https://github.com/prisma/graphql-parser", optional = true } serde.workspace = true serde_json.workspace = true -futures = "0.3" +futures.workspace = true indexmap.workspace = true bigdecimal = "0.3" thiserror = "1" @@ -75,7 +76,7 @@ all = [ graphql-protocol = ["query-core/graphql-protocol", "dep:graphql-parser"] [build-dependencies] -cfg_aliases = "0.2.0" +cfg_aliases = "0.2.1" [[bench]] name = "query_planning_bench" diff --git a/query-engine/request-handlers/src/handler.rs b/query-engine/request-handlers/src/handler.rs index a8e6ba8b8a9b..123af6541c45 100644 --- a/query-engine/request-handlers/src/handler.rs +++ b/query-engine/request-handlers/src/handler.rs @@ -13,6 +13,7 @@ use query_core::{ }; use query_structure::{parse_datetime, stringify_datetime, PrismaValue}; use std::{collections::HashMap, fmt, panic::AssertUnwindSafe, str::FromStr}; +use telemetry::helpers::TraceParent; type ArgsToResult = (HashMap, IndexMap); @@ -41,24 +42,34 @@ impl<'a> RequestHandler<'a> { } } - pub async fn handle(&self, body: RequestBody, tx_id: Option, trace_id: Option) -> PrismaResponse { + pub async fn handle( + &self, + body: RequestBody, + tx_id: Option, + traceparent: Option, + ) -> PrismaResponse { tracing::debug!("Incoming GraphQL query: {:?}", &body); match body.into_doc(self.query_schema) { - Ok(QueryDocument::Single(query)) => self.handle_single(query, tx_id, trace_id).await, + Ok(QueryDocument::Single(query)) => self.handle_single(query, tx_id, traceparent).await, Ok(QueryDocument::Multi(batch)) => match batch.compact(self.query_schema) { BatchDocument::Multi(batch, transaction) => { - self.handle_batch(batch, transaction, tx_id, trace_id).await + self.handle_batch(batch, transaction, tx_id, traceparent).await } - BatchDocument::Compact(compacted) => self.handle_compacted(compacted, tx_id, trace_id).await, + BatchDocument::Compact(compacted) => self.handle_compacted(compacted, tx_id, traceparent).await, }, Err(err) => PrismaResponse::Single(GQLError::from_handler_error(err).into()), } } - async fn handle_single(&self, query: Operation, tx_id: Option, trace_id: Option) -> PrismaResponse { - let gql_response = match AssertUnwindSafe(self.handle_request(query, tx_id, trace_id)) + async fn handle_single( + &self, + query: Operation, + tx_id: Option, + traceparent: Option, + ) -> PrismaResponse { + let gql_response = match AssertUnwindSafe(self.handle_request(query, tx_id, traceparent)) .catch_unwind() .await { @@ -75,14 +86,14 @@ impl<'a> RequestHandler<'a> { queries: Vec, transaction: Option, tx_id: Option, - trace_id: Option, + traceparent: Option, ) -> PrismaResponse { match AssertUnwindSafe(self.executor.execute_all( tx_id, queries, transaction, self.query_schema.clone(), - trace_id, + traceparent, self.engine_protocol, )) .catch_unwind() @@ -108,7 +119,7 @@ impl<'a> RequestHandler<'a> { &self, document: CompactedDocument, tx_id: Option, - trace_id: Option, + traceparent: Option, ) -> PrismaResponse { let plural_name = document.plural_name(); let singular_name = document.single_name(); @@ -117,7 +128,7 @@ impl<'a> RequestHandler<'a> { let arguments = document.arguments; let nested_selection = document.nested_selection; - match AssertUnwindSafe(self.handle_request(document.operation, tx_id, trace_id)) + match AssertUnwindSafe(self.handle_request(document.operation, tx_id, traceparent)) .catch_unwind() .await { @@ -200,14 +211,14 @@ impl<'a> RequestHandler<'a> { &self, query_doc: Operation, tx_id: Option, - trace_id: Option, + traceparent: Option, ) -> query_core::Result { self.executor .execute( tx_id, query_doc, self.query_schema.clone(), - trace_id, + traceparent, self.engine_protocol, ) .await @@ -224,20 +235,29 @@ impl<'a> RequestHandler<'a> { } fn compare_args(left: &HashMap, right: &HashMap) -> bool { - left.iter().all(|(key, left_value)| { - right + let (large, small) = if left.len() > right.len() { + (&left, &right) + } else { + (&right, &left) + }; + + small.iter().all(|(key, small_value)| { + large .get(key) - .map_or(false, |right_value| Self::compare_values(left_value, right_value)) + .map_or(false, |large_value| Self::compare_values(small_value, large_value)) }) } /// Compares two PrismaValues with special comparisons rules needed because user-inputted values are coerced differently than response values. + /// /// We need this when comparing user-inputted values with query response values in the context of compacted queries. + /// /// Here are the cases covered: /// - DateTime/String: User-input: DateTime / Response: String /// - Int/BigInt: User-input: Int / Response: BigInt /// - (JSON protocol only) Custom types (eg: { "$type": "BigInt", value: "1" }): User-input: Scalar / Response: Object /// - (JSON protocol only) String/Enum: User-input: String / Response: Enum + /// /// This should likely _not_ be used outside of this specific context. fn compare_values(left: &ArgumentValue, right: &ArgumentValue) -> bool { match (left, right) { diff --git a/query-engine/schema/src/identifier_type.rs b/query-engine/schema/src/identifier_type.rs index d4cf309a299f..0aa77140139c 100644 --- a/query-engine/schema/src/identifier_type.rs +++ b/query-engine/schema/src/identifier_type.rs @@ -186,7 +186,12 @@ impl std::fmt::Display for IdentifierType { IdentifierType::ToOneRelationFilterInput(related_model, arity) => { let nullable = if arity.is_optional() { "Nullable" } else { "" }; - write!(f, "{}{}RelationFilter", capitalize(related_model.name()), nullable) + write!( + f, + "{}{}ScalarRelationFilter", + capitalize(related_model.name()), + nullable + ) } IdentifierType::ToOneCompositeFilterInput(ct, arity) => { let nullable = if arity.is_optional() { "Nullable" } else { "" }; diff --git a/query-engine/schema/test-schemas/odoo.prisma b/query-engine/schema/test-schemas/odoo.prisma index a7410606f4c0..7da843cf2fd5 100644 --- a/query-engine/schema/test-schemas/odoo.prisma +++ b/query-engine/schema/test-schemas/odoo.prisma @@ -1,5 +1,5 @@ datasource db { - provider = "postgresql" + provider = "postgres" url = env("DB_URL") } diff --git a/rust-toolchain.toml b/rust-toolchain.toml index fd5803cb5fba..8f2d5ed34666 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "1.78.0" +channel = "1.82.0" components = ["clippy", "rustfmt", "rust-src"] targets = [ # WASM target for serverless and edge environments. @@ -17,5 +17,6 @@ targets = [ # Server targets we support. "x86_64-unknown-linux-musl", "aarch64-unknown-linux-gnu", - "aarch64-unknown-linux-musl" + "aarch64-unknown-linux-musl", + "aarch64-apple-darwin" ] diff --git a/schema-engine/cli/Cargo.toml b/schema-engine/cli/Cargo.toml index bfb136f582df..65718f8c8cd5 100644 --- a/schema-engine/cli/Cargo.toml +++ b/schema-engine/cli/Cargo.toml @@ -36,6 +36,9 @@ connection-string.workspace = true expect-test = "1.4.0" quaint = { workspace = true, features = ["all-native"] } +[build-dependencies] +build-utils.path = "../../libs/build-utils" + [[bin]] name = "schema-engine" path = "src/main.rs" diff --git a/schema-engine/cli/build.rs b/schema-engine/cli/build.rs index 2e8fe20c0503..33aded23a4a5 100644 --- a/schema-engine/cli/build.rs +++ b/schema-engine/cli/build.rs @@ -1,11 +1,3 @@ -use std::process::Command; - -fn store_git_commit_hash() { - let output = Command::new("git").args(["rev-parse", "HEAD"]).output().unwrap(); - let git_hash = String::from_utf8(output.stdout).unwrap(); - println!("cargo:rustc-env=GIT_HASH={git_hash}"); -} - fn main() { - store_git_commit_hash(); + build_utils::store_git_commit_hash_in_env(); } diff --git a/schema-engine/cli/tests/cli_tests.rs b/schema-engine/cli/tests/cli_tests.rs index 973345ac4033..fd89b9bce745 100644 --- a/schema-engine/cli/tests/cli_tests.rs +++ b/schema-engine/cli/tests/cli_tests.rs @@ -51,7 +51,7 @@ where .or_else(|| panic_payload.downcast_ref::().map(|s| s.to_owned())) .unwrap_or_default(); - panic!("Error: '{}'", res) + panic!("Error: '{}'", res); } } } diff --git a/schema-engine/connectors/mongodb-schema-connector/Cargo.toml b/schema-engine/connectors/mongodb-schema-connector/Cargo.toml index af21f75c5b9a..c5c8188b4399 100644 --- a/schema-engine/connectors/mongodb-schema-connector/Cargo.toml +++ b/schema-engine/connectors/mongodb-schema-connector/Cargo.toml @@ -14,14 +14,15 @@ user-facing-errors = { path = "../../../libs/user-facing-errors", features = [ ] } enumflags2.workspace = true -futures = "0.3" -mongodb = "2.8.0" +futures.workspace = true +mongodb.workspace = true +bson.workspace = true serde_json.workspace = true tokio.workspace = true tracing.workspace = true convert_case = "0.6.0" once_cell = "1.8.0" -regex = "1.7.3" +regex.workspace = true indoc.workspace = true [dev-dependencies] @@ -32,4 +33,4 @@ url.workspace = true expect-test = "1" names = { version = "0.12", default-features = false } itertools.workspace = true -indoc.workspace = true \ No newline at end of file +indoc.workspace = true diff --git a/schema-engine/connectors/mongodb-schema-connector/src/client_wrapper.rs b/schema-engine/connectors/mongodb-schema-connector/src/client_wrapper.rs index 239d6d218c55..b2da59b73986 100644 --- a/schema-engine/connectors/mongodb-schema-connector/src/client_wrapper.rs +++ b/schema-engine/connectors/mongodb-schema-connector/src/client_wrapper.rs @@ -48,11 +48,12 @@ impl Client { pub(crate) async fn drop_database(&self) -> ConnectorResult<()> { self.database() - .drop(Some( + .drop() + .with_options( mongodb::options::DropDatabaseOptions::builder() .write_concern(WriteConcern::builder().journal(true).build()) .build(), - )) + ) .await .map_err(mongo_error_to_connector_error) } diff --git a/schema-engine/connectors/mongodb-schema-connector/src/lib.rs b/schema-engine/connectors/mongodb-schema-connector/src/lib.rs index bab0531b61c4..7020fb8d5ebe 100644 --- a/schema-engine/connectors/mongodb-schema-connector/src/lib.rs +++ b/schema-engine/connectors/mongodb-schema-connector/src/lib.rs @@ -215,6 +215,13 @@ impl SchemaConnector for MongoDbSchemaConnector { fn extract_namespaces(&self, _schema: &DatabaseSchema) -> Option { None } + + fn introspect_sql( + &mut self, + _input: IntrospectSqlQueryInput, + ) -> BoxFuture<'_, ConnectorResult> { + unreachable!() + } } fn unsupported_command_error() -> ConnectorError { diff --git a/schema-engine/connectors/mongodb-schema-connector/src/migration_step_applier.rs b/schema-engine/connectors/mongodb-schema-connector/src/migration_step_applier.rs index 5d0b1888adff..eed4c39ff72b 100644 --- a/schema-engine/connectors/mongodb-schema-connector/src/migration_step_applier.rs +++ b/schema-engine/connectors/mongodb-schema-connector/src/migration_step_applier.rs @@ -3,7 +3,7 @@ use crate::{ migration::{MongoDbMigration, MongoDbMigrationStep}, MongoDbSchemaConnector, }; -use mongodb::bson::{self, Bson, Document}; +use bson::{self, Bson, Document}; use schema_connector::{ConnectorResult, Migration, SchemaConnector}; impl MongoDbSchemaConnector { @@ -24,7 +24,7 @@ impl MongoDbSchemaConnector { for step in migration.steps.iter() { match step { MongoDbMigrationStep::CreateCollection(id) => { - db.create_collection(migration.next.walk_collection(*id).name(), None) + db.create_collection(migration.next.walk_collection(*id).name()) .await .map_err(mongo_error_to_connector_error)?; } @@ -44,7 +44,7 @@ impl MongoDbSchemaConnector { index_options.unique = Some(index.is_unique()); index_model.options = Some(index_options); collection - .create_index(index_model, None) + .create_index(index_model) .await .map_err(mongo_error_to_connector_error)?; } @@ -52,7 +52,7 @@ impl MongoDbSchemaConnector { let index = migration.previous.walk_index(*index_id); let collection: mongodb::Collection = db.collection(index.collection().name()); collection - .drop_index(index.name(), None) + .drop_index(index.name()) .await .map_err(mongo_error_to_connector_error)?; } diff --git a/schema-engine/connectors/mongodb-schema-connector/src/sampler.rs b/schema-engine/connectors/mongodb-schema-connector/src/sampler.rs index ec2d0c55c6fa..bdbd8b80f38a 100644 --- a/schema-engine/connectors/mongodb-schema-connector/src/sampler.rs +++ b/schema-engine/connectors/mongodb-schema-connector/src/sampler.rs @@ -17,7 +17,7 @@ use std::borrow::Cow; /// maximum of SAMPLE_SIZE documents for their fields with the following rules: /// /// - If the same field differs in types between documents, takes the most -/// common type or if even, the latest type and adds a warning. +/// common type or if even, the latest type and adds a warning. /// - Missing fields count as null. /// - Indices are taken, but not if they are partial. pub(super) async fn sample( @@ -49,7 +49,8 @@ pub(super) async fn sample( let mut documents = database .collection::(collection.name()) - .aggregate(vec![doc! { "$sample": { "size": SAMPLE_SIZE } }], Some(options)) + .aggregate(vec![doc! { "$sample": { "size": SAMPLE_SIZE } }]) + .with_options(options) .await?; while let Some(document) = documents.try_next().await? { diff --git a/schema-engine/connectors/mongodb-schema-connector/src/sampler/field_type.rs b/schema-engine/connectors/mongodb-schema-connector/src/sampler/field_type.rs index be74e848a19b..6b9867d96548 100644 --- a/schema-engine/connectors/mongodb-schema-connector/src/sampler/field_type.rs +++ b/schema-engine/connectors/mongodb-schema-connector/src/sampler/field_type.rs @@ -1,5 +1,5 @@ use super::statistics::Name; -use mongodb::bson::Bson; +use bson::Bson; use std::fmt; #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] diff --git a/schema-engine/connectors/mongodb-schema-connector/src/sampler/statistics.rs b/schema-engine/connectors/mongodb-schema-connector/src/sampler/statistics.rs index f89050cd227b..c019132ca2c5 100644 --- a/schema-engine/connectors/mongodb-schema-connector/src/sampler/statistics.rs +++ b/schema-engine/connectors/mongodb-schema-connector/src/sampler/statistics.rs @@ -12,9 +12,9 @@ use schema_connector::{ }; use super::field_type::FieldType; +use bson::{Bson, Document}; use convert_case::{Case, Casing}; use datamodel_renderer as renderer; -use mongodb::bson::{Bson, Document}; use mongodb_schema_describer::{CollectionWalker, IndexWalker}; use once_cell::sync::Lazy; use psl::datamodel_connector::constraint_names::ConstraintNames; diff --git a/schema-engine/connectors/mongodb-schema-connector/tests/introspection/basic/mod.rs b/schema-engine/connectors/mongodb-schema-connector/tests/introspection/basic/mod.rs index 4770e25e7b27..51d0bc30c116 100644 --- a/schema-engine/connectors/mongodb-schema-connector/tests/introspection/basic/mod.rs +++ b/schema-engine/connectors/mongodb-schema-connector/tests/introspection/basic/mod.rs @@ -4,7 +4,7 @@ use mongodb::{bson::doc, options::CreateCollectionOptions}; #[test] fn empty_collection() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; Ok(()) }); @@ -22,7 +22,7 @@ fn empty_collection() { fn integer_id() { let res = introspect(|db| async move { let collection = db.collection("A"); - collection.insert_one(doc! { "_id": 12345 }, None).await.unwrap(); + collection.insert_one(doc! { "_id": 12345 }).await.unwrap(); Ok(()) }); @@ -39,17 +39,17 @@ fn integer_id() { #[test] fn multiple_collections_with_data() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"first": "Musti"}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); - db.create_collection("B", None).await?; + db.create_collection("B").await?; let collection = db.collection("B"); let docs = vec![doc! {"second": "Naukio"}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); Ok(()) }); @@ -72,9 +72,8 @@ fn multiple_collections_with_data() { #[test] fn collection_with_json_schema() { let res = introspect(|db| async move { - db.create_collection( - "A", - Some( + db.create_collection("A") + .with_options( CreateCollectionOptions::builder() .validator(Some(mongodb::bson::doc! { "$jsonSchema": { @@ -94,9 +93,8 @@ fn collection_with_json_schema() { } })) .build(), - ), - ) - .await?; + ) + .await?; Ok(()) }); @@ -122,16 +120,14 @@ fn collection_with_json_schema() { #[test] fn capped_collection() { let res = introspect(|db| async move { - db.create_collection( - "A", - Some( + db.create_collection("A") + .with_options( CreateCollectionOptions::builder() .capped(Some(true)) .size(Some(1024)) .build(), - ), - ) - .await?; + ) + .await?; Ok(()) }); diff --git a/schema-engine/connectors/mongodb-schema-connector/tests/introspection/dirty_data/mod.rs b/schema-engine/connectors/mongodb-schema-connector/tests/introspection/dirty_data/mod.rs index 67cee19115f5..b4a5d0e0eb83 100644 --- a/schema-engine/connectors/mongodb-schema-connector/tests/introspection/dirty_data/mod.rs +++ b/schema-engine/connectors/mongodb-schema-connector/tests/introspection/dirty_data/mod.rs @@ -1,14 +1,14 @@ use crate::introspection::test_api::*; -use mongodb::bson::{doc, Bson, DateTime, Timestamp}; +use bson::{doc, Bson, DateTime, Timestamp}; #[test] fn explicit_id_field() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"id": 1}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); Ok(()) }); @@ -27,10 +27,11 @@ fn explicit_id_field() { fn mixed_id_types() { let res = introspect(|db| async move { db.collection("A") - .insert_many( - vec![doc! { "_id": 12345 }, doc! { "_id": "foo" }, doc! { "foo": false }], - None, - ) + .insert_many(vec![ + doc! { "_id": 12345 }, + doc! { "_id": "foo" }, + doc! { "foo": false }, + ]) .await .unwrap(); @@ -51,11 +52,11 @@ fn mixed_id_types() { #[test] fn mixing_types() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"first": "Musti"}, doc! {"first": 1i32}, doc! {"first": null}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); Ok(()) }); @@ -83,7 +84,7 @@ fn mixing_types() { #[test] fn mixing_types_with_the_same_base_type() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![ @@ -92,7 +93,7 @@ fn mixing_types_with_the_same_base_type() { doc! {"first": null}, ]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); Ok(()) }); diff --git a/schema-engine/connectors/mongodb-schema-connector/tests/introspection/index/mod.rs b/schema-engine/connectors/mongodb-schema-connector/tests/introspection/index/mod.rs index 584bc42b241f..0f8f13e1c549 100644 --- a/schema-engine/connectors/mongodb-schema-connector/tests/introspection/index/mod.rs +++ b/schema-engine/connectors/mongodb-schema-connector/tests/introspection/index/mod.rs @@ -9,11 +9,11 @@ use schema_connector::CompositeTypeDepth; #[test] fn single_column_normal_index() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "age": 9}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -22,7 +22,7 @@ fn single_column_normal_index() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -43,11 +43,11 @@ fn single_column_normal_index() { #[test] fn single_column_composite_index() { let res = introspect(|db| async move { - db.create_collection("Cat", None).await?; + db.create_collection("Cat").await?; let collection = db.collection("Cat"); let docs = vec![doc! {"name": "Musti", "address": { "number": 27 } }]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -56,7 +56,7 @@ fn single_column_composite_index() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -81,11 +81,11 @@ fn single_column_composite_index() { #[test] fn single_column_composite_array_index() { let res = introspect(|db| async move { - db.create_collection("Cat", None).await?; + db.create_collection("Cat").await?; let collection = db.collection("Cat"); let docs = vec![doc! {"name": "Musti", "addresses": [ { "number": 27 }, { "number": 28 } ] }]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -94,7 +94,7 @@ fn single_column_composite_array_index() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -119,11 +119,11 @@ fn single_column_composite_array_index() { #[test] fn single_column_deep_composite_index() { let res = introspect(|db| async move { - db.create_collection("Cat", None).await?; + db.create_collection("Cat").await?; let collection = db.collection("Cat"); let docs = vec![doc! {"name": "Musti", "address": { "special": { "number": 27 } } }]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -132,7 +132,7 @@ fn single_column_deep_composite_index() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -161,11 +161,11 @@ fn single_column_deep_composite_index() { #[test] fn single_column_descending_index() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "age": 9}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -174,7 +174,7 @@ fn single_column_descending_index() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -195,11 +195,11 @@ fn single_column_descending_index() { #[test] fn single_column_descending_composite_index() { let res = introspect(|db| async move { - db.create_collection("Cat", None).await?; + db.create_collection("Cat").await?; let collection = db.collection("Cat"); let docs = vec![doc! {"name": "Musti", "address": { "number": 27 }}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -208,7 +208,7 @@ fn single_column_descending_composite_index() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -233,11 +233,11 @@ fn single_column_descending_composite_index() { #[test] fn single_column_fulltext_index() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "age": 9}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -246,7 +246,7 @@ fn single_column_fulltext_index() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -267,11 +267,11 @@ fn single_column_fulltext_index() { #[test] fn single_column_fulltext_composite_index() { let res = introspect(|db| async move { - db.create_collection("Cat", None).await?; + db.create_collection("Cat").await?; let collection = db.collection("Cat"); let docs = vec![doc! {"name": "Musti", "address": { "street": "Meowallee" }}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -280,7 +280,7 @@ fn single_column_fulltext_composite_index() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -305,12 +305,12 @@ fn single_column_fulltext_composite_index() { #[test] fn single_array_column_fulltext_composite_index() { let res = introspect(|db| async move { - db.create_collection("Cat", None).await?; + db.create_collection("Cat").await?; let collection = db.collection("Cat"); let docs = vec![doc! {"name": "Musti", "addresses": [ { "street": "Meowallee" }, { "street": "Purrstrasse" } ] }]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -319,7 +319,7 @@ fn single_array_column_fulltext_composite_index() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -344,11 +344,11 @@ fn single_array_column_fulltext_composite_index() { #[test] fn multi_column_fulltext_index() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "title": "cat", "age": 9}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -357,7 +357,7 @@ fn multi_column_fulltext_index() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -379,11 +379,11 @@ fn multi_column_fulltext_index() { #[test] fn multi_column_fulltext_composite_index() { let res = introspect(|db| async move { - db.create_collection("Cat", None).await?; + db.create_collection("Cat").await?; let collection = db.collection("Cat"); let docs = vec![doc! {"name": "Musti", "address": { "street": "Meowallee", "city": "Derplin" } }]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -392,7 +392,7 @@ fn multi_column_fulltext_composite_index() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -418,11 +418,11 @@ fn multi_column_fulltext_composite_index() { #[test] fn multi_column_fulltext_index_with_desc_in_end() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "title": "cat", "age": 9}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -431,7 +431,7 @@ fn multi_column_fulltext_index_with_desc_in_end() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -453,11 +453,11 @@ fn multi_column_fulltext_index_with_desc_in_end() { #[test] fn multi_column_fulltext_composite_index_with_desc_in_end() { let res = introspect(|db| async move { - db.create_collection("Cat", None).await?; + db.create_collection("Cat").await?; let collection = db.collection("Cat"); let docs = vec![doc! {"name": "Musti", "address": { "street": "Meowallee", "number": 69 }}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -466,7 +466,7 @@ fn multi_column_fulltext_composite_index_with_desc_in_end() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -492,11 +492,11 @@ fn multi_column_fulltext_composite_index_with_desc_in_end() { #[test] fn multi_column_fulltext_index_with_desc_in_beginning() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "title": "cat", "age": 9}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -505,7 +505,7 @@ fn multi_column_fulltext_index_with_desc_in_beginning() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -527,11 +527,11 @@ fn multi_column_fulltext_index_with_desc_in_beginning() { #[test] fn multi_column_fulltext_composite_index_with_desc_in_beginning() { let res = introspect(|db| async move { - db.create_collection("Cat", None).await?; + db.create_collection("Cat").await?; let collection = db.collection("Cat"); let docs = vec![doc! {"name": "Musti", "address": { "street": "Meowallee", "number": 69 }}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -540,7 +540,7 @@ fn multi_column_fulltext_composite_index_with_desc_in_beginning() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -566,11 +566,11 @@ fn multi_column_fulltext_composite_index_with_desc_in_beginning() { #[test] fn multi_column_fulltext_index_with_asc_in_end() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "title": "cat", "age": 9}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -579,7 +579,7 @@ fn multi_column_fulltext_index_with_asc_in_end() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -601,11 +601,11 @@ fn multi_column_fulltext_index_with_asc_in_end() { #[test] fn multi_column_fulltext_index_with_asc_in_beginning() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "title": "cat", "age": 9}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -614,7 +614,7 @@ fn multi_column_fulltext_index_with_asc_in_beginning() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -636,12 +636,12 @@ fn multi_column_fulltext_index_with_asc_in_beginning() { #[test] fn multi_column_fulltext_index_with_asc_in_beginning_desc_in_end() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! { "name": "Musti", "title": "cat", "age": 9, "weight": 5 }]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder() .unique(Some(false)) @@ -653,7 +653,7 @@ fn multi_column_fulltext_index_with_asc_in_beginning_desc_in_end() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -678,11 +678,11 @@ fn fultext_index_without_preview_flag() { let depth = CompositeTypeDepth::Infinite; let res = introspect_features(depth, Default::default(), |db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "age": 9}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -691,7 +691,7 @@ fn fultext_index_without_preview_flag() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -712,11 +712,11 @@ fn fultext_composite_index_without_preview_flag() { let depth = CompositeTypeDepth::Infinite; let res = introspect_features(depth, Default::default(), |db| async move { - db.create_collection("Cat", None).await?; + db.create_collection("Cat").await?; let collection = db.collection("Cat"); let docs = vec![doc! {"name": "Musti", "address": { "street": "Meowallee" } }]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -725,7 +725,7 @@ fn fultext_composite_index_without_preview_flag() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -748,11 +748,11 @@ fn fultext_composite_index_without_preview_flag() { #[test] fn index_pointing_to_a_renamed_field() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "_age": 9}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -761,7 +761,7 @@ fn index_pointing_to_a_renamed_field() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -782,11 +782,11 @@ fn index_pointing_to_a_renamed_field() { #[test] fn composite_index_pointing_to_a_renamed_field() { let res = introspect(|db| async move { - db.create_collection("Cat", None).await?; + db.create_collection("Cat").await?; let collection = db.collection("Cat"); let docs = vec![doc! { "name": "Musti", "info": { "_age": 9} }]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -795,7 +795,7 @@ fn composite_index_pointing_to_a_renamed_field() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -820,11 +820,11 @@ fn composite_index_pointing_to_a_renamed_field() { #[test] fn single_column_normal_index_default_name() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "age": 9}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder() .unique(Some(false)) @@ -836,7 +836,7 @@ fn single_column_normal_index_default_name() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -857,11 +857,11 @@ fn single_column_normal_index_default_name() { #[test] fn single_column_normal_composite_index_default_name() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "info": { "age": 9} }]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder() .unique(Some(false)) @@ -873,7 +873,7 @@ fn single_column_normal_composite_index_default_name() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -898,11 +898,11 @@ fn single_column_normal_composite_index_default_name() { #[test] fn multi_column_normal_index() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "age": 9}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -911,7 +911,7 @@ fn multi_column_normal_index() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -932,11 +932,11 @@ fn multi_column_normal_index() { #[test] fn single_column_unique_index() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "age": 9}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(true)).build(); @@ -945,7 +945,7 @@ fn single_column_unique_index() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -964,11 +964,11 @@ fn single_column_unique_index() { #[test] fn single_column_unique_composite_index() { let res = introspect(|db| async move { - db.create_collection("Cat", None).await?; + db.create_collection("Cat").await?; let collection = db.collection("Cat"); let docs = vec![doc! {"name": "Musti", "info": { "age": 9 } }]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(true)).build(); @@ -977,7 +977,7 @@ fn single_column_unique_composite_index() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1002,11 +1002,11 @@ fn single_column_unique_composite_index() { #[test] fn single_array_column_unique_composite_index() { let res = introspect(|db| async move { - db.create_collection("Cat", None).await?; + db.create_collection("Cat").await?; let collection = db.collection("Cat"); let docs = vec![doc! {"name": "Musti", "infos": [ { "age": 9 }, { "age": 10 } ] }]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(true)).build(); @@ -1015,7 +1015,7 @@ fn single_array_column_unique_composite_index() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1040,11 +1040,11 @@ fn single_array_column_unique_composite_index() { #[test] fn single_column_unique_index_default_name() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "age": 9}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder() .unique(Some(true)) @@ -1056,7 +1056,7 @@ fn single_column_unique_index_default_name() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1075,11 +1075,11 @@ fn single_column_unique_index_default_name() { #[test] fn single_column_unique_composite_index_default_name() { let res = introspect(|db| async move { - db.create_collection("Cat", None).await?; + db.create_collection("Cat").await?; let collection = db.collection("Cat"); let docs = vec![doc! {"name": "Musti", "info": { "age": 9 } }]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder() .unique(Some(true)) @@ -1091,7 +1091,7 @@ fn single_column_unique_composite_index_default_name() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1116,11 +1116,11 @@ fn single_column_unique_composite_index_default_name() { #[test] fn multi_column_unique_index() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "age": 9}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(true)).build(); @@ -1129,7 +1129,7 @@ fn multi_column_unique_index() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1150,11 +1150,11 @@ fn multi_column_unique_index() { #[test] fn multi_column_unique_composite_index() { let res = introspect(|db| async move { - db.create_collection("Cat", None).await?; + db.create_collection("Cat").await?; let collection = db.collection("Cat"); let docs = vec![doc! {"name": "Musti", "info": { "age": 9 } }]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(true)).build(); @@ -1163,7 +1163,7 @@ fn multi_column_unique_composite_index() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1188,11 +1188,11 @@ fn multi_column_unique_composite_index() { #[test] fn unsupported_types_in_a_unique_index() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"data": Bson::JavaScriptCode("let a = 1 + 1;".to_string())}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(true)).build(); @@ -1201,7 +1201,7 @@ fn unsupported_types_in_a_unique_index() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1219,11 +1219,11 @@ fn unsupported_types_in_a_unique_index() { #[test] fn unsupported_types_in_an_index() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"data": Bson::JavaScriptCode("let a = 1 + 1;".to_string())}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -1232,7 +1232,7 @@ fn unsupported_types_in_an_index() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1261,11 +1261,11 @@ fn unsupported_types_in_an_index() { #[test] fn partial_indices_should_be_ignored() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "age": 9}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder() .unique(Some(false)) @@ -1277,7 +1277,7 @@ fn partial_indices_should_be_ignored() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1296,11 +1296,11 @@ fn partial_indices_should_be_ignored() { #[test] fn partial_composite_indices_should_be_ignored() { let res = introspect(|db| async move { - db.create_collection("Cat", None).await?; + db.create_collection("Cat").await?; let collection = db.collection("Cat"); let docs = vec![doc! {"name": "Musti", "info": { "age": 9 }}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder() .unique(Some(false)) @@ -1312,7 +1312,7 @@ fn partial_composite_indices_should_be_ignored() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1335,11 +1335,11 @@ fn partial_composite_indices_should_be_ignored() { #[test] fn index_pointing_to_non_existing_field_should_add_the_field() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti"}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -1348,7 +1348,7 @@ fn index_pointing_to_non_existing_field_should_add_the_field() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1379,11 +1379,11 @@ fn index_pointing_to_non_existing_field_should_add_the_field() { #[test] fn index_pointing_to_non_existing_composite_field_should_add_the_field_and_type() { let res = introspect(|db| async move { - db.create_collection("Cat", None).await?; + db.create_collection("Cat").await?; let collection = db.collection("Cat"); let docs = vec![doc! {"name": "Musti"}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -1392,7 +1392,7 @@ fn index_pointing_to_non_existing_composite_field_should_add_the_field_and_type( .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1431,11 +1431,11 @@ fn index_pointing_to_non_existing_composite_field_should_add_the_field_and_type( #[test] fn deep_index_pointing_to_non_existing_composite_field_should_add_the_field_and_type() { let res = introspect(|db| async move { - db.create_collection("Cat", None).await?; + db.create_collection("Cat").await?; let collection = db.collection("Cat"); let docs = vec![doc! {"name": "Musti"}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -1444,7 +1444,7 @@ fn deep_index_pointing_to_non_existing_composite_field_should_add_the_field_and_ .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1489,11 +1489,11 @@ fn deep_index_pointing_to_non_existing_composite_field_should_add_the_field_and_ #[test] fn index_pointing_to_mapped_non_existing_field_should_add_the_mapped_field() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti"}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -1502,7 +1502,7 @@ fn index_pointing_to_mapped_non_existing_field_should_add_the_mapped_field() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1533,11 +1533,11 @@ fn index_pointing_to_mapped_non_existing_field_should_add_the_mapped_field() { #[test] fn composite_index_pointing_to_mapped_non_existing_field_should_add_the_mapped_field() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti"}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -1546,7 +1546,7 @@ fn composite_index_pointing_to_mapped_non_existing_field_should_add_the_mapped_f .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1585,11 +1585,11 @@ fn composite_index_pointing_to_mapped_non_existing_field_should_add_the_mapped_f #[test] fn compound_index_pointing_to_non_existing_field_should_add_the_field() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti"}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -1598,7 +1598,7 @@ fn compound_index_pointing_to_non_existing_field_should_add_the_field() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1632,11 +1632,11 @@ fn compound_index_pointing_to_non_existing_field_should_add_the_field() { #[test] fn composite_index_with_one_existing_field_should_add_missing_stuff_only() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "info": { "age": 9 } }]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -1645,7 +1645,7 @@ fn composite_index_with_one_existing_field_should_add_missing_stuff_only() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1681,11 +1681,11 @@ fn composite_index_with_one_existing_field_should_add_missing_stuff_only() { #[test] fn deep_composite_index_with_one_existing_field_should_add_missing_stuff_only() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "info": { "age": 9 } }]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -1694,7 +1694,7 @@ fn deep_composite_index_with_one_existing_field_should_add_missing_stuff_only() .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1736,11 +1736,11 @@ fn deep_composite_index_with_one_existing_field_should_add_missing_stuff_only() #[test] fn deep_composite_index_with_one_existing_field_should_add_missing_stuff_only_2() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "info": { "special": { "age": 9 } } }]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -1749,7 +1749,7 @@ fn deep_composite_index_with_one_existing_field_should_add_missing_stuff_only_2( .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1789,11 +1789,11 @@ fn deep_composite_index_with_one_existing_field_should_add_missing_stuff_only_2( #[test] fn deep_composite_index_should_add_missing_stuff_in_different_layers() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! { "name": "Musti" }]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -1802,7 +1802,7 @@ fn deep_composite_index_should_add_missing_stuff_in_different_layers() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1850,11 +1850,11 @@ fn deep_composite_index_should_add_missing_stuff_in_different_layers() { #[test] fn compound_index_with_one_existing_field_pointing_to_non_existing_field_should_add_the_field() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti", "age": 9}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(false)).build(); @@ -1863,7 +1863,7 @@ fn compound_index_with_one_existing_field_pointing_to_non_existing_field_should_ .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1895,11 +1895,11 @@ fn compound_index_with_one_existing_field_pointing_to_non_existing_field_should_ #[test] fn unique_index_pointing_to_non_existing_field_should_add_the_field() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti"}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(true)).build(); @@ -1908,7 +1908,7 @@ fn unique_index_pointing_to_non_existing_field_should_add_the_field() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1937,11 +1937,11 @@ fn unique_index_pointing_to_non_existing_field_should_add_the_field() { #[test] fn fulltext_index_pointing_to_non_existing_field_should_add_the_field() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! {"name": "Musti"}]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); let options = IndexOptions::builder().unique(Some(true)).build(); @@ -1950,7 +1950,7 @@ fn fulltext_index_pointing_to_non_existing_field_should_add_the_field() { .options(Some(options)) .build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -1979,20 +1979,20 @@ fn fulltext_index_pointing_to_non_existing_field_should_add_the_field() { #[test] fn composite_type_index_without_corresponding_data_should_not_crash() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection::("A"); let model = IndexModel::builder().keys(doc! { "foo": 1 }).build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; let model = IndexModel::builder().keys(doc! { "foo.bar": 1 }).build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; let model = IndexModel::builder().keys(doc! { "foo.baz.quux": 1 }).build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; Ok(()) }); @@ -2027,14 +2027,14 @@ fn composite_type_index_without_corresponding_data_should_not_crash() { #[test] fn composite_type_index_with_non_composite_fields_in_the_middle_should_not_crash() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection::("A"); let model = IndexModel::builder().keys(doc! { "a.b.c": 1 }).build(); - collection.create_index(model, None).await?; + collection.create_index(model).await?; let docs = vec![doc! { "a": { "b": 1, "d": { "c": 1 } } }]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); Ok(()) }); diff --git a/schema-engine/connectors/mongodb-schema-connector/tests/introspection/model_renames/mod.rs b/schema-engine/connectors/mongodb-schema-connector/tests/introspection/model_renames/mod.rs index d0e1767d9f94..c0283e34fe11 100644 --- a/schema-engine/connectors/mongodb-schema-connector/tests/introspection/model_renames/mod.rs +++ b/schema-engine/connectors/mongodb-schema-connector/tests/introspection/model_renames/mod.rs @@ -1,12 +1,12 @@ use crate::introspection::test_api::*; -use mongodb::bson::doc; +use bson::doc; #[test] fn a_model_with_reserved_name() { let res = introspect(|db| async move { - db.create_collection("PrismaClient", None).await.unwrap(); + db.create_collection("PrismaClient").await.unwrap(); db.collection("PrismaClient") - .insert_one(doc! {"data": 1}, None) + .insert_one(doc! {"data": 1}) .await .unwrap(); @@ -29,9 +29,9 @@ fn a_model_with_reserved_name() { #[test] fn reserved_names_case_sensitivity() { let res = introspect(|db| async move { - db.create_collection("prismalclient", None).await.unwrap(); + db.create_collection("prismalclient").await.unwrap(); db.collection("prismalclient") - .insert_one(doc! {"data": 1}, None) + .insert_one(doc! {"data": 1}) .await .unwrap(); diff --git a/schema-engine/connectors/mongodb-schema-connector/tests/introspection/multi_file/mod.rs b/schema-engine/connectors/mongodb-schema-connector/tests/introspection/multi_file/mod.rs index eb6d143f526c..12b1506e3bd1 100644 --- a/schema-engine/connectors/mongodb-schema-connector/tests/introspection/multi_file/mod.rs +++ b/schema-engine/connectors/mongodb-schema-connector/tests/introspection/multi_file/mod.rs @@ -1,5 +1,5 @@ use crate::introspection::test_api::*; -use mongodb::bson::doc; +use bson::doc; // Composite types // reintrospect_removed_model_single_file @@ -677,22 +677,19 @@ fn reintrospect_empty_multi_file() { async fn seed_model(name: &str, api: &TestApi) -> Result<(), mongodb::error::Error> { let db = &api.db; - db.create_collection(name, None).await?; + db.create_collection(name).await?; let collection = db.collection(name); - collection.insert_many(vec![doc! {"name": "John"}], None).await.unwrap(); + collection.insert_many(vec![doc! {"name": "John"}]).await.unwrap(); Ok(()) } async fn seed_composite(name: &str, api: &TestApi) -> Result<(), mongodb::error::Error> { let db = &api.db; - db.create_collection(name, None).await?; + db.create_collection(name).await?; let collection = db.collection(name); collection - .insert_many( - vec![doc! {"identity": { "firstName": "John", "lastName": "Doe" }}], - None, - ) + .insert_many(vec![doc! {"identity": { "firstName": "John", "lastName": "Doe" }}]) .await .unwrap(); diff --git a/schema-engine/connectors/mongodb-schema-connector/tests/introspection/remapping_names/mod.rs b/schema-engine/connectors/mongodb-schema-connector/tests/introspection/remapping_names/mod.rs index 69223f2e6788..f08ed71ed655 100644 --- a/schema-engine/connectors/mongodb-schema-connector/tests/introspection/remapping_names/mod.rs +++ b/schema-engine/connectors/mongodb-schema-connector/tests/introspection/remapping_names/mod.rs @@ -1,26 +1,23 @@ use crate::introspection::test_api::*; -use mongodb::bson::doc; +use bson::doc; #[test] fn remapping_fields_with_invalid_characters() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; db.collection("A") - .insert_one( - doc! { - "_a": 1, - "*b": 2, - "?c": 3, - "(d": 4, - ")e": 5, - "/f": 6, - "g a": 7, - "h-a": 8, - "h1": 9, - }, - None, - ) + .insert_one(doc! { + "_a": 1, + "*b": 2, + "?c": 3, + "(d": 4, + ")e": 5, + "/f": 6, + "g a": 7, + "h-a": 8, + "h1": 9, + }) .await?; Ok(()) @@ -47,8 +44,8 @@ fn remapping_fields_with_invalid_characters() { #[test] fn remapping_models_with_invalid_characters() { let res = introspect(|db| async move { - db.create_collection("?A", None).await?; - db.create_collection("A b c", None).await?; + db.create_collection("?A").await?; + db.create_collection("A b c").await?; Ok(()) }); @@ -74,14 +71,11 @@ fn remapping_models_with_invalid_characters() { fn remapping_composite_fields_with_numbers() { let res = introspect(|db| async move { db.collection("Outer") - .insert_one( - doc! { - "inner": { - "1": 1, - }, + .insert_one(doc! { + "inner": { + "1": 1, }, - None, - ) + }) .await?; Ok(()) @@ -115,12 +109,9 @@ fn remapping_composite_fields_with_numbers() { fn remapping_model_fields_with_numbers() { let res = introspect(|db| async move { db.collection("Outer") - .insert_one( - doc! { - "1": 1, - }, - None, - ) + .insert_one(doc! { + "1": 1, + }) .await?; Ok(()) @@ -150,7 +141,7 @@ fn remapping_model_fields_with_numbers() { fn remapping_model_fields_with_numbers_dirty() { let res = introspect(|db| async move { let docs = vec![doc! {"1": "Musti"}, doc! {"1": 1}]; - db.collection("Outer").insert_many(docs, None).await.unwrap(); + db.collection("Outer").insert_many(docs).await.unwrap(); Ok(()) }); diff --git a/schema-engine/connectors/mongodb-schema-connector/tests/introspection/test_api/mod.rs b/schema-engine/connectors/mongodb-schema-connector/tests/introspection/test_api/mod.rs index 33b2e1ed8d48..bad7e9b8a24c 100644 --- a/schema-engine/connectors/mongodb-schema-connector/tests/introspection/test_api/mod.rs +++ b/schema-engine/connectors/mongodb-schema-connector/tests/introspection/test_api/mod.rs @@ -71,8 +71,6 @@ impl From for TestMultiResult { } pub struct TestApi { - pub connection_string: String, - pub database_name: String, pub db: Database, pub features: BitFlags, pub connector: MongoDbSchemaConnector, @@ -122,8 +120,6 @@ where let connector = MongoDbSchemaConnector::new(params); let api = TestApi { - connection_string, - database_name, db: database.clone(), features: preview_features, connector, @@ -131,7 +127,7 @@ where let res = setup(api).await; - database.drop(None).await.unwrap(); + database.drop().await.unwrap(); res }) diff --git a/schema-engine/connectors/mongodb-schema-connector/tests/introspection/types/composite.rs b/schema-engine/connectors/mongodb-schema-connector/tests/introspection/types/composite.rs index 20e7e13bdb12..3b105ba52603 100644 --- a/schema-engine/connectors/mongodb-schema-connector/tests/introspection/types/composite.rs +++ b/schema-engine/connectors/mongodb-schema-connector/tests/introspection/types/composite.rs @@ -1,5 +1,5 @@ use crate::introspection::test_api::*; -use mongodb::bson::{doc, oid::ObjectId, Bson}; +use bson::{doc, oid::ObjectId, Bson}; use schema_connector::CompositeTypeDepth; #[test] @@ -10,7 +10,7 @@ fn singular() { doc! { "name": "Naukio", "address": { "street": "Meowstrasse", "number": 123, "knock": true }}, ]; - db.collection("Cat").insert_many(docs, None).await?; + db.collection("Cat").insert_many(docs).await?; Ok(()) }); @@ -41,7 +41,7 @@ fn dirty_data() { doc! { "name": "Bob", "address": { "street": "Kantstrasse", "number": "123" }}, ]; - db.collection("Cat").insert_many(docs, None).await?; + db.collection("Cat").insert_many(docs).await?; Ok(()) }); @@ -80,7 +80,7 @@ fn array() { { "title": "Hello, world!", "content": "Like, whatever...", "published": true }, ]}]; - db.collection("Blog").insert_many(docs, None).await?; + db.collection("Blog").insert_many(docs).await?; Ok(()) }); @@ -110,7 +110,7 @@ fn deep_array() { [{ "title": "Hello, world!", "content": "Like, whatever...", "published": true }], ]}]; - db.collection("Blog").insert_many(docs, None).await?; + db.collection("Blog").insert_many(docs).await?; Ok(()) }); @@ -137,7 +137,7 @@ fn nullability() { doc! {"first": {"foo": 1}, "second": {"foo": 1}}, ]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); Ok(()) }); @@ -173,7 +173,7 @@ fn unsupported() { vec![doc! { "dataType": "Code", "data": { "code": Bson::JavaScriptCode("let a = 1 + 1;".to_string()) }}]; db.collection("FrontendEngineerWritesBackendCode") - .insert_many(docs, None) + .insert_many(docs) .await?; Ok(()) @@ -198,7 +198,7 @@ fn unsupported() { fn underscores_in_names() { let res = introspect(|db| async move { let docs = vec![doc! { "name": "Musti", "home_address": { "street": "Meowstrasse", "number": 123 }}]; - db.collection("Cat").insert_many(docs, None).await?; + db.collection("Cat").insert_many(docs).await?; Ok(()) }); @@ -222,7 +222,7 @@ fn underscores_in_names() { fn depth_none() { let res = introspect_depth(CompositeTypeDepth::None, |db| async move { let docs = vec![doc! { "name": "Musti", "home_address": { "street": "Meowstrasse", "number": 123 }}]; - db.collection("Cat").insert_many(docs, None).await?; + db.collection("Cat").insert_many(docs).await?; Ok(()) }); @@ -241,7 +241,7 @@ fn depth_none() { fn depth_none_level_1_array() { let res = introspect_depth(CompositeTypeDepth::None, |db| async move { let docs = vec![doc! { "name": "Musti", "home_address": [{ "street": "Meowstrasse", "number": 123 }]}]; - db.collection("Cat").insert_many(docs, None).await?; + db.collection("Cat").insert_many(docs).await?; Ok(()) }); @@ -260,7 +260,7 @@ fn depth_none_level_1_array() { fn depth_1_level_1() { let res = introspect_depth(CompositeTypeDepth::Level(1), |db| async move { let docs = vec![doc! { "name": "Musti", "home_address": { "street": "Meowstrasse", "number": 123 }}]; - db.collection("Cat").insert_many(docs, None).await?; + db.collection("Cat").insert_many(docs).await?; Ok(()) }); @@ -286,7 +286,7 @@ fn depth_1_level_2() { let docs = vec![ doc! { "name": "Musti", "home_address": { "street": "Meowstrasse", "number": 123, "data": { "something": "other" } } }, ]; - db.collection("Cat").insert_many(docs, None).await?; + db.collection("Cat").insert_many(docs).await?; Ok(()) }); @@ -313,7 +313,7 @@ fn depth_1_level_2_array() { let docs = vec![ doc! { "name": "Musti", "home_address": [{ "street": "Meowstrasse", "number": 123, "data": [{ "something": "other" }] }] }, ]; - db.collection("Cat").insert_many(docs, None).await?; + db.collection("Cat").insert_many(docs).await?; Ok(()) }); @@ -340,7 +340,7 @@ fn depth_2_level_2_array() { let docs = vec![ doc! { "name": "Musti", "home_address": [{ "street": "Meowstrasse", "number": 123, "data": [{ "something": "other" }] }] }, ]; - db.collection("Cat").insert_many(docs, None).await?; + db.collection("Cat").insert_many(docs).await?; Ok(()) }); @@ -369,10 +369,10 @@ fn depth_2_level_2_array() { fn name_clashes() { let res = introspect(|db| async move { let docs = vec![doc! { "name": "Musti", "address": { "street": "Meowstrasse", "number": 123 }}]; - db.collection("Cat").insert_many(docs, None).await?; + db.collection("Cat").insert_many(docs).await?; let docs = vec![doc! { "knock": false, "number": 420, "street": "Meowstrasse" }]; - db.collection("CatAddress").insert_many(docs, None).await?; + db.collection("CatAddress").insert_many(docs).await?; Ok(()) }); @@ -407,7 +407,7 @@ fn non_id_object_ids() { doc! { "non_id_object_id": Bson::ObjectId(ObjectId::new()), "data": {"non_id_object_id": Bson::ObjectId(ObjectId::new())}}, ]; - db.collection("Test").insert_many(docs, None).await?; + db.collection("Test").insert_many(docs).await?; Ok(()) }); @@ -432,7 +432,7 @@ fn fields_named_id_in_composite() { let res = introspect(|db| async move { let docs = vec![doc! {"id": "test","data": {"id": "test"}, "data2": {"_id": "test", "id": "test"}}]; - db.collection("Test").insert_many(docs, None).await?; + db.collection("Test").insert_many(docs).await?; Ok(()) }); @@ -463,7 +463,7 @@ fn do_not_create_empty_types() { let res = introspect(|db| async move { let docs = vec![doc! { "data": {} }, doc! {}]; - db.collection("Test").insert_many(docs, None).await?; + db.collection("Test").insert_many(docs).await?; Ok(()) }); @@ -492,7 +492,7 @@ fn do_not_create_empty_types() { fn do_not_spam_empty_type_warnings() { let res = introspect(|db| async move { let docs = vec![doc! { "data": {} }, doc! {}, doc! { "data": {} }, doc! { "data": {} }]; - db.collection("Test").insert_many(docs, None).await?; + db.collection("Test").insert_many(docs).await?; Ok(()) }); @@ -522,7 +522,7 @@ fn do_not_create_empty_types_in_types() { let res = introspect(|db| async move { let docs = vec![doc! { "tost": { "data": {} } }]; - db.collection("Test").insert_many(docs, None).await?; + db.collection("Test").insert_many(docs).await?; Ok(()) }); @@ -557,7 +557,7 @@ fn no_empty_type_warnings_when_depth_is_reached() { let res = introspect_depth(depth, |db| async move { let docs = vec![doc! { "data": {} }, doc! {}]; - db.collection("Test").insert_many(docs, None).await?; + db.collection("Test").insert_many(docs).await?; Ok(()) }); @@ -584,7 +584,7 @@ fn kanji() { doc! { "name": "Naukio", "推荐点RichText": { "street": "Meowstrasse", "number": 123, "knock": true }}, ]; - db.collection("TheCollectionName").insert_many(docs, None).await?; + db.collection("TheCollectionName").insert_many(docs).await?; Ok(()) }); diff --git a/schema-engine/connectors/mongodb-schema-connector/tests/introspection/types/mod.rs b/schema-engine/connectors/mongodb-schema-connector/tests/introspection/types/mod.rs index d32b4a26367e..68d12764becb 100644 --- a/schema-engine/connectors/mongodb-schema-connector/tests/introspection/types/mod.rs +++ b/schema-engine/connectors/mongodb-schema-connector/tests/introspection/types/mod.rs @@ -1,12 +1,12 @@ mod composite; use crate::introspection::test_api::*; -use mongodb::bson::{doc, oid::ObjectId, Binary, Bson, DateTime, Decimal128, Timestamp}; +use bson::{doc, oid::ObjectId, Binary, Bson, DateTime, Decimal128, Timestamp}; #[test] fn string() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![ @@ -15,7 +15,7 @@ fn string() { doc! {"first": "Lol", "second": "Bar"}, ]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); Ok(()) }); @@ -35,7 +35,7 @@ fn string() { #[test] fn double() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![ @@ -44,7 +44,7 @@ fn double() { doc! {"first": Bson::Double(1.23), "second": Bson::Double(2.23)}, ]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); Ok(()) }); @@ -64,7 +64,7 @@ fn double() { #[test] fn bool() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![ @@ -73,7 +73,7 @@ fn bool() { doc! {"first": true, "second": false}, ]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); Ok(()) }); @@ -93,7 +93,7 @@ fn bool() { #[test] fn int() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![ @@ -102,7 +102,7 @@ fn int() { doc! {"first": Bson::Int32(1), "second": Bson::Int32(1)}, ]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); Ok(()) }); @@ -122,7 +122,7 @@ fn int() { #[test] fn bigint() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![ @@ -131,7 +131,7 @@ fn bigint() { doc! {"first": Bson::Int64(1), "second": Bson::Int64(1)}, ]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); Ok(()) }); @@ -151,7 +151,7 @@ fn bigint() { #[test] fn timestamp() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![ @@ -171,7 +171,7 @@ fn timestamp() { }, ]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); Ok(()) }); @@ -191,7 +191,7 @@ fn timestamp() { #[test] fn binary() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let bin = Binary { @@ -216,7 +216,7 @@ fn binary() { }, ]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); Ok(()) }); @@ -236,7 +236,7 @@ fn binary() { #[test] fn object_id() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![ @@ -256,7 +256,7 @@ fn object_id() { }, ]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); Ok(()) }); @@ -276,7 +276,7 @@ fn object_id() { #[test] fn date() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![ @@ -296,7 +296,7 @@ fn date() { }, ]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); Ok(()) }); @@ -316,7 +316,7 @@ fn date() { #[test] fn decimal() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![ @@ -336,7 +336,7 @@ fn decimal() { }, ]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); Ok(()) }); @@ -356,7 +356,7 @@ fn decimal() { #[test] fn array() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![ @@ -376,7 +376,7 @@ fn array() { }, ]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); Ok(()) }); @@ -396,14 +396,14 @@ fn array() { #[test] fn deep_array() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); let docs = vec![doc! { "first": Bson::Array(vec![Bson::Array(vec![Bson::Int32(1)])]), }]; - collection.insert_many(docs, None).await.unwrap(); + collection.insert_many(docs).await.unwrap(); Ok(()) }); @@ -421,11 +421,11 @@ fn deep_array() { #[test] fn empty_arrays() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); collection - .insert_one(doc! { "data": Bson::Array(Vec::new()) }, None) + .insert_one(doc! { "data": Bson::Array(Vec::new()) }) .await .unwrap(); @@ -455,10 +455,10 @@ fn empty_arrays() { #[test] fn unknown_types() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; + db.create_collection("A").await?; let collection = db.collection("A"); - collection.insert_one(doc! { "data": Bson::Null }, None).await.unwrap(); + collection.insert_one(doc! { "data": Bson::Null }).await.unwrap(); Ok(()) }); diff --git a/schema-engine/connectors/mongodb-schema-connector/tests/introspection/views/mod.rs b/schema-engine/connectors/mongodb-schema-connector/tests/introspection/views/mod.rs index fbe2d281f641..4237ead2eea9 100644 --- a/schema-engine/connectors/mongodb-schema-connector/tests/introspection/views/mod.rs +++ b/schema-engine/connectors/mongodb-schema-connector/tests/introspection/views/mod.rs @@ -4,10 +4,9 @@ use mongodb::{bson::doc, options::CreateCollectionOptions}; #[test] fn collection_with_view() { let res = introspect(|db| async move { - db.create_collection("A", None).await?; - db.create_collection( - "myView".to_string(), - Some( + db.create_collection("A").await?; + db.create_collection("myView".to_string()) + .with_options( CreateCollectionOptions::builder() .view_on("A".to_string()) .pipeline(vec![doc! { @@ -19,9 +18,8 @@ fn collection_with_view() { }, }]) .build(), - ), - ) - .await?; + ) + .await?; Ok(()) }); diff --git a/schema-engine/connectors/mongodb-schema-connector/tests/migrations/test_api.rs b/schema-engine/connectors/mongodb-schema-connector/tests/migrations/test_api.rs index eacd5f23b2c4..ba440a9ef29e 100644 --- a/schema-engine/connectors/mongodb-schema-connector/tests/migrations/test_api.rs +++ b/schema-engine/connectors/mongodb-schema-connector/tests/migrations/test_api.rs @@ -1,6 +1,6 @@ +use bson::{self, doc}; use enumflags2::BitFlags; use futures::TryStreamExt; -use mongodb::bson::{self, doc}; use mongodb_schema_connector::MongoDbSchemaConnector; use once_cell::sync::Lazy; use psl::{parser_database::SourceFile, PreviewFeature}; @@ -85,7 +85,7 @@ fn new_connector(preview_features: BitFlags) -> (String, MongoDb } async fn get_state(db: &mongodb::Database) -> State { - let collection_names = db.list_collection_names(None).await.unwrap(); + let collection_names = db.list_collection_names().await.unwrap(); let mut state = State::default(); for collection_name in collection_names { @@ -93,13 +93,13 @@ async fn get_state(db: &mongodb::Database) -> State { let mut documents = Vec::new(); let mut indexes = Vec::new(); - let mut cursor: mongodb::Cursor = collection.find(None, None).await.unwrap(); + let mut cursor: mongodb::Cursor = collection.find(bson::Document::default()).await.unwrap(); while let Some(doc) = cursor.try_next().await.unwrap() { documents.push(doc) } - let mut cursor = collection.list_indexes(None).await.unwrap(); + let mut cursor = collection.list_indexes().await.unwrap(); while let Some(index) = cursor.try_next().await.unwrap() { let options = index.options.unwrap(); @@ -133,11 +133,11 @@ async fn apply_state(db: &mongodb::Database, state: State) { model }); - collection.create_indexes(indexes, None).await.unwrap(); + collection.create_indexes(indexes).await.unwrap(); } if !documents.is_empty() { - collection.insert_many(documents, None).await.unwrap(); + collection.insert_many(documents).await.unwrap(); } } } @@ -177,7 +177,7 @@ pub(crate) fn test_scenario(scenario_name: &str) { let (db_name, mut connector) = new_connector(parsed_schema.configuration.preview_features()); let client = client().await; let db = client.database(&db_name); - db.drop(None).await.unwrap(); + db.drop().await.unwrap(); apply_state(&db, state).await; let from = connector diff --git a/schema-engine/connectors/mongodb-schema-connector/tests/migrations/tests.rs b/schema-engine/connectors/mongodb-schema-connector/tests/migrations/tests.rs index c8ef4de9e870..8298f8d3b8c1 100644 --- a/schema-engine/connectors/mongodb-schema-connector/tests/migrations/tests.rs +++ b/schema-engine/connectors/mongodb-schema-connector/tests/migrations/tests.rs @@ -5,7 +5,7 @@ //! Each test scenario folder must contain two files: //! //! - `state.json` must contain the initial state of the database. See examples and `State` in -//! `test_api.rs` for details. +//! `test_api.rs` for details. //! - `schema.prisma` must be the Prisma schema. //! //! On the first run, a `result` file will also be created. It is a snapshot test, do not edit it diff --git a/schema-engine/connectors/schema-connector/src/introspect_sql.rs b/schema-engine/connectors/schema-connector/src/introspect_sql.rs new file mode 100644 index 000000000000..06bbb34f54b4 --- /dev/null +++ b/schema-engine/connectors/schema-connector/src/introspect_sql.rs @@ -0,0 +1,67 @@ +#[allow(missing_docs)] +#[derive(Debug)] +pub struct IntrospectSqlContext { + pub queries: Vec, + pub force: bool, +} + +#[allow(missing_docs)] +#[derive(Debug)] +pub struct IntrospectSqlQueryInput { + pub name: String, + pub source: String, +} + +#[allow(missing_docs)] +pub struct IntrospectSqlResult { + pub queries: Vec, +} + +#[allow(missing_docs)] +#[derive(Debug)] +pub struct IntrospectSqlQueryOutput { + pub name: String, + pub source: String, + pub documentation: Option, + pub parameters: Vec, + pub result_columns: Vec, +} + +#[allow(missing_docs)] +#[derive(Debug)] +pub struct IntrospectSqlQueryParameterOutput { + pub documentation: Option, + pub name: String, + pub typ: String, + pub nullable: bool, +} + +#[allow(missing_docs)] +#[derive(Debug)] +pub struct IntrospectSqlQueryColumnOutput { + pub name: String, + pub typ: String, + pub nullable: bool, +} + +impl From for IntrospectSqlQueryColumnOutput { + fn from(item: quaint::connector::DescribedColumn) -> Self { + let nullable_override = parse_nullability_override(&item.name); + + Self { + name: item.name, + typ: item.enum_name.unwrap_or_else(|| item.typ.to_string()), + nullable: nullable_override.unwrap_or(item.nullable), + } + } +} + +fn parse_nullability_override(column_name: &str) -> Option { + if column_name.ends_with('?') { + Some(true) + } else if column_name.ends_with('!') { + Some(false) + } else { + None + } +} diff --git a/schema-engine/connectors/schema-connector/src/lib.rs b/schema-engine/connectors/schema-connector/src/lib.rs index ad3e836df9e2..3635e66fa379 100644 --- a/schema-engine/connectors/schema-connector/src/lib.rs +++ b/schema-engine/connectors/schema-connector/src/lib.rs @@ -9,6 +9,7 @@ mod database_schema; mod destructive_change_checker; mod diff; mod error; +mod introspect_sql; mod introspection_context; mod introspection_result; mod migration; @@ -29,6 +30,7 @@ pub use destructive_change_checker::{ }; pub use diff::DiffTarget; pub use error::{ConnectorError, ConnectorResult}; +pub use introspect_sql::*; pub use introspection_context::{CompositeTypeDepth, IntrospectionContext}; pub use introspection_result::{IntrospectionResult, ViewDefinition}; pub use migration::Migration; diff --git a/schema-engine/connectors/schema-connector/src/schema_connector.rs b/schema-engine/connectors/schema-connector/src/schema_connector.rs index ddef88400081..0aacf3fd6aaa 100644 --- a/schema-engine/connectors/schema-connector/src/schema_connector.rs +++ b/schema-engine/connectors/schema-connector/src/schema_connector.rs @@ -5,8 +5,8 @@ use psl::ValidatedSchema; use crate::{ migrations_directory::MigrationDirectory, BoxFuture, ConnectorHost, ConnectorParams, ConnectorResult, - DatabaseSchema, DestructiveChangeChecker, DestructiveChangeDiagnostics, DiffTarget, IntrospectionContext, - IntrospectionResult, Migration, MigrationPersistence, Namespaces, + DatabaseSchema, DestructiveChangeChecker, DestructiveChangeDiagnostics, DiffTarget, IntrospectSqlQueryInput, + IntrospectSqlQueryOutput, IntrospectionContext, IntrospectionResult, Migration, MigrationPersistence, Namespaces, }; /// The top-level trait for connectors. This is the abstraction the schema engine core relies on to @@ -133,6 +133,12 @@ pub trait SchemaConnector: Send + Sync + 'static { ctx: &'a IntrospectionContext, ) -> BoxFuture<'a, ConnectorResult>; + /// Introspect queries and returns type information. + fn introspect_sql( + &mut self, + input: IntrospectSqlQueryInput, + ) -> BoxFuture<'_, ConnectorResult>; + /// If possible, check that the passed in migrations apply cleanly. fn validate_migrations<'a>( &'a mut self, diff --git a/schema-engine/connectors/sql-schema-connector/Cargo.toml b/schema-engine/connectors/sql-schema-connector/Cargo.toml index ae88c65279d4..a336495e7436 100644 --- a/schema-engine/connectors/sql-schema-connector/Cargo.toml +++ b/schema-engine/connectors/sql-schema-connector/Cargo.toml @@ -34,12 +34,16 @@ chrono.workspace = true connection-string.workspace = true enumflags2.workspace = true once_cell = "1.3" -regex = "1" +regex.workspace = true serde_json.workspace = true tracing.workspace = true -tracing-futures = "0.2" +tracing-futures.workspace = true url.workspace = true either = "1.6" sqlformat = "0.2.1" sqlparser = "0.32.0" versions = "6.1.0" +sqlx-sqlite = { version = "0.8.0" } +sqlx-core = "0.8.0" +[dev-dependencies] +expect-test = "1" diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour.rs b/schema-engine/connectors/sql-schema-connector/src/flavour.rs index 25b3bd3c605e..3b2e901bcd67 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour.rs @@ -175,6 +175,11 @@ pub(crate) trait SqlFlavour: self.describe_schema(namespaces) } + fn describe_query<'a>( + &'a mut self, + sql: &'a str, + ) -> BoxFuture<'a, ConnectorResult>; + fn load_migrations_table( &mut self, ) -> BoxFuture<'_, ConnectorResult, PersistenceNotInitializedError>>> { diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/mssql.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/mssql.rs index bd59c51de8ad..92843aae628e 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/mssql.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/mssql.rs @@ -501,6 +501,13 @@ impl SqlFlavour for MssqlFlavour { fn search_path(&self) -> &str { self.schema_name() } + + fn describe_query<'a>( + &'a mut self, + _sql: &str, + ) -> BoxFuture<'a, ConnectorResult> { + unimplemented!("SQL Server does not support describe_query yet.") + } } fn with_connection<'a, O, F, C>(state: &'a mut State, f: C) -> BoxFuture<'a, ConnectorResult> diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/mysql.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/mysql.rs index 148387bf43c4..e169d825e9d8 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/mysql.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/mysql.rs @@ -405,6 +405,15 @@ impl SqlFlavour for MysqlFlavour { fn search_path(&self) -> &str { self.database_name() } + + fn describe_query<'a>( + &'a mut self, + sql: &'a str, + ) -> BoxFuture<'a, ConnectorResult> { + with_connection(&mut self.state, move |conn_params, circumstances, conn| { + conn.describe_query(sql, &conn_params.url, circumstances) + }) + } } #[enumflags2::bitflags] diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/mysql/connection.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/mysql/connection.rs index 34fe9ef42fdd..fd470fc298d1 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/mysql/connection.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/mysql/connection.rs @@ -7,7 +7,7 @@ use quaint::{ mysql_async::{self as my, prelude::Query}, MysqlUrl, }, - prelude::{ConnectionInfo, NativeConnectionInfo, Queryable}, + prelude::{ColumnType, ConnectionInfo, NativeConnectionInfo, Queryable}, }; use schema_connector::{ConnectorError, ConnectorResult}; use sql_schema_describer::{DescriberErrorKind, SqlSchema}; @@ -115,6 +115,30 @@ impl Connection { self.0.query_raw(sql, params).await.map_err(quaint_err(url)) } + pub(super) async fn describe_query( + &self, + sql: &str, + url: &MysqlUrl, + circumstances: BitFlags, + ) -> ConnectorResult { + tracing::debug!(query_type = "describe_query", sql); + let mut parsed = self.0.describe_query(sql).await.map_err(quaint_err(url))?; + + if circumstances.contains(super::Circumstances::IsMysql56) + || circumstances.contains(super::Circumstances::IsMysql57) + { + parsed.parameters = parsed + .parameters + .into_iter() + .map(|p| p.set_typ(ColumnType::Unknown)) + .collect(); + + return Ok(parsed); + } + + Ok(parsed) + } + pub(super) async fn apply_migration_script( &mut self, migration_name: &str, diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs index b67adcb4d0c2..02752e491eeb 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs @@ -5,12 +5,17 @@ use self::connection::*; use crate::SqlFlavour; use enumflags2::BitFlags; use indoc::indoc; -use quaint::{connector::PostgresUrl, Value}; +use once_cell::sync::Lazy; +use quaint::{ + connector::{PostgresUrl, PostgresWebSocketUrl}, + prelude::NativeConnectionInfo, + Value, +}; use schema_connector::{ migrations_directory::MigrationDirectory, BoxFuture, ConnectorError, ConnectorParams, ConnectorResult, Namespaces, }; use sql_schema_describer::SqlSchema; -use std::{borrow::Cow, collections::HashMap, future, time}; +use std::{borrow::Cow, collections::HashMap, future, str::FromStr, time}; use url::Url; use user_facing_errors::{ common::{DatabaseAccessDenied, DatabaseDoesNotExist}, @@ -28,9 +33,70 @@ SET enable_experimental_alter_column_type_general = true; type State = super::State, Connection)>; +#[derive(Debug, Clone)] +struct MigratePostgresUrl(PostgresUrl); + +static MIGRATE_WS_BASE_URL: Lazy> = Lazy::new(|| { + std::env::var("PRISMA_SCHEMA_ENGINE_WS_BASE_URL") + .map(Cow::Owned) + .unwrap_or_else(|_| Cow::Borrowed("wss://migrations.prisma-data.net/websocket")) +}); + +impl MigratePostgresUrl { + const WEBSOCKET_SCHEME: &'static str = "prisma+postgres"; + const API_KEY_PARAM: &'static str = "api_key"; + const DBNAME_PARAM: &'static str = "dbname"; + + fn new(url: Url) -> ConnectorResult { + let postgres_url = if url.scheme() == Self::WEBSOCKET_SCHEME { + let ws_url = Url::from_str(&MIGRATE_WS_BASE_URL).map_err(ConnectorError::url_parse_error)?; + let Some((_, api_key)) = url.query_pairs().find(|(name, _)| name == Self::API_KEY_PARAM) else { + return Err(ConnectorError::url_parse_error( + "Required `api_key` query string parameter was not provided in a connection URL", + )); + }; + + let dbname_override = url.query_pairs().find(|(name, _)| name == Self::DBNAME_PARAM); + let mut ws_url = PostgresWebSocketUrl::new(ws_url, api_key.into_owned()); + if let Some((_, dbname_override)) = dbname_override { + ws_url.override_db_name(dbname_override.into_owned()); + } + + Ok(PostgresUrl::WebSocket(ws_url)) + } else { + PostgresUrl::new_native(url) + } + .map_err(ConnectorError::url_parse_error)?; + + Ok(Self(postgres_url)) + } + + pub(super) fn host(&self) -> &str { + self.0.host() + } + + pub(super) fn port(&self) -> u16 { + self.0.port() + } + + pub(super) fn dbname(&self) -> &str { + self.0.dbname() + } + + pub(super) fn schema(&self) -> &str { + self.0.schema() + } +} + +impl From for NativeConnectionInfo { + fn from(value: MigratePostgresUrl) -> Self { + NativeConnectionInfo::Postgres(value.0) + } +} + struct Params { connector_params: ConnectorParams, - url: PostgresUrl, + url: MigratePostgresUrl, } /// The specific provider that was requested by the user. @@ -221,6 +287,15 @@ impl SqlFlavour for PostgresFlavour { }) } + fn describe_query<'a>( + &'a mut self, + sql: &'a str, + ) -> BoxFuture<'a, ConnectorResult> { + with_connection(self, move |conn_params, _, conn| { + conn.describe_query(sql, &conn_params.url) + }) + } + fn apply_migration_script<'a>( &'a mut self, migration_name: &'a str, @@ -369,7 +444,7 @@ impl SqlFlavour for PostgresFlavour { .map_err(ConnectorError::url_parse_error)?; disable_postgres_statement_cache(&mut url)?; let connection_string = url.to_string(); - let url = PostgresUrl::new(url).map_err(ConnectorError::url_parse_error)?; + let url = MigratePostgresUrl::new(url)?; connector_params.connection_string = connection_string; let params = Params { connector_params, url }; self.state.set_params(params); @@ -451,7 +526,14 @@ impl SqlFlavour for PostgresFlavour { .connection_string .parse() .map_err(ConnectorError::url_parse_error)?; - shadow_database_url.set_path(&format!("/{shadow_database_name}")); + + if shadow_database_url.scheme() == MigratePostgresUrl::WEBSOCKET_SCHEME { + shadow_database_url + .query_pairs_mut() + .append_pair(MigratePostgresUrl::DBNAME_PARAM, &shadow_database_name); + } else { + shadow_database_url.set_path(&format!("/{shadow_database_name}")); + } let shadow_db_params = ConnectorParams { connection_string: shadow_database_url.to_string(), preview_features: params.connector_params.preview_features, @@ -501,7 +583,11 @@ impl SqlFlavour for PostgresFlavour { /// TL;DR, /// 1. pg >= 13 -> it works. /// 2. pg < 13 -> syntax error on WITH (FORCE), and then fail with db in use if pgbouncer is used. -async fn drop_db_try_force(conn: &mut Connection, url: &PostgresUrl, database_name: &str) -> ConnectorResult<()> { +async fn drop_db_try_force( + conn: &mut Connection, + url: &MigratePostgresUrl, + database_name: &str, +) -> ConnectorResult<()> { let drop_database = format!("DROP DATABASE IF EXISTS \"{database_name}\" WITH (FORCE)"); if let Err(err) = conn.raw_cmd(&drop_database, url).await { if let Some(msg) = err.message() { @@ -528,7 +614,7 @@ fn strip_schema_param_from_url(url: &mut Url) { /// Try to connect as an admin to a postgres database. We try to pick a default database from which /// we can create another database. -async fn create_postgres_admin_conn(mut url: Url) -> ConnectorResult<(Connection, PostgresUrl)> { +async fn create_postgres_admin_conn(mut url: Url) -> ConnectorResult<(Connection, MigratePostgresUrl)> { // "postgres" is the default database on most postgres installations, // "template1" is guaranteed to exist, and "defaultdb" is the only working // option on DigitalOcean managed postgres databases. @@ -538,7 +624,7 @@ async fn create_postgres_admin_conn(mut url: Url) -> ConnectorResult<(Connection for database_name in CANDIDATE_DEFAULT_DATABASES { url.set_path(&format!("/{database_name}")); - let postgres_url = PostgresUrl::new(url.clone()).unwrap(); + let postgres_url = MigratePostgresUrl(PostgresUrl::new_native(url.clone()).unwrap()); match Connection::new(url.clone()).await { // If the database does not exist, try the next one. Err(err) => match &err.error_code() { diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs index c3bceb6fb381..3a8f9fb6517a 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs @@ -4,8 +4,8 @@ use enumflags2::BitFlags; use indoc::indoc; use psl::PreviewFeature; use quaint::{ - connector::{self, tokio_postgres::error::ErrorPosition, PostgresUrl}, - prelude::{ConnectionInfo, NativeConnectionInfo, Queryable}, + connector::{self, tokio_postgres::error::ErrorPosition, MakeTlsConnectorManager, PostgresUrl}, + prelude::{ConnectionInfo, Queryable}, }; use schema_connector::{ConnectorError, ConnectorResult, Namespaces}; use sql_schema_describer::{postgres::PostgresSchemaExt, SqlSchema}; @@ -13,19 +13,22 @@ use user_facing_errors::{schema_engine::ApplyMigrationError, schema_engine::Data use crate::sql_renderer::IteratorJoin; +use super::MigratePostgresUrl; + pub(super) struct Connection(connector::PostgreSql); impl Connection { pub(super) async fn new(url: url::Url) -> ConnectorResult { - let url = PostgresUrl::new(url).map_err(|err| { - ConnectorError::user_facing(user_facing_errors::common::InvalidConnectionString { - details: err.to_string(), - }) - })?; + let url = MigratePostgresUrl::new(url)?; - let quaint = connector::PostgreSql::new(url.clone()) - .await - .map_err(quaint_err(&url))?; + let quaint = match url.0 { + PostgresUrl::Native(ref native_url) => { + let tls_manager = MakeTlsConnectorManager::new(native_url.as_ref().clone()); + connector::PostgreSql::new(native_url.as_ref().clone(), &tls_manager).await + } + PostgresUrl::WebSocket(ref ws_url) => connector::PostgreSql::new_with_websocket(ws_url.clone()).await, + } + .map_err(quaint_err(&url))?; let version = quaint.version().await.map_err(quaint_err(&url))?; @@ -116,12 +119,12 @@ impl Connection { Ok(schema) } - pub(super) async fn raw_cmd(&mut self, sql: &str, url: &PostgresUrl) -> ConnectorResult<()> { + pub(super) async fn raw_cmd(&mut self, sql: &str, url: &MigratePostgresUrl) -> ConnectorResult<()> { tracing::debug!(query_type = "raw_cmd", sql); self.0.raw_cmd(sql).await.map_err(quaint_err(url)) } - pub(super) async fn version(&mut self, url: &PostgresUrl) -> ConnectorResult> { + pub(super) async fn version(&mut self, url: &MigratePostgresUrl) -> ConnectorResult> { tracing::debug!(query_type = "version"); self.0.version().await.map_err(quaint_err(url)) } @@ -129,7 +132,7 @@ impl Connection { pub(super) async fn query( &mut self, query: quaint::ast::Query<'_>, - url: &PostgresUrl, + url: &MigratePostgresUrl, ) -> ConnectorResult { use quaint::visitor::Visitor; let (sql, params) = quaint::visitor::Postgres::build(query).unwrap(); @@ -140,12 +143,21 @@ impl Connection { &self, sql: &str, params: &[quaint::prelude::Value<'_>], - url: &PostgresUrl, + url: &MigratePostgresUrl, ) -> ConnectorResult { tracing::debug!(query_type = "query_raw", sql, ?params); self.0.query_raw(sql, params).await.map_err(quaint_err(url)) } + pub(super) async fn describe_query( + &self, + sql: &str, + url: &MigratePostgresUrl, + ) -> ConnectorResult { + tracing::debug!(query_type = "describe_query", sql); + self.0.describe_query(sql).await.map_err(quaint_err(url)) + } + pub(super) async fn apply_migration_script(&mut self, migration_name: &str, script: &str) -> ConnectorResult<()> { tracing::debug!(query_type = "raw_cmd", script); let client = self.0.client(); @@ -228,11 +240,6 @@ fn normalize_sql_schema(schema: &mut SqlSchema, preview_features: BitFlags impl (Fn(quaint::error::Error) -> ConnectorError) + '_ { - |err| { - crate::flavour::quaint_error_to_connector_error( - err, - &ConnectionInfo::Native(NativeConnectionInfo::Postgres(url.clone())), - ) - } +fn quaint_err(url: &MigratePostgresUrl) -> impl (Fn(quaint::error::Error) -> ConnectorError) + '_ { + |err| crate::flavour::quaint_error_to_connector_error(err, &ConnectionInfo::Native(url.clone().into())) } diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite.rs index 6570682f4d95..85bff2133cb0 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite.rs @@ -255,6 +255,16 @@ impl SqlFlavour for SqliteFlavour { ready(with_connection(&mut self.state, |_, conn| conn.query_raw(sql, params))) } + fn describe_query<'a>( + &'a mut self, + sql: &'a str, + ) -> BoxFuture<'a, ConnectorResult> { + tracing::debug!(sql, query_type = "describe_query"); + ready(with_connection(&mut self.state, |params, conn| { + conn.describe_query(sql, params) + })) + } + fn introspect( &mut self, namespaces: Option, diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite/connection.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite/connection.rs index 959ed6de8632..995a86e87c97 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite/connection.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/sqlite/connection.rs @@ -2,9 +2,11 @@ pub(crate) use quaint::connector::rusqlite; -use quaint::connector::{GetRow, ToColumnNames}; +use quaint::connector::{ColumnType, DescribedColumn, DescribedParameter, GetRow, ToColumnNames}; use schema_connector::{ConnectorError, ConnectorResult}; use sql_schema_describer::{sqlite as describer, DescriberErrorKind, SqlSchema}; +use sqlx_core::{column::Column, type_info::TypeInfo}; +use sqlx_sqlite::SqliteColumn; use std::sync::Mutex; use user_facing_errors::schema_engine::ApplyMigrationError; @@ -56,6 +58,7 @@ impl Connection { let conn = self.0.lock().unwrap(); let mut stmt = conn.prepare_cached(sql).map_err(convert_error)?; + let column_types = stmt.columns().iter().map(ColumnType::from).collect::>(); let mut rows = stmt .query(rusqlite::params_from_iter(params.iter())) .map_err(convert_error)?; @@ -65,7 +68,60 @@ impl Connection { converted_rows.push(row.get_result_row().unwrap()); } - Ok(quaint::prelude::ResultSet::new(column_names, converted_rows)) + Ok(quaint::prelude::ResultSet::new( + column_names, + column_types, + converted_rows, + )) + } + + pub(super) fn describe_query( + &mut self, + sql: &str, + params: &super::Params, + ) -> ConnectorResult { + tracing::debug!(query_type = "describe_query", sql); + // SQLite only provides type information for _declared_ column types. That means any expression will not contain type information. + // Sqlx works around this by running an `EXPLAIN` query and inferring types by interpreting sqlite bytecode. + // If you're curious, here's the code: https://github.com/launchbadge/sqlx/blob/16e3f1025ad1e106d1acff05f591b8db62d688e2/sqlx-sqlite/src/connection/explain.rs#L557 + // We use SQLx's as a fallback for when quaint's infers Unknown. + let describe = sqlx_sqlite::describe_blocking(sql, ¶ms.file_path) + .map_err(|err| ConnectorError::from_source(err, "Error describing the query."))?; + let conn = self.0.lock().unwrap(); + let stmt = conn.prepare_cached(sql).map_err(convert_error)?; + + let parameters = (1..=stmt.parameter_count()) + .map(|idx| match stmt.parameter_name(idx) { + Some(name) => { + // SQLite parameter names are prefixed with a colon. We remove it here so that the js doc parser can match the names. + let name = name.strip_prefix(':').unwrap_or(name); + + DescribedParameter::new_named(name, ColumnType::Unknown) + } + None => DescribedParameter::new_unnamed(idx, ColumnType::Unknown), + }) + .collect(); + let columns = stmt + .columns() + .iter() + .zip(&describe.nullable) + .enumerate() + .map(|(idx, (col, nullable))| { + let typ = match ColumnType::from(col) { + // If the column type is unknown, we try to infer it from the describe. + ColumnType::Unknown => describe.column(idx).to_column_type(), + typ => typ, + }; + + DescribedColumn::new_named(col.name(), typ).is_nullable(nullable.unwrap_or(true)) + }) + .collect(); + + Ok(quaint::connector::DescribedQuery { + columns, + parameters, + enum_names: None, + }) } } @@ -97,3 +153,30 @@ pub(super) fn generic_apply_migration_script( fn convert_error(err: rusqlite::Error) -> ConnectorError { ConnectorError::from_source(err, "SQLite database error") } + +trait ToColumnTypeExt { + fn to_column_type(&self) -> ColumnType; +} + +impl ToColumnTypeExt for &SqliteColumn { + fn to_column_type(&self) -> ColumnType { + let ty = self.type_info(); + + match ty.name() { + "NULL" => ColumnType::Null, + "TEXT" => ColumnType::Text, + "REAL" => ColumnType::Double, + "BLOB" => ColumnType::Bytes, + "INTEGER" => ColumnType::Int64, + // Not supported by sqlx-sqlite + "NUMERIC" => ColumnType::Numeric, + + // non-standard extensions + "BOOLEAN" => ColumnType::Boolean, + "DATE" => ColumnType::Date, + "TIME" => ColumnType::Time, + "DATETIME" => ColumnType::DateTime, + _ => ColumnType::Unknown, + } + } +} diff --git a/schema-engine/connectors/sql-schema-connector/src/lib.rs b/schema-engine/connectors/sql-schema-connector/src/lib.rs index 97bbbde409c9..3641eeb96339 100644 --- a/schema-engine/connectors/sql-schema-connector/src/lib.rs +++ b/schema-engine/connectors/sql-schema-connector/src/lib.rs @@ -9,6 +9,7 @@ mod flavour; mod introspection; mod migration_pair; mod sql_destructive_change_checker; +mod sql_doc_parser; mod sql_migration; mod sql_migration_persistence; mod sql_renderer; @@ -19,8 +20,10 @@ use database_schema::SqlDatabaseSchema; use enumflags2::BitFlags; use flavour::{MssqlFlavour, MysqlFlavour, PostgresFlavour, SqlFlavour, SqliteFlavour}; use migration_pair::MigrationPair; -use psl::ValidatedSchema; +use psl::{datamodel_connector::NativeTypeInstance, parser_database::ScalarType, ValidatedSchema}; +use quaint::connector::DescribedQuery; use schema_connector::{migrations_directory::MigrationDirectory, *}; +use sql_doc_parser::{parse_sql_doc, sanitize_sql}; use sql_migration::{DropUserDefinedType, DropView, SqlMigration, SqlMigrationStep}; use sql_schema_describer as sql; use std::{future, sync::Arc}; @@ -149,6 +152,13 @@ impl SqlSchemaConnector { DiffTarget::Empty => Ok(self.flavour.empty_database_schema().into()), } } + + /// Returns the native types that can be used to represent the given scalar type. + pub fn scalar_type_for_native_type(&self, native_type: &NativeTypeInstance) -> ScalarType { + self.flavour + .datamodel_connector() + .scalar_type_for_native_type(native_type) + } } impl SchemaConnector for SqlSchemaConnector { @@ -346,6 +356,55 @@ impl SchemaConnector for SqlSchemaConnector { .collect::>(), ) } + + fn introspect_sql( + &mut self, + input: IntrospectSqlQueryInput, + ) -> BoxFuture<'_, ConnectorResult> { + Box::pin(async move { + let sanitized_sql = sanitize_sql(&input.source); + let DescribedQuery { + parameters, + columns, + enum_names, + } = self.flavour.describe_query(&sanitized_sql).await?; + let enum_names = enum_names.unwrap_or_default(); + let sql_source = input.source.clone(); + let parsed_doc = parse_sql_doc(&sql_source, enum_names.as_slice())?; + + let parameters = parameters + .into_iter() + .zip(1..) + .map(|(param, idx)| { + let parsed_param = parsed_doc + .get_param_at(idx) + .or_else(|| parsed_doc.get_param_by_alias(¶m.name)); + + IntrospectSqlQueryParameterOutput { + typ: parsed_param + .and_then(|p| p.typ()) + .unwrap_or_else(|| param.typ.to_string()), + name: parsed_param + .and_then(|p| p.alias()) + .map(ToOwned::to_owned) + .unwrap_or_else(|| param.name), + documentation: parsed_param.and_then(|p| p.documentation()).map(ToOwned::to_owned), + // Params are required by default unless overridden by sql doc. + nullable: parsed_param.and_then(|p| p.nullable()).unwrap_or(false), + } + }) + .collect(); + let columns = columns.into_iter().map(IntrospectSqlQueryColumnOutput::from).collect(); + + Ok(IntrospectSqlQueryOutput { + name: input.name, + source: sanitized_sql, + documentation: parsed_doc.description().map(ToOwned::to_owned), + parameters, + result_columns: columns, + }) + }) + } } fn new_shadow_database_name() -> String { diff --git a/schema-engine/connectors/sql-schema-connector/src/sql_doc_parser.rs b/schema-engine/connectors/sql-schema-connector/src/sql_doc_parser.rs new file mode 100644 index 000000000000..819621512b30 --- /dev/null +++ b/schema-engine/connectors/sql-schema-connector/src/sql_doc_parser.rs @@ -0,0 +1,1144 @@ +use psl::parser_database::ScalarType; +use quaint::prelude::ColumnType; +use schema_connector::{ConnectorError, ConnectorResult}; + +use crate::sql_renderer::IteratorJoin; + +#[derive(Debug, Default)] +pub(crate) struct ParsedSqlDoc<'a> { + parameters: Vec>, + description: Option<&'a str>, +} + +#[derive(Debug)] +pub enum ParsedParamType<'a> { + ColumnType(ColumnType), + Enum(&'a str), +} + +impl<'a> ParsedSqlDoc<'a> { + fn add_parameter(&mut self, param: ParsedParameterDoc<'a>) -> ConnectorResult<()> { + if self + .parameters + .iter() + .any(|p| p.position == param.position || p.alias == param.alias) + { + return Err(ConnectorError::from_msg( + "duplicate parameter (position or alias is already used)".to_string(), + )); + } + + self.parameters.push(param); + + Ok(()) + } + + fn set_description(&mut self, doc: Option<&'a str>) { + self.description = doc; + } + + pub(crate) fn get_param_at(&self, at: usize) -> Option<&ParsedParameterDoc<'a>> { + self.parameters.iter().find(|p| p.position == Some(at)) + } + + pub(crate) fn get_param_by_alias(&self, alias: &str) -> Option<&ParsedParameterDoc<'a>> { + self.parameters.iter().find(|p| p.alias == Some(alias)) + } + + pub(crate) fn description(&self) -> Option<&str> { + self.description + } +} + +#[derive(Debug, Default)] +pub(crate) struct ParsedParameterDoc<'a> { + alias: Option<&'a str>, + typ: Option>, + nullable: Option, + position: Option, + documentation: Option<&'a str>, +} + +impl<'a> ParsedParameterDoc<'a> { + fn set_alias(&mut self, name: Option<&'a str>) { + self.alias = name; + } + + fn set_typ(&mut self, typ: Option>) { + self.typ = typ; + } + + fn set_nullable(&mut self, nullable: Option) { + self.nullable = nullable; + } + + fn set_position(&mut self, position: Option) { + self.position = position; + } + + fn set_documentation(&mut self, doc: Option<&'a str>) { + self.documentation = doc; + } + + fn is_empty(&self) -> bool { + self.alias.is_none() + && self.position.is_none() + && self.typ.is_none() + && self.documentation.is_none() + && self.nullable.is_none() + } + + pub(crate) fn alias(&self) -> Option<&str> { + self.alias + } + + pub(crate) fn typ(&self) -> Option { + self.typ.as_ref().map(|typ| match typ { + ParsedParamType::ColumnType(ct) => ct.to_string(), + ParsedParamType::Enum(enm) => enm.to_string(), + }) + } + + pub(crate) fn documentation(&self) -> Option<&str> { + self.documentation + } + + pub(crate) fn nullable(&self) -> Option { + self.nullable + } +} + +#[derive(Debug, Clone, Copy)] +struct Input<'a>(&'a str); + +impl<'a> Input<'a> { + fn find(&self, pat: &[char]) -> Option { + self.0.find(pat) + } + + fn strip_prefix_char(&self, pat: char) -> Option { + self.0.strip_prefix(pat).map(Self) + } + + fn strip_prefix_str(&self, pat: &str) -> Option { + self.0.strip_prefix(pat).map(Self) + } + + fn strip_suffix_char(&self, pat: char) -> Option { + self.0.strip_suffix(pat).map(Self) + } + + fn starts_with(&self, pat: &str) -> bool { + self.0.starts_with(pat) + } + + fn is_empty(&self) -> bool { + self.0.is_empty() + } + + fn move_from(&self, n: usize) -> Input<'a> { + Self(&self.0[n..]) + } + + fn move_to(&self, n: usize) -> Input<'a> { + Self(&self.0[..n]) + } + + fn move_between(&self, start: usize, end: usize) -> Input<'a> { + Self(&self.0[start..end]) + } + + fn move_to_end(&self) -> Input<'a> { + Self(&self.0[self.0.len()..]) + } + + fn trim_start(&self) -> Input<'a> { + Self(self.0.trim_start()) + } + + fn trim_end(&self) -> Input<'a> { + Self(self.0.trim_end()) + } + + fn take_until_pattern_or_eol(&self, pattern: &[char]) -> (Input<'a>, Input<'a>) { + if let Some(end) = self.find(pattern) { + (self.move_from(end), self.move_to(end)) + } else { + (self.move_to_end(), *self) + } + } + + fn inner(&self) -> &'a str { + self.0 + } +} + +impl std::fmt::Display for Input<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +#[inline] +fn build_error(input: Input<'_>, msg: &str) -> ConnectorError { + ConnectorError::from_msg(format!("SQL documentation parsing: {msg} at '{input}'.")) +} + +fn render_enum_names(enum_names: &[String]) -> String { + if enum_names.is_empty() { + String::new() + } else { + format!( + ", {enum_names}", + enum_names = enum_names.iter().map(|name| format!("'{name}'")).join(", ") + ) + } +} + +fn parse_typ_opt<'a>( + input: Input<'a>, + enum_names: &'a [String], +) -> ConnectorResult<(Input<'a>, Option>)> { + if let Some(start) = input.find(&['{']) { + if let Some(end) = input.find(&['}']) { + let typ = input.move_between(start + 1, end); + + if typ.is_empty() { + return Err(build_error(input, "missing type (accepted types are: 'Int', 'BigInt', 'Float', 'Boolean', 'String', 'DateTime', 'Json', 'Bytes', 'Decimal')")); + } + + let parsed_typ = ScalarType::try_from_str(typ.inner(), false) + .map(|st| match st { + ScalarType::Int => ColumnType::Int32, + ScalarType::BigInt => ColumnType::Int64, + ScalarType::Float => ColumnType::Float, + ScalarType::Boolean => ColumnType::Boolean, + ScalarType::String => ColumnType::Text, + ScalarType::DateTime => ColumnType::DateTime, + ScalarType::Json => ColumnType::Json, + ScalarType::Bytes => ColumnType::Bytes, + ScalarType::Decimal => ColumnType::Numeric, + }) + .map(ParsedParamType::ColumnType) + .or_else(|| { + enum_names.iter().any(|enum_name| *enum_name == typ.inner()) + .then(|| ParsedParamType::Enum(typ.inner())) + }) + .ok_or_else(|| build_error( + input, + &format!("invalid type: '{typ}' (accepted types are: 'Int', 'BigInt', 'Float', 'Boolean', 'String', 'DateTime', 'Json', 'Bytes', 'Decimal'{})", render_enum_names(enum_names)), + ))?; + + Ok((input.move_from(end + 1), Some(parsed_typ))) + } else { + Err(build_error(input, "missing closing bracket")) + } + } else { + Ok((input, None)) + } +} + +fn parse_position_opt(input: Input<'_>) -> ConnectorResult<(Input<'_>, Option)> { + if let Some((param_input, param_pos)) = input + .trim_start() + .strip_prefix_char('$') + .map(|input| input.take_until_pattern_or_eol(&[':', ' '])) + { + match param_pos.inner().parse::().map_err(|_| { + build_error( + input, + &format!("invalid position. Expected a number found: {param_pos}"), + ) + }) { + Ok(param_pos) => Ok((param_input, Some(param_pos))), + Err(err) => Err(err), + } + } else { + Ok((input, None)) + } +} + +fn parse_alias_opt(input: Input<'_>) -> ConnectorResult<(Input<'_>, Option<&'_ str>, Option)> { + if let Some((input, alias)) = input + .trim_start() + .strip_prefix_char(':') + .map(|input| input.take_until_pattern_or_eol(&[' '])) + { + if let Some(alias) = alias.strip_suffix_char('?') { + Ok((input, Some(alias.inner()), Some(true))) + } else { + Ok((input, Some(alias.inner()), None)) + } + } else { + Ok((input, None, None)) + } +} + +fn parse_rest(input: Input<'_>) -> ConnectorResult> { + let input = input.trim_start(); + + if input.is_empty() { + return Ok(None); + } + + Ok(Some(input.trim_end().inner())) +} + +fn validate_param(param: &ParsedParameterDoc<'_>, input: Input<'_>) -> ConnectorResult<()> { + if param.is_empty() { + return Err(build_error(input, "invalid parameter: could not parse any information")); + } + + if param.position.is_none() && param.alias().is_none() { + return Err(build_error(input, "missing position or alias (eg: $1:alias)")); + } + + Ok(()) +} + +fn parse_param<'a>(param_input: Input<'a>, enum_names: &'a [String]) -> ConnectorResult> { + let input = param_input.strip_prefix_str("@param").unwrap().trim_start(); + + let (input, typ) = parse_typ_opt(input, enum_names)?; + let (input, position) = parse_position_opt(input)?; + let (input, alias, nullable) = parse_alias_opt(input)?; + let documentation = parse_rest(input)?; + + let mut param = ParsedParameterDoc::default(); + + param.set_typ(typ); + param.set_nullable(nullable); + param.set_position(position); + param.set_alias(alias); + param.set_documentation(documentation); + + validate_param(¶m, param_input)?; + + Ok(param) +} + +fn parse_description(input: Input<'_>) -> ConnectorResult> { + let input = input.strip_prefix_str("@description").unwrap(); + + parse_rest(input) +} + +pub(crate) fn parse_sql_doc<'a>(sql: &'a str, enum_names: &'a [String]) -> ConnectorResult> { + let mut parsed_sql = ParsedSqlDoc::default(); + + let lines = sql.lines(); + + for line in lines { + let input = Input(line.trim()); + + if let Some(input) = input.strip_prefix_str("--") { + let input = input.trim_start(); + + if input.starts_with("@description") { + parsed_sql.set_description(parse_description(input)?); + } else if input.starts_with("@param") { + parsed_sql + .add_parameter(parse_param(input, enum_names)?) + .map_err(|err| build_error(input, err.message().unwrap()))?; + } + } + } + + Ok(parsed_sql) +} + +/// Mysql-async poorly parses the sql input to support named parameters, which conflicts with our own syntax for overriding query parameters type and nullability. +/// This function removes all single-line comments from the sql input to avoid conflicts. +pub(crate) fn sanitize_sql(sql: &str) -> String { + sql.lines() + .map(|line| line.trim()) + .filter(|line| !line.starts_with("--")) + .join("\n") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_param_1() { + use expect_test::expect; + + let res = parse_param(Input("@param $1:userId"), &[]); + + let expected = expect![[r#" + Ok( + ParsedParameterDoc { + alias: Some( + "userId", + ), + typ: None, + nullable: None, + position: Some( + 1, + ), + documentation: None, + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_2() { + use expect_test::expect; + + let res = parse_param(Input("@param $1:userId valid user identifier"), &[]); + + let expected = expect![[r#" + Ok( + ParsedParameterDoc { + alias: Some( + "userId", + ), + typ: None, + nullable: None, + position: Some( + 1, + ), + documentation: Some( + "valid user identifier", + ), + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_3() { + use expect_test::expect; + + let res = parse_param(Input("@param {Int} :userId"), &[]); + + let expected = expect![[r#" + Ok( + ParsedParameterDoc { + alias: Some( + "userId", + ), + typ: Some( + ColumnType( + Int32, + ), + ), + nullable: None, + position: None, + documentation: None, + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_4() { + use expect_test::expect; + + let res = parse_param(Input("@param {Int} $1:userId"), &[]); + + let expected = expect![[r#" + Ok( + ParsedParameterDoc { + alias: Some( + "userId", + ), + typ: Some( + ColumnType( + Int32, + ), + ), + nullable: None, + position: Some( + 1, + ), + documentation: None, + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_5() { + use expect_test::expect; + + let res = parse_param(Input("@param {Int} $1:userId valid user identifier"), &[]); + + let expected = expect![[r#" + Ok( + ParsedParameterDoc { + alias: Some( + "userId", + ), + typ: Some( + ColumnType( + Int32, + ), + ), + nullable: None, + position: Some( + 1, + ), + documentation: Some( + "valid user identifier", + ), + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_6() { + use expect_test::expect; + + let res = parse_param(Input("@param {Int} $1 valid user identifier"), &[]); + + let expected = expect![[r#" + Ok( + ParsedParameterDoc { + alias: None, + typ: Some( + ColumnType( + Int32, + ), + ), + nullable: None, + position: Some( + 1, + ), + documentation: Some( + "valid user identifier", + ), + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_7() { + use expect_test::expect; + + let res = parse_param(Input("@param {Int} $1f valid user identifier"), &[]); + + let expected = expect![[r#" + Err( + ConnectorErrorImpl { + user_facing_error: None, + message: Some( + "SQL documentation parsing: invalid position. Expected a number found: 1f at ' $1f valid user identifier'.", + ), + source: None, + context: SpanTrace [], + } + SQL documentation parsing: invalid position. Expected a number found: 1f at ' $1f valid user identifier'. + , + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_8() { + use expect_test::expect; + + let res = parse_param(Input("@param {} valid user identifier"), &[]); + + let expected = expect![[r#" + Err( + ConnectorErrorImpl { + user_facing_error: None, + message: Some( + "SQL documentation parsing: missing type (accepted types are: 'Int', 'BigInt', 'Float', 'Boolean', 'String', 'DateTime', 'Json', 'Bytes', 'Decimal') at '{} valid user identifier'.", + ), + source: None, + context: SpanTrace [], + } + SQL documentation parsing: missing type (accepted types are: 'Int', 'BigInt', 'Float', 'Boolean', 'String', 'DateTime', 'Json', 'Bytes', 'Decimal') at '{} valid user identifier'. + , + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_9() { + use expect_test::expect; + + let res = parse_param(Input("@param {Int $1f valid user identifier"), &[]); + + let expected = expect![[r#" + Err( + ConnectorErrorImpl { + user_facing_error: None, + message: Some( + "SQL documentation parsing: missing closing bracket at '{Int $1f valid user identifier'.", + ), + source: None, + context: SpanTrace [], + } + SQL documentation parsing: missing closing bracket at '{Int $1f valid user identifier'. + , + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_10() { + use expect_test::expect; + + let res = parse_param(Input("@param {Int} valid user identifier $10"), &[]); + + let expected = expect![[r#" + Err( + ConnectorErrorImpl { + user_facing_error: None, + message: Some( + "SQL documentation parsing: missing position or alias (eg: $1:alias) at '@param {Int} valid user identifier $10'.", + ), + source: None, + context: SpanTrace [], + } + SQL documentation parsing: missing position or alias (eg: $1:alias) at '@param {Int} valid user identifier $10'. + , + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_11() { + use expect_test::expect; + + let res = parse_param(Input("@param "), &[]); + + let expected = expect![[r#" + Err( + ConnectorErrorImpl { + user_facing_error: None, + message: Some( + "SQL documentation parsing: invalid parameter: could not parse any information at '@param '.", + ), + source: None, + context: SpanTrace [], + } + SQL documentation parsing: invalid parameter: could not parse any information at '@param '. + , + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_12() { + use expect_test::expect; + + let res = parse_param(Input("@param {Int}$1 some documentation"), &[]); + + let expected = expect![[r#" + Ok( + ParsedParameterDoc { + alias: None, + typ: Some( + ColumnType( + Int32, + ), + ), + nullable: None, + position: Some( + 1, + ), + documentation: Some( + "some documentation", + ), + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_13() { + use expect_test::expect; + + let res = parse_param(Input("@param {Int} $1 some documentation"), &[]); + + let expected = expect![[r#" + Ok( + ParsedParameterDoc { + alias: None, + typ: Some( + ColumnType( + Int32, + ), + ), + nullable: None, + position: Some( + 1, + ), + documentation: Some( + "some documentation", + ), + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_14() { + use expect_test::expect; + + let res = parse_param(Input("@param {Unknown} $1"), &[]); + + let expected = expect![[r#" + Err( + ConnectorErrorImpl { + user_facing_error: None, + message: Some( + "SQL documentation parsing: invalid type: 'Unknown' (accepted types are: 'Int', 'BigInt', 'Float', 'Boolean', 'String', 'DateTime', 'Json', 'Bytes', 'Decimal') at '{Unknown} $1'.", + ), + source: None, + context: SpanTrace [], + } + SQL documentation parsing: invalid type: 'Unknown' (accepted types are: 'Int', 'BigInt', 'Float', 'Boolean', 'String', 'DateTime', 'Json', 'Bytes', 'Decimal') at '{Unknown} $1'. + , + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_15() { + use expect_test::expect; + + let res = parse_param(Input("@param {Int} $1:alias!"), &[]); + + let expected = expect![[r#" + Ok( + ParsedParameterDoc { + alias: Some( + "alias!", + ), + typ: Some( + ColumnType( + Int32, + ), + ), + nullable: None, + position: Some( + 1, + ), + documentation: None, + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_16() { + use expect_test::expect; + + let res = parse_param(Input("@param {Int} $1:alias?"), &[]); + + let expected = expect![[r#" + Ok( + ParsedParameterDoc { + alias: Some( + "alias", + ), + typ: Some( + ColumnType( + Int32, + ), + ), + nullable: Some( + true, + ), + position: Some( + 1, + ), + documentation: None, + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_17() { + use expect_test::expect; + + let res = parse_param(Input("@param {Int} $1:alias!?"), &[]); + + let expected = expect![[r#" + Ok( + ParsedParameterDoc { + alias: Some( + "alias!", + ), + typ: Some( + ColumnType( + Int32, + ), + ), + nullable: Some( + true, + ), + position: Some( + 1, + ), + documentation: None, + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_18() { + use expect_test::expect; + + let res = parse_param(Input("@param $1:alias?"), &[]); + + let expected = expect![[r#" + Ok( + ParsedParameterDoc { + alias: Some( + "alias", + ), + typ: None, + nullable: Some( + true, + ), + position: Some( + 1, + ), + documentation: None, + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_19() { + use expect_test::expect; + + let enums = ["MyEnum".to_string()]; + let res = parse_param(Input("@param {MyEnum} $1:alias?"), &enums); + + let expected = expect![[r#" + Ok( + ParsedParameterDoc { + alias: Some( + "alias", + ), + typ: Some( + Enum( + "MyEnum", + ), + ), + nullable: Some( + true, + ), + position: Some( + 1, + ), + documentation: None, + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_20() { + use expect_test::expect; + + let enums = ["MyEnum".to_string()]; + let res = parse_param(Input("@param {MyEnum} $12567:alias"), &enums); + + let expected = expect![[r#" + Ok( + ParsedParameterDoc { + alias: Some( + "alias", + ), + typ: Some( + Enum( + "MyEnum", + ), + ), + nullable: None, + position: Some( + 12567, + ), + documentation: None, + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_param_21() { + use expect_test::expect; + + let enums = ["MyEnum".to_string(), "MyEnum2".to_string()]; + let res = parse_param(Input("@param {UnknownTyp} $12567:alias"), &enums); + + let expected = expect![[r#" + Err( + ConnectorErrorImpl { + user_facing_error: None, + message: Some( + "SQL documentation parsing: invalid type: 'UnknownTyp' (accepted types are: 'Int', 'BigInt', 'Float', 'Boolean', 'String', 'DateTime', 'Json', 'Bytes', 'Decimal', 'MyEnum', 'MyEnum2') at '{UnknownTyp} $12567:alias'.", + ), + source: None, + context: SpanTrace [], + } + SQL documentation parsing: invalid type: 'UnknownTyp' (accepted types are: 'Int', 'BigInt', 'Float', 'Boolean', 'String', 'DateTime', 'Json', 'Bytes', 'Decimal', 'MyEnum', 'MyEnum2') at '{UnknownTyp} $12567:alias'. + , + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_sql_1() { + use expect_test::expect; + + let res = parse_sql_doc("-- @param {Int} $1 some documentation ", &[]); + + let expected = expect![[r#" + Ok( + ParsedSqlDoc { + parameters: [ + ParsedParameterDoc { + alias: None, + typ: Some( + ColumnType( + Int32, + ), + ), + nullable: None, + position: Some( + 1, + ), + documentation: Some( + "some documentation", + ), + }, + ], + description: None, + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_sql_2() { + use expect_test::expect; + + let res = parse_sql_doc( + r#" -- @description This query returns a user by it's id + -- @param {Int} $1:userId valid user identifier + -- @param {String} $2:parentId valid parent identifier + SELECT enum FROM "test_introspect_sql"."model" + WHERE + id = $1;"#, + &[], + ); + + let expected = expect![[r#" + Ok( + ParsedSqlDoc { + parameters: [ + ParsedParameterDoc { + alias: Some( + "userId", + ), + typ: Some( + ColumnType( + Int32, + ), + ), + nullable: None, + position: Some( + 1, + ), + documentation: Some( + "valid user identifier", + ), + }, + ParsedParameterDoc { + alias: Some( + "parentId", + ), + typ: Some( + ColumnType( + Text, + ), + ), + nullable: None, + position: Some( + 2, + ), + documentation: Some( + "valid parent identifier", + ), + }, + ], + description: Some( + "This query returns a user by it's id", + ), + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_sql_3() { + use expect_test::expect; + + let res = parse_sql_doc( + r#"-- @description This query returns a user by it's id + -- @param {Int} $1:userId valid user identifier + -- @param {String} $1:parentId valid parent identifier +SELECT enum FROM "test_introspect_sql"."model" WHERE id = $1;"#, + &[], + ); + + let expected = expect![[r#" + Err( + ConnectorErrorImpl { + user_facing_error: None, + message: Some( + "SQL documentation parsing: duplicate parameter (position or alias is already used) at '@param {String} $1:parentId valid parent identifier'.", + ), + source: None, + context: SpanTrace [], + } + SQL documentation parsing: duplicate parameter (position or alias is already used) at '@param {String} $1:parentId valid parent identifier'. + , + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_sql_4() { + use expect_test::expect; + + let res = parse_sql_doc( + r#"-- @description This query returns a user by it's id +-- @param {Int} $1:userId valid user identifier +-- @param {String} $2:userId valid parent identifier +SELECT enum FROM "test_introspect_sql"."model" WHERE id = $1;"#, + &[], + ); + + let expected = expect![[r#" + Err( + ConnectorErrorImpl { + user_facing_error: None, + message: Some( + "SQL documentation parsing: duplicate parameter (position or alias is already used) at '@param {String} $2:userId valid parent identifier'.", + ), + source: None, + context: SpanTrace [], + } + SQL documentation parsing: duplicate parameter (position or alias is already used) at '@param {String} $2:userId valid parent identifier'. + , + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn parse_sql_5() { + use expect_test::expect; + + let res = parse_sql_doc( + r#" + /** + * Unhandled multi-line comment + */ + SELECT enum FROM "test_introspect_sql"."model" WHERE id = $1;"#, + &[], + ); + + let expected = expect![[r#" + Ok( + ParsedSqlDoc { + parameters: [], + description: None, + }, + ) + "#]]; + + expected.assert_debug_eq(&res); + } + + #[test] + fn sanitize_sql_test_1() { + use expect_test::expect; + + let sql = r#" + -- @description This query returns a user by it's id + -- @param {Int} $1:userId valid user identifier + -- @param {String} $2:parentId valid parent identifier + SELECT enum + FROM + "test_introspect_sql"."model" WHERE id = + $1; + "#; + + let expected = expect![[r#" + + SELECT enum + FROM + "test_introspect_sql"."model" WHERE id = + $1; + "#]]; + + expected.assert_eq(&sanitize_sql(sql)); + } +} diff --git a/schema-engine/connectors/sql-schema-connector/src/sql_schema_differ/index.rs b/schema-engine/connectors/sql-schema-connector/src/sql_schema_differ/index.rs index 257d79b36df6..441960df8c58 100644 --- a/schema-engine/connectors/sql-schema-connector/src/sql_schema_differ/index.rs +++ b/schema-engine/connectors/sql-schema-connector/src/sql_schema_differ/index.rs @@ -1,6 +1,11 @@ use sql_schema_describer::walkers::{IndexWalker, TableWalker}; pub(super) fn index_covers_fk(table: TableWalker<'_>, index: IndexWalker<'_>) -> bool { + // Only normal indexes can cover foreign keys. + if index.index_type() != sql_schema_describer::IndexType::Normal { + return false; + } + table.foreign_keys().any(|fk| { let fk_cols = fk.constrained_columns().map(|col| col.name()); let index_cols = index.column_names(); diff --git a/schema-engine/core/Cargo.toml b/schema-engine/core/Cargo.toml index 6814bf60ed23..6fb22d8a98cb 100644 --- a/schema-engine/core/Cargo.toml +++ b/schema-engine/core/Cargo.toml @@ -21,7 +21,7 @@ serde_json.workspace = true tokio.workspace = true tracing.workspace = true tracing-subscriber = "0.3" -tracing-futures = "0.2" +tracing-futures.workspace = true url.workspace = true [build-dependencies] diff --git a/schema-engine/core/src/api.rs b/schema-engine/core/src/api.rs index dcea95e6757b..1c0cb5d59b53 100644 --- a/schema-engine/core/src/api.rs +++ b/schema-engine/core/src/api.rs @@ -52,6 +52,9 @@ pub trait GenericApi: Send + Sync + 'static { /// Introspect the database schema. async fn introspect(&self, input: IntrospectParams) -> CoreResult; + /// Introspects a SQL query and returns types information + async fn introspect_sql(&self, input: IntrospectSqlParams) -> CoreResult; + /// List the migration directories. async fn list_migration_directories( &self, diff --git a/schema-engine/core/src/commands.rs b/schema-engine/core/src/commands.rs index 22b92267204c..ded1f163d0c6 100644 --- a/schema-engine/core/src/commands.rs +++ b/schema-engine/core/src/commands.rs @@ -6,6 +6,7 @@ mod dev_diagnostic; mod diagnose_migration_history; mod diff; mod evaluate_data_loss; +mod introspect_sql; mod mark_migration_applied; mod mark_migration_rolled_back; mod schema_push; @@ -20,6 +21,7 @@ pub use dev_diagnostic::dev_diagnostic; pub use diagnose_migration_history::diagnose_migration_history; pub use diff::diff; pub use evaluate_data_loss::evaluate_data_loss; +pub use introspect_sql::introspect_sql; pub use mark_migration_applied::mark_migration_applied; pub use mark_migration_rolled_back::mark_migration_rolled_back; pub use schema_push::schema_push; diff --git a/schema-engine/core/src/commands/introspect_sql.rs b/schema-engine/core/src/commands/introspect_sql.rs new file mode 100644 index 000000000000..7cf323ce43a9 --- /dev/null +++ b/schema-engine/core/src/commands/introspect_sql.rs @@ -0,0 +1,28 @@ +use crate::json_rpc::types::IntrospectSqlParams; +use schema_connector::{IntrospectSqlQueryInput, IntrospectSqlResult, SchemaConnector}; + +pub async fn introspect_sql( + input: IntrospectSqlParams, + connector: &mut dyn SchemaConnector, +) -> crate::CoreResult { + let queries: Vec<_> = input + .queries + .into_iter() + .map(|q| IntrospectSqlQueryInput { + name: q.name, + source: q.source, + }) + .collect(); + + let mut parsed_queries = Vec::with_capacity(queries.len()); + + for q in queries { + let parsed_query = connector.introspect_sql(q).await?; + + parsed_queries.push(parsed_query); + } + + Ok(IntrospectSqlResult { + queries: parsed_queries, + }) +} diff --git a/schema-engine/core/src/lib.rs b/schema-engine/core/src/lib.rs index b367ab0bfff9..3c0a2bf6d6a1 100644 --- a/schema-engine/core/src/lib.rs +++ b/schema-engine/core/src/lib.rs @@ -41,7 +41,7 @@ fn connector_for_connection_string( preview_features: BitFlags, ) -> CoreResult> { match connection_string.split(':').next() { - Some("postgres") | Some("postgresql") => { + Some("postgres") | Some("postgresql") | Some("prisma+postgres") => { let params = ConnectorParams { connection_string, preview_features, diff --git a/schema-engine/core/src/rpc.rs b/schema-engine/core/src/rpc.rs index cf7bac51a8ec..b5467260f638 100644 --- a/schema-engine/core/src/rpc.rs +++ b/schema-engine/core/src/rpc.rs @@ -47,6 +47,7 @@ async fn run_command( EVALUATE_DATA_LOSS => render(executor.evaluate_data_loss(params.parse()?).await), GET_DATABASE_VERSION => render(executor.version(params.parse()?).await), INTROSPECT => render(executor.introspect(params.parse()?).await), + INTROSPECT_SQL => render(executor.introspect_sql(params.parse()?).await), LIST_MIGRATION_DIRECTORIES => render(executor.list_migration_directories(params.parse()?).await), MARK_MIGRATION_APPLIED => render(executor.mark_migration_applied(params.parse()?).await), MARK_MIGRATION_ROLLED_BACK => render(executor.mark_migration_rolled_back(params.parse()?).await), @@ -64,7 +65,7 @@ fn render(result: CoreResult) -> jsonrpc_core::Result JsonRpcError { - serde_json::to_value(&crate_error.to_user_facing()) + serde_json::to_value(crate_error.to_user_facing()) .map(|data| JsonRpcError { // We separate the JSON-RPC error code (defined by the JSON-RPC spec) from the // prisma error code, which is located in `data`. diff --git a/schema-engine/core/src/state.rs b/schema-engine/core/src/state.rs index 60cf3a7598df..b0ae38d95e4f 100644 --- a/schema-engine/core/src/state.rs +++ b/schema-engine/core/src/state.rs @@ -409,6 +409,49 @@ impl GenericApi for EngineState { .await } + async fn introspect_sql(&self, params: IntrospectSqlParams) -> CoreResult { + self.with_connector_for_url( + params.url.clone(), + Box::new(move |conn| { + Box::pin(async move { + let res = crate::commands::introspect_sql(params, conn).await?; + + Ok(IntrospectSqlResult { + queries: res + .queries + .into_iter() + .map(|q| SqlQueryOutput { + name: q.name, + source: q.source, + documentation: q.documentation, + parameters: q + .parameters + .into_iter() + .map(|p| SqlQueryParameterOutput { + name: p.name, + typ: p.typ, + documentation: p.documentation, + nullable: p.nullable, + }) + .collect(), + result_columns: q + .result_columns + .into_iter() + .map(|c| SqlQueryColumnOutput { + name: c.name, + typ: c.typ, + nullable: c.nullable, + }) + .collect(), + }) + .collect(), + }) + }) + }), + ) + .await + } + async fn list_migration_directories( &self, input: ListMigrationDirectoriesInput, diff --git a/schema-engine/datamodel-renderer/Cargo.toml b/schema-engine/datamodel-renderer/Cargo.toml index ad1b0435d66b..b74352589e12 100644 --- a/schema-engine/datamodel-renderer/Cargo.toml +++ b/schema-engine/datamodel-renderer/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] once_cell = "1.15.0" psl.workspace = true -regex = "1.6.0" +regex.workspace = true base64 = "0.13.1" [dev-dependencies] diff --git a/schema-engine/json-rpc-api-build/methods/introspectSql.toml b/schema-engine/json-rpc-api-build/methods/introspectSql.toml new file mode 100644 index 000000000000..19eb888da327 --- /dev/null +++ b/schema-engine/json-rpc-api-build/methods/introspectSql.toml @@ -0,0 +1,68 @@ +[methods.introspectSql] +description = "Introspect a SQL query and returns type information" +requestShape = "introspectSqlParams" +responseShape = "introspectSqlResult" + +# Input + +[recordShapes.introspectSqlParams] +description = "Params type for the introspectSql method." + +[recordShapes.introspectSqlParams.fields.url] +shape = "string" + +[recordShapes.introspectSqlParams.fields.queries] +shape = "sqlQueryInput" +isList = true + +# Result + +[recordShapes.introspectSqlResult] +description = "Result type for the introspectSql method." + +[recordShapes.introspectSqlResult.fields.queries] +shape = "sqlQueryOutput" +isList = true + +# Containers + +[recordShapes.sqlQueryInput] +[recordShapes.sqlQueryInput.fields.name] +shape = "string" +[recordShapes.sqlQueryInput.fields.source] +shape = "string" + +[recordShapes.sqlQueryOutput] +[recordShapes.sqlQueryOutput.fields.name] +shape = "string" +[recordShapes.sqlQueryOutput.fields.source] +shape = "string" +[recordShapes.sqlQueryOutput.fields.documentation] +isNullable = true +shape = "string" +[recordShapes.sqlQueryOutput.fields.parameters] +shape = "sqlQueryParameterOutput" +isList = true +[recordShapes.sqlQueryOutput.fields.resultColumns] +shape = "sqlQueryColumnOutput" +isList = true + +[recordShapes.sqlQueryParameterOutput] +[recordShapes.sqlQueryParameterOutput.fields.name] +shape = "string" +[recordShapes.sqlQueryParameterOutput.fields.typ] +shape = "string" +[recordShapes.sqlQueryParameterOutput.fields.documentation] +isNullable = true +shape = "string" +[recordShapes.sqlQueryParameterOutput.fields.nullable] +shape = "bool" + +[recordShapes.sqlQueryColumnOutput] +[recordShapes.sqlQueryColumnOutput.fields.name] +shape = "string" +[recordShapes.sqlQueryColumnOutput.fields.typ] +shape = "string" +[recordShapes.sqlQueryColumnOutput.fields.nullable] +shape = "bool" + diff --git a/schema-engine/json-rpc-api-build/methods/markMigrationApplied.toml b/schema-engine/json-rpc-api-build/methods/markMigrationApplied.toml index f76ee11f3094..ce4522dbb52d 100644 --- a/schema-engine/json-rpc-api-build/methods/markMigrationApplied.toml +++ b/schema-engine/json-rpc-api-build/methods/markMigrationApplied.toml @@ -4,9 +4,9 @@ description = """Mark a migration as applied in the migrations table. There are two possible outcomes: - The migration is already in the table, but in a failed state. In this case, we will mark it -as rolled back, then create a new entry. + as rolled back, then create a new entry. - The migration is not in the table. We will create a new entry in the migrations table. The -`started_at` and `finished_at` will be the same. + `started_at` and `finished_at` will be the same. - If it is already applied, we return a user-facing error. """ requestShape = "markMigrationAppliedInput" diff --git a/schema-engine/mongodb-schema-describer/Cargo.toml b/schema-engine/mongodb-schema-describer/Cargo.toml index ffd917317487..e8885e990827 100644 --- a/schema-engine/mongodb-schema-describer/Cargo.toml +++ b/schema-engine/mongodb-schema-describer/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -mongodb = "2.8.0" -futures = "0.3" +mongodb.workspace = true +bson.workspace = true +futures.workspace = true serde.workspace = true diff --git a/schema-engine/mongodb-schema-describer/src/lib.rs b/schema-engine/mongodb-schema-describer/src/lib.rs index 27fcfb149abc..29186158cbc4 100644 --- a/schema-engine/mongodb-schema-describer/src/lib.rs +++ b/schema-engine/mongodb-schema-describer/src/lib.rs @@ -10,8 +10,8 @@ mod walkers; pub use schema::*; pub use walkers::*; +use bson::{Bson, Document}; use futures::stream::TryStreamExt; -use mongodb::bson::{Bson, Document}; /// Describe the contents of the given database. Only bothers about the schema, meaning the /// collection names and indexes created. Does a bit of magic to the indexes, so if having a @@ -22,7 +22,7 @@ use mongodb::bson::{Bson, Document}; pub async fn describe(client: &mongodb::Client, db_name: &str) -> mongodb::error::Result { let mut schema = MongoSchema::default(); let database = client.database(db_name); - let mut cursor = database.list_collections(None, None).await?; + let mut cursor = database.list_collections().await?; while let Some(collection) = cursor.try_next().await? { let collection_name = collection.name; @@ -45,7 +45,7 @@ pub async fn describe(client: &mongodb::Client, db_name: &str) -> mongodb::error let collection = database.collection::(&collection_name); let collection_id = schema.push_collection(collection_name, has_schema, is_capped); - let mut indexes_cursor = collection.list_indexes(None).await?; + let mut indexes_cursor = collection.list_indexes().await?; while let Some(index) = indexes_cursor.try_next().await? { let options = index.options.unwrap_or_default(); @@ -122,9 +122,9 @@ pub async fn describe(client: &mongodb::Client, db_name: &str) -> mongodb::error /// Get the version. pub async fn version(client: &mongodb::Client, db_name: &str) -> mongodb::error::Result { let database = client.database(db_name); - use mongodb::bson::doc; + use bson::doc; let version_cmd = doc! {"buildInfo": 1}; - let res = database.run_command(version_cmd, None).await?; + let res = database.run_command(version_cmd).await?; let version = res .get("versionArray") .unwrap() diff --git a/schema-engine/mongodb-schema-describer/src/schema.rs b/schema-engine/mongodb-schema-describer/src/schema.rs index 583fa0b6aee6..3243205e2cb3 100644 --- a/schema-engine/mongodb-schema-describer/src/schema.rs +++ b/schema-engine/mongodb-schema-describer/src/schema.rs @@ -1,4 +1,4 @@ -use mongodb::bson::Bson; +use bson::Bson; use serde::{Deserialize, Serialize}; use std::{ collections::BTreeMap, diff --git a/schema-engine/sql-introspection-tests/Cargo.toml b/schema-engine/sql-introspection-tests/Cargo.toml index 3d45c178f09f..d0891b0bbaaa 100644 --- a/schema-engine/sql-introspection-tests/Cargo.toml +++ b/schema-engine/sql-introspection-tests/Cargo.toml @@ -18,7 +18,7 @@ itertools.workspace = true enumflags2.workspace = true connection-string.workspace = true pretty_assertions = "1" -tracing-futures = "0.2" +tracing-futures.workspace = true tokio.workspace = true tracing.workspace = true indoc.workspace = true diff --git a/schema-engine/sql-migration-tests/Cargo.toml b/schema-engine/sql-migration-tests/Cargo.toml index 710ef1ebbd3b..6a345a365486 100644 --- a/schema-engine/sql-migration-tests/Cargo.toml +++ b/schema-engine/sql-migration-tests/Cargo.toml @@ -30,6 +30,9 @@ serde_json.workspace = true tempfile = "3.1.0" tokio.workspace = true tracing.workspace = true -tracing-futures = "0.2" +tracing-futures.workspace = true url.workspace = true quaint = { workspace = true, features = ["all-native"] } + +[dev-dependencies] +paste = "1" diff --git a/schema-engine/sql-migration-tests/src/assertions.rs b/schema-engine/sql-migration-tests/src/assertions.rs index e32554a451b5..6a2598edfb38 100644 --- a/schema-engine/sql-migration-tests/src/assertions.rs +++ b/schema-engine/sql-migration-tests/src/assertions.rs @@ -191,13 +191,11 @@ impl SchemaAssertion { } fn print_context(&self) { - match &self.context { - Some(context) => println!("Test failure with context <{}>", context.red()), - None => {} + if let Some(context) = &self.context { + println!("Test failure with context <{}>", context.red()) } - match &self.description { - Some(description) => println!("{}: {}", "Description".bold(), description.italic()), - None => {} + if let Some(description) = &self.description { + println!("{}: {}", "Description".bold(), description.italic()) } } @@ -325,13 +323,11 @@ pub struct TableAssertion<'a> { impl<'a> TableAssertion<'a> { fn print_context(&self) { - match &self.context { - Some(context) => println!("Test failure with context <{}>", context.red()), - None => {} + if let Some(context) = &self.context { + println!("Test failure with context <{}>", context.red()) } - match &self.description { - Some(description) => println!("{}: {}", "Description".bold(), description.italic()), - None => {} + if let Some(description) = &self.description { + println!("{}: {}", "Description".bold(), description.italic()) } } diff --git a/schema-engine/sql-migration-tests/src/assertions/quaint_result_set_ext.rs b/schema-engine/sql-migration-tests/src/assertions/quaint_result_set_ext.rs index 3f486f34163b..ad0e01a7fe29 100644 --- a/schema-engine/sql-migration-tests/src/assertions/quaint_result_set_ext.rs +++ b/schema-engine/sql-migration-tests/src/assertions/quaint_result_set_ext.rs @@ -112,4 +112,26 @@ impl<'a> RowAssertion<'a> { self } + + pub fn assert_bigint_value(self, column_name: &str, expected_value: i64) -> Self { + let actual_value = self.0.get(column_name).and_then(|col: &Value<'_>| (*col).as_i64()); + + assert!( + actual_value == Some(expected_value), + "Value assertion failed for {column_name}. Expected: {expected_value:?}, got: {actual_value:?}", + ); + + self + } + + pub fn assert_bytes_value(self, column_name: &str, expected_value: &[u8]) -> Self { + let actual_value = self.0.get(column_name).and_then(|col: &Value<'_>| (*col).as_bytes()); + + assert!( + actual_value == Some(expected_value), + "Value assertion failed for {column_name}. Expected: {expected_value:?}, got: {actual_value:?}", + ); + + self + } } diff --git a/schema-engine/sql-migration-tests/src/commands.rs b/schema-engine/sql-migration-tests/src/commands.rs index 2889ebf8ab31..035731485aa0 100644 --- a/schema-engine/sql-migration-tests/src/commands.rs +++ b/schema-engine/sql-migration-tests/src/commands.rs @@ -3,6 +3,7 @@ mod create_migration; mod dev_diagnostic; mod diagnose_migration_history; mod evaluate_data_loss; +mod introspect_sql; mod list_migration_directories; mod mark_migration_applied; mod mark_migration_rolled_back; @@ -14,6 +15,7 @@ pub(crate) use create_migration::*; pub(crate) use dev_diagnostic::*; pub(crate) use diagnose_migration_history::*; pub(crate) use evaluate_data_loss::*; +pub(crate) use introspect_sql::*; pub(crate) use list_migration_directories::*; pub(crate) use mark_migration_applied::*; pub(crate) use mark_migration_rolled_back::*; diff --git a/schema-engine/sql-migration-tests/src/commands/introspect_sql.rs b/schema-engine/sql-migration-tests/src/commands/introspect_sql.rs new file mode 100644 index 000000000000..d8800cb05804 --- /dev/null +++ b/schema-engine/sql-migration-tests/src/commands/introspect_sql.rs @@ -0,0 +1,95 @@ +use quaint::prelude::ColumnType; +use schema_core::{ + schema_connector::{IntrospectSqlQueryInput, IntrospectSqlQueryOutput, SchemaConnector}, + CoreError, CoreResult, +}; + +#[must_use = "This struct does nothing on its own. See ApplyMigrations::send()"] +pub struct IntrospectSql<'a> { + api: &'a mut dyn SchemaConnector, + name: &'a str, + source: String, +} + +impl<'a> IntrospectSql<'a> { + pub fn new(api: &'a mut dyn SchemaConnector, name: &'a str, source: String) -> Self { + Self { api, name, source } + } + + pub async fn send(self) -> CoreResult { + let res = self + .api + .introspect_sql(IntrospectSqlQueryInput { + name: self.name.to_owned(), + source: self.source, + }) + .await?; + + Ok(IntrospectSqlAssertion { output: res }) + } + + #[track_caller] + pub fn send_sync(self) -> IntrospectSqlAssertion { + test_setup::runtime::run_with_thread_local_runtime(self.send()).unwrap() + } + + #[track_caller] + pub fn send_unwrap_err(self) -> CoreError { + test_setup::runtime::run_with_thread_local_runtime(self.send()).unwrap_err() + } +} + +pub struct IntrospectSqlAssertion { + pub output: IntrospectSqlQueryOutput, +} + +impl std::fmt::Debug for IntrospectSqlAssertion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ApplyMigrationsAssertion {{ .. }}") + } +} + +impl IntrospectSqlAssertion { + #[track_caller] + pub fn expect_result(&self, expectation: expect_test::Expect) { + expectation.assert_debug_eq(&self.output) + } + + #[track_caller] + pub fn expect_param_type(self, idx: usize, ty: ColumnType) -> Self { + let param = &self + .output + .parameters + .get(idx) + .unwrap_or_else(|| panic!("parameter at index {idx} not found")); + let param_name = ¶m.name; + let actual_typ = ¶m.typ; + let expected_typ = &ty.to_string(); + + assert_eq!( + expected_typ, actual_typ, + "expected param {param_name} to be of type {expected_typ}, got: {actual_typ}", + ); + + self + } + + #[track_caller] + pub fn expect_column_type(self, idx: usize, ty: ColumnType) -> Self { + let column = &self + .output + .result_columns + .get(idx) + .unwrap_or_else(|| panic!("column at index {idx} not found")); + let column_name = &column.name; + let actual_typ = &column.typ; + let expected_typ = &ty.to_string(); + + assert_eq!( + expected_typ, actual_typ, + "expected column {column_name} to be of type {expected_typ}, got: {actual_typ}" + ); + + self + } +} diff --git a/schema-engine/sql-migration-tests/src/commands/schema_push.rs b/schema-engine/sql-migration-tests/src/commands/schema_push.rs index f7442b3a72c4..f20121f7b3aa 100644 --- a/schema-engine/sql-migration-tests/src/commands/schema_push.rs +++ b/schema-engine/sql-migration-tests/src/commands/schema_push.rs @@ -102,13 +102,11 @@ impl SchemaPushAssertion { } pub fn print_context(&self) { - match &self.context { - Some(context) => println!("Test failure with context <{}>", context.red()), - None => {} + if let Some(context) = &self.context { + println!("Test failure with context <{}>", context.red()) } - match &self.description { - Some(description) => println!("{}: {}", "Description".bold(), description.italic()), - None => {} + if let Some(description) = &self.description { + println!("{}: {}", "Description".bold(), description.italic()) } } diff --git a/schema-engine/sql-migration-tests/src/test_api.rs b/schema-engine/sql-migration-tests/src/test_api.rs index 7f2bc769f78b..6d67bbd98491 100644 --- a/schema-engine/sql-migration-tests/src/test_api.rs +++ b/schema-engine/sql-migration-tests/src/test_api.rs @@ -10,7 +10,10 @@ pub use test_macros::test_connector; pub use test_setup::{runtime::run_with_thread_local_runtime as tok, BitFlags, Capabilities, Tags}; use crate::{commands::*, multi_engine_test_api::TestApi as RootTestApi}; -use psl::parser_database::SourceFile; +use psl::{ + datamodel_connector::NativeTypeInstance, + parser_database::{ScalarType, SourceFile}, +}; use quaint::{ prelude::{ConnectionInfo, ResultSet}, Value, @@ -157,6 +160,59 @@ impl TestApi { EvaluateDataLoss::new(&mut self.connector, migrations_directory, files) } + pub fn introspect_sql<'a>(&'a mut self, name: &'a str, source: &'a str) -> IntrospectSql<'a> { + let sanitized = self.sanitize_sql(source); + + IntrospectSql::new(&mut self.connector, name, sanitized) + } + + // Replaces `?` with the appropriate positional parameter syntax for the current database. + pub fn sanitize_sql(&self, sql: &str) -> String { + let mut counter = 1; + + if self.is_mysql() || self.is_mariadb() || self.is_sqlite() { + return sql.to_string(); + } + + let mut out = String::with_capacity(sql.len()); + let mut lines = sql.lines().peekable(); + + while let Some(line) = lines.next() { + // Avoid replacing query params in comments + if line.trim_start().starts_with("--") { + out.push_str(line); + + if lines.peek().is_some() { + out.push('\n'); + } + } else { + let mut line = line.to_string(); + + while let Some(idx) = line.find('?') { + let replacer = if self.is_postgres() || self.is_cockroach() { + format!("${}", counter) + } else if self.is_mssql() { + format!("@P{}", counter) + } else { + unimplemented!() + }; + + line.replace_range(idx..idx + 1, &replacer); + + counter += 1; + } + + out.push_str(&line); + + if lines.peek().is_some() { + out.push('\n'); + } + } + } + + out + } + /// Returns true only when testing on MSSQL. pub fn is_mssql(&self) -> bool { self.root.is_mssql() @@ -449,6 +505,10 @@ impl TestApi { out } + + pub fn scalar_type_for_native_type(&self, typ: &NativeTypeInstance) -> ScalarType { + self.connector.scalar_type_for_native_type(typ) + } } pub struct SingleRowInsert<'a> { diff --git a/schema-engine/sql-migration-tests/tests/errors/error_tests.rs b/schema-engine/sql-migration-tests/tests/errors/error_tests.rs index 90d0f90c5250..ddcf1adb802a 100644 --- a/schema-engine/sql-migration-tests/tests/errors/error_tests.rs +++ b/schema-engine/sql-migration-tests/tests/errors/error_tests.rs @@ -49,7 +49,7 @@ fn authentication_failure_must_return_a_known_error_on_postgres(api: TestApi) { let user = db_url.username(); let host = db_url.host().unwrap().to_string(); - let json_error = serde_json::to_value(&error.to_user_facing()).unwrap(); + let json_error = serde_json::to_value(error.to_user_facing()).unwrap(); let expected = json!({ "is_panic": false, "message": format!("Authentication failed against database server at `{host}`, the provided database credentials for `postgres` are not valid.\n\nPlease make sure to provide valid database credentials for the database server at `{host}`."), @@ -83,7 +83,7 @@ fn authentication_failure_must_return_a_known_error_on_mysql(api: TestApi) { let user = url.username(); let host = url.host().unwrap().to_string(); - let json_error = serde_json::to_value(&error.to_user_facing()).unwrap(); + let json_error = serde_json::to_value(error.to_user_facing()).unwrap(); let expected = json!({ "is_panic": false, "message": format!("Authentication failed against database server at `{host}`, the provided database credentials for `{user}` are not valid.\n\nPlease make sure to provide valid database credentials for the database server at `{host}`."), @@ -118,7 +118,7 @@ fn authentication_failure_must_return_a_known_error_on_mssql(api: TestApi) { let error = tok(connection_error(dm)); - let json_error = serde_json::to_value(&error.to_user_facing()).unwrap(); + let json_error = serde_json::to_value(error.to_user_facing()).unwrap(); let expected = json!({ "is_panic": false, "message": format!("Authentication failed against database server at `{host}`, the provided database credentials for `{user}` are not valid.\n\nPlease make sure to provide valid database credentials for the database server at `{host}`."), @@ -156,7 +156,7 @@ fn unreachable_database_must_return_a_proper_error_on_mysql(api: TestApi) { let port = url.port().unwrap(); let host = url.host().unwrap().to_string(); - let json_error = serde_json::to_value(&error.to_user_facing()).unwrap(); + let json_error = serde_json::to_value(error.to_user_facing()).unwrap(); let expected = json!({ "is_panic": false, "message": format!("Can't reach database server at `{host}:{port}`\n\nPlease make sure your database server is running at `{host}:{port}`."), @@ -190,7 +190,7 @@ fn unreachable_database_must_return_a_proper_error_on_postgres(api: TestApi) { let host = url.host().unwrap().to_string(); let port = url.port().unwrap(); - let json_error = serde_json::to_value(&error.to_user_facing()).unwrap(); + let json_error = serde_json::to_value(error.to_user_facing()).unwrap(); let expected = json!({ "is_panic": false, "message": format!("Can't reach database server at `{host}:{port}`\n\nPlease make sure your database server is running at `{host}:{port}`."), @@ -222,7 +222,7 @@ fn database_does_not_exist_must_return_a_proper_error(api: TestApi) { let error = tok(connection_error(dm)); - let json_error = serde_json::to_value(&error.to_user_facing()).unwrap(); + let json_error = serde_json::to_value(error.to_user_facing()).unwrap(); let expected = json!({ "is_panic": false, "message": format!("Database `{database_name}` does not exist on the database server at `{database_host}:{database_port}`.", database_name = database_name, database_host = url.host().unwrap(), database_port = url.port().unwrap()), @@ -251,7 +251,7 @@ fn bad_datasource_url_and_provider_combinations_must_return_a_proper_error(api: let error = tok(connection_error(dm)); - let json_error = serde_json::to_value(&error.to_user_facing()).unwrap(); + let json_error = serde_json::to_value(error.to_user_facing()).unwrap(); let err_message: String = json_error["message"].as_str().unwrap().into(); @@ -293,7 +293,7 @@ fn connections_to_system_databases_must_be_rejected(api: TestApi) { let name = if name == &"" { "mysql" } else { name }; let error = tok(connection_error(dm)); - let json_error = serde_json::to_value(&error.to_user_facing()).unwrap(); + let json_error = serde_json::to_value(error.to_user_facing()).unwrap(); let expected = json!({ "is_panic": false, @@ -455,7 +455,7 @@ async fn connection_string_problems_give_a_nice_error() { .await .unwrap_err(); - let json_error = serde_json::to_value(&error.to_user_facing()).unwrap(); + let json_error = serde_json::to_value(error.to_user_facing()).unwrap(); let details = match provider.0 { "sqlserver" => { @@ -509,7 +509,7 @@ async fn bad_connection_string_in_datamodel_returns_nice_error() { Err(e) => e, }; - let json_error = serde_json::to_value(&error.to_user_facing()).unwrap(); + let json_error = serde_json::to_value(error.to_user_facing()).unwrap(); let expected_json_error = json!({ "is_panic": false, diff --git a/schema-engine/sql-migration-tests/tests/migration_tests.rs b/schema-engine/sql-migration-tests/tests/migration_tests.rs index be4727d4d433..de7b587cdf6d 100644 --- a/schema-engine/sql-migration-tests/tests/migration_tests.rs +++ b/schema-engine/sql-migration-tests/tests/migration_tests.rs @@ -8,4 +8,5 @@ mod introspection; mod list_migration_directories; mod migrations; mod native_types; +mod query_introspection; mod schema_push; diff --git a/schema-engine/sql-migration-tests/tests/migrations/diff.rs b/schema-engine/sql-migration-tests/tests/migrations/diff.rs index 0eadac39657e..b9225bd7a22d 100644 --- a/schema-engine/sql-migration-tests/tests/migrations/diff.rs +++ b/schema-engine/sql-migration-tests/tests/migrations/diff.rs @@ -7,6 +7,95 @@ use schema_core::{ use sql_migration_tests::{test_api::*, utils::to_schema_containers}; use std::sync::Arc; +#[test_connector(tags(Sqlite, Mysql, Postgres, CockroachDb, Mssql))] +fn from_unique_index_to_without(mut api: TestApi) { + let tempdir = tempfile::tempdir().unwrap(); + let host = Arc::new(TestConnectorHost::default()); + + api.connector.set_host(host.clone()); + + let from_schema = api.datamodel_with_provider( + r#" + model Post { + id Int @id + title String + author User? @relation(fields: [authorId], references: [id]) + authorId Int? @unique + // ^^^^^^^ this will be removed later + } + + model User { + id Int @id + name String? + posts Post[] + } + "#, + ); + + let to_schema = api.datamodel_with_provider( + r#" + model Post { + id Int @id + title String + author User? @relation(fields: [authorId], references: [id]) + authorId Int? + } + + model User { + id Int @id + name String? + posts Post[] + } + "#, + ); + + let from_file = write_file_to_tmp(&from_schema, &tempdir, "from"); + let to_file = write_file_to_tmp(&to_schema, &tempdir, "to"); + + api.diff(DiffParams { + exit_code: None, + from: DiffTarget::SchemaDatamodel(SchemasContainer { + files: vec![SchemaContainer { + path: from_file.to_string_lossy().into_owned(), + content: from_schema.to_string(), + }], + }), + shadow_database_url: None, + to: DiffTarget::SchemaDatamodel(SchemasContainer { + files: vec![SchemaContainer { + path: to_file.to_string_lossy().into_owned(), + content: to_schema.to_string(), + }], + }), + script: true, + }) + .unwrap(); + + let expected_printed_messages = if api.is_mysql() { + expect![[r#" + [ + "-- DropIndex\nDROP INDEX `Post_authorId_key` ON `Post`;\n", + ] + "#]] + } else if api.is_sqlite() || api.is_postgres() || api.is_cockroach() { + expect![[r#" + [ + "-- DropIndex\nDROP INDEX \"Post_authorId_key\";\n", + ] + "#]] + } else if api.is_mssql() { + expect![[r#" + [ + "BEGIN TRY\n\nBEGIN TRAN;\n\n-- DropIndex\nDROP INDEX [Post_authorId_key] ON [dbo].[Post];\n\nCOMMIT TRAN;\n\nEND TRY\nBEGIN CATCH\n\nIF @@TRANCOUNT > 0\nBEGIN\n ROLLBACK TRAN;\nEND;\nTHROW\n\nEND CATCH\n", + ] + "#]] + } else { + unreachable!() + }; + + expected_printed_messages.assert_debug_eq(&host.printed_messages.lock().unwrap()); +} + #[test_connector(tags(Sqlite))] fn diffing_postgres_schemas_when_initialized_on_sqlite(mut api: TestApi) { // We should get a postgres diff. diff --git a/schema-engine/sql-migration-tests/tests/migrations/indexes.rs b/schema-engine/sql-migration-tests/tests/migrations/indexes.rs index dded1559ba2a..77badfd061e4 100644 --- a/schema-engine/sql-migration-tests/tests/migrations/indexes.rs +++ b/schema-engine/sql-migration-tests/tests/migrations/indexes.rs @@ -946,7 +946,7 @@ fn adding_fulltext_index_to_an_existing_column(api: TestApi) { } "#}; - api.schema_push(&api.datamodel_with_provider(dm)).send().assert_green(); + api.schema_push(api.datamodel_with_provider(dm)).send().assert_green(); api.assert_schema() .assert_table("A", |table| table.assert_indexes_count(0)); @@ -961,7 +961,7 @@ fn adding_fulltext_index_to_an_existing_column(api: TestApi) { } "#}; - api.schema_push(&api.datamodel_with_provider(dm)).send().assert_green(); + api.schema_push(api.datamodel_with_provider(dm)).send().assert_green(); api.assert_schema().assert_table("A", |table| { table.assert_index_on_columns(&["a", "b"], |index| index.assert_is_fulltext()) @@ -980,7 +980,7 @@ fn changing_normal_index_to_a_fulltext_index(api: TestApi) { } "#}; - api.schema_push(&api.datamodel_with_provider(dm)).send().assert_green(); + api.schema_push(api.datamodel_with_provider(dm)).send().assert_green(); api.assert_schema().assert_table("A", |table| { table.assert_indexes_count(1); @@ -997,7 +997,7 @@ fn changing_normal_index_to_a_fulltext_index(api: TestApi) { } "#}; - api.schema_push(&api.datamodel_with_provider(dm)).send().assert_green(); + api.schema_push(api.datamodel_with_provider(dm)).send().assert_green(); api.assert_schema().assert_table("A", |table| { table.assert_indexes_count(1); diff --git a/schema-engine/sql-migration-tests/tests/migrations/mssql.rs b/schema-engine/sql-migration-tests/tests/migrations/mssql.rs index fc543f99227a..12e8996cec91 100644 --- a/schema-engine/sql-migration-tests/tests/migrations/mssql.rs +++ b/schema-engine/sql-migration-tests/tests/migrations/mssql.rs @@ -158,7 +158,7 @@ fn mssql_apply_migrations_error_output(api: TestApi) { .split_terminator(" 0: ") .next() .unwrap() - .trim_end_matches(|c| c == '\n' || c == ' '); + .trim_end_matches(['\n', ' ']); expectation.assert_eq(first_segment) } diff --git a/schema-engine/sql-migration-tests/tests/query_introspection/docs.rs b/schema-engine/sql-migration-tests/tests/query_introspection/docs.rs new file mode 100644 index 000000000000..33eb9f50161c --- /dev/null +++ b/schema-engine/sql-migration-tests/tests/query_introspection/docs.rs @@ -0,0 +1,353 @@ +use super::utils::*; +use sql_migration_tests::test_api::*; + +#[test_connector(tags(Postgres))] +fn parses_doc_complex_pg(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "\nSELECT int FROM model WHERE int = $1 and string = $2;\n", + documentation: Some( + "some fancy query", + ), + parameters: [ + IntrospectSqlQueryParameterOutput { + documentation: Some( + "some integer", + ), + name: "myInt", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: Some( + "some string", + ), + name: "myString", + typ: "string", + nullable: true, + }, + ], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + let sql = r#" + -- @description some fancy query + -- @param {Int} $1:myInt some integer + -- @param {String}$2:myString? some string + SELECT int FROM model WHERE int = ? and string = ?; + "#; + + api.introspect_sql("test_1", sql).send_sync().expect_result(expected) +} + +#[test_connector(tags(Mysql))] +fn parses_doc_complex_mysql(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "\nSELECT `int` FROM `model` WHERE `int` = ? and `string` = ?;\n", + documentation: Some( + "some fancy query", + ), + parameters: [ + IntrospectSqlQueryParameterOutput { + documentation: Some( + "some integer", + ), + name: "myInt", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: Some( + "some string", + ), + name: "myString", + typ: "string", + nullable: true, + }, + ], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + let sql = r#" + -- @description some fancy query + -- @param {Int} $1:myInt some integer + -- @param {String}$2:myString? some string + SELECT `int` FROM `model` WHERE `int` = ? and `string` = ?; + "#; + + let res = api.introspect_sql("test_1", sql).send_sync(); + + res.expect_result(expected); + + api.query_raw( + &res.output.source, + &[quaint::Value::int32(1), quaint::Value::text("hello")], + ); +} + +#[test_connector(tags(Sqlite))] +fn parses_doc_no_position(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "\nSELECT int FROM model WHERE int = :myInt and string = ?;\n", + documentation: None, + parameters: [ + IntrospectSqlQueryParameterOutput { + documentation: Some( + "some integer", + ), + name: "myInt", + typ: "string", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "_2", + typ: "unknown", + nullable: false, + }, + ], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + let sql = r#" + -- @param {String} :myInt some integer + SELECT int FROM model WHERE int = :myInt and string = ?; + "#; + + api.introspect_sql("test_1", sql).send_sync().expect_result(expected) +} + +#[test_connector(tags(Postgres))] +fn parses_doc_no_alias(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "\nSELECT int FROM model WHERE int = $1 and string = $2;\n", + documentation: None, + parameters: [ + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "int4", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: Some( + "some string", + ), + name: "text", + typ: "string", + nullable: false, + }, + ], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + let sql = r#" + -- @param {String} $2 some string + SELECT int FROM model WHERE int = $1 and string = $2; + "#; + + api.introspect_sql("test_1", sql).send_sync().expect_result(expected) +} + +#[test_connector(tags(Postgres))] +fn parses_doc_enum_name(api: TestApi) { + api.schema_push(ENUM_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "\nSELECT * FROM model WHERE id = $1;\n", + documentation: None, + parameters: [ + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "int4", + typ: "MyFancyEnum", + nullable: false, + }, + ], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "enum", + typ: "MyFancyEnum", + nullable: false, + }, + ], + } + "#]]; + + let sql = r#" + -- @param {MyFancyEnum} $1 + SELECT * FROM model WHERE id = ?; + "#; + + api.introspect_sql("test_1", sql).send_sync().expect_result(expected) +} + +#[test_connector(tags(Postgres))] +fn invalid_position_fails(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let sql = r#" + -- @param {Int} $hello:myInt some integer + SELECT int FROM model WHERE int = ? and string = ?; + "#; + + let expected = expect![ + "SQL documentation parsing: invalid position. Expected a number found: hello at ' $hello:myInt some integer'." + ]; + + expected.assert_eq( + api.introspect_sql("test_1", sql) + .send_unwrap_err() + .message() + .unwrap_or_default(), + ); +} + +#[test_connector(tags(Postgres))] +fn unknown_type_fails(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let sql = r#" + -- @param {Hello} $hello:myInt some integer + SELECT int FROM model WHERE int = ? and string = ?; + "#; + + let expected = expect!["SQL documentation parsing: invalid type: 'Hello' (accepted types are: 'Int', 'BigInt', 'Float', 'Boolean', 'String', 'DateTime', 'Json', 'Bytes', 'Decimal') at '{Hello} $hello:myInt some integer'."]; + + expected.assert_eq( + api.introspect_sql("test_1", sql) + .send_unwrap_err() + .message() + .unwrap_or_default(), + ); +} + +#[test_connector(tags(Postgres))] +fn duplicate_param_position_fails(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let sql = r#" + -- @param {Int} $1:myInt + -- @param {String} $1:myString + SELECT int FROM model WHERE int = ? and string = ?; + "#; + + let expected = expect!["SQL documentation parsing: duplicate parameter (position or alias is already used) at '@param {String} $1:myString'."]; + + expected.assert_eq( + api.introspect_sql("test_1", sql) + .send_unwrap_err() + .message() + .unwrap_or_default(), + ); +} + +#[test_connector(tags(Postgres))] +fn duplicate_param_name_fails(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let sql = r#" + -- @param {Int} $1:myInt + -- @param {String} $2:myInt + SELECT int FROM model WHERE int = ? and string = ?; + "#; + + let expected = expect!["SQL documentation parsing: duplicate parameter (position or alias is already used) at '@param {String} $2:myInt'."]; + + expected.assert_eq( + api.introspect_sql("test_1", sql) + .send_unwrap_err() + .message() + .unwrap_or_default(), + ); +} + +#[test_connector(tags(Postgres))] +fn missing_param_position_or_alias_fails(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let sql = r#" + -- @param {Int} myInt + SELECT int FROM model WHERE int = ? and string = ?; + "#; + + let expected = + expect!["SQL documentation parsing: missing position or alias (eg: $1:alias) at '@param {Int} myInt'."]; + + expected.assert_eq( + api.introspect_sql("test_1", sql) + .send_unwrap_err() + .message() + .unwrap_or_default(), + ); +} + +#[test_connector(tags(Postgres))] +fn missing_everything_fails(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let sql = r#" + -- @param + SELECT int FROM model WHERE int = ? and string = ?; + "#; + + let expected = + expect!["SQL documentation parsing: invalid parameter: could not parse any information at '@param'."]; + + expected.assert_eq( + api.introspect_sql("test_1", sql) + .send_unwrap_err() + .message() + .unwrap_or_default(), + ); +} diff --git a/schema-engine/sql-migration-tests/tests/query_introspection/mod.rs b/schema-engine/sql-migration-tests/tests/query_introspection/mod.rs new file mode 100644 index 000000000000..24e89cabfb9c --- /dev/null +++ b/schema-engine/sql-migration-tests/tests/query_introspection/mod.rs @@ -0,0 +1,6 @@ +mod mysql; +mod pg; +mod sqlite; + +mod docs; +mod utils; diff --git a/schema-engine/sql-migration-tests/tests/query_introspection/mysql.rs b/schema-engine/sql-migration-tests/tests/query_introspection/mysql.rs new file mode 100644 index 000000000000..7b9df66d51ca --- /dev/null +++ b/schema-engine/sql-migration-tests/tests/query_introspection/mysql.rs @@ -0,0 +1,598 @@ +use super::utils::*; +use psl::{builtin_connectors::MySqlType, parser_database::ScalarType}; +use quaint::prelude::ColumnType; +use sql_migration_tests::test_api::*; + +#[test_connector(tags(Mysql8), exclude(Vitess))] +fn insert_mysql(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let query = + "INSERT INTO `model` (`int`, `string`, `bigint`, `float`, `bytes`, `bool`, `dt`) VALUES (?, ?, ?, ?, ?, ?, ?);"; + + let res = api.introspect_sql("test_1", query).send_sync(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "INSERT INTO `model` (`int`, `string`, `bigint`, `float`, `bytes`, `bool`, `dt`) VALUES (?, ?, ?, ?, ?, ?, ?);", + documentation: None, + parameters: [ + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "_0", + typ: "bigint", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "_1", + typ: "string", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "_2", + typ: "bigint", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "_3", + typ: "double", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "_4", + typ: "bytes", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "_5", + typ: "bigint", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "_6", + typ: "datetime", + nullable: false, + }, + ], + result_columns: [], + } + "#]]; + + res.expect_result(expected); +} + +#[test_connector(tags(Mysql, Mariadb))] +fn select_mysql(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let res = api + .introspect_sql( + "test_1", + "SELECT `int`, `string`, `bigint`, `float`, `bytes`, `bool`, `dt` FROM `model`;", + ) + .send_sync(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT `int`, `string`, `bigint`, `float`, `bytes`, `bool`, `dt` FROM `model`;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "string", + typ: "string", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "bigint", + typ: "bigint", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "float", + typ: "double", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "bytes", + typ: "bytes", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "bool", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "dt", + typ: "datetime", + nullable: false, + }, + ], + } + "#]]; + + res.expect_result(expected); +} + +#[test_connector(tags(Mysql, Mariadb))] +fn select_nullable_mysql(api: TestApi) { + api.schema_push(SIMPLE_NULLABLE_SCHEMA).send().assert_green(); + + let res = api + .introspect_sql( + "test_1", + "SELECT `int`, `string`, `bigint`, `float`, `bytes`, `bool`, `dt` FROM `model`;", + ) + .send_sync(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT `int`, `string`, `bigint`, `float`, `bytes`, `bool`, `dt` FROM `model`;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "string", + typ: "string", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "bigint", + typ: "bigint", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "float", + typ: "double", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "bytes", + typ: "bytes", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "bool", + typ: "int", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "dt", + typ: "datetime", + nullable: true, + }, + ], + } + "#]]; + + res.expect_result(expected); +} + +#[test_connector(tags(Mysql8), exclude(Vitess))] +fn empty_result(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT `int` FROM model WHERE 1 = 0 AND `int` = ?;", + documentation: None, + parameters: [ + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "_0", + typ: "bigint", + nullable: false, + }, + ], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT `int` FROM model WHERE 1 = 0 AND `int` = ?;") + .send_sync() + .expect_result(expected) +} + +#[test_connector(tags(Mysql8))] +fn unnamed_expr(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT 1 + 1;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "1 + 1", + typ: "bigint", + nullable: false, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT 1 + 1;") + .send_sync() + .expect_result(expected) +} + +#[test_connector(tags(Mysql8))] +fn named_expr(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT 1 + 1 as \"add\";", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "add", + typ: "bigint", + nullable: false, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT 1 + 1 as \"add\";") + .send_sync() + .expect_result(expected) +} + +#[test_connector(tags(Mysql8))] +fn mixed_named_expr(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT `int` + 1 as \"add\" FROM `model`;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "add", + typ: "bigint", + nullable: false, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT `int` + 1 as \"add\" FROM `model`;") + .send_sync() + .expect_result(expected) +} + +#[test_connector(tags(Mysql8))] +fn mixed_unnamed_expr(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT `int` + 1 FROM `model`;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "`int` + 1", + typ: "bigint", + nullable: false, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT `int` + 1 FROM `model`;") + .send_sync() + .expect_result(expected) +} + +#[test_connector(tags(Mysql8))] +fn mixed_expr_cast(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT CAST(`int` + 1 AS CHAR) as test FROM `model`;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "test", + typ: "string", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT CAST(`int` + 1 AS CHAR) as test FROM `model`;") + .send_sync() + .expect_result(expected) +} + +const DATASOURCE: &str = r#" + datasource db { + provider = "mysql" + url = "mysql://localhost:5432" +} +"#; + +macro_rules! test_native_types { + ( + $tag:ident; + + $( + $test_name:ident($nt:expr) => ($ct_input:ident, $ct_output:ident), + )* + ) => { + $( + paste::paste! { + #[test_connector(tags($tag), exclude(Vitess))] + fn [](api: TestApi) { + + let dm = render_native_type_datamodel::(&api, DATASOURCE, $nt.to_parts(), $nt); + + api.schema_push(&dm).send(); + + api.introspect_sql("test_1", "INSERT INTO test (field) VALUES (?);") + .send_sync() + .expect_param_type(0, ColumnType::$ct_input); + + api.introspect_sql("test_2", "SELECT field FROM test;") + .send_sync() + .expect_column_type(0, ColumnType::$ct_output); + } + } + )* + }; +} + +mod mysql8 { + use super::*; + + test_scalar_types!( + Mysql8; + + int(ScalarType::Int) => (Int64, Int32), + string(ScalarType::String) => (Text, Text), + bigint(ScalarType::BigInt) => (Int64, Int64), + float(ScalarType::Float) => (Double, Double), + bytes(ScalarType::Bytes) => (Bytes, Bytes), + bool(ScalarType::Boolean) => (Int64, Int32), + dt(ScalarType::DateTime) => (DateTime, DateTime), + decimal(ScalarType::Decimal) => (Numeric, Numeric), + ); + + test_native_types! { + Mysql8; + + int(MySqlType::Int) => (Int64, Int32), + unsigned_int(MySqlType::UnsignedInt) => (Int64, Int64), + small_int(MySqlType::SmallInt) => (Int64, Int32), + unsigned_small_int(MySqlType::UnsignedSmallInt) => (Int64, Int32), + tiny_int(MySqlType::TinyInt) => (Int64, Int32), + unsigned_tiny_int(MySqlType::UnsignedTinyInt) => (Int64, Int32), + medium_int(MySqlType::MediumInt) => (Int64, Int32), + unsigned_medium_int(MySqlType::UnsignedMediumInt) => (Int64, Int64), + big_int(MySqlType::BigInt) => (Int64, Int64), + decimal(MySqlType::Decimal(Some((4, 4)))) => (Numeric, Numeric), + unsigned_big_int(MySqlType::UnsignedBigInt) => (Int64, Int64), + float(MySqlType::Float) => (Double, Float), + double(MySqlType::Double) => (Double, Double), + bit(MySqlType::Bit(1)) => (Bytes, Boolean), + char(MySqlType::Char(255)) => (Text, Text), + var_char(MySqlType::VarChar(255)) => (Text, Text), + binary(MySqlType::Binary(255)) => (Bytes, Bytes), + var_binary(MySqlType::VarBinary(255)) => (Bytes, Bytes), + tiny_blob(MySqlType::TinyBlob) => (Bytes, Bytes), + blob(MySqlType::Blob) => (Bytes, Bytes), + medium_blob(MySqlType::MediumBlob) => (Bytes, Bytes), + long_blob(MySqlType::LongBlob) => (Bytes, Bytes), + tiny_text(MySqlType::TinyText) => (Text, Text), + text(MySqlType::Text) => (Text, Text), + medium_text(MySqlType::MediumText) => (Text, Text), + long_text(MySqlType::LongText) => (Text, Text), + date(MySqlType::Date) => (Date, Date), + time(MySqlType::Time(Some(3))) => (Time, Time), + date_time(MySqlType::DateTime(Some(3))) => (DateTime, DateTime), + timestamp(MySqlType::Timestamp(Some(3))) => (DateTime, DateTime), + year(MySqlType::Year) => (Int32, Int32), + json(MySqlType::Json) => (Json, Json), + } +} + +mod mysql57 { + use super::*; + + test_scalar_types!( + Mysql57; + + int(ScalarType::Int) => (Unknown, Int32), + string(ScalarType::String) => (Unknown, Text), + bigint(ScalarType::BigInt) => (Unknown, Int64), + float(ScalarType::Float) => (Unknown, Double), + bytes(ScalarType::Bytes) => (Unknown, Bytes), + bool(ScalarType::Boolean) => (Unknown, Int32), + dt(ScalarType::DateTime) => (Unknown, DateTime), + decimal(ScalarType::Decimal) => (Unknown, Numeric), + ); + + test_native_types! { + Mysql57; + + int(MySqlType::Int) => (Unknown, Int32), + unsigned_int(MySqlType::UnsignedInt) => (Unknown, Int64), + small_int(MySqlType::SmallInt) => (Unknown, Int32), + unsigned_small_int(MySqlType::UnsignedSmallInt) => (Unknown, Int32), + tiny_int(MySqlType::TinyInt) => (Unknown, Int32), + unsigned_tiny_int(MySqlType::UnsignedTinyInt) => (Unknown, Int32), + medium_int(MySqlType::MediumInt) => (Unknown, Int32), + unsigned_medium_int(MySqlType::UnsignedMediumInt) => (Unknown, Int64), + big_int(MySqlType::BigInt) => (Unknown, Int64), + decimal(MySqlType::Decimal(Some((4, 4)))) => (Unknown, Numeric), + unsigned_big_int(MySqlType::UnsignedBigInt) => (Unknown, Int64), + float(MySqlType::Float) => (Unknown, Float), + double(MySqlType::Double) => (Unknown, Double), + bit(MySqlType::Bit(1)) => (Unknown, Boolean), + char(MySqlType::Char(255)) => (Unknown, Text), + var_char(MySqlType::VarChar(255)) => (Unknown, Text), + binary(MySqlType::Binary(255)) => (Unknown, Bytes), + var_binary(MySqlType::VarBinary(255)) => (Unknown, Bytes), + tiny_blob(MySqlType::TinyBlob) => (Unknown, Bytes), + blob(MySqlType::Blob) => (Unknown, Bytes), + medium_blob(MySqlType::MediumBlob) => (Unknown, Bytes), + long_blob(MySqlType::LongBlob) => (Unknown, Bytes), + tiny_text(MySqlType::TinyText) => (Unknown, Text), + text(MySqlType::Text) => (Unknown, Text), + medium_text(MySqlType::MediumText) => (Unknown, Text), + long_text(MySqlType::LongText) => (Unknown, Text), + date(MySqlType::Date) => (Unknown, Date), + time(MySqlType::Time(Some(3))) => (Unknown, Time), + date_time(MySqlType::DateTime(Some(3))) => (Unknown, DateTime), + timestamp(MySqlType::Timestamp(Some(3))) => (Unknown, DateTime), + year(MySqlType::Year) => (Unknown, Int32), + json(MySqlType::Json) => (Unknown, Json), + } +} + +mod mysql56 { + use super::*; + + test_scalar_types!( + Mysql56; + + int(ScalarType::Int) => (Unknown, Int32), + string(ScalarType::String) => (Unknown, Text), + bigint(ScalarType::BigInt) => (Unknown, Int64), + float(ScalarType::Float) => (Unknown, Double), + bytes(ScalarType::Bytes) => (Unknown, Bytes), + bool(ScalarType::Boolean) => (Unknown, Int32), + dt(ScalarType::DateTime) => (Unknown, DateTime), + decimal(ScalarType::Decimal) => (Unknown, Numeric), + ); + + test_native_types! { + Mysql56; + + int(MySqlType::Int) => (Unknown, Int32), + unsigned_int(MySqlType::UnsignedInt) => (Unknown, Int64), + small_int(MySqlType::SmallInt) => (Unknown, Int32), + unsigned_small_int(MySqlType::UnsignedSmallInt) => (Unknown, Int32), + tiny_int(MySqlType::TinyInt) => (Unknown, Int32), + unsigned_tiny_int(MySqlType::UnsignedTinyInt) => (Unknown, Int32), + medium_int(MySqlType::MediumInt) => (Unknown, Int32), + unsigned_medium_int(MySqlType::UnsignedMediumInt) => (Unknown, Int64), + big_int(MySqlType::BigInt) => (Unknown, Int64), + decimal(MySqlType::Decimal(Some((4, 4)))) => (Unknown, Numeric), + unsigned_big_int(MySqlType::UnsignedBigInt) => (Unknown, Int64), + float(MySqlType::Float) => (Unknown, Float), + double(MySqlType::Double) => (Unknown, Double), + bit(MySqlType::Bit(1)) => (Unknown, Boolean), + char(MySqlType::Char(255)) => (Unknown, Text), + var_char(MySqlType::VarChar(255)) => (Unknown, Text), + binary(MySqlType::Binary(255)) => (Unknown, Bytes), + var_binary(MySqlType::VarBinary(255)) => (Unknown, Bytes), + tiny_blob(MySqlType::TinyBlob) => (Unknown, Bytes), + blob(MySqlType::Blob) => (Unknown, Bytes), + medium_blob(MySqlType::MediumBlob) => (Unknown, Bytes), + long_blob(MySqlType::LongBlob) => (Unknown, Bytes), + tiny_text(MySqlType::TinyText) => (Unknown, Text), + text(MySqlType::Text) => (Unknown, Text), + medium_text(MySqlType::MediumText) => (Unknown, Text), + long_text(MySqlType::LongText) => (Unknown, Text), + date(MySqlType::Date) => (Unknown, Date), + time(MySqlType::Time(Some(3))) => (Unknown, Time), + date_time(MySqlType::DateTime(Some(3))) => (Unknown, DateTime), + timestamp(MySqlType::Timestamp(Some(3))) => (Unknown, DateTime), + year(MySqlType::Year) => (Unknown, Int32), + } +} + +mod mariadb { + use super::*; + + test_scalar_types!( + Mariadb; + + int(ScalarType::Int) => (Unknown, Int32), + string(ScalarType::String) => (Unknown, Text), + bigint(ScalarType::BigInt) => (Unknown, Int64), + float(ScalarType::Float) => (Unknown, Double), + bytes(ScalarType::Bytes) => (Unknown, Bytes), + bool(ScalarType::Boolean) => (Unknown, Int32), + dt(ScalarType::DateTime) => (Unknown, DateTime), + decimal(ScalarType::Decimal) => (Unknown, Numeric), + ); + + test_native_types! { + Mariadb; + + int(MySqlType::Int) => (Unknown, Int32), + unsigned_int(MySqlType::UnsignedInt) => (Unknown, Int64), + small_int(MySqlType::SmallInt) => (Unknown, Int32), + unsigned_small_int(MySqlType::UnsignedSmallInt) => (Unknown, Int32), + tiny_int(MySqlType::TinyInt) => (Unknown, Int32), + unsigned_tiny_int(MySqlType::UnsignedTinyInt) => (Unknown, Int32), + medium_int(MySqlType::MediumInt) => (Unknown, Int32), + unsigned_medium_int(MySqlType::UnsignedMediumInt) => (Unknown, Int64), + big_int(MySqlType::BigInt) => (Unknown, Int64), + decimal(MySqlType::Decimal(Some((4, 4)))) => (Unknown, Numeric), + unsigned_big_int(MySqlType::UnsignedBigInt) => (Unknown, Int64), + float(MySqlType::Float) => (Unknown, Float), + double(MySqlType::Double) => (Unknown, Double), + bit(MySqlType::Bit(1)) => (Unknown, Boolean), + char(MySqlType::Char(255)) => (Unknown, Text), + var_char(MySqlType::VarChar(255)) => (Unknown, Text), + binary(MySqlType::Binary(255)) => (Unknown, Bytes), + var_binary(MySqlType::VarBinary(255)) => (Unknown, Bytes), + tiny_blob(MySqlType::TinyBlob) => (Unknown, Bytes), + blob(MySqlType::Blob) => (Unknown, Bytes), + medium_blob(MySqlType::MediumBlob) => (Unknown, Bytes), + long_blob(MySqlType::LongBlob) => (Unknown, Bytes), + tiny_text(MySqlType::TinyText) => (Unknown, Text), + text(MySqlType::Text) => (Unknown, Text), + medium_text(MySqlType::MediumText) => (Unknown, Text), + long_text(MySqlType::LongText) => (Unknown, Text), + date(MySqlType::Date) => (Unknown, Date), + time(MySqlType::Time(Some(3))) => (Unknown, Time), + date_time(MySqlType::DateTime(Some(3))) => (Unknown, DateTime), + timestamp(MySqlType::Timestamp(Some(3))) => (Unknown, DateTime), + year(MySqlType::Year) => (Unknown, Int32), + json(MySqlType::Json) => (Unknown, Text), + } +} diff --git a/schema-engine/sql-migration-tests/tests/query_introspection/pg.rs b/schema-engine/sql-migration-tests/tests/query_introspection/pg.rs new file mode 100644 index 000000000000..b4b59251b9b4 --- /dev/null +++ b/schema-engine/sql-migration-tests/tests/query_introspection/pg.rs @@ -0,0 +1,1073 @@ +use super::utils::*; + +use psl::builtin_connectors::{CockroachType, PostgresType}; +use quaint::prelude::ColumnType; +use sql_migration_tests::test_api::*; + +mod common { + use super::*; + + #[test_connector(tags(Postgres))] + fn insert(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let query = "INSERT INTO model (int, string, bigint, float, bytes, bool, dt) VALUES (?, ?, ?, ?, ?, ?, ?) RETURNING int, string, bigint, float, bytes, bool, dt;"; + let res = api.introspect_sql("test_1", query).send_sync(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "INSERT INTO model (int, string, bigint, float, bytes, bool, dt) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING int, string, bigint, float, bytes, bool, dt;", + documentation: None, + parameters: [ + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "int4", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "text", + typ: "string", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "int8", + typ: "bigint", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "float8", + typ: "double", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "bytea", + typ: "bytes", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "bool", + typ: "bool", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "timestamp", + typ: "datetime", + nullable: false, + }, + ], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "string", + typ: "string", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "bigint", + typ: "bigint", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "float", + typ: "double", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "bytes", + typ: "bytes", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "bool", + typ: "bool", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "dt", + typ: "datetime", + nullable: false, + }, + ], + } + "#]]; + + res.expect_result(expected); + } + + #[test_connector(tags(Postgres))] + fn insert_nullable(api: TestApi) { + api.schema_push(SIMPLE_NULLABLE_SCHEMA).send().assert_green(); + + let query = "INSERT INTO model (int, string, bigint, float, bytes, bool, dt) VALUES (?, ?, ?, ?, ?, ?, ?) RETURNING int, string, bigint, float, bytes, bool, dt;"; + let res = api.introspect_sql("test_1", query).send_sync(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "INSERT INTO model (int, string, bigint, float, bytes, bool, dt) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING int, string, bigint, float, bytes, bool, dt;", + documentation: None, + parameters: [ + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "int4", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "text", + typ: "string", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "int8", + typ: "bigint", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "float8", + typ: "double", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "bytea", + typ: "bytes", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "bool", + typ: "bool", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "timestamp", + typ: "datetime", + nullable: false, + }, + ], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "string", + typ: "string", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "bigint", + typ: "bigint", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "float", + typ: "double", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "bytes", + typ: "bytes", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "bool", + typ: "bool", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "dt", + typ: "datetime", + nullable: true, + }, + ], + } + "#]]; + + res.expect_result(expected); + } + + #[test_connector(tags(Postgres))] + fn empty_result(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT int FROM model WHERE 1 = 0 AND int = $1;", + documentation: None, + parameters: [ + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "int4", + typ: "int", + nullable: false, + }, + ], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT int FROM model WHERE 1 = 0 AND int = ?;") + .send_sync() + .expect_result(expected) + } + + #[test_connector(tags(Postgres, CockroachDb))] + fn custom_enum(api: TestApi) { + api.schema_push(ENUM_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "INSERT INTO model (id, enum) VALUES ($1, $2) RETURNING id, enum;", + documentation: None, + parameters: [ + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "int4", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "MyFancyEnum", + typ: "string", + nullable: false, + }, + ], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "enum", + typ: "MyFancyEnum", + nullable: false, + }, + ], + } + "#]]; + + api.introspect_sql( + "test_1", + "INSERT INTO model (id, enum) VALUES (?, ?) RETURNING id, enum;", + ) + .send_sync() + .expect_result(expected) + } +} + +mod postgres { + use super::*; + + const PG_DATASOURCE: &str = r#" + datasource db { + provider = "postgres" + url = "postgresql://localhost:5432" + } + "#; + + #[test_connector(tags(Postgres), exclude(CockroachDb))] + fn named_expr(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT 1 + 1 as \"add\";", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "add", + typ: "int", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT 1 + 1 as \"add\";") + .send_sync() + .expect_result(expected) + } + + #[test_connector(tags(Postgres), exclude(CockroachDb))] + fn mixed_named_expr(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT \"int\" + 1 as \"add\" FROM \"model\";", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "add", + typ: "int", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT \"int\" + 1 as \"add\" FROM \"model\";") + .send_sync() + .expect_result(expected) + } + + #[test_connector(tags(Postgres), exclude(CockroachDb))] + fn mixed_unnamed_expr(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + ConnectorErrorImpl { + user_facing_error: None, + message: Some( + "Invalid input provided to query: Invalid column name '?column?' for index 0. Your SQL query must explicitly alias that column name.", + ), + source: None, + context: SpanTrace [], + } + Invalid input provided to query: Invalid column name '?column?' for index 0. Your SQL query must explicitly alias that column name. + + "#]]; + + expected.assert_debug_eq( + &api.introspect_sql("test_1", "SELECT \"int\" + 1 FROM \"model\";") + .send_unwrap_err(), + ); + } + + #[test_connector(tags(Postgres), exclude(CockroachDb))] + fn mixed_expr_cast(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT CAST(\"int\" + 1 as int) FROM model;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int4", + typ: "int", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT CAST(\"int\" + 1 as int) FROM model;") + .send_sync() + .expect_result(expected) + } + + #[test_connector(tags(Postgres), exclude(CockroachDb))] + fn subquery(api: TestApi) { + api.schema_push(SIMPLE_NULLABLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT int, foo.int, foo.string FROM (SELECT * FROM model) AS foo", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "string", + typ: "string", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql( + "test_1", + "SELECT int, foo.int, foo.string FROM (SELECT * FROM model) AS foo", + ) + .send_sync() + .expect_result(expected) + } + + #[test_connector(tags(Postgres), exclude(CockroachDb))] + fn left_join(api: TestApi) { + api.schema_push(RELATION_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT parent.id as parentId, parent.nullable as parentNullable, child.id as childId, child.nullable as childNullable FROM parent LEFT JOIN child ON parent.id = child.parent_id", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "parentid", + typ: "int", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "parentnullable", + typ: "string", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "childid", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "childnullable", + typ: "string", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT parent.id as parentId, parent.nullable as parentNullable, child.id as childId, child.nullable as childNullable FROM parent LEFT JOIN child ON parent.id = child.parent_id") + .send_sync() + .expect_result(expected) + } + + // test nullability inference for various joins + #[test_connector(tags(Postgres), exclude(CockroachDb))] + fn outer_join(api: TestApi) { + api.schema_push( + "model products { + product_no Int @id + name String? + } + + model tweet { + id Int @id @default(autoincrement()) + text String + }", + ) + .send() + .assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "select tweet.id from (values (null)) vals(val) inner join tweet on false", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + // inner join, nullability should not be overridden + api.introspect_sql( + "test_1", + "select tweet.id from (values (null)) vals(val) inner join tweet on false", + ) + .send_sync() + .expect_result(expected); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_2", + source: "select tweet.id from (values (null)) vals(val) left join tweet on false", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: true, + }, + ], + } + "#]]; + + // tweet.id is marked NOT NULL but it's brought in from a left-join here + // which should make it nullable + api.introspect_sql( + "test_2", + "select tweet.id from (values (null)) vals(val) left join tweet on false", + ) + .send_sync() + .expect_result(expected); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_3", + source: "select tweet1.id, tweet2.id from tweet tweet1 left join tweet tweet2 on false", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: true, + }, + ], + } + "#]]; + + // make sure we don't mis-infer for the outer half of the join + api.introspect_sql( + "test_3", + "select tweet1.id, tweet2.id from tweet tweet1 left join tweet tweet2 on false", + ) + .send_sync() + .expect_result(expected); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_4", + source: "select tweet1.id, tweet2.id from tweet tweet1 right join tweet tweet2 on false", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + // right join, nullability should be inverted + api.introspect_sql( + "test_4", + "select tweet1.id, tweet2.id from tweet tweet1 right join tweet tweet2 on false", + ) + .send_sync() + .expect_result(expected); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_5", + source: "select tweet1.id, tweet2.id from tweet tweet1 full join tweet tweet2 on false", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: true, + }, + ], + } + "#]]; + + // right join, nullability should be inverted + api.introspect_sql( + "test_5", + "select tweet1.id, tweet2.id from tweet tweet1 full join tweet tweet2 on false", + ) + .send_sync() + .expect_result(expected); + } + + macro_rules! test_native_types_pg { + ( + $($test_name:ident($nt:expr) => $ct:ident,)* + ) => { + $( + paste::paste! { + #[test_connector(tags(Postgres), exclude(CockroachDb))] + fn $test_name(api: TestApi) { + let dm = render_native_type_datamodel::(&api, PG_DATASOURCE, $nt.to_parts(), $nt); + + if PostgresType::Citext == $nt { + api.raw_cmd("CREATE EXTENSION IF NOT EXISTS citext;"); + } + + api.schema_push(&dm).send(); + + let query = "INSERT INTO test (field) VALUES (?) RETURNING field;"; + + api.introspect_sql("test", query) + .send_sync() + .expect_param_type(0, ColumnType::$ct) + .expect_column_type(0, ColumnType::$ct); + } + } + )* + }; + } + + test_native_types_pg! { + small_int(PostgresType::SmallInt) => Int32, + integer(PostgresType::Integer) => Int32, + big_int(PostgresType::BigInt) => Int64, + nt_decimal(PostgresType::Decimal(Some((4, 4)))) => Numeric, + money(PostgresType::Money) => Numeric, + inet(PostgresType::Inet) => Text, + oid(PostgresType::Oid) => Int64, + citext(PostgresType::Citext) => Text, + real(PostgresType::Real) => Float, + double(PostgresType::DoublePrecision) => Double, + var_char(PostgresType::VarChar(Some(255))) => Text, + char(PostgresType::Char(Some(255))) => Text, + text(PostgresType::Text) => Text, + byte(PostgresType::ByteA) => Bytes, + timestamp(PostgresType::Timestamp(Some(1))) => DateTime, + timestamptz(PostgresType::Timestamptz(Some(1))) => DateTime, + date(PostgresType::Date) => Date, + time(PostgresType::Time(Some(1))) => Time, + timetz(PostgresType::Timetz(Some(1))) => Time, + boolean(PostgresType::Boolean) => Boolean, + bit(PostgresType::Bit(Some(1))) => Text, + var_bit(PostgresType::VarBit(Some(1))) => Text, + uuid(PostgresType::Uuid) => Uuid, + xml(PostgresType::Xml) => Xml, + json(PostgresType::Json) => Json, + json_b(PostgresType::JsonB) => Json, + } +} + +mod crdb { + use super::*; + + #[test_connector(tags(CockroachDb))] + fn named_expr(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT 1 + 1 as \"add\";", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "add", + typ: "bigint", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT 1 + 1 as \"add\";") + .send_sync() + .expect_result(expected) + } + + #[test_connector(tags(CockroachDb))] + fn mixed_named_expr(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT \"int\" + 1 as \"add\" FROM \"model\";", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "add", + typ: "bigint", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT \"int\" + 1 as \"add\" FROM \"model\";") + .send_sync() + .expect_result(expected) + } + + #[test_connector(tags(CockroachDb))] + fn mixed_unnamed_expr(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + ConnectorErrorImpl { + user_facing_error: None, + message: Some( + "Invalid input provided to query: Invalid column name '?column?' for index 0. Your SQL query must explicitly alias that column name.", + ), + source: None, + context: SpanTrace [], + } + Invalid input provided to query: Invalid column name '?column?' for index 0. Your SQL query must explicitly alias that column name. + + "#]]; + + expected.assert_debug_eq( + &api.introspect_sql("test_1", "SELECT \"int\" + 1 FROM \"model\";") + .send_unwrap_err(), + ); + } + + #[test_connector(tags(CockroachDb))] + fn mixed_expr_cast(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT CAST(\"int\" + 1 as int) FROM model;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int8", + typ: "bigint", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT CAST(\"int\" + 1 as int) FROM model;") + .send_sync() + .expect_result(expected) + } + + #[test_connector(tags(CockroachDb))] + fn subquery(api: TestApi) { + api.schema_push(SIMPLE_NULLABLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT int, foo.int, foo.string FROM (SELECT * FROM model) AS foo", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "string", + typ: "string", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql( + "test_1", + "SELECT int, foo.int, foo.string FROM (SELECT * FROM model) AS foo", + ) + .send_sync() + .expect_result(expected) + } + + #[test_connector(tags(CockroachDb))] + fn left_join(api: TestApi) { + api.schema_push(RELATION_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT parent.id as parentId, parent.nullable as parentNullable, child.id as childId, child.nullable as childNullable FROM parent LEFT JOIN child ON parent.id = child.parent_id", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "parentid", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "parentnullable", + typ: "string", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "childid", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "childnullable", + typ: "string", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT parent.id as parentId, parent.nullable as parentNullable, child.id as childId, child.nullable as childNullable FROM parent LEFT JOIN child ON parent.id = child.parent_id") + .send_sync() + .expect_result(expected) + } + + // test nullability inference for various joins + #[test_connector(tags(CockroachDb))] + fn outer_join(api: TestApi) { + api.schema_push( + "model products { + product_no Int @id + name String? + } + + model tweet { + id Int @id @default(autoincrement()) + text String + }", + ) + .send() + .assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "select tweet.id from (values (null)) vals(val) inner join tweet on false", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + // inner join, nullability should not be overridden + api.introspect_sql( + "test_1", + "select tweet.id from (values (null)) vals(val) inner join tweet on false", + ) + .send_sync() + .expect_result(expected); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_2", + source: "select tweet.id from (values (null)) vals(val) left join tweet on false", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + // tweet.id is marked NOT NULL but it's brought in from a left-join here + // which should make it nullable + api.introspect_sql( + "test_2", + "select tweet.id from (values (null)) vals(val) left join tweet on false", + ) + .send_sync() + .expect_result(expected); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_3", + source: "select tweet1.id, tweet2.id from tweet tweet1 left join tweet tweet2 on false", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + // make sure we don't mis-infer for the outer half of the join + api.introspect_sql( + "test_3", + "select tweet1.id, tweet2.id from tweet tweet1 left join tweet tweet2 on false", + ) + .send_sync() + .expect_result(expected); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_4", + source: "select tweet1.id, tweet2.id from tweet tweet1 right join tweet tweet2 on false", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + // right join, nullability should be inverted + api.introspect_sql( + "test_4", + "select tweet1.id, tweet2.id from tweet tweet1 right join tweet tweet2 on false", + ) + .send_sync() + .expect_result(expected); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_5", + source: "select tweet1.id, tweet2.id from tweet tweet1 full join tweet tweet2 on false", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "id", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + // right join, nullability should be inverted + api.introspect_sql( + "test_5", + "select tweet1.id, tweet2.id from tweet tweet1 full join tweet tweet2 on false", + ) + .send_sync() + .expect_result(expected); + } + + macro_rules! test_native_types_crdb { + ( + $($test_name:ident($nt:expr) => $ct:ident,)* + ) => { + $( + paste::paste! { + #[test_connector(tags(CockroachDb))] + fn $test_name(api: TestApi) { + let dm = render_native_type_datamodel::(&api, CRDB_DATASOURCE, $nt.to_parts(), $nt); + + api.schema_push(&dm).send(); + + let query = "INSERT INTO test (id, field) VALUES (?, ?) RETURNING field;"; + + api.introspect_sql("test", query) + .send_sync() + .expect_param_type(1, ColumnType::$ct) + .expect_column_type(0, ColumnType::$ct); + } + } + )* + }; +} + + const CRDB_DATASOURCE: &str = r#" + datasource db { + provider = "cockroachdb" + url = "postgresql://localhost:5432" +} +"#; + + test_native_types_crdb! { + bit(CockroachType::Bit(Some(1))) => Text, + boolean(CockroachType::Bool) => Boolean, + nt_bytes(CockroachType::Bytes) => Bytes, + char(CockroachType::Char(Some(255))) => Text, + date(CockroachType::Date) => Date, + nt_decimal(CockroachType::Decimal(Some((4, 4)))) => Numeric, + float4(CockroachType::Float4) => Float, + float8(CockroachType::Float8) => Double, + inet(CockroachType::Inet) => Text, + int2(CockroachType::Int2) => Int32, + int4(CockroachType::Int4) => Int32, + int8(CockroachType::Int8) => Int64, + json_b(CockroachType::JsonB) => Json, + oid(CockroachType::Oid) => Int64, + catalog_single_char(CockroachType::CatalogSingleChar) => Char, + nt_string(CockroachType::String(Some(255))) => Text, + time(CockroachType::Time(Some(1))) => Time, + timestamp(CockroachType::Timestamp(Some(1))) => DateTime, + timestamptz(CockroachType::Timestamptz(Some(1))) => DateTime, + timetz(CockroachType::Timetz(Some(1))) => Time, + uuid(CockroachType::Uuid) => Uuid, + var_bit(CockroachType::VarBit(Some(1))) => Text, + } +} diff --git a/schema-engine/sql-migration-tests/tests/query_introspection/sqlite.rs b/schema-engine/sql-migration-tests/tests/query_introspection/sqlite.rs new file mode 100644 index 000000000000..d018202da392 --- /dev/null +++ b/schema-engine/sql-migration-tests/tests/query_introspection/sqlite.rs @@ -0,0 +1,649 @@ +use super::utils::*; + +use sql_migration_tests::test_api::*; + +#[test_connector(tags(Sqlite))] +fn insert_sqlite(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let query = + "INSERT INTO `model` (`int`, `string`, `bigint`, `float`, `bytes`, `bool`, `dt`) VALUES (?, ?, ?, ?, ?, ?, ?);"; + + let res = api.introspect_sql("test_1", query).send_sync(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "INSERT INTO `model` (`int`, `string`, `bigint`, `float`, `bytes`, `bool`, `dt`) VALUES (?, ?, ?, ?, ?, ?, ?);", + documentation: None, + parameters: [ + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "_1", + typ: "unknown", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "_2", + typ: "unknown", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "_3", + typ: "unknown", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "_4", + typ: "unknown", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "_5", + typ: "unknown", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "_6", + typ: "unknown", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "_7", + typ: "unknown", + nullable: false, + }, + ], + result_columns: [], + } + "#]]; + + res.expect_result(expected); +} + +#[test_connector(tags(Sqlite))] +fn select_sqlite(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let res = api + .introspect_sql( + "test_1", + "SELECT `int`, `string`, `bigint`, `float`, `bytes`, `bool`, `dt` FROM `model`;", + ) + .send_sync(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT `int`, `string`, `bigint`, `float`, `bytes`, `bool`, `dt` FROM `model`;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "string", + typ: "string", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "bigint", + typ: "bigint", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "float", + typ: "double", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "bytes", + typ: "bytes", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "bool", + typ: "bool", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "dt", + typ: "datetime", + nullable: false, + }, + ], + } + "#]]; + + res.expect_result(expected); +} + +#[test_connector(tags(Sqlite))] +fn select_nullable_sqlite(api: TestApi) { + api.schema_push(SIMPLE_NULLABLE_SCHEMA).send().assert_green(); + + let res = api + .introspect_sql( + "test_1", + "SELECT `int`, `string`, `bigint`, `float`, `bytes`, `bool`, `dt` FROM `model`;", + ) + .send_sync(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT `int`, `string`, `bigint`, `float`, `bytes`, `bool`, `dt` FROM `model`;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "string", + typ: "string", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "bigint", + typ: "bigint", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "float", + typ: "double", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "bytes", + typ: "bytes", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "bool", + typ: "bool", + nullable: true, + }, + IntrospectSqlQueryColumnOutput { + name: "dt", + typ: "datetime", + nullable: true, + }, + ], + } + "#]]; + + res.expect_result(expected); +} + +#[test_connector(tags(Sqlite))] +fn empty_result(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT int FROM model WHERE 1 = 0 AND int = ?;", + documentation: None, + parameters: [ + IntrospectSqlQueryParameterOutput { + documentation: None, + name: "_1", + typ: "unknown", + nullable: false, + }, + ], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT int FROM model WHERE 1 = 0 AND int = ?;") + .send_sync() + .expect_result(expected) +} + +#[test_connector(tags(Sqlite))] +fn unnamed_expr_int(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT 1 + 1;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "1 + 1", + typ: "bigint", + nullable: false, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT 1 + 1;") + .send_sync() + .expect_result(expected); + + api.query_raw("SELECT 1 + 1;", &[]) + .assert_single_row(|row| row.assert_bigint_value("1 + 1", 2)); +} + +#[test_connector(tags(Sqlite))] +fn named_expr_int(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT 1 + 1 as \"add\";", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "add", + typ: "bigint", + nullable: false, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT 1 + 1 as \"add\";") + .send_sync() + .expect_result(expected); + + api.query_raw("SELECT 1 + 1 as \"add\";", &[]) + .assert_single_row(|row| row.assert_bigint_value("add", 2)); +} + +#[test_connector(tags(Sqlite))] +fn named_expr_int_optional(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT 1 + 1 as `add?`;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "add?", + typ: "bigint", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT 1 + 1 as `add?`;") + .send_sync() + .expect_result(expected); + + api.query_raw("SELECT 1 + 1 as \"add?\";", &[]) + .assert_single_row(|row| row.assert_bigint_value("add?", 2)); +} + +#[test_connector(tags(Sqlite))] +fn mixed_named_expr_int(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT `int` + 1 as \"add\" FROM `model`;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "add", + typ: "bigint", + nullable: false, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT `int` + 1 as \"add\" FROM `model`;") + .send_sync() + .expect_result(expected) +} + +#[test_connector(tags(Sqlite))] +fn mixed_unnamed_expr_int(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT `int` + 1 FROM `model`;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "`int` + 1", + typ: "bigint", + nullable: false, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT `int` + 1 FROM `model`;") + .send_sync() + .expect_result(expected) +} + +#[test_connector(tags(Sqlite))] +fn mixed_expr_cast_int(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT CAST(`int` + 1 as int) FROM `model`;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "CAST(`int` + 1 as int)", + typ: "bigint", + nullable: false, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT CAST(`int` + 1 as int) FROM `model`;") + .send_sync() + .expect_result(expected) +} + +#[test_connector(tags(Sqlite))] +fn unnamed_expr_string(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT 'hello world';", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "'hello world'", + typ: "string", + nullable: false, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT 'hello world';") + .send_sync() + .expect_result(expected); + + api.query_raw("SELECT 'hello world' as `str`;", &[]) + .assert_single_row(|row| row.assert_text_value("str", "hello world")); +} + +#[test_connector(tags(Sqlite))] +fn unnamed_expr_bool(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT 1=1, 1=0;", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "1=1", + typ: "bigint", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "1=0", + typ: "bigint", + nullable: false, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT 1=1, 1=0;") + .send_sync() + .expect_result(expected); + + api.query_raw("SELECT 1=1 as `true`, 1=0 AS `false`;", &[]) + .assert_single_row(|row| row.assert_int_value("true", 1).assert_int_value("false", 0)); +} + +#[test_connector(tags(Sqlite))] +fn unnamed_expr_real(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT 1.2, 2.34567891023, round(2.345);", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "1.2", + typ: "double", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "2.34567891023", + typ: "double", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "round(2.345)", + typ: "double", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT 1.2, 2.34567891023, round(2.345);") + .send_sync() + .expect_result(expected); + + api.query_raw("SELECT 1.2 AS a, 2.34567891023 AS b, round(2.345) AS c;", &[]) + .assert_single_row(|row| { + row.assert_float_value("a", 1.2) + .assert_float_value("b", 2.34567891023) + .assert_float_value("c", 2.0) + }); +} + +#[test_connector(tags(Sqlite))] +fn unnamed_expr_blob(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT unhex('537475666673');", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "unhex('537475666673')", + typ: "bytes", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT unhex('537475666673');") + .send_sync() + .expect_result(expected); + + api.query_raw("SELECT unhex('537475666673') as blob;", &[]) + .assert_single_row(|row| row.assert_bytes_value("blob", &[83, 116, 117, 102, 102, 115])); +} + +#[test_connector(tags(Sqlite))] +fn unnamed_expr_date(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT date('2025-05-29 14:16:00');", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "date('2025-05-29 14:16:00')", + typ: "string", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT date('2025-05-29 14:16:00');") + .send_sync() + .expect_result(expected); + + api.query_raw("SELECT date('2025-05-29 14:16:00') as dt;", &[]) + .assert_single_row(|row| row.assert_text_value("dt", "2025-05-29")); +} + +#[test_connector(tags(Sqlite))] +fn unnamed_expr_time(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT time('2025-05-29 14:16:00');", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "time('2025-05-29 14:16:00')", + typ: "string", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT time('2025-05-29 14:16:00');") + .send_sync() + .expect_result(expected); + + api.query_raw("SELECT time('2025-05-29 14:16:00') as dt;", &[]) + .assert_single_row(|row| row.assert_text_value("dt", "14:16:00")); +} + +#[test_connector(tags(Sqlite))] +fn unnamed_expr_datetime(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT datetime('2025-05-29 14:16:00');", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "datetime('2025-05-29 14:16:00')", + typ: "string", + nullable: true, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT datetime('2025-05-29 14:16:00');") + .send_sync() + .expect_result(expected); + + api.query_raw("SELECT datetime('2025-05-29 14:16:00') as dt;", &[]) + .assert_single_row(|row| row.assert_text_value("dt", "2025-05-29 14:16:00")); +} + +#[test_connector(tags(Sqlite))] +fn subquery(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT `int` FROM (SELECT * FROM `model`)", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT `int` FROM (SELECT * FROM `model`)") + .send_sync() + .expect_result(expected) +} + +#[test_connector(tags(Sqlite))] +fn left_join(api: TestApi) { + api.schema_push(RELATION_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "SELECT `parent`.`id` as `parentId`, `child`.`id` as `childId` FROM `parent` LEFT JOIN `child` ON `parent`.`id` = `child`.`parent_id`", + documentation: None, + parameters: [], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "parentId", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryColumnOutput { + name: "childId", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + api.introspect_sql("test_1", "SELECT `parent`.`id` as `parentId`, `child`.`id` as `childId` FROM `parent` LEFT JOIN `child` ON `parent`.`id` = `child`.`parent_id`") + .send_sync() + .expect_result(expected) +} diff --git a/schema-engine/sql-migration-tests/tests/query_introspection/utils.rs b/schema-engine/sql-migration-tests/tests/query_introspection/utils.rs new file mode 100644 index 000000000000..f1666ee112a3 --- /dev/null +++ b/schema-engine/sql-migration-tests/tests/query_introspection/utils.rs @@ -0,0 +1,130 @@ +use std::any::Any; + +use psl::{datamodel_connector::NativeTypeInstance, parser_database::ScalarType}; +use sql_migration_tests::test_api::TestApi; + +pub(crate) const SIMPLE_SCHEMA: &str = r#" +model model { + int Int @id + string String + bigint BigInt + float Float + bytes Bytes + bool Boolean + dt DateTime +}"#; + +pub(crate) const SIMPLE_NULLABLE_SCHEMA: &str = r#" +model model { + int Int @id + string String? + bigint BigInt? + float Float? + bytes Bytes? + bool Boolean? + dt DateTime? +}"#; + +pub(crate) const ENUM_SCHEMA: &str = r#" +model model { + id Int @id + enum MyFancyEnum +} + +enum MyFancyEnum { + A + B + C +} +"#; + +pub(crate) const RELATION_SCHEMA: &str = r#" +model parent { + id Int @id + nullable String? + + children child[] +} + +model child { + id Int @id + nullable String? + + parent_id Int? + parent parent? @relation(fields: [parent_id], references: [id]) +} +"#; + +pub(crate) fn render_scalar_type_datamodel(datasource: &str, prisma_type: ScalarType) -> String { + let prisma_type = prisma_type.as_str(); + + format!( + r#" + {datasource} + model test {{ + id Int @id @default(autoincrement()) + field {prisma_type} + }} + "# + ) +} + +pub(crate) fn render_native_type_datamodel( + api: &TestApi, + datasource: &str, + nt_parts: (&str, Vec), + nt: T, +) -> String { + let (nt_name, rest) = nt_parts; + let args = if rest.is_empty() { + "".to_string() + } else { + format!("({})", rest.join(",")) + }; + + let instance = NativeTypeInstance::new::(nt); + let prisma_type = api.connector.scalar_type_for_native_type(&instance).as_str(); + + format!( + r#" + {datasource} + + model test {{ + id Int @id + field {prisma_type} @db.{nt_name}{args} + }} + "# + ) +} + +macro_rules! test_scalar_types { + ( + $tag:ident; + + $( + $test_name:ident($st:expr) => ($ct_input:ident, $ct_output:ident), + )* + ) => { + $( + paste::paste! { + #[test_connector(tags($tag), exclude(Vitess))] + fn [<$test_name _ $tag:lower>](api: TestApi) { + + let dm = render_scalar_type_datamodel(DATASOURCE, $st); + + api.schema_push(&dm).send(); + + api.introspect_sql("test_1", "INSERT INTO test (field) VALUES (?);") + .send_sync() + .expect_param_type(0, ColumnType::$ct_input); + + api.introspect_sql("test_2", "SELECT field FROM test;") + .send_sync() + .expect_column_type(0, ColumnType::$ct_output); + } + } + )* + }; +} + +pub(crate) use test_scalar_types; diff --git a/schema-engine/sql-schema-describer/Cargo.toml b/schema-engine/sql-schema-describer/Cargo.toml index 514eac9daecf..17b8eae63684 100644 --- a/schema-engine/sql-schema-describer/Cargo.toml +++ b/schema-engine/sql-schema-describer/Cargo.toml @@ -14,11 +14,11 @@ enumflags2.workspace = true indexmap.workspace = true indoc.workspace = true once_cell = "1.3" -regex = "1.2" +regex.workspace = true serde.workspace = true tracing.workspace = true tracing-error = "0.2" -tracing-futures = "0.2" +tracing-futures.workspace = true quaint = { workspace = true, features = [ "all-native", "pooled", diff --git a/schema-engine/sql-schema-describer/src/sqlite.rs b/schema-engine/sql-schema-describer/src/sqlite.rs index 51f75a90343a..bd82c52fce0e 100644 --- a/schema-engine/sql-schema-describer/src/sqlite.rs +++ b/schema-engine/sql-schema-describer/src/sqlite.rs @@ -9,7 +9,7 @@ use either::Either; use indexmap::IndexMap; use quaint::{ ast::{Value, ValueType}, - connector::{GetRow, ToColumnNames}, + connector::{ColumnType as QuaintColumnType, GetRow, ToColumnNames}, prelude::ResultRow, }; use std::{any::type_name, borrow::Cow, collections::BTreeMap, convert::TryInto, fmt::Debug, path::Path}; @@ -33,6 +33,7 @@ impl Connection for std::sync::Mutex { ) -> quaint::Result { let conn = self.lock().unwrap(); let mut stmt = conn.prepare_cached(sql)?; + let column_types = stmt.columns().iter().map(QuaintColumnType::from).collect::>(); let mut rows = stmt.query(quaint::connector::rusqlite::params_from_iter(params.iter()))?; let column_names = rows.to_column_names(); let mut converted_rows = Vec::new(); @@ -40,7 +41,11 @@ impl Connection for std::sync::Mutex { converted_rows.push(row.get_result_row().unwrap()); } - Ok(quaint::prelude::ResultSet::new(column_names, converted_rows)) + Ok(quaint::prelude::ResultSet::new( + column_names, + column_types, + converted_rows, + )) } }