From 21c89f324752290f7b84894b832436561fd2bef1 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Thu, 20 Jun 2024 20:33:58 +0800 Subject: [PATCH] perf: optimize RecordBatch to HttpOutput conversion (#4178) * add benchmark Signed-off-by: Ruihang Xia * save 70ms Signed-off-by: Ruihang Xia * add profiler Signed-off-by: Ruihang Xia * save 50ms Signed-off-by: Ruihang Xia * save 160ms Signed-off-by: Ruihang Xia * format toml file Signed-off-by: Ruihang Xia * fix license header Signed-off-by: Ruihang Xia * fix windows build Signed-off-by: Ruihang Xia --------- Signed-off-by: Ruihang Xia --- Cargo.lock | 35 ++++++++++-- src/common/base/src/bytes.rs | 25 +++++---- src/datatypes/src/value.rs | 49 +++++++++++++++- src/servers/Cargo.toml | 9 ++- src/servers/benches/to_http_output.rs | 81 +++++++++++++++++++++++++++ src/servers/src/http.rs | 23 ++++---- 6 files changed, 192 insertions(+), 30 deletions(-) create mode 100644 src/servers/benches/to_http_output.rs diff --git a/Cargo.lock b/Cargo.lock index e2cdd67ebe03..9bda000949a8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1877,7 +1877,7 @@ dependencies = [ "common-runtime", "common-telemetry", "common-time", - "criterion", + "criterion 0.4.0", "dashmap", "datatypes", "flatbuffers", @@ -2385,6 +2385,32 @@ dependencies = [ "walkdir", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap 4.5.7", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + [[package]] name = "criterion-plot" version = "0.5.0" @@ -5973,7 +5999,7 @@ dependencies = [ "common-time", "common-wal", "crc32fast", - "criterion", + "criterion 0.4.0", "crossbeam-utils", "datafusion 38.0.0", "datafusion-common 38.0.0", @@ -7633,6 +7659,7 @@ checksum = "ef5c97c51bd34c7e742402e216abdeb44d415fbe6ae41d56b114723e953711cb" dependencies = [ "backtrace", "cfg-if", + "criterion 0.5.1", "findshlibs", "inferno", "libc", @@ -9556,7 +9583,7 @@ dependencies = [ "common-test-util", "common-time", "console", - "criterion", + "criterion 0.4.0", "crossbeam-utils", "datafusion 38.0.0", "datafusion-common 38.0.0", @@ -9844,7 +9871,7 @@ dependencies = [ "common-test-util", "common-time", "common-version", - "criterion", + "criterion 0.5.1", "dashmap", "datafusion 38.0.0", "datafusion-common 38.0.0", diff --git a/src/common/base/src/bytes.rs b/src/common/base/src/bytes.rs index 7f757917b80a..aec2dfd9edbf 100644 --- a/src/common/base/src/bytes.rs +++ b/src/common/base/src/bytes.rs @@ -81,16 +81,17 @@ impl PartialEq for [u8] { /// Now this buffer is restricted to only hold valid UTF-8 string (only allow constructing `StringBytes` /// from String or str). We may support other encoding in the future. #[derive(Debug, Default, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct StringBytes(bytes::Bytes); +pub struct StringBytes(String); impl StringBytes { /// View this string as UTF-8 string slice. - /// - /// # Safety - /// We only allow constructing `StringBytes` from String/str, so the inner - /// buffer must holds valid UTF-8. pub fn as_utf8(&self) -> &str { - unsafe { std::str::from_utf8_unchecked(&self.0) } + &self.0 + } + + /// Convert this string into owned UTF-8 string. + pub fn into_string(self) -> String { + self.0 } pub fn len(&self) -> usize { @@ -104,37 +105,37 @@ impl StringBytes { impl From for StringBytes { fn from(string: String) -> StringBytes { - StringBytes(bytes::Bytes::from(string)) + StringBytes(string) } } impl From<&str> for StringBytes { fn from(string: &str) -> StringBytes { - StringBytes(bytes::Bytes::copy_from_slice(string.as_bytes())) + StringBytes(string.to_string()) } } impl PartialEq for StringBytes { fn eq(&self, other: &String) -> bool { - self.0 == other.as_bytes() + &self.0 == other } } impl PartialEq for String { fn eq(&self, other: &StringBytes) -> bool { - self.as_bytes() == other.0 + self == &other.0 } } impl PartialEq for StringBytes { fn eq(&self, other: &str) -> bool { - self.0 == other.as_bytes() + self.0.as_str() == other } } impl PartialEq for str { fn eq(&self, other: &StringBytes) -> bool { - self.as_bytes() == other.0 + self == other.0 } } diff --git a/src/datatypes/src/value.rs b/src/datatypes/src/value.rs index 225420cbc27c..4e9db7e2e52d 100644 --- a/src/datatypes/src/value.rs +++ b/src/datatypes/src/value.rs @@ -29,7 +29,7 @@ use common_time::timestamp::{TimeUnit, Timestamp}; use common_time::{Duration, Interval, Timezone}; use datafusion_common::ScalarValue; pub use ordered_float::OrderedFloat; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Serialize, Serializer}; use snafu::{ensure, ResultExt}; use crate::error::{self, ConvertArrowArrayToScalarsSnafu, Error, Result, TryFromValueSnafu}; @@ -695,7 +695,7 @@ impl TryFrom for serde_json::Value { Value::Int64(v) => serde_json::Value::from(v), Value::Float32(v) => serde_json::Value::from(v.0), Value::Float64(v) => serde_json::Value::from(v.0), - Value::String(bytes) => serde_json::Value::String(bytes.as_utf8().to_string()), + Value::String(bytes) => serde_json::Value::String(bytes.into_string()), Value::Binary(bytes) => serde_json::to_value(bytes)?, Value::Date(v) => serde_json::Value::Number(v.val().into()), Value::DateTime(v) => serde_json::Value::Number(v.val().into()), @@ -1165,6 +1165,39 @@ impl<'a> From>> for ValueRef<'a> { } } +impl<'a> TryFrom> for serde_json::Value { + type Error = serde_json::Error; + + fn try_from(value: ValueRef<'a>) -> serde_json::Result { + let json_value = match value { + ValueRef::Null => serde_json::Value::Null, + ValueRef::Boolean(v) => serde_json::Value::Bool(v), + ValueRef::UInt8(v) => serde_json::Value::from(v), + ValueRef::UInt16(v) => serde_json::Value::from(v), + ValueRef::UInt32(v) => serde_json::Value::from(v), + ValueRef::UInt64(v) => serde_json::Value::from(v), + ValueRef::Int8(v) => serde_json::Value::from(v), + ValueRef::Int16(v) => serde_json::Value::from(v), + ValueRef::Int32(v) => serde_json::Value::from(v), + ValueRef::Int64(v) => serde_json::Value::from(v), + ValueRef::Float32(v) => serde_json::Value::from(v.0), + ValueRef::Float64(v) => serde_json::Value::from(v.0), + ValueRef::String(bytes) => serde_json::Value::String(bytes.to_string()), + ValueRef::Binary(bytes) => serde_json::to_value(bytes)?, + ValueRef::Date(v) => serde_json::Value::Number(v.val().into()), + ValueRef::DateTime(v) => serde_json::Value::Number(v.val().into()), + ValueRef::List(v) => serde_json::to_value(v)?, + ValueRef::Timestamp(v) => serde_json::to_value(v.value())?, + ValueRef::Time(v) => serde_json::to_value(v.value())?, + ValueRef::Interval(v) => serde_json::to_value(v.to_i128())?, + ValueRef::Duration(v) => serde_json::to_value(v.value())?, + ValueRef::Decimal128(v) => serde_json::to_value(v.to_string())?, + }; + + Ok(json_value) + } +} + /// Reference to a [ListValue]. /// /// Now comparison still requires some allocation (call of `to_value()`) and @@ -1195,6 +1228,18 @@ impl<'a> ListValueRef<'a> { } } +impl<'a> Serialize for ListValueRef<'a> { + fn serialize(&self, serializer: S) -> std::result::Result { + match self { + ListValueRef::Indexed { vector, idx } => match vector.get(*idx) { + Value::List(v) => v.serialize(serializer), + _ => unreachable!(), + }, + ListValueRef::Ref { val } => val.serialize(serializer), + } + } +} + impl<'a> PartialEq for ListValueRef<'a> { fn eq(&self, other: &Self) -> bool { self.to_value().eq(&other.to_value()) diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index 755d59bfacd6..b217089c97b6 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -116,7 +116,7 @@ catalog = { workspace = true, features = ["testing"] } client = { workspace = true, features = ["testing"] } common-base.workspace = true common-test-util.workspace = true -criterion = "0.4" +criterion = "0.5" mysql_async = { version = "0.33", default-features = false, features = [ "default-rustls", ] } @@ -131,9 +131,16 @@ tokio-postgres = "0.7" tokio-postgres-rustls = "0.11" tokio-test = "0.4" +[target.'cfg(not(windows))'.dev-dependencies] +pprof = { version = "0.13", features = ["criterion", "flamegraph"] } + [build-dependencies] common-version.workspace = true [[bench]] name = "bench_prom" harness = false + +[[bench]] +name = "to_http_output" +harness = false diff --git a/src/servers/benches/to_http_output.rs b/src/servers/benches/to_http_output.rs new file mode 100644 index 000000000000..9e7881bb0fa7 --- /dev/null +++ b/src/servers/benches/to_http_output.rs @@ -0,0 +1,81 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; +use std::time::Instant; + +use arrow::array::StringArray; +use arrow_schema::{DataType, Field, Schema}; +use common_recordbatch::RecordBatch; +use criterion::{criterion_group, criterion_main, Criterion}; +use datatypes::schema::SchemaRef; +use datatypes::vectors::StringVector; +use servers::http::HttpRecordsOutput; + +fn mock_schema() -> SchemaRef { + let mut fields = Vec::with_capacity(10); + for i in 0..10 { + fields.push(Field::new(format!("field{}", i), DataType::Utf8, true)); + } + let arrow_schema = Arc::new(Schema::new(fields)); + Arc::new(arrow_schema.try_into().unwrap()) +} + +fn mock_input_record_batch(batch_size: usize, num_batches: usize) -> Vec { + let mut result = Vec::with_capacity(num_batches); + for _ in 0..num_batches { + let mut vectors = Vec::with_capacity(10); + for _ in 0..10 { + let vector: StringVector = StringArray::from( + (0..batch_size) + .map(|_| String::from("Copyright 2024 Greptime Team")) + .collect::>(), + ) + .into(); + vectors.push(Arc::new(vector) as _); + } + + let schema = mock_schema(); + let record_batch = RecordBatch::new(schema, vectors).unwrap(); + result.push(record_batch); + } + + result +} + +fn bench_convert_record_batch_to_http_output(c: &mut Criterion) { + let record_batches = mock_input_record_batch(4096, 100); + c.bench_function("convert_record_batch_to_http_output", |b| { + b.iter_custom(|iters| { + let mut elapsed_sum = std::time::Duration::new(0, 0); + for _ in 0..iters { + let record_batches = record_batches.clone(); + let start = Instant::now(); + let _result = HttpRecordsOutput::try_new(mock_schema(), record_batches); + elapsed_sum += start.elapsed(); + } + elapsed_sum + }); + }); +} + +#[cfg(not(windows))] +criterion_group! { + name = benches; + config = Criterion::default().with_profiler(pprof::criterion::PProfProfiler::new(101, pprof::criterion::Output::Flamegraph(None))); + targets = bench_convert_record_batch_to_http_output +} +#[cfg(windows)] +criterion_group!(benches, bench_convert_record_batch_to_http_output); +criterion_main!(benches); diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 3f7f71653f73..e8465a6649de 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -221,7 +221,7 @@ impl HttpRecordsOutput { } impl HttpRecordsOutput { - pub(crate) fn try_new( + pub fn try_new( schema: SchemaRef, recordbatches: Vec, ) -> std::result::Result { @@ -233,19 +233,20 @@ impl HttpRecordsOutput { metrics: Default::default(), }) } else { - let mut rows = - Vec::with_capacity(recordbatches.iter().map(|r| r.num_rows()).sum::()); + let num_rows = recordbatches.iter().map(|r| r.num_rows()).sum::(); + let mut rows = Vec::with_capacity(num_rows); + let num_cols = schema.column_schemas().len(); + rows.resize_with(num_rows, || Vec::with_capacity(num_cols)); + let mut finished_row_cursor = 0; for recordbatch in recordbatches { - for row in recordbatch.rows() { - let value_row = row - .into_iter() - .map(Value::try_from) - .collect::, _>>() - .context(ToJsonSnafu)?; - - rows.push(value_row); + for col in recordbatch.columns() { + for row_idx in 0..recordbatch.num_rows() { + let value = Value::try_from(col.get_ref(row_idx)).context(ToJsonSnafu)?; + rows[row_idx + finished_row_cursor].push(value); + } } + finished_row_cursor += recordbatch.num_rows(); } Ok(HttpRecordsOutput {