Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(parquet): Add next_row_group API for ParquetRecordBatchStream #6907

Merged
merged 2 commits into from
Dec 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions parquet/src/arrow/async_reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,9 @@ impl<T> std::fmt::Debug for StreamState<T> {

/// An asynchronous [`Stream`](https://docs.rs/futures/latest/futures/stream/trait.Stream.html) of [`RecordBatch`]
/// for a parquet file that can be constructed using [`ParquetRecordBatchStreamBuilder`].
///
/// `ParquetRecordBatchStream` also provides [`ParquetRecordBatchStream::next_row_group`] for fetching row groups,
/// allowing users to decode record batches separately from I/O.
pub struct ParquetRecordBatchStream<T> {
metadata: Arc<ParquetMetaData>,

Expand Down Expand Up @@ -654,6 +657,70 @@ impl<T> ParquetRecordBatchStream<T> {
}
}

impl<T> ParquetRecordBatchStream<T>
where
T: AsyncFileReader + Unpin + Send + 'static,
{
/// Fetches the next row group from the stream.
///
/// Users can continue to call this function to get row groups and decode them concurrently.
///
/// ## Notes
///
/// ParquetRecordBatchStream should be used either as a `Stream` or with `next_row_group`; they should not be used simultaneously.
///
/// ## Returns
///
/// - `Ok(None)` if the stream has ended.
/// - `Err(error)` if the stream has errored. All subsequent calls will return `Ok(None)`.
/// - `Ok(Some(reader))` which holds all the data for the row group.
pub async fn next_row_group(&mut self) -> Result<Option<ParquetRecordBatchReader>> {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if next_row_group is the best name, open to other options.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a good, clear name as it clearly explains what it does

loop {
match &mut self.state {
StreamState::Decoding(_) | StreamState::Reading(_) => {
return Err(ParquetError::General(
"Cannot combine the use of next_row_group with the Stream API".to_string(),
))
}
StreamState::Init => {
let row_group_idx = match self.row_groups.pop_front() {
Some(idx) => idx,
None => return Ok(None),
};

let row_count = self.metadata.row_group(row_group_idx).num_rows() as usize;

let selection = self.selection.as_mut().map(|s| s.split_off(row_count));

let reader_factory = self.reader.take().expect("lost reader");

let (reader_factory, maybe_reader) = reader_factory
.read_row_group(
row_group_idx,
selection,
self.projection.clone(),
self.batch_size,
)
.await
.map_err(|err| {
self.state = StreamState::Error;
err
})?;
self.reader = Some(reader_factory);

if let Some(reader) = maybe_reader {
return Ok(Some(reader));
} else {
// All rows skipped, read next row group
continue;
}
}
StreamState::Error => return Ok(None), // Ends the stream as error happens.
}
}
}
}

impl<T> Stream for ParquetRecordBatchStream<T>
where
T: AsyncFileReader + Unpin + Send + 'static,
Expand Down Expand Up @@ -1020,6 +1087,71 @@ mod tests {
);
}

#[tokio::test]
async fn test_async_reader_with_next_row_group() {
let testdata = arrow::util::test_util::parquet_test_data();
let path = format!("{testdata}/alltypes_plain.parquet");
let data = Bytes::from(std::fs::read(path).unwrap());

let metadata = ParquetMetaDataReader::new()
.parse_and_finish(&data)
.unwrap();
let metadata = Arc::new(metadata);

assert_eq!(metadata.num_row_groups(), 1);

let async_reader = TestReader {
data: data.clone(),
metadata: metadata.clone(),
requests: Default::default(),
};

let requests = async_reader.requests.clone();
let builder = ParquetRecordBatchStreamBuilder::new(async_reader)
.await
.unwrap();

let mask = ProjectionMask::leaves(builder.parquet_schema(), vec![1, 2]);
let mut stream = builder
.with_projection(mask.clone())
.with_batch_size(1024)
.build()
.unwrap();

let mut readers = vec![];
while let Some(reader) = stream.next_row_group().await.unwrap() {
readers.push(reader);
}

let async_batches: Vec<_> = readers
.into_iter()
.flat_map(|r| r.map(|v| v.unwrap()).collect::<Vec<_>>())
.collect();

let sync_batches = ParquetRecordBatchReaderBuilder::try_new(data)
.unwrap()
.with_projection(mask)
.with_batch_size(104)
.build()
.unwrap()
.collect::<ArrowResult<Vec<_>>>()
.unwrap();

assert_eq!(async_batches, sync_batches);

let requests = requests.lock().unwrap();
let (offset_1, length_1) = metadata.row_group(0).column(1).byte_range();
let (offset_2, length_2) = metadata.row_group(0).column(2).byte_range();

assert_eq!(
&requests[..],
&[
offset_1 as usize..(offset_1 + length_1) as usize,
offset_2 as usize..(offset_2 + length_2) as usize
]
);
}

#[tokio::test]
async fn test_async_reader_with_index() {
let testdata = arrow::util::test_util::parquet_test_data();
Expand Down
Loading