zstd_sinksource.cc
1 // Copyright 2017, Beeri 15. All rights reserved.
2 // Author: Roman Gershman (romange@gmail.com)
3 //
4 #define ZSTD_STATIC_LINKING_ONLY
5 
6 #include <zstd.h>
7 
8 #include "util/zstd_sinksource.h"
9 
10 #include "base/logging.h"
11 
12 namespace util {
13 
14 #define HANDLE reinterpret_cast<ZSTD_CStream*>(zstd_handle_)
15 
16 inline Status ZstdStatus(size_t res) {
17  return Status(StatusCode::IO_ERROR, ZSTD_getErrorName(res));
18 }
19 
20 
21 size_t ZStdSink::CompressBound(size_t src_size) {
22  return ZSTD_compressBound(src_size);
23 }
24 
25 ZStdSink::ZStdSink(Sink* upstream) : upstream_(upstream) {
26  buf_sz_ = ZSTD_CStreamOutSize();
27  buf_.reset(new uint8_t[buf_sz_]);
28 
29  // ZSTD_customMem mem_params{my_alloc, my_free, nullptr};
30 
31  zstd_handle_ = ZSTD_createCStream();
32  VLOG(1) << "Allocated " << buf_sz_ << " bytes";
33 }
34 
35 ZStdSink::~ZStdSink() {
36  ZSTD_freeCStream(HANDLE);
37 }
38 
39 
40 Status ZStdSink::Init(int level) {
41  size_t const res = ZSTD_initCStream_srcSize(HANDLE, level, 0);
42  if (ZSTD_isError(res)) {
43  return ZstdStatus(res);
44  }
45  VLOG(1) << "allocated " << ZSTD_sizeof_CStream(HANDLE);
46  return Status::OK;
47 }
48 
49 Status ZStdSink::Append(const strings::ByteRange& slice) {
50  ZSTD_inBuffer input = { slice.data(), slice.size(), 0 };
51  while (input.pos < input.size) {
52  ZSTD_outBuffer out_buf{ buf_.get(), buf_sz_, 0};
53  size_t res = ZSTD_compressStream(HANDLE, &out_buf , &input);
54  if (ZSTD_isError(res)) {
55  return ZstdStatus(res);
56  }
57  RETURN_IF_ERROR(upstream_->Append(strings::ByteRange(buf_.get(), out_buf.pos)));
58  }
59  return Status::OK;
60 }
61 
62 Status ZStdSink::Flush() {
63  ZSTD_outBuffer out_buf{buf_.get(), buf_sz_, 0};
64 
65  size_t res = ZSTD_endStream(HANDLE, &out_buf);
66  if (ZSTD_isError(res)) {
67  return ZstdStatus(res);
68  }
69  CHECK_EQ(0, res);
70  if (out_buf.pos) {
71  RETURN_IF_ERROR(upstream_->Append(strings::ByteRange(buf_.get(), out_buf.pos)));
72  }
73  return upstream_->Flush();
74 }
75 
76 
77 #define DC_HANDLE reinterpret_cast<ZSTD_DStream*>(zstd_handle_)
78 
79 bool ZStdSource::HasValidHeader(Source* upstream) {
80  uint8_t buf[4];
81  auto res = upstream->Read(strings::MutableByteRange(buf, arraysize(buf)));
82  if (!res.ok() || res.obj != arraysize(buf))
83  return false;
84  upstream->Prepend(strings::ByteRange(buf, arraysize(buf)));
85 
86  return ZSTD_isFrame(buf, arraysize(buf));
87 }
88 
89 
90 const unsigned kReadBuf = 1 << 12;
91 
92 ZStdSource::ZStdSource(Source* upstream)
93  : sub_stream_(upstream) {
94  CHECK(upstream);
95  zstd_handle_ = ZSTD_createDStream();
96  size_t const res = ZSTD_initDStream(DC_HANDLE);
97  CHECK(!ZSTD_isError(res)) << ZSTD_getErrorName(res);
98  buf_.reset(new uint8_t[kReadBuf]);
99 }
100 
101 ZStdSource::~ZStdSource() {
102  ZSTD_freeDStream(DC_HANDLE);
103 }
104 
105 
106 StatusObject<size_t> ZStdSource::ReadInternal(const strings::MutableByteRange& range) {
107  ZSTD_outBuffer output = { range.begin(), range.size(), 0 };
108  do {
109  if (buf_range_.empty()) {
110  auto res = sub_stream_->Read(strings::MutableByteRange(buf_.get(), kReadBuf));
111  if (!res.ok())
112  return res;
113  if (res.obj == 0)
114  break;
115  buf_range_.reset(buf_.get(), res.obj);
116  }
117 
118  ZSTD_inBuffer input{buf_range_.begin(), buf_range_.size(), 0 };
119 
120  size_t to_read = ZSTD_decompressStream(DC_HANDLE, &output , &input);
121  if (ZSTD_isError(to_read)) {
122  return ZstdStatus(to_read);
123  }
124 
125  buf_range_.advance(input.pos);
126  if (input.pos < input.size) {
127  CHECK_EQ(output.pos, output.size);
128  break;
129  }
130  } while (output.pos < output.size);
131  return output.pos;
132 }
133 
134 } // namespace util