Skip to content

Commit

Permalink
perf: optimize RecordBatch to HttpOutput conversion (#4178)
Browse files Browse the repository at this point in the history
* add benchmark

Signed-off-by: Ruihang Xia <[email protected]>

* save 70ms

Signed-off-by: Ruihang Xia <[email protected]>

* add profiler

Signed-off-by: Ruihang Xia <[email protected]>

* save 50ms

Signed-off-by: Ruihang Xia <[email protected]>

* save 160ms

Signed-off-by: Ruihang Xia <[email protected]>

* format toml file

Signed-off-by: Ruihang Xia <[email protected]>

* fix license header

Signed-off-by: Ruihang Xia <[email protected]>

* fix windows build

Signed-off-by: Ruihang Xia <[email protected]>

---------

Signed-off-by: Ruihang Xia <[email protected]>
  • Loading branch information
waynexia authored Jun 20, 2024
1 parent 5bcd7a1 commit 21c89f3
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 30 deletions.
35 changes: 31 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 13 additions & 12 deletions src/common/base/src/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,17 @@ impl PartialEq<Bytes> for [u8] {
/// Now this buffer is restricted to only hold valid UTF-8 string (only allow constructing `StringBytes`
/// from String or str). We may support other encoding in the future.
#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct StringBytes(bytes::Bytes);
pub struct StringBytes(String);

impl StringBytes {
/// View this string as UTF-8 string slice.
///
/// # Safety
/// We only allow constructing `StringBytes` from String/str, so the inner
/// buffer must holds valid UTF-8.
pub fn as_utf8(&self) -> &str {
unsafe { std::str::from_utf8_unchecked(&self.0) }
&self.0
}

/// Convert this string into owned UTF-8 string.
pub fn into_string(self) -> String {
self.0
}

pub fn len(&self) -> usize {
Expand All @@ -104,37 +105,37 @@ impl StringBytes {

impl From<String> for StringBytes {
fn from(string: String) -> StringBytes {
StringBytes(bytes::Bytes::from(string))
StringBytes(string)
}
}

impl From<&str> for StringBytes {
fn from(string: &str) -> StringBytes {
StringBytes(bytes::Bytes::copy_from_slice(string.as_bytes()))
StringBytes(string.to_string())
}
}

impl PartialEq<String> for StringBytes {
fn eq(&self, other: &String) -> bool {
self.0 == other.as_bytes()
&self.0 == other
}
}

impl PartialEq<StringBytes> for String {
fn eq(&self, other: &StringBytes) -> bool {
self.as_bytes() == other.0
self == &other.0
}
}

impl PartialEq<str> for StringBytes {
fn eq(&self, other: &str) -> bool {
self.0 == other.as_bytes()
self.0.as_str() == other
}
}

impl PartialEq<StringBytes> for str {
fn eq(&self, other: &StringBytes) -> bool {
self.as_bytes() == other.0
self == other.0
}
}

Expand Down
49 changes: 47 additions & 2 deletions src/datatypes/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use common_time::timestamp::{TimeUnit, Timestamp};
use common_time::{Duration, Interval, Timezone};
use datafusion_common::ScalarValue;
pub use ordered_float::OrderedFloat;
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Serialize, Serializer};
use snafu::{ensure, ResultExt};

use crate::error::{self, ConvertArrowArrayToScalarsSnafu, Error, Result, TryFromValueSnafu};
Expand Down Expand Up @@ -695,7 +695,7 @@ impl TryFrom<Value> for serde_json::Value {
Value::Int64(v) => serde_json::Value::from(v),
Value::Float32(v) => serde_json::Value::from(v.0),
Value::Float64(v) => serde_json::Value::from(v.0),
Value::String(bytes) => serde_json::Value::String(bytes.as_utf8().to_string()),
Value::String(bytes) => serde_json::Value::String(bytes.into_string()),
Value::Binary(bytes) => serde_json::to_value(bytes)?,
Value::Date(v) => serde_json::Value::Number(v.val().into()),
Value::DateTime(v) => serde_json::Value::Number(v.val().into()),
Expand Down Expand Up @@ -1165,6 +1165,39 @@ impl<'a> From<Option<ListValueRef<'a>>> for ValueRef<'a> {
}
}

impl<'a> TryFrom<ValueRef<'a>> for serde_json::Value {
type Error = serde_json::Error;

fn try_from(value: ValueRef<'a>) -> serde_json::Result<serde_json::Value> {
let json_value = match value {
ValueRef::Null => serde_json::Value::Null,
ValueRef::Boolean(v) => serde_json::Value::Bool(v),
ValueRef::UInt8(v) => serde_json::Value::from(v),
ValueRef::UInt16(v) => serde_json::Value::from(v),
ValueRef::UInt32(v) => serde_json::Value::from(v),
ValueRef::UInt64(v) => serde_json::Value::from(v),
ValueRef::Int8(v) => serde_json::Value::from(v),
ValueRef::Int16(v) => serde_json::Value::from(v),
ValueRef::Int32(v) => serde_json::Value::from(v),
ValueRef::Int64(v) => serde_json::Value::from(v),
ValueRef::Float32(v) => serde_json::Value::from(v.0),
ValueRef::Float64(v) => serde_json::Value::from(v.0),
ValueRef::String(bytes) => serde_json::Value::String(bytes.to_string()),
ValueRef::Binary(bytes) => serde_json::to_value(bytes)?,
ValueRef::Date(v) => serde_json::Value::Number(v.val().into()),
ValueRef::DateTime(v) => serde_json::Value::Number(v.val().into()),
ValueRef::List(v) => serde_json::to_value(v)?,
ValueRef::Timestamp(v) => serde_json::to_value(v.value())?,
ValueRef::Time(v) => serde_json::to_value(v.value())?,
ValueRef::Interval(v) => serde_json::to_value(v.to_i128())?,
ValueRef::Duration(v) => serde_json::to_value(v.value())?,
ValueRef::Decimal128(v) => serde_json::to_value(v.to_string())?,
};

Ok(json_value)
}
}

/// Reference to a [ListValue].
///
/// Now comparison still requires some allocation (call of `to_value()`) and
Expand Down Expand Up @@ -1195,6 +1228,18 @@ impl<'a> ListValueRef<'a> {
}
}

impl<'a> Serialize for ListValueRef<'a> {
fn serialize<S: Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> {
match self {
ListValueRef::Indexed { vector, idx } => match vector.get(*idx) {
Value::List(v) => v.serialize(serializer),
_ => unreachable!(),
},
ListValueRef::Ref { val } => val.serialize(serializer),
}
}
}

impl<'a> PartialEq for ListValueRef<'a> {
fn eq(&self, other: &Self) -> bool {
self.to_value().eq(&other.to_value())
Expand Down
9 changes: 8 additions & 1 deletion src/servers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ catalog = { workspace = true, features = ["testing"] }
client = { workspace = true, features = ["testing"] }
common-base.workspace = true
common-test-util.workspace = true
criterion = "0.4"
criterion = "0.5"
mysql_async = { version = "0.33", default-features = false, features = [
"default-rustls",
] }
Expand All @@ -131,9 +131,16 @@ tokio-postgres = "0.7"
tokio-postgres-rustls = "0.11"
tokio-test = "0.4"

[target.'cfg(not(windows))'.dev-dependencies]
pprof = { version = "0.13", features = ["criterion", "flamegraph"] }

[build-dependencies]
common-version.workspace = true

[[bench]]
name = "bench_prom"
harness = false

[[bench]]
name = "to_http_output"
harness = false
81 changes: 81 additions & 0 deletions src/servers/benches/to_http_output.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;
use std::time::Instant;

use arrow::array::StringArray;
use arrow_schema::{DataType, Field, Schema};
use common_recordbatch::RecordBatch;
use criterion::{criterion_group, criterion_main, Criterion};
use datatypes::schema::SchemaRef;
use datatypes::vectors::StringVector;
use servers::http::HttpRecordsOutput;

fn mock_schema() -> SchemaRef {
let mut fields = Vec::with_capacity(10);
for i in 0..10 {
fields.push(Field::new(format!("field{}", i), DataType::Utf8, true));
}
let arrow_schema = Arc::new(Schema::new(fields));
Arc::new(arrow_schema.try_into().unwrap())
}

fn mock_input_record_batch(batch_size: usize, num_batches: usize) -> Vec<RecordBatch> {
let mut result = Vec::with_capacity(num_batches);
for _ in 0..num_batches {
let mut vectors = Vec::with_capacity(10);
for _ in 0..10 {
let vector: StringVector = StringArray::from(
(0..batch_size)
.map(|_| String::from("Copyright 2024 Greptime Team"))
.collect::<Vec<_>>(),
)
.into();
vectors.push(Arc::new(vector) as _);
}

let schema = mock_schema();
let record_batch = RecordBatch::new(schema, vectors).unwrap();
result.push(record_batch);
}

result
}

fn bench_convert_record_batch_to_http_output(c: &mut Criterion) {
let record_batches = mock_input_record_batch(4096, 100);
c.bench_function("convert_record_batch_to_http_output", |b| {
b.iter_custom(|iters| {
let mut elapsed_sum = std::time::Duration::new(0, 0);
for _ in 0..iters {
let record_batches = record_batches.clone();
let start = Instant::now();
let _result = HttpRecordsOutput::try_new(mock_schema(), record_batches);
elapsed_sum += start.elapsed();
}
elapsed_sum
});
});
}

#[cfg(not(windows))]
criterion_group! {
name = benches;
config = Criterion::default().with_profiler(pprof::criterion::PProfProfiler::new(101, pprof::criterion::Output::Flamegraph(None)));
targets = bench_convert_record_batch_to_http_output
}
#[cfg(windows)]
criterion_group!(benches, bench_convert_record_batch_to_http_output);
criterion_main!(benches);
23 changes: 12 additions & 11 deletions src/servers/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ impl HttpRecordsOutput {
}

impl HttpRecordsOutput {
pub(crate) fn try_new(
pub fn try_new(
schema: SchemaRef,
recordbatches: Vec<RecordBatch>,
) -> std::result::Result<HttpRecordsOutput, Error> {
Expand All @@ -233,19 +233,20 @@ impl HttpRecordsOutput {
metrics: Default::default(),
})
} else {
let mut rows =
Vec::with_capacity(recordbatches.iter().map(|r| r.num_rows()).sum::<usize>());
let num_rows = recordbatches.iter().map(|r| r.num_rows()).sum::<usize>();
let mut rows = Vec::with_capacity(num_rows);
let num_cols = schema.column_schemas().len();
rows.resize_with(num_rows, || Vec::with_capacity(num_cols));

let mut finished_row_cursor = 0;
for recordbatch in recordbatches {
for row in recordbatch.rows() {
let value_row = row
.into_iter()
.map(Value::try_from)
.collect::<std::result::Result<Vec<Value>, _>>()
.context(ToJsonSnafu)?;

rows.push(value_row);
for col in recordbatch.columns() {
for row_idx in 0..recordbatch.num_rows() {
let value = Value::try_from(col.get_ref(row_idx)).context(ToJsonSnafu)?;
rows[row_idx + finished_row_cursor].push(value);
}
}
finished_row_cursor += recordbatch.num_rows();
}

Ok(HttpRecordsOutput {
Expand Down

0 comments on commit 21c89f3

Please sign in to comment.