Skip to content

Commit

Permalink
Fix the Avro file reader when reading gzip files with individual exam…
Browse files Browse the repository at this point in the history
…ples larger than 1MB.

PiperOrigin-RevId: 703440788
  • Loading branch information
achoum authored and copybara-github committed Dec 6, 2024
1 parent 630eca7 commit f2839a7
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 13 deletions.
8 changes: 6 additions & 2 deletions yggdrasil_decision_forests/dataset/avro.cc
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,12 @@ absl::StatusOr<bool> AvroReader::ReadNextBlock() {
current_block_reader_ = utils::StringViewInputByteStream(current_block_);
break;
case AvroCodec::kDeflate:
// TODO: Multi-thread decompression.
zlib_working_buffer_.resize(1024 * 1024);
RETURN_IF_ERROR(utils::Inflate(
current_block_, &current_block_decompressed_, &zlib_working_buffer_));
current_block_decompressed_.clear();
RETURN_IF_ERROR(
utils::Inflate(current_block_, &current_block_decompressed_,
&zlib_working_buffer_, /*raw_deflate=*/true));
current_block_reader_ =
utils::StringViewInputByteStream(current_block_decompressed_);
break;
Expand Down Expand Up @@ -388,6 +391,7 @@ absl::StatusOr<bool> AvroReader::ReadNextRecord() {
}
}
next_object_in_current_block_++;
DCHECK_LE(next_object_in_current_block_, num_objects_in_current_block_);
return true;
}

