sequence_array.cc
1 // Copyright 2017, Beeri 15. All rights reserved.
2 // Author: Roman Gershman (romange@gmail.com)
3 //
4 #include "util/coding/sequence_array.h"
5 
6 #define ZSTD_STATIC_LINKING_ONLY
7 #include <zstd.h>
8 
9 #include "base/endian.h"
10 #include "base/flit.h"
11 #include "base/logging.h"
12 
13 namespace flit = base::flit;
14 
15 namespace util {
16 
17 namespace {
18 
19 template<typename C> size_t byte_size(const C& container) {
20  return container.size() * sizeof(typename C::value_type);
21 }
22 
23 inline size_t GetContentSizeChecked(const uint8_t* src, uint32_t size) {
24  size_t content_size = ZSTD_getFrameContentSize(src, size);
25 
26  CHECK_NE(ZSTD_CONTENTSIZE_ERROR, content_size);
27  CHECK_NE(ZSTD_CONTENTSIZE_UNKNOWN, content_size);
28  return content_size;
29 }
30 
31 } // namespace
32 
33 size_t SequenceArray::GetMaxSerializedSize() const {
34  size_t sz1 = len_.size() * flit::Traits<uint32_t>::max_size;
35  return ZSTD_compressBound(data_.size()) + ZSTD_compressBound(sz1) + 4;
36 }
37 
38 size_t SequenceArray::SerializeTo(uint8* dest) const {
39  base::PODArray<uint8_t> len_buf;
40  len_buf.reserve(len_.size() * flit::Traits<uint32_t>::max_size);
41  size_t len_size = 0;
42 
43  for (const uint32_t val : len_) {
44  len_size += flit::EncodeT<uint32_t>(val, len_buf.data() + len_size);
45  }
46  CHECK_LE(len_size, len_buf.capacity());
47 
48  uint8* next = dest + 4;
49  size_t res1 = ZSTD_compress(next, ZSTD_compressBound(len_size), len_buf.data(), len_size, 1);
50  CHECK(!ZSTD_isError(res1)) << ZSTD_getErrorName(res1);
51 
52  next += res1;
53  LittleEndian::Store32(dest, res1);
54 
55  size_t res2 = ZSTD_compress(next, ZSTD_compressBound(data_.size()), data_.data(),
56  data_.size(), 1);
57  CHECK(!ZSTD_isError(res2)) << ZSTD_getErrorName(res2);
58 
59  VLOG(1) << "SequenceArray::SerializeTo: from " << byte_size(len_) << "/" << data_size()
60  << " to " << res1 << "/" << res2 << " bytes";
61  return res1 + res2 + 4;
62 }
63 
64 void SequenceArray::SerializeFrom(const uint8_t* src, uint32_t count) {
65  clear();
66 
67  CHECK_GT(count, 4);
68  uint32_t len_sz = LittleEndian::Load32(src);
69 
70  CHECK_LT(4 + len_sz, count);
71  src += 4;
72 
73  size_t len_content_size = GetContentSizeChecked(src, len_sz);
74 
75  // +8 to allow fast and safe flit parsing.
76  std::unique_ptr<uint8_t[]> len_buf(new uint8_t[len_content_size + 8]);
77 
78  size_t res = ZSTD_decompress(len_buf.get(), len_content_size, src, len_sz);
79  CHECK(!ZSTD_isError(res)) << ZSTD_getErrorName(res);
80  CHECK_EQ(res, len_content_size);
81 
82  const uint8_t* next = len_buf.get();
83  while (next < len_buf.get() + res) {
84  uint32_t val;
85  next += flit::ParseT(next, &val);
86  len_.push_back(val);
87  }
88  CHECK_EQ(next, len_buf.get() + res);
89  uint32_t buf_sz = count - 4 - len_sz;
90  src += len_sz;
91  size_t buf_content_size = GetContentSizeChecked(src, buf_sz);
92  data_.resize(buf_content_size);
93 
94  res = ZSTD_decompress(data_.data(), buf_content_size, src, buf_sz);
95  CHECK(!ZSTD_isError(res)) << ZSTD_getErrorName(res);
96  CHECK_EQ(buf_content_size, res);
97 }
98 
99 
100 } // namespace util
101