root/trunk/whisperlib/net/rpc/lib/server/rpc_server_connection.cc

Revision 7, 16.2 kB (checked in by whispercastorg, 2 years ago)

version 0.2.0

Line 
1 // Copyright (c) 2009, Whispersoft s.r.l.
2 // All rights reserved.
3 //
4 // Redistribution and use in source and binary forms, with or without
5 // modification, are permitted provided that the following conditions are
6 // met:
7 //
8 // * Redistributions of source code must retain the above copyright
9 // notice, this list of conditions and the following disclaimer.
10 // * Redistributions in binary form must reproduce the above
11 // copyright notice, this list of conditions and the following disclaimer
12 // in the documentation and/or other materials provided with the
13 // distribution.
14 // * Neither the name of Whispersoft s.r.l. nor the names of its
15 // contributors may be used to endorse or promote products derived from
16 // this software without specific prior written permission.
17 //
18 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 //
30 // Author: Cosmin Tudorache
31
32 #include "common/base/common.h"
33 #include "common/base/errno.h"
34 #include "common/base/log.h"
35 #include "common/base/scoped_ptr.h"
36
37 #include "net/rpc/lib/server/rpc_server_connection.h"
38 #include "net/rpc/lib/codec/rpc_codec.h"
39 #include "net/rpc/lib/codec/binary/rpc_binary_codec.h"
40 #include "net/rpc/lib/codec/json/rpc_json_codec.h"
41 #include "net/rpc/lib/rpc_version.h"
42
43 namespace rpc {
44
45 void rpc::ServerConnection::ProtocolHandleHandshake(io::MemoryStream& in) {
46   // It's a 3 way handshake. The involved packets are as follows:
47   //  1. client -> server
48   //    - 3 bytes: "rpc"
49   //    - 2 bytes: rpc-protocol-version (hi,lo)
50   //    - 1 byte: codec identifier.
51   //    - 32 bytes: client random generated data.
52   // 2. server -> client
53   //    - 3 bytes: "rpc"
54   //    - 2 bytes: rpc-protocol-version. Should match the client version.
55   //               Otherwise don't send this packet and drop the handshake.
56   //    - 1 byte: codec identifier. Should be identical with the client
57   //               indicated value. Otherwise don't send this packet and
58   //               drop the handshake.
59   //    - 32 bytes: server random generated data. Different from the
60   //                client data.
61   //    - 32 bytes: repeat client data.
62   // 3. client -> server
63   //    - 3 bytes: "rpc"
64   //    - 2 bytes: rpc-protocol-version. Should match the server version.
65   //               Otherwise don't send this packet and drop the handshake.
66   //    - 1 byte: codec identifier. Identical with the first packet.
67   //    - 32 bytes: repeat server data.
68   //
69
70   CHECK_NE(handshakeState_, HS_CONNECTED) << "Handshake already done!";
71   CHECK_NE(handshakeState_, HS_FAILURE) << "Handshake already failed!";
72
73   // protocol constant = 3 + 2 + 1 + HANDSHAKE_RANDOM_SIZE
74   if ( in.Size() < 38 ) {
75     // not enough data. Wait for more.
76     return;
77   }
78
79   // decode the handshake
80   char mark[4] = { 0, };
81   in.Read(mark, 3);
82   const uint8 versHi = io::NumStreamer::ReadByte(&in);
83   const uint8 versLo = io::NumStreamer::ReadByte(&in);
84   const uint8 codecID = io::NumStreamer::ReadByte(&in);
85   char data[HANDSHAKE_RANDOM_SIZE] = { 0, };
86   uint32 readSize = in.Read(data, HANDSHAKE_RANDOM_SIZE);
87   CHECK_EQ(readSize, HANDSHAKE_RANDOM_SIZE);
88
89   if ( !strutil::StrEql(mark, "rpc") ) {
90     // the handshake packet should start with "rpc"
91     LOG_DEBUG << "ERROR: Handshake does not start with \"rpc\" but with: "
92               << strutil::PrintableDataBuffer(mark, 3);
93     handshakeState_ = HS_FAILURE;
94     return;
95   }
96
97   if ( versHi != RPC_VERSION_MAJOR ||
98        versLo != RPC_VERSION_MINOR ) {
99     // different version
100     LOG_WARNING << "handshake attempt"
101                     " from " << net_connection_->remote_address()
102                 << " with version " << versHi << "." << versLo
103                 << "; Server version is " << RPC_VERSION_STR
104                 << ";";
105     handshakeState_ = HS_FAILURE;
106     return;
107   }
108
109   // create codec according to client codecID
110   //
111   if ( !codec_ ) {
112     switch ( codecID ) {
113       case rpc::CID_BINARY:
114         codec_ = codec_ ? codec_ : new rpc::BinaryCodec();
115         break;
116       case rpc::CID_JSON:
117         codec_ = codec_ ? codec_ : new rpc::JsonCodec();
118         break;
119       default:
120         LOG_WARNING << "handshake attempt"
121                        " from " << net_connection_->remote_address()
122                     << " with invalid codecID: " << codecID;
123         handshakeState_ = HS_FAILURE;
124         break;
125     };
126   }
127
128   switch ( handshakeState_ ) {
129     case HS_WAITING_REQUEST: {
130       // compose handshake response
131       uint8 handResponse[3+2+1+HANDSHAKE_RANDOM_SIZE+HANDSHAKE_RANDOM_SIZE];
132       memcpy(handResponse, "rpc", 3);          // "rpc" head
133       handResponse[3] = RPC_VERSION_MAJOR;     // version HI
134       handResponse[4] = RPC_VERSION_MINOR;     // version LO
135       handResponse[5] = codec_->GetCodecID();  // codec identifier
136       memcpy(handResponse + 6,
137              handshakeServerRandomData_,
138              HANDSHAKE_RANDOM_SIZE);           // 32 bytes server random data
139       memcpy(handResponse + 6 + HANDSHAKE_RANDOM_SIZE,
140              data,
141              HANDSHAKE_RANDOM_SIZE);     // 32 bytes the received client data
142       // send response
143       net_connection_->Write(
144           reinterpret_cast<const void*>(handResponse), sizeof(handResponse));
145       handshakeState_ = HS_WAITING_RESPONSE;
146       return;
147     }
148       break;
149     case HS_WAITING_RESPONSE: {
150       // check replied data. The client should have replied with our
151       // random data.
152       if ( memcmp(data, handshakeServerRandomData_, HANDSHAKE_RANDOM_SIZE) ) {
153         LOG_WARNING << "handshake failed: client replied different random data";
154         handshakeState_ = HS_FAILURE;
155         return;
156       }
157       CHECK(codec_);
158       if ( codecID != codec_->GetCodecID() ) {
159         LOG_WARNING << "handshake failed: client replied "
160                     << "different codecid. First: "
161                     << codec_->GetCodecID() << "  second: " << codecID;
162         handshakeState_ = HS_FAILURE;
163         return;
164       }
165       handshakeState_ = HS_CONNECTED;
166       return;
167     }
168       break;
169     case HS_FAILURE: {
170       // hanshake failed already
171       return;
172     }
173       break;
174     default:
175       LOG_FATAL << "Invalid Handshake state "
176                 << (static_cast<uint32>(handshakeState_));
177       return;
178   }
179 }
180
181 void rpc::ServerConnection::ProtocolHandleMessage(const rpc::Message& p) {
182   LOG_DEBUG << "Handle received packet: " << p;
183
184   if ( p.header_.msgType_ != RPC_CALL ) {
185     LOG_ERROR << "Received a non-CALL message! ignoring: " << p;
186     return;
187   }
188
189   // extract transport
190   rpc::Transport transport(rpc::Transport::TCP,
191                            net_connection_->local_address(),
192                            net_connection_->remote_address());
193   // extract call: service, method and arguments
194   CHECK_EQ(p.header_.msgType_, RPC_CALL);
195   const uint32 xid = p.header_.xid_;
196   const string service = p.cbody_.service_.StdStr();
197   const string method = p.cbody_.method_.StdStr();
198   io::MemoryStream& params =
199       const_cast<io::MemoryStream&>(p.cbody_.params_);
200
201   // create an internal query. Use xid as qid, because inside a connection
202   //  xid s are unique.
203   rpc::Query* query = new rpc::Query(transport, xid, service, method,
204                                      params, *codec_, GetResultHandlerID());
205
206   // send the query to the executor. Specify us as the result collector.
207   //
208   if ( asyncQueryExecutor_.QueueRPC(query) ) {
209     return;  // Success
210   }
211   // on error, send error message
212   //
213   delete query;
214   query = NULL;
215   LOG_ERROR << "Async queue execution failed:"
216             << " service=" << service
217             << " method=" << method
218             << " reason=" << GetLastSystemErrorDescription();
219   io::MemoryStream ms;
220   codec_->Encode(ms, rpc::String(GetLastSystemErrorDescription()));
221   WriteReply(xid, RPC_SYSTEM_ERR, ms);
222 }
223
224 rpc::ServerConnection::ServerConnection(
225     net::Selector* selector,
226     bool auto_delete_on_close,
227     net::NetConnection * net_connection,
228     rpc::IAsyncQueryExecutor& queryExecutor)
229     : selector_(selector),
230       net_connection_(net_connection),
231       cachedPacketBuffer_(),
232       syncCachedPacketBuffer_(),
233       handshakeState_(HS_WAITING_REQUEST),
234       asyncQueryExecutor_(queryExecutor),
235       registeredToQueryExecutor_(false),
236       codec_(NULL),
237       expectedWriteReplyCalls_(0),
238       accessExpectedWriteReplyCalls_(),
239       auto_delete_on_close_(auto_delete_on_close) {
240   net_connection_->SetReadHandler(NewPermanentCallback(
241       this, &ServerConnection::ConnectionReadHandler), true);
242   net_connection_->SetWriteHandler(NewPermanentCallback(
243       this, &ServerConnection::ConnectionWriteHandler), true);
244   net_connection_->SetCloseHandler(NewPermanentCallback(
245       this, &ServerConnection::ConnectionCloseHandler), true);
246   // generate random data
247   for ( uint32 i = 0; i < HANDSHAKE_RANDOM_SIZE; i++ ) {
248     handshakeServerRandomData_[i] = '0' + i % 10;
249   }
250 }
251
252 rpc::ServerConnection::~ServerConnection() {
253   // stop RPC result receiving
254   if ( registeredToQueryExecutor_ ) {
255     asyncQueryExecutor_.UnregisterResultHandler(*this);
256     registeredToQueryExecutor_ = false;
257   }
258   delete codec_;
259   codec_ = NULL;
260   delete net_connection_;
261   net_connection_ = NULL;
262   CHECK_EQ(expectedWriteReplyCalls_, 0);
263 }
264
265 void rpc::ServerConnection::IncExpectedWriteReplyCalls() {
266   synch::MutexLocker lock(&accessExpectedWriteReplyCalls_);
267   expectedWriteReplyCalls_++;
268 }
269 void
270 rpc::ServerConnection::DecExpectedWriteReplyCallsAndPossiblyDeleteConnection() {
271   bool deleteConnection = false;
272   {
273     synch::MutexLocker lock(&accessExpectedWriteReplyCalls_);
274     CHECK_GT(expectedWriteReplyCalls_, 0);
275     expectedWriteReplyCalls_--;
276     deleteConnection =
277        (expectedWriteReplyCalls_ == 0) &&
278        (net_connection_->state() == net::NetConnection::DISCONNECTED) &&
279        auto_delete_on_close_;
280   }
281   if ( deleteConnection ) {
282     selector_->DeleteInSelectLoop(this);
283   }
284 }
285
286 //////////////////////////////////////////////////////////////
287 //
288 //   Methods available to any external thread (worker threads).
289 //
290 void rpc::ServerConnection::WriteReply(uint32 xid,
291                                        rpc::REPLY_STATUS status,
292                                        const io::MemoryStream& result) {
293   rpc::Message p;
294
295   rpc::Message::Header& header = p.header_;
296   header.xid_ = xid;
297   header.msgType_ = RPC_REPLY;
298
299   rpc::Message::ReplyBody& body = p.rbody_;
300   body.replyStatus_ = status;
301   body.result_.AppendStreamNonDestructive(&result);
302
303   LOG_DEBUG << "WriteReply sending packet: " << p;
304
305   WriteWithEncodeNow(p);
306 }
307
308 void rpc::ServerConnection::WriteWithEncodeInSelector(const rpc::Message* p) {
309   CHECK_NOT_NULL(p);
310
311   IncExpectedWriteReplyCalls();
312   selector_->RunInSelectLoop(
313       NewCallback(this, &rpc::ServerConnection::CallbackSendRPCPacket, p));
314 }
315
316 void rpc::ServerConnection::WriteWithEncodeNow(const rpc::Message& p) {
317   io::MemoryStream* ms = new io::MemoryStream();
318   CHECK_NOT_NULL(ms);
319   codec_->EncodePacket(*ms, p);
320
321   IncExpectedWriteReplyCalls();
322   selector_->RunInSelectLoop(
323       NewCallback(this,
324                   &rpc::ServerConnection::CallbackSendData,
325                   (const io::MemoryStream *)ms));
326 }
327
328 //////////////////////////////////////////////////////////////
329 //
330 //     Methods available only from the selector thread.
331 //
332 void rpc::ServerConnection::CallbackSendRPCPacket(const rpc::Message* p) {
333   CHECK_NOT_NULL(p);
334   // the message was dynamically allocated
335   scoped_ptr<const rpc::Message> autoDelP(p);
336   if ( net_connection_->state() != net::NetConnection::CONNECTED ) {
337     LOG_DEBUG << "Bad connection state: " << net_connection_->StateName()
338               << " Cannot SendRPCPacket: " << *p;
339     DecExpectedWriteReplyCallsAndPossiblyDeleteConnection();
340     return;
341   }
342
343   LOG_DEBUG << "SendRPCPacket: " << *p;
344
345   // serialize the given RPC packet in the cache buffer.
346   //
347   cachedPacketBuffer_.Clear();
348   codec_->EncodePacket(cachedPacketBuffer_, *p);
349
350   // send reply over network
351   //
352   net_connection_->outbuf()->AppendStream(&cachedPacketBuffer_);
353   net_connection_->RequestWriteEvents(true);
354   DecExpectedWriteReplyCallsAndPossiblyDeleteConnection();
355 }
356
357 void rpc::ServerConnection::CallbackSendData(const io::MemoryStream* ms) {
358   CHECK_NOT_NULL(ms);
359   // the stream was dynamically allocated
360   scoped_ptr<const io::MemoryStream> autoDelStream(ms);
361
362   if ( net_connection_->state() != net::NetConnection::CONNECTED ) {
363     DLOG_DEBUG << "Bad connection state: " << net_connection_->StateName()
364                << " Cannot SendData: " << ms->DebugString();
365     DecExpectedWriteReplyCallsAndPossiblyDeleteConnection();
366     return;
367   }
368
369   // send reply over network
370   //
371   net_connection_->outbuf()->AppendStreamNonDestructive(ms);
372   net_connection_->RequestWriteEvents(true);
373   DecExpectedWriteReplyCallsAndPossiblyDeleteConnection();
374 }
375
376 //////////////////////////////////////////////////////////////////////
377 //
378 //             net::BufferedConnection methods
379 //
380 bool rpc::ServerConnection::ConnectionReadHandler() {
381   // get the BufferedConnection's internal memory stream
382   io::MemoryStream* in = net_connection_->inbuf();
383
384   if ( handshakeState_ != HS_CONNECTED ) {
385     ProtocolHandleHandshake(*in);
386     if ( handshakeState_ == HS_FAILURE ) {
387       LOG_ERROR << "Handshake failed. Closing connection.";
388       return false;
389     }
390     if ( handshakeState_ != HS_CONNECTED ) {
391       // handshake not finished yet
392       return true;
393     }
394
395     // handshake completed, register to the executor to be able to
396     // receive RPC results (as we're about to receive RPC queries)
397     //
398     asyncQueryExecutor_.RegisterResultHandler(*this);
399     registeredToQueryExecutor_ = true;
400   }
401
402   // the codec is estabilished in handshake. It should be known by now.
403   CHECK(codec_);
404
405   while ( !in->IsEmpty() ) {
406     // set a marker, to be able to restore read data on incomplete packets.
407     in->MarkerSet();
408
409     // try decoding a RPC message
410     rpc::Message p;
411     DECODE_RESULT result = codec_->DecodePacket(*in, p);
412
413     if ( result == DECODE_RESULT_NOT_ENOUGH_DATA ) {
414       // not enough data to read the entire packet.
415       // restore read data
416       in->MarkerRestore();
417       return true;
418     } else if ( result == DECODE_RESULT_SUCCESS ) {
419       // RPC packet read & decoded successfully
420       // a complete packet was read, cancel the marker.
421       in->MarkerClear();
422
423       // handle message (queue RPC query for execution)
424       ProtocolHandleMessage(p);
425
426       // try read & handle next packet in input stream
427       continue;
428     } else if ( result == DECODE_RESULT_ERROR ) {
429       // decoder error. Data is not a packet.
430       // no need to restore data, we're closing the connection.
431       in->MarkerClear();
432       LOG_ERROR << "Invalid data. Closing Connection.";
433       net_connection_->FlushAndClose();
434       return false;
435     }
436
437     // no such result
438     in->MarkerClear();
439     LOG_FATAL << "Invalid result from rpc::Message::SerializeLoad: " << result;
440     return false;
441   }
442   return true;
443 }
444 bool rpc::ServerConnection::ConnectionWriteHandler() {
445   return true;
446 }
447 void rpc::ServerConnection::ConnectionCloseHandler(
448     int err, net::NetConnection::CloseWhat what) {
449   if ( what != net::NetConnection::CLOSE_READ_WRITE ) {
450     net_connection_->FlushAndClose();
451     return;
452   }
453   LOG_INFO << "Closed connection to " << net_connection_->remote_address()
454            << " err: " << GetSystemErrorDescription(err);
455
456   // stop RPC result receiving
457   if ( registeredToQueryExecutor_ ) {
458     asyncQueryExecutor_.UnregisterResultHandler(*this);
459     registeredToQueryExecutor_ = false;
460   }
461 }
462
463 void rpc::ServerConnection::HandleRPCResult(const rpc::Query& q) {
464   WriteReply(q.qid(), q.status(), q.result());
465 }
466
467 // Returns a description of this connection. Good for logging.
468 string rpc::ServerConnection::ToString() const {
469   std::ostringstream oss;
470   oss << "rpc::ServerConnection to " << net_connection_->remote_address();
471   return oss.str();
472 }
473 }
Note: See TracBrowser for help on using the browser.