4 #include "util/plang/plang.h" 8 #include "absl/strings/ascii.h" 9 #include "base/logging.h" 10 #include "strings/hash.h" 11 #include "util/math/mathutil.h" 13 #include <google/protobuf/message.h> 14 #include <google/protobuf/repeated_field.h> 17 using namespace std::placeholders;
18 using strings::AsString;
22 typedef std::pair<const gpb::Message*, const gpb::FieldDescriptor*> MsgDscrPair;
23 typedef std::vector<std::tuple<uint32, int, MsgDscrPair>> PathState;
25 static const gpb::Message* AdvanceState(PathState* state) {
26 while (!state->empty()) {
27 auto& b = state->back();
28 int index = ++std::get<1>(b);
29 MsgDscrPair& result = std::get<2>(b);
30 CHECK(result.second->is_repeated());
31 const gpb::Reflection* refl = result.first->GetReflection();
32 if (index >= refl->FieldSize(*result.first, result.second)) {
35 return &refl->GetRepeatedMessage(*result.first, result.second, index);
41 static void RetrieveNode(
const gpb::Message* msg, StringPiece path,
42 std::function<
void(
const MsgDscrPair&)> cb) {
43 MsgDscrPair result(msg,
nullptr);
48 while (result.first !=
nullptr) {
49 size_t next = path.find(
'.', start);
50 StringPiece part = path.substr(start, next - start);
51 VLOG(2) <<
"Looking for " << part <<
" in " << result.first->GetDescriptor()->name()
53 result.second = result.first->GetDescriptor()->FindFieldByName(AsString(part));
54 CHECK(result.second !=
nullptr) <<
"Could not find field " << part;
55 if (next == string::npos) {
57 if ((result.first = AdvanceState(&state)) !=
nullptr) {
58 start = std::get<0>(state.back());
62 CHECK_EQ(result.second->cpp_type(), gpb::FieldDescriptor::CPPTYPE_MESSAGE)
63 << part <<
" is not a message.";
64 const gpb::Reflection* refl = result.first->GetReflection();
66 if (result.second->is_repeated()) {
67 if (refl->FieldSize(*result.first, result.second) > 0) {
68 state.push_back(std::make_tuple(start, 0, result));
69 result.first = &refl->GetRepeatedMessage(*result.first, result.second, 0);
71 if ((result.first = AdvanceState(&state)) !=
nullptr) {
72 start = std::get<0>(state.back());
76 result.first = &refl->GetMessage(*result.first, result.second);
81 double ExprValue::PromoteToDouble()
const {
90 LOG(FATAL) <<
"Not supported " << type;
95 bool ExprValue::Equal(
const ExprValue& other)
const {
96 CppType t1 = type, t2 = other.type;
100 return val.int_val == other.val.int_val;
102 return val.str == other.val.str;
104 return MathUtil::AlmostEquals(val.d_val, other.val.d_val);
106 return val.enum_val == other.val.enum_val;
108 LOG(FATAL) <<
"Not supported " << type;
111 if (t1 == CPPTYPE_ENUM) {
114 return val.enum_val->number() == other.val.int_val;
116 return val.enum_val->name() == AsString(other.val.str);
118 LOG(FATAL) <<
"Unsupported type for comparing with enum " << t2;
121 if (t1 == CPPTYPE_DOUBLE && t2 == CPPTYPE_INT64) {
122 return MathUtil::AlmostEquals(val.d_val,
double(other.val.int_val));
124 if (t2 == CPPTYPE_ENUM || t2 == CPPTYPE_DOUBLE) {
125 return other.Equal(*
this);
129 CHECK_LE(other.type, 4);
133 if (t1 == CPPTYPE_INT64) {
136 return uint64(val.int_val) == other.val.uint_val;
138 DCHECK_EQ(CPPTYPE_INT64, t2);
139 if (other.val.int_val < 0)
141 return uint64(other.val.int_val) == val.uint_val;
144 LOG(FATAL) <<
"Unsupported combination " << t1 <<
" and " << t2;
148 bool ExprValue::Less(
const ExprValue& other)
const {
149 CppType t1 = type, t2 = other.type;
153 if (t1 == CPPTYPE_DOUBLE || t2 == CPPTYPE_DOUBLE) {
154 double d1 = PromoteToDouble();
155 double d2 = other.PromoteToDouble();
158 if (t1 == CPPTYPE_INT64) {
161 return uint64(val.int_val) < other.val.uint_val;
163 DCHECK_EQ(CPPTYPE_INT64, t2);
164 if (other.val.int_val <= 0)
166 return val.uint_val < uint64(other.val.int_val);
170 return val.int_val < other.val.int_val;
172 return val.uint_val < other.val.uint_val;
174 return val.d_val < other.val.d_val;
176 LOG(FATAL) <<
"Not supported " << type;
181 bool ExprValue::RLike(
const ExprValue& other)
const {
182 CppType t1 = type, t2 = other.type;
183 CHECK(t1 == CPPTYPE_STRING && t2 == CPPTYPE_STRING);
187 return std::regex_match(val.str.begin(), val.str.end(),
188 std::regex(other.val.str.begin(), other.val.str.end()));
191 static void EvalField(Expr::ExprValueCb cb, MsgDscrPair msg_dscr) {
192 const gpb::Message* pmsg = msg_dscr.first;
193 const gpb::FieldDescriptor* fd = msg_dscr.second;
194 const gpb::Reflection* refl = pmsg->GetReflection();
196 typedef gpb::FieldDescriptor FD;
197 if (fd->is_repeated()) {
198 switch (fd->cpp_type()) {
199 case FD::CPPTYPE_INT32: {
200 const auto& arr = refl->GetRepeatedField<int32>(*pmsg, fd);
201 for (int32 val : arr) {
202 cb(ExprValue::fromInt(val));
206 case FD::CPPTYPE_UINT32: {
207 const auto& arr = refl->GetRepeatedField<uint32>(*pmsg, fd);
208 for (uint32 val : arr) {
209 cb(ExprValue::fromInt(val));
214 LOG(FATAL) <<
"Not supported repeated " << fd->cpp_type_name();
217 switch (fd->cpp_type()) {
218 case FD::CPPTYPE_INT32:
219 cb(ExprValue::fromInt(refl->GetInt32(*pmsg, fd)));
221 case FD::CPPTYPE_UINT32:
222 cb(ExprValue::fromUInt(refl->GetUInt32(*pmsg, fd)));
224 case FD::CPPTYPE_INT64:
225 cb(ExprValue::fromInt(refl->GetInt64(*pmsg, fd)));
227 case FD::CPPTYPE_UINT64:
228 cb(ExprValue::fromUInt(refl->GetUInt64(*pmsg, fd)));
230 case FD::CPPTYPE_STRING: {
232 cb(ExprValue(refl->GetStringReference(*pmsg, fd, &tmp)));
235 case FD::CPPTYPE_FLOAT:
236 cb(ExprValue::fromDouble(refl->GetFloat(*pmsg, fd)));
238 case FD::CPPTYPE_DOUBLE:
239 cb(ExprValue::fromDouble(refl->GetDouble(*pmsg, fd)));
241 case FD::CPPTYPE_BOOL:
242 cb(ExprValue::fromInt(refl->GetBool(*pmsg, fd)));
244 case FD::CPPTYPE_ENUM:
245 cb(ExprValue(refl->GetEnum(*pmsg, fd)));
248 LOG(FATAL) <<
"Not supported yet " << fd->cpp_type_name();
252 void StringTerm::eval(
const gpb::Message& msg, ExprValueCb cb)
const {
253 if (type_ == CONST) {
257 RetrieveNode(&msg, val_, std::bind(&EvalField, cb, _1));
260 template <
typename T,
typename U>
bool IsOneOf(T&& t,
const U&& u) {
262 using namespace std::placeholders;
265 template <
typename T,
typename U1,
typename... U2>
bool IsOneOf(T&& t, U1&& u, U2&&... rest) {
266 return t == u || IsOneOf(t, rest...);
269 void BinOp::eval(
const gpb::Message& msg, ExprValueCb cb)
const {
271 Expr::ExprValueCb local_cb;
274 local_cb = [
this, &res, &msg](
const ExprValue& val_left) {
275 right_->eval(msg, [&val_left, &res](
const ExprValue& val_right) {
276 if (val_left.Equal(val_right))
284 local_cb = [
this, &res, &msg](
const ExprValue& val_left) {
285 right_->eval(msg, [&val_left, &res](
const ExprValue& val_right) {
286 if (val_left.RLike(val_right))
294 local_cb = [
this, &res, &msg](
const ExprValue& val_left) {
295 CHECK_EQ(ExprValue::CPPTYPE_BOOL, val_left.type);
296 if (!val_left.val.bool_val)
299 right_->eval(msg, [&res](
const ExprValue& val_right) {
300 CHECK_EQ(ExprValue::CPPTYPE_BOOL, val_right.type);
301 if (!val_right.val.bool_val)
310 local_cb = [
this, &res, &msg](
const ExprValue& val_left) {
311 CHECK_EQ(ExprValue::CPPTYPE_BOOL, val_left.type);
312 if (val_left.val.bool_val) {
316 right_->eval(msg, [&res](
const ExprValue& val_right) {
317 CHECK_EQ(ExprValue::CPPTYPE_BOOL, val_right.type);
318 if (!val_right.val.bool_val)
327 local_cb = [
this, &res, &msg](
const ExprValue& val_left) {
328 right_->eval(msg, [&val_left, &res](
const ExprValue& val_right) {
329 if (val_left.Less(val_right))
337 local_cb = [&](
const ExprValue& val_left) {
338 right_->eval(msg, [&val_left, &res](
const ExprValue& val_right) {
339 if (val_left.Less(val_right) || val_left.Equal(val_right))
347 local_cb = [&](
const ExprValue& val_left) {
348 CHECK_EQ(ExprValue::CPPTYPE_BOOL, val_left.type);
349 if (!val_left.val.bool_val) {
356 left_->eval(msg, local_cb);
357 cb(ExprValue::fromBool(res));
360 FunctionTerm::FunctionTerm(
const std::string& name, ArgList&& lst)
361 : name_(name), args_(std::move(lst)) {
362 absl::AsciiStrToLower(&name_);
365 FunctionTerm::~FunctionTerm() {
370 void FunctionTerm::eval(
const gpb::Message& msg, ExprValueCb cb)
const {
371 if (name_ ==
"hash") {
372 if (args_.size() != 1)
373 LOG(FATAL) <<
"hash() accepts a single argument";
375 args_[0]->eval(msg, [&res](
const ExprValue& val) {
376 CHECK_EQ(ExprValue::CPPTYPE_STRING, val.type);
377 res = std::hash<StringPiece>()(val.val.str);
380 cb(ExprValue::fromUInt(res));
382 LOG(FATAL) <<
"Unknown function";
386 static void IsDefField(Expr::ExprValueCb cb, MsgDscrPair msg_dscr) {
387 const gpb::Message* pmsg = msg_dscr.first;
388 const gpb::FieldDescriptor* fd = msg_dscr.second;
389 const gpb::Reflection* refl = pmsg->GetReflection();
390 bool res = fd->is_repeated() ? refl->FieldSize(*pmsg, fd) > 0 : refl->HasField(*pmsg, fd);
391 cb(ExprValue::fromBool(res));
394 void IsDefFun::eval(
const gpb::Message& msg, ExprValueCb cb)
const {
395 VLOG(1) <<
"IsDefFun " << name_;
396 RetrieveNode(&msg, name_, std::bind(&IsDefField, cb, _1));
399 bool EvaluateBoolExpr(
const Expr& e,
const gpb::Message& msg) {
402 CHECK_EQ(ExprValue::CPPTYPE_BOOL, val.type);
403 res = val.val.bool_val;