file_printer.cc
1 // Copyright 2018, Beeri 15. All rights reserved.
2 // Author: Roman Gershman (roman@ubimo.com)
3 //
4 #include "util/pprint/file_printer.h"
5 
6 #include <google/protobuf/compiler/importer.h>
7 
8 #include "absl/strings/escaping.h"
9 #include "base/flags.h"
10 #include "base/hash.h"
11 #include "base/logging.h"
12 #include "base/map-util.h"
13 #include "file/list_file.h"
14 #include "file/proto_writer.h"
15 #include "util/pb2json.h"
16 #include "util/plang/plang_parser.hh"
17 #include "util/plang/plang_scanner.h"
18 #include "util/pprint/pprint_utils.h"
19 
20 DEFINE_string(protofiles, "", "");
21 DEFINE_string(proto_db_file, "s3://test/roman/proto_db.lst", "");
22 DEFINE_string(type, "", "");
23 
24 DEFINE_string(where, "", "boolean constraint in plang language");
25 DEFINE_bool(sizes, false, "Prints a rough estimation of the size of every field");
26 DEFINE_bool(json, false, "");
27 DEFINE_bool(raw, false, "");
28 DEFINE_string(sample_key, "", "");
29 DEFINE_int32(sample_factor, 0, "If bigger than 0 samples and outputs record once in k times");
30 DEFINE_bool(parallel, true, "");
31 DEFINE_bool(count, false, "");
32 
33 
34 namespace util {
35 namespace pprint {
36 
37 using namespace std;
38 using namespace file;
39 namespace gpc = gpb::compiler;
40 
41 using strings::AsString;
42 
43 class ErrorCollector : public gpc::MultiFileErrorCollector {
44  void AddError(const string& filenname, int line, int column, const string& message) {
45  std::cerr << "Error File : " << filenname << " : " << message << std::endl;
46  }
47 };
48 
49 static const gpb::Descriptor* FindDescriptor() {
50  CHECK(!FLAGS_type.empty()) << "type must be filled. For example: --type=foursquare.Category";
51  const gpb::DescriptorPool* gen_pool = gpb::DescriptorPool::generated_pool();
52  const gpb::Descriptor* descriptor = gen_pool->FindMessageTypeByName(FLAGS_type);
53  if (descriptor)
54  return descriptor;
55 
56  gpc::DiskSourceTree tree;
57  tree.MapPath("START_FILE", FLAGS_protofiles);
58  ErrorCollector collector;
59  gpc::Importer importer(&tree, &collector);
60  if (!FLAGS_protofiles.empty()) {
61  // TODO: to support multiple files some day.
62  CHECK(importer.Import("START_FILE"));
63  }
64  descriptor = importer.pool()->FindMessageTypeByName(FLAGS_type);
65  if (descriptor)
66  return descriptor;
67  static gpb::SimpleDescriptorDatabase proto_db;
68  static gpb::DescriptorPool proto_db_pool(&proto_db);
69  file::ListReader reader(FLAGS_proto_db_file);
70  string record_buf;
71  StringPiece record;
72  while (reader.ReadRecord(&record, &record_buf)) {
73  gpb::FileDescriptorProto* fdp = new gpb::FileDescriptorProto;
74  CHECK(fdp->ParseFromArray(record.data(), record.size()));
75  proto_db.AddAndOwn(fdp);
76  }
77  descriptor = proto_db_pool.FindMessageTypeByName(FLAGS_type);
78  CHECK(descriptor) << "Can not find " << FLAGS_type << " in the proto pool.";
79  return descriptor;
80 }
81 
82 static bool ShouldSkip(const gpb::Message& msg, const FdPath& fd_path) {
83  if (FLAGS_sample_factor <= 0 || FLAGS_sample_key.empty())
84  return false;
85  const string* val = nullptr;
86  string buf;
87  auto cb = [&val, &buf](const gpb::Message& msg, const gpb::FieldDescriptor* fd, int, int) {
88  const gpb::Reflection* refl = msg.GetReflection();
89  val = &refl->GetStringReference(msg, fd, &buf);
90  };
91  fd_path.ExtractValue(msg, cb);
92  CHECK(val);
93 
94  uint32 num = base::Fingerprint(*val);
95 
96  return (num % FLAGS_sample_factor) != 0;
97 }
98 
100  public:
101  typedef PrintSharedData* SharedData;
102 
103  void InitShared(SharedData d) {
104  shared_data_ = d;
105  }
106 
107  PrintTask(const gpb::Message* to_clone, const Pb2JsonOptions& opts) : options_(opts) {
108  if (to_clone) {
109  local_msg_.reset(to_clone->New());
110  }
111  if (!FLAGS_sample_key.empty()) {
112  fd_path_ = FdPath(to_clone->GetDescriptor(), FLAGS_sample_key);
113  CHECK(!fd_path_.IsRepeated());
114  CHECK_EQ(gpb::FieldDescriptor::CPPTYPE_STRING, fd_path_.path().back()->cpp_type());
115  }
116  }
117 
118  void operator()(const std::string& obj) {
119  if (FLAGS_raw) {
120  std::lock_guard<mutex> lock(shared_data_->m);
121  std::cout << absl::Utf8SafeCEscape(obj) << "\n";
122  return;
123  }
124  CHECK(local_msg_->ParseFromString(obj));
125  if (shared_data_->expr && !plang::EvaluateBoolExpr(*shared_data_->expr, *local_msg_))
126  return;
127 
128  if (ShouldSkip(*local_msg_, fd_path_))
129  return;
130  std::lock_guard<mutex> lock(shared_data_->m);
131 
132  if (FLAGS_sizes) {
133  shared_data_->size_summarizer->AddSizes(*local_msg_);
134  return;
135  }
136 
137  if (FLAGS_json) {
138  string str = Pb2Json(*local_msg_, options_);
139  std::cout << str << "\n";
140  } else {
141  shared_data_->printer->Output(*local_msg_);
142  }
143  }
144 
145  private:
146  std::unique_ptr<gpb::Message> local_msg_;
147  FdPath fd_path_;
148  SharedData shared_data_;
149  Pb2JsonOptions options_;
150 };
151 
152 FilePrinter::FilePrinter() {}
153 FilePrinter::~FilePrinter() {}
154 
155 void FilePrinter::Init(const string& fname) {
156  CHECK(!descr_msg_);
157 
158  if (!FLAGS_where.empty()) {
159  std::istringstream istr(FLAGS_where);
160  plang::Scanner scanner(&istr);
161  plang::Parser parser(&scanner, &test_expr_);
162  CHECK_EQ(0, parser.parse()) << "Could not parse " << FLAGS_where;
163  }
164 
165  LoadFile(fname);
166 
167  if (descr_msg_) {
168  if (FLAGS_sizes)
169  size_summarizer_.reset(new SizeSummarizer(descr_msg_->GetDescriptor()));
170  printer_.reset(new Printer(descr_msg_->GetDescriptor(), field_printer_cb_));
171  } else {
172  CHECK(!FLAGS_sizes);
173  }
174 
175  pool_.reset(new TaskPool("pool", 10));
176 
177  shared_data_.size_summarizer = size_summarizer_.get();
178  shared_data_.printer = printer_.get();
179  shared_data_.expr = test_expr_.get();
180  pool_->SetSharedData(&shared_data_);
181  pool_->Launch(descr_msg_.get(), options_);
182 
183  if (FLAGS_parallel) {
184  LOG(INFO) << "Running in parallel " << pool_->thread_count() << " threads";
185  }
186 }
187 
188 auto FilePrinter::GetDescriptor() const -> const Descriptor* {
189  return descr_msg_ ? descr_msg_->GetDescriptor() : nullptr;
190 }
191 
192 Status FilePrinter::Run() {
193  StringPiece record;
194  while (true) {
195  // Reads raw record from the file.
196  util::StatusObject<bool> res = Next(&record);
197  if (!res.ok())
198  return res.status;
199  if (!res.obj)
200  break;
201  if (FLAGS_count) {
202  ++count_;
203  } else {
204  if (FLAGS_parallel) {
205  pool_->RunTask(AsString(record));
206  } else {
207  pool_->RunInline(AsString(record));
208  }
209  }
210  }
211  pool_->WaitForTasksToComplete();
212 
213  if (size_summarizer_.get())
214  std::cout << *size_summarizer_ << "\n";
215  return Status::OK;
216 }
217 
218 
219 void ListReaderPrinter::LoadFile(const std::string& fname) {
220  auto corrupt_cb = [this](size_t bytes, const util::Status& status) { st_ = status; };
221 
222  reader_.reset(new ListReader(fname, false, corrupt_cb));
223 
224  if (!FLAGS_raw && !FLAGS_count) {
225  std::map<std::string, std::string> meta;
226  if (!reader_->GetMetaData(&meta)) {
227  LOG(FATAL) << "Error fetching metadata from " << fname;
228  }
229  string ptype = FindValueWithDefault(meta, file::kProtoTypeKey, string());
230  string fd_set = FindValueWithDefault(meta, file::kProtoSetKey, string());
231  if (!ptype.empty() && !fd_set.empty())
232  descr_msg_.reset(AllocateMsgByMeta(ptype, fd_set));
233  else
234  descr_msg_.reset(AllocateMsgFromDescr(FindDescriptor()));
235  }
236 }
237 
238 util::StatusObject<bool> ListReaderPrinter::Next(StringPiece* record) {
239  bool res = reader_->ReadRecord(record, &record_buf_);
240  if (!st_.ok())
241  return st_;
242 
243  return res;
244 }
245 
246 void ListReaderPrinter::PostRun() {
247  LOG(INFO) << "Data bytes: " << reader_->read_data_bytes()
248  << " header bytes: " << reader_->read_header_bytes();
249 }
250 
251 } // namespace pprint
252 } // namespace util