set_encoder.cc
1 // Copyright 2017, Beeri 15. All rights reserved.
2 // Author: Roman Gershman (romange@gmail.com)
3 //
4 
5 #include "util/coding/set_encoder.h"
6 
7 #define ZSTD_STATIC_LINKING_ONLY
8 #include <zstd.h>
9 #include <zdict.h>
10 
11 #include <fcntl.h>
12 #include <sys/stat.h>
13 
14 #include "base/endian.h"
15 #include "base/hash.h"
16 #include "base/logging.h"
17 
18 #include "strings/join.h"
19 #include "base/flit.h"
20 
21 #ifndef IS_LITTLE_ENDIAN
22 #error this file assumes little endian architecture
23 #endif
24 
25 using namespace std;
26 using strings::ByteRange;
27 using namespace base;
28 
29 namespace util {
30 
31 namespace {
32 
33 
34 constexpr uint16_t kMagic = 0x2ca7;
35 constexpr uint32_t kMaxHeaderSize = 11; // including kMagic + flags
36 constexpr uint32_t SEQ_BLOCK_LOG_SIZE = 17;
37 constexpr uint32_t SEQ_BLOCK_SIZE = 1U << SEQ_BLOCK_LOG_SIZE;
38 
39 // We will store 8 times more data for analysis than we gonna store in compressed blocks.
40 constexpr uint32_t MAX_BATCH_SIZE = SEQ_BLOCK_SIZE * 8;
41 
42 constexpr uint32_t SEQ_DICT_MAX_SIZE = 1U << 17;
43 
44 
45 constexpr unsigned kArrLengthLimit = 1 << 14; // to accommodate at least 1 array in 128KB.
46 constexpr unsigned kLenLimit = (1 << 16);
47 
48 #define CHECK_ZSTDERR(res) do { auto foo = (res); \
49  CHECK(!ZSTD_isError(foo)) << ZSTD_getErrorName(foo); } while(false)
50 
51 inline uint32_t InlineCode(uint32_t len) { return (len << 1) | 0;}
52 inline uint32_t DictCode(uint32_t id) { return (id << 1) | 1;}
53 inline bool IsDictCode(uint32_t code) { return (code & 1) != 0;}
54 
55 constexpr uint32_t kDictHeaderSize = 6;
56 
57 #if 0
58 static uint8_t *svb_encode_scalar(const uint16_t *in,
59  uint8_t *__restrict__ key_ptr, uint8_t *__restrict__ data_ptr,
60  uint32_t count) {
61  if (count == 0)
62  return data_ptr; // exit immediately if no data
63 
64  uint8_t shift = 0; // cycles 0, 2, 4, 6, 0, 2, 4, 6, ...
65  uint8_t key = 0;
66  for (uint32_t c = 0; c < count; c++) {
67  if (shift == 8) {
68  shift = 0;
69  *key_ptr++ = key;
70  key = 0;
71  }
72  uint32_t val = in[c];
73  uint8_t code = _encode_data(val, &data_ptr);
74  key |= code << shift;
75  shift += 1;
76  }
77 
78  *key_ptr = key; // write last key (no increment needed)
79  return data_ptr; // pointer to first unused data byte
80 }
81 
82 // Encode an array of a given length read from in to bout in streamvbyte format.
83 // Returns the number of bytes written.
84 // Taken from: https://github.com/lemire/streamvbyte/blob/master/src/streamvbyte.c
85 size_t streamvbyte_encode(const uint16_t *in, uint32_t count, uint8_t *out) {
86  uint8_t *keyPtr = out;
87  uint32_t keyLen = (count + 7) / 8; // 1-bits rounded to full byte
88  uint8_t *dataPtr = keyPtr + keyLen; // variable byte data after all keys
89  return svb_encode_scalar(in, keyPtr, dataPtr, count) - out;
90 }
91 
92 
93 uint8* encode_pair(uint16 a, uint16 b, uint8* dest) {
94  uint32 v = (a & 0xf) | ((b & 0xf) << 4);
95  a >>= 4;
96  b >>= 4;
97 
98  if (a || b) {
99  v |= ((a & 0xf) | ((b & 0xf) << 4)) << 8;
100  }
101 
102  return Varint::Encode32(dest, v);
103 }
104 #endif
105 
106 
107 } // namespace
108 
109 
110 // We both delta encoding ordered set of unique numbers and shrink it with partial RLE.
111 // The RLE uses the fact that the input numbers are smaller than 16K and it uses sometimes
112 // LSB bit to encode repeat value. The rule is that if the destination symbol is less than
113 // kSmallNum then the next destination symbol shifted left by 1 and reserves LSB to say
114 // whether it's repeat or not. See header for more info.
115 unsigned DeltaEncode16(const uint16* src, unsigned cnt, uint16* dest) {
116  if (VLOG_IS_ON(3)) {
117  string tmp = absl::StrJoin(src, src + cnt, ",");
118  LOG(INFO) << "Adding " << cnt << ": " << tmp;
119  }
120  uint16* dest_orig = dest;
121 
122  *dest++ = *src;
123  if (cnt == 1)
124  return 1;
125 
126  const uint16* end = src + cnt;
127  uint16 prev = *src++;
128  uint16 prev_dest = prev;
129 
130  unsigned rep = 0;
131  for (; src < end; ++src) {
132  DCHECK_LT(prev, *src);
133 
134  unsigned is_prev_dest_small = (prev_dest < internal::kSmallNum);
135  uint16 delta = *src - prev - 1;
136  prev = *src;
137  if (delta == prev_dest && is_prev_dest_small) {
138  ++rep;
139  continue;
140  }
141 
142  if (rep) {
143  if (rep > 1) {
144  *dest++ = ((rep - 1) << 1) | 1;
145  *dest++ = delta;
146  } else {
147  *dest++ = prev_dest << 1;
148  *dest++ = delta << 1;
149  }
150  rep = 0;
151  } else {
152  // Store delta.
153  *dest++ = (delta << is_prev_dest_small);
154  }
155  prev_dest = delta;
156  }
157 
158  if (rep) {
159  if (rep > 1)
160  *dest++ = ((rep - 1) << 1) | 1;
161  else
162  *dest++ = prev_dest << 1;
163  }
164 
165  return dest - dest_orig;
166 }
167 
168 // Returns (num symbols, num of bytes) pair
169 uint32_t DeltaAndFlitEncode(LiteralDictBase::SymbId* src, uint32_t cnt,
170  LiteralDictBase::SymbId* tmp,
171  void* dest) {
172  using SymbId = LiteralDictBase::SymbId;
173  std::sort(src, src + cnt);
174 
175  uint32_t out_sz = DeltaEncode16(src, cnt, tmp);
176 
177  DCHECK_LE(out_sz, cnt); // it always holds due to how DeltaEncode16 works.
178 
179  uint8* start = static_cast<uint8_t*>(dest);
180  uint8* next = start;
181  SymbId current = tmp[0];
182  for (unsigned j = 1; j < out_sz; ++j) {
183  // EncodeT uncoditionally writes at least 4 bytes so we store the next value.
184  SymbId val = tmp[j];
185  next += flit::EncodeT<uint32_t>(current, next);
186  current = val;
187  }
188  next += flit::EncodeT<uint32_t>(current, next);
189 
190  return next - start;
191 }
192 
193 
194 template<typename T> void LiteralDict<T>::Build() {
195  using iterator = typename decltype(freq_map_)::iterator;
196 
197  vector<pair<unsigned, iterator>> freq_arr;
198  freq_arr.reserve(freq_map_.size());
199 
200  for (auto it = begin(freq_map_); it != end(freq_map_); ++it) {
201  freq_arr.emplace_back(it->second.cnt, it);
202  }
203  std::sort(begin(freq_arr), end(freq_arr),
204  [](const auto& val1, const auto& val2) { return val1.first > val2.first; }
205  );
206 
207  alphabet_.resize(freq_arr.size());
208 
209  for (unsigned i = 0; i < freq_arr.size(); ++i) {
210  const iterator& it = freq_arr[i].second;
211  T t = it->first;
212  alphabet_[i] = t;
213  it->second.id = i;
214  }
215 }
216 
217 template<typename T> size_t LiteralDict<T>::GetMaxSerializedSize() const {
218  return sizeof(T) * alphabet_.size();
219 }
220 
221 template<typename T> size_t LiteralDict<T>::SerializeTo(uint8_t* dest) const {
222  uint8_t* ptr = dest;
223 
224  for (const auto& t : alphabet_) {
225  LittleEndian::StoreT(typename std::make_unsigned<T>::type(t), ptr);
226  ptr += sizeof(T);
227  }
228  return ptr - dest;
229 }
230 
231 template<typename T> bool LiteralDict<T>::Resolve(const T* src, uint32_t count, SymbId* dest) {
232  for (uint32_t i = 0; i < count; ++i) {
233  auto res = freq_map_.emplace(src[i], Record{});
234  if (res.second) {
235  if (alphabet_.size() >= kMaxAlphabetSize) {
236  freq_map_.erase(res.first);
237  return false;
238  }
239 
240  res.first->second.id = alphabet_.size();
241  alphabet_.push_back(src[i]);
242  }
243 
244  LittleEndian::Store16(dest + i, res.first->second.id);
245  }
246  return true;
247 }
248 
249 
250 ostream& operator<<(ostream& os, const ZSTD_parameters& p) {
251  os << "wlog: " << p.cParams.windowLog << ", clog: " << p.cParams.chainLog << ", strategy: "
252  << p.cParams.strategy << ", slog: " << p.cParams.searchLog << ", cntflag: "
253  << p.fParams.contentSizeFlag << ", hashlog: " << p.cParams.hashLog;
254  return os;
255 }
256 
258  ZSTD_CCtx* context;
259 
260 
261  ZstdCntx(const ZstdCntx&) = delete;
262  void operator=(const ZstdCntx&) = delete;
263 
264  ZstdCntx() {
265  context = ZSTD_createCCtx();
266  }
267 
268  ~ZstdCntx() {
269  ZSTD_freeCCtx(context);
270  }
271 
272  bool start = true;
273 };
274 
275 
276 void BlockHeader::Read(const uint8_t* src) {
277  flags = *src;
278 
279  const uint8_t* next = src + 1;
280  if (flags & kDictBit) {
281  num_sequences = LittleEndian::Load16(next);
282  byte_len_size_comprs = LittleEndian::Load24(next + 2);
283  next += 5;
284  } else {
285  num_sequences = byte_len_size_comprs = 0;
286  }
287  sequence_size_comprs = LittleEndian::Load24(next);
288 }
289 
290 uint8 BlockHeader::Write(uint8_t* dest) const {
291  LittleEndian::Store16(dest, kMagic);
292 
293  dest[2] = flags; // type
294  uint8* next = dest + 3;
295  if (flags & kDictBit) {
296  LittleEndian::Store16(next, num_sequences);
297  LittleEndian::Store24(next + 2, byte_len_size_comprs);
298  next += 5;
299  }
300  LittleEndian::Store24(next, sequence_size_comprs);
301  return next - dest + 3;
302 }
303 
304 // not including magic but including flags.
305 uint8_t BlockHeader::HeaderSize(uint8_t flags) {
306  if (flags & kDictBit) {
307  return 9;
308  }
309  return 4;
310 }
311 
313  ZSTD_DCtx* context;
314  uint8_t start = 2; // bit 1 - start, bit 0 - which destination buffer to fill.
315 
316  Zstd(const Zstd&) = delete;
317  void operator=(const Zstd&) = delete;
318 
319  Zstd() {
320  context = ZSTD_createDCtx();
321  }
322 
323  ~Zstd() {
324  ZSTD_freeDCtx(context);
325  }
326 
327  size_t offset() const { return SEQ_BLOCK_SIZE * (start & 1); }
328 };
329 
330 SeqEncoderBase::SeqEncoderBase() {
331  seq_map_.set_empty_key(ByteRange());
332  dict_seq_map_.set_empty_key(strings::ByteRange());
333 
334  prev_block_.reserve(SEQ_BLOCK_SIZE);
335 
336  zstd_cntx_.reset(new ZstdCntx);
337  tmp_symb_.reserve(kArrLengthLimit);
338 }
339 
340 SeqEncoderBase::~SeqEncoderBase() {
341 }
342 
343 void SeqEncoderBase::AddCompressedBuf(const BlockHeader& bh) {
344  VLOG(1) << "Adding compressed block: " << bh.num_sequences
345  << "/" << bh.byte_len_size_comprs << "/" << bh.sequence_size_comprs
346  << "/" << int(bh.flags);
347 
348  size_t blob_sz = bh.byte_len_size_comprs + bh.sequence_size_comprs;
349  std::unique_ptr<uint8_t[]> cb(new uint8_t[blob_sz + kMaxHeaderSize]);
350 
351  uint8_t offset = bh.Write(cb.get());
352  memcpy(cb.get() + offset, compress_data_.begin(), blob_sz);
353 
354  compressed_blocks_.emplace_back(cb.get(), blob_sz + offset);
355  compressed_bufs_.emplace_back(std::move(cb));
356 }
357 
358 bool SeqEncoderBase::LearnSeqDict(strings::ByteRange key) {
359  auto res = seq_map_.emplace(key, EntryVal{});
360  EntryVal& ev = res.first->second;
361  ev.ref_cnt++;
362 
363  if (!res.second) {
364  const uint8_t* ptr = res.first->first.data();
365  DCHECK_GE(ptr, lit_data_.data());
366  DCHECK_LT(ptr, lit_data_.data() + lit_data_.capacity());
367 
368  duplicate_seq_.push_back(ptr - lit_data_.data());
369  }
370  return res.second;
371 }
372 
373 
374 uint32 SeqEncoderBase::Cost() const {
375  uint32_t cost = lit_data_.size();
376 
377  for (auto block : compressed_blocks_) {
378  cost += block.size();
379  }
380 
381  if (state_ == LIT_DICT) {
382  cost += len_code_.size() * sizeof(uint16_t);
383  }
384  return cost;
385 }
386 
387 void SeqEncoderBase::AnalyzePreDict() {
388  uint32_t alplhabet_size = PrepareDict();
389  if (state_ == NO_LIT_DICT)
390  return;
391 
392  uint32_t lit_cnt = lit_data_.size() / sizeof(SymbId);
393 
394  SymbId* symb_arr = reinterpret_cast<SymbId*>(lit_data_.data());
395 
396  uint8_t* dest = lit_data_.begin();
397  uint32_t delta_flit_bytes = 0;
398 
399  for (size_t i = 0; i < len_code_.size(); ++i) {
400  uint32_t lit_num = len_code_[i];
401 
402  uint32_t bytes_num = DeltaAndFlitEncode(symb_arr, lit_num, tmp_symb_.begin(), dest);
403 
404  // it always holds due to how DeltaEncode16 and flit encoding work with 14 bit numbers.
405  DCHECK_LE(bytes_num, lit_num * sizeof(SymbId));
406 
407  bool new_entry = disable_seq_dict_ || LearnSeqDict(ByteRange(dest, bytes_num));
408 
409  if (new_entry) {
410  dest += bytes_num;
411  len_code_[i] = InlineCode(bytes_num);
412  } else {
413  len_code_[i] = DictCode(bytes_num);
414  }
415 
416  symb_arr += lit_num;
417  delta_flit_bytes += bytes_num;
418  }
419 
420  size_t delta_flit_cost = alplhabet_size * literal_size_ + delta_flit_bytes;
421 
422  // TODO: to refine the state machine: to allow variations of
423  // delta-flit / sequence dictionary encodings.
424  CHECK_LT(delta_flit_cost, literal_size_ * lit_cnt) << "TBD to fallback.";
425 
426  size_t seq_size = dest - lit_data_.data();
427  lit_data_.resize_assume_reserved(seq_size);
428 
429  // TODO: Is that the right place to do it? Maybe it's worth to wait until lit_data is full
430  // and then to analyze it? PRO this checl - faster fallback to faster heuristic on uncompressable
431  // data.
432  if (seq_size + seq_map_.size() * 2 > delta_flit_bytes) {
433  VLOG(1) << "Falling back to no sequence. Reason: " << seq_size + seq_map_.size() * 2
434  << " vs " << delta_flit_bytes;
435  for (auto& k_v : seq_map_) {
436  k_v.second.ref_cnt = 0;
437  }
438 
439  // will flush and clear using sequence dictionaries.
440  AnalyzeSequenceDict();
441  }
442 
443  VLOG(1) << "original/flit: " << added_lit_cnt_ * literal_size_ << "/" << lit_data_.size();
444 }
445 
446 void SeqEncoderBase::CompressRawLit(bool final) {
447  DCHECK_EQ(state_, NO_LIT_DICT);
448 
449  size_t csz1 = ZSTD_compressBound(lit_data_.size());
450 
451  compress_data_.reserve(csz1);
452  size_t res = ZSTD_compress(compress_data_.begin(), csz1,
453  lit_data_.data(), lit_data_.size(), 1);
454  CHECK(!ZSTD_isError(res)) << ZSTD_getErrorName(res);
455  VLOG(1) << "CompressRawLit: from " << lit_data_.size() << " to " << res;
456 
457  BlockHeader bh;
458  bh.flags = (final ? BlockHeader::kFinalBit : 0);
459 
460  bh.sequence_size_comprs = res;
461 
462  AddCompressedBuf(bh);
463 
464  lit_data_.clear();
465  len_code_.clear();
466 }
467 
468 void SeqEncoderBase::CompressFlitSequences(bool final) {
469  DCHECK_EQ(state_, LIT_DICT);
470  CHECK_LT(len_code_.size(), kLenLimit);
471 
472  // We support arrays upto kArrLengthLimit elements so 3 bytes per len value is enough.
473  tmp_space_.reserve(len_code_.size() * 3);
474 
475  size_t len_size = 0;
476  for (const uint32_t val : len_code_) {
477  len_size += flit::EncodeT<uint32_t>(val, tmp_space_.data() + len_size);
478  }
479  CHECK_LE(len_size, tmp_space_.capacity());
480 
481  BlockHeader bh;
482  bh.flags = BlockHeader::kDictBit | (final ? BlockHeader::kFinalBit : 0);
483  bh.num_sequences = len_code_.size();
484 
485  if (!dict_seq_.empty()) {
486  bh.flags |= BlockHeader::kDictSeqBit;
487  }
488 
489  size_t csz1 = ZSTD_compressBound(len_size);
490  size_t csz2 = ZSTD_compressBound(lit_data_.size());
491 
492  size_t upper_bound = csz1 + csz2;
493  compress_data_.reserve(upper_bound);
494 
495  uint8_t* compress_pos = compress_data_.begin();
496 
497  size_t res = ZSTD_compress(compress_pos, csz1,
498  tmp_space_.data(), len_size, 1);
499  CHECK(!ZSTD_isError(res)) << ZSTD_getErrorName(res);
500  bh.byte_len_size_comprs = res;
501 
502  compress_pos += res;
503  if (zstd_cntx_->start) {
504  CHECK(compressed_blocks_.empty() && compressed_bufs_.empty());
505 
506 #if 0
507  if (!dict_seq_.empty()) {
508  zstd_dict_.reset(new uint8[1 << 17]);
509  std::vector<size_t> ss(dict_seq_.len_array().begin(), dict_seq_.len_array().end());
510 
511  zstd_dict_size_ = ZDICT_trainFromBuffer(zstd_dict_.get(), 1 << 17,
512  dict_seq_.data().data(), ss.data(), ss.size());
513  VLOG(1) << "Zdict dictsize: " << zstd_dict_size_;
514  } else {
515  zstd_dict_size_ = 0;
516  }
517 #endif
518  ZSTD_parameters params{ZSTD_getCParams(6, 0 /* est src size*/, zstd_dict_size_),
519  ZSTD_frameParameters()};
520 
521  params.cParams.windowLog = SEQ_BLOCK_LOG_SIZE + 1;
522  params.cParams.hashLog = SEQ_BLOCK_LOG_SIZE - 2;
523  // params.cParams.chainLog = SEQ_BLOCK_LOG_SIZE;
524 
525  VLOG(1) << "Using: " << params;
526 
527  CHECK_ZSTDERR(ZSTD_compressBegin_advanced(
528  zstd_cntx_->context, zstd_dict_.get(), zstd_dict_size_, params, ZSTD_CONTENTSIZE_UNKNOWN));
529  zstd_cntx_->start = false;
530  }
531 
532  bool finish_frame = final;
533 
534  auto func = finish_frame ? ZSTD_compressEnd : ZSTD_compressContinue;
535  res = func(zstd_cntx_->context, compress_pos, csz2, lit_data_.data(), lit_data_.size());
536  CHECK_ZSTDERR(res);
537  bh.sequence_size_comprs = res;
538  compress_pos += res;
539 
540  size_t real_size = compress_pos - compress_data_.begin();
541  VLOG(1) << "flit/compressed: " << lit_data_.size() << "/" << real_size + kMaxHeaderSize;
542 
543  CHECK_LE(real_size, upper_bound);
544 
545  AddCompressedBuf(bh);
546 
547  if (finish_frame) {
548  VLOG(1) << "SeqDict referenced " << dict_ref_bytes_;
549  zstd_cntx_->start = true;
550  dict_ref_bytes_ = 0;
551  }
552 
553  lit_data_.swap(prev_block_);
554  lit_data_.clear();
555  len_code_.clear();
556  added_lit_cnt_ = 0;
557 }
558 
559 void SeqEncoderBase::AnalyzeSequenceDict() {
560  using iterator = decltype(seq_map_)::iterator;
561  std::vector<pair<unsigned, iterator>> sorted_it;
562 
563  // Add key to key_order_ to maintain the original order of the keys.
564  for (auto it = seq_map_.begin(); it != seq_map_.end(); ++it) {
565  if (it->first.size() > 1 && it->second.ref_cnt > 1)
566  sorted_it.emplace_back(it->second.ref_cnt, it);
567  else
568  it->second.ref_cnt = 0; // just reset it.
569  }
570 
571  std::sort(sorted_it.begin(), sorted_it.end(),
572  [](const pair<unsigned, iterator>& p1, const pair<unsigned, iterator>& p2) {
573  return p1.first > p2.first;
574  });
575 
576  size_t selected_size = 0, ref_bytes = 0;
577 
578  auto it = sorted_it.begin();
579 
580  // Copy into dict_seq_ entries with larger reference count first.
581  for (; it != sorted_it.end(); ++it) {
582  ByteRange seq_str = it->second->first;
583  EntryVal& ev = it->second->second;
584 
585  if (selected_size + seq_str.size() >= SEQ_DICT_MAX_SIZE) {
586  break;
587  }
588  VLOG(2) << "Choosen dict with factor " << ev.ref_cnt << ", size " << seq_str.size();
589 
590  // Returns the sequence id of the added sequence seq_str.
591  ev.dict_id = dict_seq_.Add(seq_str.begin(), seq_str.end());
592 
593  selected_size += seq_str.size();
594 
595  // How many bytes reference this entry.
596  ref_bytes += seq_str.size() * ev.ref_cnt;
597  }
598  dict_nominal_ratio_ = ref_bytes * 1.0 / (selected_size + 1);
599 
600  size_t dict_count = it - sorted_it.begin();
601 
602  VLOG(1) << "Dictionary will take " << selected_size << " bytes and will represent "
603  << ref_bytes << " bytes with " << dict_count << " items, leaving "
604  << sorted_it.size() - dict_count;
605  VLOG(1) << "original/flit: " << added_lit_cnt_ * literal_size_ << "/" << lit_data_.size();
606 
607  // We gonna proceed with sequence dictionary.
608  for (; it != sorted_it.end(); ++it) {
609  it->second->second.ref_cnt = 0;
610  }
611 
612  {
613  // Create reverse mapping to dict_seq.
614  unsigned dict_index = 0;
615  CHECK(dict_seq_map_.empty());
616  for (auto it = dict_seq_.begin(); it != dict_seq_.end(); ++it) {
617  auto res = dict_seq_map_.emplace(*it, dict_index++);
618  CHECK(res.second);
619  }
620  }
621 
622  uint32_t inline_index = 0, dest_index = 0, duplicate_index = 0;
623  base::PODArray<uint8> seq_data(pmr::get_default_resource());
624  seq_data.reserve(SEQ_BLOCK_SIZE);
625  seq_data.swap(lit_data_);
626 
627  const uint8* key_src = seq_data.data();
628  size_t len_size = len_code_.size();
629 
630  ByteRange ref_entry;
631 
632  // now lit data has only SEQ_BLOCK_SIZE capacity and seq_data has all the data.
633  // I reuse len_code_ for both reading and writing using the fact that pod array
634  // does not change memory when resizing the array.
635  dict_ref_bytes_ = 0;
636  for (size_t i = 0; i < len_size; ++i, ++dest_index) {
637  uint32_t len_code = len_code_[i];
638  uint32_t len = len_code >> 1;
639 
640  if (IsDictCode(len_code)) {
641  ref_entry.reset(seq_data.data() + duplicate_seq_[duplicate_index++], len);
642  } else {
643  ref_entry.reset(key_src, len);
644  key_src += len;
645  }
646 
647  auto it = seq_map_.find(ref_entry);
648 
649  CHECK(it != seq_map_.end());
650  const EntryVal& ev = it->second;
651 
652  // Copy into dict_seq_ entries with larger reference count first.
653  if (ev.ref_cnt < 2) {
654  if (len + lit_data_.size() > SEQ_BLOCK_SIZE ||
655  dest_index >= kLenLimit) {
656  len_code_.resize_assume_reserved(dest_index);
657  CompressFlitSequences(false);
658  dest_index = 0;
659  }
660 
661  lit_data_.insert(ref_entry.begin(), ref_entry.end());
662  ++inline_index;
663 
664  len_code_[dest_index] = InlineCode(len);
665  } else {
666  dict_ref_bytes_ += len;
667  len_code_[dest_index] = DictCode(ev.dict_id);
668  }
669  }
670  CHECK_EQ(duplicate_index, duplicate_seq_.size());
671  CHECK_EQ(dict_ref_bytes_, ref_bytes);
672 
673  seq_map_.clear();
674  len_code_.resize_assume_reserved(dest_index);
675 }
676 
677 bool SeqEncoderBase::PrepareForSymbAvailability(uint32_t cnt) {
678  DCHECK_LE(cnt * sizeof(SymbId), lit_data_.capacity());
679 
680  // lit_data_.capacity() changes depending on the state of the encoder.
681  // It will be MAX_BATCH_SIZE when we analyze literal and sequence dictionaries.
682  // It will SEQ_BLOCK_SIZE once the dictionaries are determined.
683  DCHECK(lit_data_.capacity() == SEQ_BLOCK_SIZE || lit_data_.capacity() == MAX_BATCH_SIZE);
684  DCHECK(!seq_map_.empty() || lit_data_.capacity() == SEQ_BLOCK_SIZE);
685 
686  if (cnt * sizeof(SymbId) + lit_data_.size() > lit_data_.capacity()) {
687  if (!seq_map_.empty()) {
688  // Decide whether to use sequence dictionary.
689  AnalyzeSequenceDict();
690  return false;
691  }
692  CHECK_EQ(LIT_DICT, state_);
693  CompressFlitSequences(false);
694  }
695  return true;
696 }
697 
698 void SeqEncoderBase::BacktrackToRaw() {
699  if (lit_data_.size() >= SEQ_BLOCK_SIZE / 8 ) {
700  // There is enough intermediate data to create last dictionary based block.
701  CompressFlitSequences(false);
702  } else {
703  LOG(FATAL) << "TBD_INFLATE the current buffer back to raw literals";
704  }
705  state_ = NO_LIT_DICT;
706 }
707 
708 void SeqEncoderBase::AddEncodedSymbols(SymbId* src, uint32 cnt) {
709  uint8* next = end(lit_data_);
710 
711  uint32_t bytes_num = DeltaAndFlitEncode(src, cnt, tmp_symb_.begin(), next);
712 
713  ByteRange candidate(next, bytes_num);
714  uint32_t code = InlineCode(bytes_num);
715 
716  if (!seq_map_.empty()) {
717  bool new_entry = LearnSeqDict(candidate);
718 
719  if (!new_entry) {
720  code = DictCode(bytes_num);
721  bytes_num = 0;
722  }
723  } else if (!dict_seq_map_.empty()) {
724  auto it = dict_seq_map_.find(candidate);
725  if (it != dict_seq_map_.end()) {
726  code = DictCode(it->second);
727  dict_ref_bytes_ += bytes_num;
728  bytes_num = 0;
729  }
730  } else {
731  // just regular LIT_DICT, no seq learning or resolving.
732  }
733  len_code_.push_back(code);
734 
735  lit_data_.resize_assume_reserved(lit_data_.size() + bytes_num);
736 }
737 
738 void SeqEncoderBase::Flush() {
739  if (state_ == PRE_DICT) {
740  AnalyzePreDict();
741  }
742 
743  if (!seq_map_.empty()) {
744  AnalyzeSequenceDict();
745  }
746 
747  if (state_ == LIT_DICT) {
748  CompressFlitSequences(true);
749  } else {
750  CompressRawLit(true);
751  }
752 }
753 
754 
755 /****************************************************************
756  SeqEncoder
757 ******************************************************************************/
758 template <size_t INT_SIZE> SeqEncoder<INT_SIZE>::SeqEncoder() {
759  literal_size_ = INT_SIZE;
760 }
761 
762 
763 // TODO: fix reinterpret_cast issues (strict aliasing rules are broken).
764 template <size_t INT_SIZE> void SeqEncoder<INT_SIZE>::Add(const UT* src, unsigned cnt) {
765  CHECK_LT(cnt, kArrLengthLimit);
766  CHECK_GT(cnt, 0);
767 
768  size_t added_bytes = cnt * INT_SIZE;
769  bool finished = true;
770  do {
771  switch (state_) {
772  case PRE_DICT:
773  if (added_bytes + lit_data_.size() <= lit_data_.capacity()) {
774  uint8* dest = lit_data_.end();
775  memcpy(dest, src, INT_SIZE * cnt); // Assuming little endian.
776  len_code_.push_back(cnt);
777  lit_data_.resize(lit_data_.size() + added_bytes);
778  finished = true;
779  } else {
780  if (lit_data_.capacity() == 0) {
781  lit_data_.reserve(disable_seq_dict_ ? SEQ_BLOCK_SIZE : MAX_BATCH_SIZE);
782  } else {
783  // We have full batch ready.
784  AnalyzePreDict();
785  }
786  finished = false;
787  }
788  break;
789  case LIT_DICT:
790  finished = AddDictEncoded(src, cnt);
791  break;
792  case NO_LIT_DICT:
793  LOG(FATAL) << "TBD";
794  break;
795  }
796  } while (!finished);
797 
798  added_lit_cnt_ += cnt;
799 }
800 
801 template <size_t INT_SIZE> bool SeqEncoder<INT_SIZE>::GetDictSerialized(std::string* dest) {
802  if (state_ == NO_LIT_DICT)
803  return false;
804 
805  size_t max_size = kDictHeaderSize + dict_.GetMaxSerializedSize();
806  if (!dict_seq_.empty()) {
807  max_size += dict_seq_.GetMaxSerializedSize();
808  }
809 
810  dest->resize(max_size);
811 
812  uint8_t* start = reinterpret_cast<uint8_t*>(&dest->front());
813  size_t dict_sz = dict_.SerializeTo(start + kDictHeaderSize);
814 
815  size_t seq_sz = 0;
816  LittleEndian::Store24(start, dict_sz);
817  if (!dict_seq_.empty()) {
818  seq_sz = dict_seq_.SerializeTo(start + kDictHeaderSize + dict_sz);
819  }
820  LittleEndian::Store24(start + 3, seq_sz);
821 
822  dest->resize(dict_sz + seq_sz + kDictHeaderSize);
823 
824  return true;
825 }
826 
827 template <size_t INT_SIZE> bool SeqEncoder<INT_SIZE>::AddDictEncoded(const UT* src, unsigned cnt) {
828  if (!PrepareForSymbAvailability(cnt))
829  return false;
830 
831  // TODO: to resolve strict aliasing issues.
832  // SymbId* begin = reinterpret_cast<SymbId*>(end(lit_data_));
833 
834  if (!dict_.Resolve(src, cnt, tmp_symb_.begin())) {
835  // Dictionary is too large - we should to fallback.
836  BacktrackToRaw();
837  return false;
838  }
839  AddEncodedSymbols(tmp_symb_.begin(), cnt);
840 
841  return true;
842 }
843 
844 template <size_t INT_SIZE> uint32_t SeqEncoder<INT_SIZE>::PrepareDict() {
845  CHECK_EQ(0, dict_.alphabet_size());
846 
847  constexpr size_t kLiteralSize = INT_SIZE;
848 
849  uint32_t lit_cnt = lit_data_.size() / kLiteralSize;
850 
851  const UT* src = reinterpret_cast<UT*>(lit_data_.data());
852  for (uint32_t i = 0; i < lit_cnt; ++i) {
853  dict_.Add(src[i]);
854  }
855 
856  if (dict_.dict_size() >= LiteralDictBase::kMaxAlphabetSize || dict_.dict_size() >= lit_cnt/2) {
857  state_ = NO_LIT_DICT;
858  dict_.Clear();
859  return 0;
860  }
861 
862  dict_.Build();
863  state_ = LIT_DICT;
864 
865  SymbId* dest = reinterpret_cast<SymbId*>(lit_data_.data());
866 
867  for (uint32_t i = 0; i < lit_cnt; ++i) {
868  SymbId id = dict_.Resolve(src[i]);
869 
870  CHECK(id != LiteralDictBase::kInvalidId) << src[i];
871 
872  LittleEndian::Store16(dest + i, id);
873  }
874  lit_data_.resize_assume_reserved(lit_cnt * sizeof(SymbId));
875 
876  return dict_.alphabet_size();
877 }
878 
879 
880 // Decoder
881 SeqDecoderBase::SeqDecoderBase() {
882  data_buf_.reserve(SEQ_BLOCK_SIZE);
883 }
884 
885 SeqDecoderBase::~SeqDecoderBase() {}
886 
887 void SeqDecoderBase::SetDict(const uint8_t* src, unsigned cnt) {
888  CHECK_GE(cnt, 6);
889 
890  uint32_t lit_dict_sz = LittleEndian::Load24(src);
891  uint32_t seq_dict_sz = LittleEndian::Load24(src + 3);
892  CHECK_EQ(cnt, lit_dict_sz + seq_dict_sz + 6);
893  src += 6;
894 
895  SetLitDict(ByteRange(src, lit_dict_sz));
896 
897  src += lit_dict_sz;
898 
899  if (seq_dict_sz > 0) {
900  seq_dict_.SerializeFrom(src, seq_dict_sz);
901  }
902 
903  seq_dict_range_.clear();
904  for (auto it = seq_dict_.begin(); it != seq_dict_.end(); ++it) {
905  seq_dict_range_.push_back(*it);
906  }
907 }
908 
909 int SeqDecoderBase::Decompress(strings::ByteRange br, uint32_t* consumed) {
910  *consumed = 0;
911 
912  if (!read_header_) {
913  if (br.size() < 3) {
914  return -6;
915  }
916  uint16_t magic = LittleEndian::Load16(br.data());
917  CHECK_EQ(magic, kMagic);
918 
919  size_t hs = BlockHeader::HeaderSize(br[2]) + 2;
920  if (br.size() < hs) {
921  return -hs;
922  }
923  bh_.Read(br.data() + 2);
924  *consumed = hs;
925  br.advance(hs);
926 
927  read_header_ = true;
928  }
929 
930  size_t total_sz = bh_.byte_len_size_comprs + bh_.sequence_size_comprs;
931  if (br.size() < total_sz)
932  return -total_sz;
933 
934  if (bh_.flags & BlockHeader::kDictBit) {
935  if (!zstd_cntx_) {
936  zstd_cntx_.reset(new Zstd);
937  data_buf_.reserve(SEQ_BLOCK_SIZE * 2 + 8); // +8 to allow safe flit parsing.
938  code_buf_.reserve(SEQ_BLOCK_SIZE);
939  }
940 
941  DecompressCodes(br.data());
942  br.advance(bh_.byte_len_size_comprs);
943 
944  if (zstd_cntx_->start & 2) {
945  VLOG(1) << "Decompressing new frame";
946 
947  ZSTD_decompressBegin(zstd_cntx_->context);
948  zstd_cntx_->start &= 1;
949  } else {
950  zstd_cntx_->start ^= 1; // flip the buffer.
951  }
952 
953  next_flit_ptr_ = data_buf_.data() + zstd_cntx_->offset();
954  uint32_t sz;
955  while (true) {
956  sz = ZSTD_nextSrcSizeToDecompress(zstd_cntx_->context);
957  CHECK_LE(sz, br.size());
958  if (sz == 0) {
959  break;
960  }
961 
962  size_t res = ZSTD_decompressContinue(zstd_cntx_->context, next_flit_ptr_,
963  SEQ_BLOCK_SIZE, br.data(), sz);
964  CHECK_ZSTDERR(res);
965 
966  br.advance(sz);
967 
968  if (res > 0) {
969  data_buf_.resize_assume_reserved(next_flit_ptr_ + res - data_buf_.data());
970  sz = ZSTD_nextSrcSizeToDecompress(zstd_cntx_->context);
971 
972  break;
973  }
974 
975  }
976  next_seq_id_ = 0;
977 
978  if (sz == 0) {
979  zstd_cntx_->start |= 2;
980  DCHECK(bh_.flags & BlockHeader::kFinalBit);
981  }
982  } else {
983  size_t res = ZSTD_decompress(data_buf_.data(), data_buf_.capacity(), br.data(),
984  bh_.sequence_size_comprs);
985  CHECK(!ZSTD_isError(res)) << ZSTD_getErrorName(res);
986  data_buf_.resize_assume_reserved(res);
987  }
988 
989  *consumed += total_sz;
990  read_header_ = false;
991  return (bh_.flags & BlockHeader::kFinalBit) ? 0 : 1;
992 }
993 
994 void SeqDecoderBase::DecompressCodes(const uint8_t* src) {
995  len_code_.resize(bh_.num_sequences);
996 
997  size_t res = ZSTD_decompress(code_buf_.data(), code_buf_.capacity(), src,
998  bh_.byte_len_size_comprs);
999  CHECK(!ZSTD_isError(res)) << ZSTD_getErrorName(res);
1000 
1001  uint8_t* next = code_buf_.data(), *end = next + res;
1002  size_t index = 0;
1003  while (next < end && index < len_code_.size()) {
1004  next += flit::ParseT(next, &len_code_[index]);
1005  index++;
1006  }
1007  CHECK_EQ(index, len_code_.size());
1008  CHECK_EQ(next, end);
1009 }
1010 
1011 void SeqDecoderBase::InflateSequences() {
1012  const uint8_t* end_capacity = data_buf_.data() + data_buf_.capacity();
1013 
1014  for (; next_seq_id_ < len_code_.size(); ++next_seq_id_) {
1015  uint32_t len_code = len_code_[next_seq_id_];
1016  uint32_t len = 0;
1017  bool success = false;
1018  if (IsDictCode(len_code)) { // Dict
1019  CHECK(!seq_dict_.empty());
1020 
1021  uint32_t id = len_code >> 1;
1022  CHECK_LT(id, seq_dict_range_.size());
1023  const ByteRange& src = seq_dict_range_[id];
1024  success = AddFlitSeq(src);
1025  } else {
1026  len = len_code >> 1;
1027 
1028  CHECK_LE(next_flit_ptr_ + len, end_capacity);
1029  success = AddFlitSeq(ByteRange(next_flit_ptr_, len));
1030  }
1031 
1032  if (!success)
1033  break;
1034 
1035  // Must be after success check.
1036  next_flit_ptr_ += len;
1037  }
1038 }
1039 
1040 
1041 template<size_t INT_SIZE> SeqDecoder<INT_SIZE>::SeqDecoder() {
1042  int_buf_.reserve(SEQ_BLOCK_SIZE / INT_SIZE);
1043 }
1044 
1045 template<size_t INT_SIZE> auto SeqDecoder<INT_SIZE>::GetNextIntPage() -> IntRange {
1046  if (!(bh_.flags & BlockHeader::kDictBit)) {
1047  CHECK_EQ(0, data_buf_.size() % INT_SIZE);
1048 
1049  IntRange res(reinterpret_cast<UT*>(data_buf_.data()), data_buf_.size() / INT_SIZE);
1050 
1051  // clear() does not really change the underlying memory and it still belongs to data_buf_.
1052  // for non-dict encoding it will return an empty page for the next call.
1053  data_buf_.clear();
1054 
1055  return res;
1056  }
1057 
1058  next_int_ptr_ = int_buf_.data();
1059 
1060  InflateSequences();
1061 
1062  int_buf_.resize_assume_reserved(next_int_ptr_ - int_buf_.data());
1063 
1064  return IntRange(int_buf_.data(), next_int_ptr_);
1065 }
1066 
1067 template<size_t INT_SIZE> void SeqDecoder<INT_SIZE>::SetLitDict(strings::ByteRange br) {
1068  CHECK_EQ(0, br.size() % INT_SIZE);
1069 
1070  lit_dict_.resize(br.size() / INT_SIZE);
1071  memcpy(lit_dict_.data(), br.data(), br.size());
1072 }
1073 
1074 
1075 template<size_t INT_SIZE> bool SeqDecoder<INT_SIZE>::AddFlitSeq(strings::ByteRange src) {
1076  size_t sz = next_int_ptr_ - int_buf_.begin();
1077  if (sz >= int_buf_.capacity())
1078  return false;
1079 
1080  uint32_t left_capacity = int_buf_.capacity() - sz;
1081  uint32_t expanded =
1082  internal::DeflateFlitAndMap(
1083  src.data(), src.size(),
1084  [this](unsigned id) { return lit_dict_[id];},
1085  next_int_ptr_, left_capacity);
1086  if (expanded > left_capacity)
1087  return false;
1088 
1089  next_int_ptr_ += expanded;
1090 
1091  return true;
1092 }
1093 
1094 template class SeqEncoder<4>;
1095 template class SeqEncoder<8>;
1096 
1097 template class SeqDecoder<4>;
1098 template class SeqDecoder<8>;
1099 
1100 } // namespace util