output.h
1 // Copyright 2019, Beeri 15. All rights reserved.
2 // Author: Roman Gershman (romange@gmail.com)
3 //
4 
5 #pragma once
6 
7 #include "base/type_traits.h"
8 
9 #include "mr/mr3.pb.h"
10 #include "mr/mr_types.h"
11 
12 namespace mr3 {
13 
14 namespace detail {
15  template <typename OutT> class TableImplT;
16  inline bool IsBinary(pb::WireFormat::Type tp) { return tp == pb::WireFormat::LST; }
17 }
18 
19 class OutputBase {
20  public:
21  pb::Output* mutable_msg() { return out_; }
22  const pb::Output& msg() const { return *out_; }
23 
24  bool is_binary() const { return detail::IsBinary(out_->format().type()); }
25 
26  protected:
27  pb::Output* out_;
28 
29  OutputBase(pb::Output* out) : out_(out) {}
30 
31  void SetCompress(pb::Output::CompressType ct, int level);
32  void SetShardSpec(pb::ShardSpec::Type st, unsigned modn = 0);
33  void FailUndefinedShard() const;
34 };
35 
36 template <typename T> class Output : public OutputBase {
37  friend class detail::TableImplT<T>; // To allow the instantiation of Output<T>;
38 
39  using CustomShardingFunc = std::function<std::string(const T&)>;
40  using ModNShardingFunc = std::function<unsigned(const T&)>;
41 
42  absl::variant<absl::monostate, ShardId, ModNShardingFunc, CustomShardingFunc> shard_op_;
43  unsigned modn_ = 0;
44 
45  struct Visitor {
46  const T& t_;
47  unsigned modn_;
48 
49  Visitor(const T& t, unsigned modn) :t_(t), modn_(modn) {}
50 
51  ShardId operator()(const ShardId& id) const { return id; }
52  ShardId operator()(const ModNShardingFunc& func) const { return ShardId{func(t_) % modn_}; }
53  ShardId operator()(const CustomShardingFunc& func) const { return ShardId{func(t_)}; }
54  ShardId operator()(absl::monostate ms) const {
55  return ms;
56  }
57 
58  };
59 
60  public:
61  Output() : OutputBase(nullptr) {}
62 
63  template <typename U> Output& WithCustomSharding(U&& func) {
64  static_assert(base::is_invocable_r<std::string, U, const T&>::value, "");
65  shard_op_ = std::forward<U>(func);
66  SetShardSpec(pb::ShardSpec::USER_DEFINED);
67 
68  return *this;
69  }
70 
71  template <typename U> Output& WithModNSharding(unsigned modn, U&& func) {
72  static_assert(base::is_invocable_r<unsigned, U, const T&>::value, "");
73  shard_op_ = std::forward<U>(func);
74  SetShardSpec(pb::ShardSpec::MODN, modn);
75  modn_ = modn;
76 
77  return *this;
78  }
79 
80  Output& AndCompress(pb::Output::CompressType ct, int level = -10000);
81 
82  ShardId Shard(const T& t) const {
83  auto res = absl::visit(Visitor{t, modn_}, shard_op_);
84  if (absl::holds_alternative<absl::monostate>(res)) {
85  FailUndefinedShard();
86  }
87  return res;
88  }
89 
90  // TODO: to expose it for friends.
91  void SetConstantShard(ShardId sid) { shard_op_ = std::move(sid); }
92 
93  private:
94  Output(pb::Output* out) : OutputBase(out) {}
95 };
96 
97 template <typename OutT>
98 Output<OutT>& Output<OutT>::AndCompress(pb::Output::CompressType ct, int level) {
99  SetCompress(ct, level);
100  return *this;
101 }
102 
103 } // namespace mr3