Expand Down
1 change: 1 addition & 0 deletions yggdrasil_decision_forests/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,7 @@ cc_test(
srcs = ["zlib_test.cc"],
data = ["//yggdrasil_decision_forests/test_data"],
deps = [
":bytestream",
":filesystem",
":logging",
":test",
Expand Down
18 changes: 11 additions & 7 deletions yggdrasil_decision_forests/utils/zlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ absl::Status GZipInputByteStream::CloseDeflateStream() {

absl::StatusOr<std::unique_ptr<GZipOutputByteStream>>
GZipOutputByteStream::Create(std::unique_ptr<utils::OutputByteStream>&& stream,
int compression_level, size_t buffer_size) {
int compression_level, size_t buffer_size,
bool raw_deflate) {
if (compression_level != Z_DEFAULT_COMPRESSION) {
STATUS_CHECK_GT(compression_level, Z_NO_COMPRESSION);
STATUS_CHECK_LT(compression_level, Z_BEST_COMPRESSION);
Expand All @@ -157,8 +158,10 @@ GZipOutputByteStream::Create(std::unique_ptr<utils::OutputByteStream>&& stream,
std::make_unique<GZipOutputByteStream>(std::move(stream), buffer_size);
std::memset(&gz_stream->deflate_stream_, 0,
sizeof(gz_stream->deflate_stream_));
// Note: A negative window size indicate to use the raw deflate algorithm (!=
// zlib or gzip).
if (deflateInit2(&gz_stream->deflate_stream_, compression_level, Z_DEFLATED,
MAX_WBITS + 16,
raw_deflate ? -15 : (MAX_WBITS + 16),
/*memLevel=*/8, // 8 is the recommended default
Z_DEFAULT_STRATEGY) != Z_OK) {
return absl::InternalError("Cannot initialize gzip stream");
Expand Down Expand Up @@ -212,6 +215,7 @@ absl::Status GZipOutputByteStream::WriteImpl(absl::string_view chunk,
const size_t compressed_bytes = buffer_size_ - deflate_stream_.avail_out;

if (compressed_bytes > 0) {
DCHECK(stream_);
RETURN_IF_ERROR(stream_->Write(absl::string_view{
reinterpret_cast<char*>(output_buffer_.data()), compressed_bytes}));
}
Expand Down Expand Up @@ -244,7 +248,7 @@ absl::Status GZipOutputByteStream::CloseInflateStream() {
}

absl::Status Inflate(absl::string_view input, std::string* output,
std::string* working_buffer) {
std::string* working_buffer, bool raw_deflate) {
if (working_buffer->size() < 1024) {
return absl::InvalidArgumentError(
"worker buffer should be at least 1024 bytes");
Expand All @@ -253,7 +257,7 @@ absl::Status Inflate(absl::string_view input, std::string* output,
std::memset(&stream, 0, sizeof(stream));
// Note: A negative window size indicate to use the raw deflate algorithm (!=
// zlib or gzip).
if (inflateInit2(&stream, -15) != Z_OK) {
if (inflateInit2(&stream, raw_deflate ? -15 : (MAX_WBITS + 16)) != Z_OK) {
return absl::InternalError("Cannot initialize gzip stream");
}
stream.next_in = reinterpret_cast<const Bytef*>(input.data());
Expand All @@ -267,10 +271,10 @@ absl::Status Inflate(absl::string_view input, std::string* output,
inflateEnd(&stream);
return absl::InternalError(absl::StrCat("Internal error", zlib_error));
}
if (stream.avail_out == 0) {
break;
}
const size_t produced_bytes = working_buffer->size() - stream.avail_out;
if (produced_bytes == 0 && zlib_error != Z_STREAM_END) {
continue;
}
absl::StrAppend(output,
absl::string_view{working_buffer->data(), produced_bytes});
if (zlib_error == Z_STREAM_END) {
Expand Down
13 changes: 11 additions & 2 deletions yggdrasil_decision_forests/utils/zlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,25 @@ class GZipInputByteStream : public utils::InputByteStream {

class GZipOutputByteStream : public utils::OutputByteStream {
public:
// Creates a gzip compression stream.
// Args:
// stream: Stream of non-compressed data to compress.
// compression_level: Compression level between 0 (not compressed) and 9.
// buffer_size: Size of the working buffer. The minimum size depends on the
// compressed data, but 1MB should work in most cases.
// raw_deflate: If true, uses the raw deflate algorithm (!= zlib or gzip).
static absl::StatusOr<std::unique_ptr<GZipOutputByteStream>> Create(
std::unique_ptr<utils::OutputByteStream>&& stream,
int compression_level = Z_DEFAULT_COMPRESSION,
size_t buffer_size = 1024 * 1024);
size_t buffer_size = 1024 * 1024, bool raw_deflate = false);

GZipOutputByteStream(std::unique_ptr<utils::OutputByteStream>&& stream,
size_t buffer_size);
~GZipOutputByteStream() override;

absl::Status Write(absl::string_view chunk) override;
absl::Status Close() override;
utils::OutputByteStream& stream() { return *stream_; }

private:
absl::Status CloseInflateStream();
Expand All @@ -96,8 +104,9 @@ class GZipOutputByteStream : public utils::OutputByteStream {
bool deflate_stream_is_allocated_ = false;
};

// Inflates (i.e. decompress) "input" and appends it to "output".
absl::Status Inflate(absl::string_view input, std::string* output,
std::string* working_buffer);
std::string* working_buffer, bool raw_deflate = false);

} // namespace yggdrasil_decision_forests::utils

Expand Down
42 changes: 40 additions & 2 deletions yggdrasil_decision_forests/utils/zlib_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "absl/log/log.h"
#include "absl/strings/escaping.h"
#include "absl/strings/str_cat.h"
#include "yggdrasil_decision_forests/utils/bytestream.h"
#include "yggdrasil_decision_forests/utils/filesystem.h"
#include "yggdrasil_decision_forests/utils/logging.h"
#include "yggdrasil_decision_forests/utils/test.h"
Expand Down Expand Up @@ -139,14 +140,51 @@ TEST_P(GZipTestCaseTest, WriteAndRead) {
}
}

TEST(RawDeflate, Base) {
TEST(RawInflate, Base) {
const auto input =
absl::HexStringToBytes("05804109000008c4aa184ec1c7e0c08ff5c70ea43e470b");
std::string output;
std::string working_buffer(1024, 0);
ASSERT_OK(Inflate(input, &output, &working_buffer));
ASSERT_OK(Inflate(input, &output, &working_buffer, /*raw_deflate=*/true));
EXPECT_EQ(output, "hello world");
}

TEST(RawInflate, ExceedBuffer) {
// Create a large chunk of data (need to be larger than the
// decompress buffer for this test to make sense).
std::string raw_data;
raw_data.reserve(13'000'000);
// Write 13MB of non-compressed data.
for (int i = 0; i < 1'000'000; i++) {
absl::StrAppend(&raw_data, "13 characters");
}

std::string compressed_data;
{
// Compress the data.
auto raw_stream = std::make_unique<StringOutputByteStream>();
auto stream =
GZipOutputByteStream::Create(std::move(raw_stream), 8, 1024 * 1024,
/*raw_deflate=*/true)
.value();
EXPECT_OK(stream->Write(raw_data));
EXPECT_OK(stream->Close());

// TODO: Change "GZipOutputByteStream" so we don't need a dynamic cast
// e.g. GZipOutputByteStream don't own the sub-stream.
compressed_data =
std::move(dynamic_cast<StringOutputByteStream*>(&stream->stream()))
->ToString();
LOG(INFO) << "Compressed data size:" << compressed_data.size();
}

std::string output;
// The buffer is smaller than the decompressed data.
std::string working_buffer(1024 * 1024, 0);
ASSERT_OK(
Inflate(compressed_data, &output, &working_buffer, /*raw_deflate=*/true));
EXPECT_EQ(output, raw_data);
}

} // namespace
} // namespace yggdrasil_decision_forests::utils

0 comments on commit f2839a7

Please sign in to comment.