4 #include <boost/asio/write.hpp> 5 #include <boost/fiber/future.hpp> 8 #include "util/rpc/channel.h" 10 #include "base/logging.h" 11 #include "util/asio/asio_utils.h" 12 #include "util/rpc/frame_format.h" 13 #include "util/rpc/rpc_envelope.h" 18 DEFINE_uint32(rpc_client_pending_limit, 1 << 17,
19 "How many outgoing requests we are ready to accommodate before rejecting " 22 DEFINE_uint32(rpc_client_queue_size, 128,
23 "The size of the outgoing batch queue that contains envelopes waiting to send.");
25 using namespace boost;
28 using folly::RWSpinLock;
29 namespace error = asio::error;
33 bool IsExpectedFinish(system::error_code ec) {
34 return ec == error::eof || ec == error::operation_aborted || ec == error::connection_reset ||
35 ec == error::not_connected;
38 constexpr uint32_t kTickPrecision = 3;
45 CHECK(read_fiber_.joinable());
48 VLOG(1) <<
"After ReadFiberJoin";
51 void Channel::Shutdown() {
53 socket_->Shutdown(ec);
56 auto Channel::Connect(uint32_t ms) -> error_code {
57 CHECK(!read_fiber_.joinable());
58 error_code ec = socket_->ClientWaitToConnect(ms);
60 IoContext& context = socket_->context();
61 context.Await([
this] {
62 read_fiber_ = fibers::fiber(&Channel::ReadFiber,
this);
63 flush_fiber_ = fibers::fiber(&Channel::FlushFiber,
this);
65 expiry_task_.reset(
new PeriodicTask(context, chrono::milliseconds(kTickPrecision)));
67 expiry_task_->Start([
this](
int ticks) {
69 this->expire_timer_.advance(ticks);
70 DVLOG(3) <<
"Advancing expiry to " << this->expire_timer_.now();
76 auto Channel::PresendChecks() -> error_code {
77 if (!socket_->is_open()) {
78 return asio::error::shut_down;
81 if (socket_->status()) {
82 return socket_->status();
85 if (pending_calls_size_.load(std::memory_order_relaxed) >= FLAGS_rpc_client_pending_limit) {
86 return asio::error::no_buffer_space;
91 if (outgoing_buf_size_.load(std::memory_order_relaxed) >= FLAGS_rpc_client_queue_size) {
97 auto Channel::Send(uint32 deadline_msec, Envelope* envelope) -> future_code_t {
98 DCHECK(read_fiber_.joinable()) <<
"Call Channel::Connect(), stupid.";
99 DCHECK_GT(deadline_msec, 0);
102 fibers::promise<error_code> p;
103 fibers::future<error_code> res = p.get_future();
104 error_code ec = PresendChecks();
111 uint32_t ticks = (deadline_msec + kTickPrecision - 1) / kTickPrecision;
112 std::unique_ptr<ExpiryEvent> ev(
new ExpiryEvent(
this));
119 bool lock_exclusive = OutgoingBufLock();
120 RpcId
id = next_send_rpc_id_++;
123 base::Tick at = expire_timer_.schedule(ev.get(), ticks);
124 DVLOG(2) <<
"Scheduled expiry at " << at <<
" for rpcid " << id;
126 outgoing_buf_.emplace_back(SendItem(
id, PendingCall{std::move(p), envelope}));
127 outgoing_buf_.back().second.expiry_event = std::move(ev);
128 outgoing_buf_size_.store(outgoing_buf_.size(), std::memory_order_relaxed);
130 OutgoingBufUnlock(lock_exclusive);
135 auto Channel::SendAndReadStream(Envelope* msg, MessageCallback cb) -> error_code {
136 DCHECK(read_fiber_.joinable());
139 error_code ec = PresendChecks();
144 fibers::promise<error_code> p;
145 fibers::future<error_code> future = p.get_future();
152 bool exclusive = OutgoingBufLock();
154 RpcId
id = next_send_rpc_id_++;
156 outgoing_buf_.emplace_back(SendItem(
id, PendingCall{std::move(p), msg, std::move(cb)}));
157 outgoing_buf_size_.store(outgoing_buf_.size(), std::memory_order_relaxed);
159 OutgoingBufUnlock(exclusive);
165 void Channel::ReadFiber() {
166 CHECK(socket_->context().InContextThread());
168 VLOG(1) <<
"Start ReadFiber on socket " << socket_->native_handle();
169 this_fiber::properties<IoFiberProperties>().SetNiceLevel(1);
171 while (socket_->is_open()) {
172 error_code ec = ReadEnvelope();
174 LOG_IF(WARNING, !IsExpectedFinish(ec))
175 <<
"Error reading envelope " << ec <<
" " << ec.message();
177 CancelPendingCalls(ec);
182 this_fiber::sleep_for(10ms);
186 CancelPendingCalls(error_code{});
187 VLOG(1) <<
"Finish ReadFiber on socket " << socket_->native_handle();
192 void Channel::FlushFiber() {
193 using namespace std::chrono_literals;
194 CHECK(socket_->context().get_executor().running_in_this_thread());
195 this_fiber::properties<IoFiberProperties>().SetNiceLevel(IoFiberProperties::MAX_NICE_LEVEL - 1);
198 this_fiber::sleep_for(300us);
199 if (!socket_->is_open())
202 if (outgoing_buf_size_.load(std::memory_order_acquire) == 0 || !send_mu_.try_lock())
204 VLOG(1) <<
"FlushFiber::FlushSendsGuarded";
206 outgoing_buf_size_.store(outgoing_buf_.size(), std::memory_order_release);
212 auto Channel::FlushSends() -> error_code {
215 std::lock_guard<fibers::mutex> guard(send_mu_);
221 while (outgoing_buf_.size() >= FLAGS_rpc_client_queue_size) {
222 ec = FlushSendsGuarded();
224 outgoing_buf_size_.store(outgoing_buf_.size(), std::memory_order_relaxed);
229 auto Channel::FlushSendsGuarded() -> error_code {
232 if (outgoing_buf_.empty())
235 ec = socket_->status();
237 CancelSentBufferGuarded(ec);
243 RWSpinLock::ReadHolder holder(buf_lock_);
245 size_t count = outgoing_buf_.size();
246 write_seq_.resize(count * 3);
247 frame_buf_.resize(count);
248 for (
size_t i = 0; i < count; ++i) {
249 auto& p = outgoing_buf_[i];
250 Frame f(p.first, p.second.envelope->header.size(), p.second.envelope->letter.size());
251 size_t sz = f.Write(frame_buf_[i].data());
253 write_seq_[3 * i] = asio::buffer(frame_buf_[i].data(), sz);
254 write_seq_[3 * i + 1] = asio::buffer(p.second.envelope->header);
255 write_seq_[3 * i + 2] = asio::buffer(p.second.envelope->letter);
261 pending_calls_size_.fetch_add(count, std::memory_order_relaxed);
262 for (
size_t i = 0; i < count; ++i) {
263 auto& item = outgoing_buf_[i];
264 auto emplace_res = pending_calls_.emplace(item.first, std::move(item.second));
265 CHECK(emplace_res.second);
267 outgoing_buf_.clear();
273 asio::write(*socket_, write_seq_, ec);
276 CancelPendingCalls(ec);
283 void Channel::ExpirePending(RpcId
id) {
284 DVLOG(1) <<
"Expire rpc id " << id;
286 auto it = this->pending_calls_.find(
id);
287 if (it == this->pending_calls_.end()) {
292 EcPromise pr = std::move(it->second.promise);
293 this->pending_calls_.erase(it);
294 pr.set_value(asio::error::timed_out);
297 void Channel::CancelSentBufferGuarded(error_code ec) {
298 std::vector<SendItem> tmp;
300 buf_lock_.lock_shared();
301 tmp.swap(outgoing_buf_);
302 buf_lock_.unlock_shared();
304 for (
auto& item : tmp) {
305 auto promise = std::move(item.second.promise);
306 promise.set_value(ec);
310 auto Channel::ReadEnvelope() -> error_code {
312 error_code ec = f.Read(socket_.get());
316 VLOG(2) <<
"Got rpc_id " << f.rpc_id <<
" from socket " << socket_->native_handle();
318 auto it = pending_calls_.find(f.rpc_id);
319 if (it == pending_calls_.end()) {
322 VLOG(1) <<
"Unknown id " << f.rpc_id;
324 Envelope envelope(f.header_size, f.letter_size);
327 asio::read(*socket_, envelope.buf_seq(), ec);
332 PendingCall& call = it->second;
333 Envelope* env = call.envelope;
334 env->Resize(f.header_size, f.letter_size);
335 bool is_stream = static_cast<bool>(call.cb);
338 VLOG(1) <<
"Processing stream";
339 asio::read(*socket_, env->buf_seq(), ec);
341 HandleStreamResponse(f.rpc_id);
347 fibers::promise<error_code> promise = std::move(call.promise);
350 pending_calls_.erase(it);
351 pending_calls_size_.fetch_sub(1, std::memory_order_relaxed);
354 asio::read(*socket_, env->buf_seq(), ec);
355 promise.set_value(ec);
360 void Channel::HandleStreamResponse(RpcId rpc_id) {
361 auto it = pending_calls_.find(rpc_id);
362 if (it == pending_calls_.end()) {
365 PendingCall& call = it->second;
366 error_code ec = call.cb(*call.envelope);
371 if (ec == error::eof) {
372 ec = system::error_code{};
378 auto promise = std::move(call.promise);
379 pending_calls_.erase(it);
380 promise.set_value(ec);
383 void Channel::CancelPendingCalls(error_code ec) {
384 if (pending_calls_.empty())
388 tmp.swap(pending_calls_);
389 pending_calls_size_.store(0, std::memory_order_relaxed);
394 for (
auto& c : tmp) {
395 c.second.promise.set_value(ec);
398 if (pending_calls_.empty()) {
399 tmp.swap(pending_calls_);