21 #include "kinetic/message_stream.h"
23 #include <arpa/inet.h>
26 #include "glog/logging.h"
28 #include "kinetic/incoming_value.h"
29 #include "kinetic/outgoing_value.h"
33 MessageStream::MessageStream(uint32_t max_message_size_bytes, ByteStreamInterface *byte_stream)
34 : max_message_size_bytes_(max_message_size_bytes), byte_stream_(byte_stream) {}
36 MessageStream::~MessageStream() {
40 MessageStream::MessageStreamReadStatus MessageStream::ReadMessage(
41 ::google::protobuf::Message *message,
42 IncomingValueInterface** value) {
44 uint32_t message_size, value_size;
45 if (!ReadHeader(&message_size, &value_size)) {
46 return MessageStreamReadStatus_INTERNAL_ERROR;
51 if (message_size > max_message_size_bytes_) {
52 return MessageStreamReadStatus_TOO_LARGE;
56 char *message_bytes =
new char[message_size];
57 if (!byte_stream_->Read(message_bytes, message_size)) {
58 LOG(WARNING) <<
"Unable to read message";
59 delete[] message_bytes;
60 return MessageStreamReadStatus_INTERNAL_ERROR;
63 if (!message->ParseFromArray(message_bytes, message_size)) {
64 LOG(WARNING) <<
"Failed to parse protobuf message";
65 delete[] message_bytes;
66 return MessageStreamReadStatus_INTERNAL_ERROR;
69 delete[] message_bytes;
72 *value = byte_stream_->ReadValue(value_size);
74 return MessageStreamReadStatus_INTERNAL_ERROR;
77 return MessageStreamReadStatus_SUCCESS;
80 int MessageStream::WriteMessage(const ::google::protobuf::Message &message,
81 const OutgoingValueInterface& value,
int* err) {
83 if (!WriteHeader(message.ByteSize(), value.size())) {
84 LOG(WARNING) <<
"Failed to write header";
89 std::string message_string;
90 if (!message.SerializeToString(&message_string)) {
91 LOG(WARNING) <<
"Failed to serialize protocol buffer";
94 if (!byte_stream_->Write(message_string.data(), message_string.size())) {
95 LOG(WARNING) <<
"Failed to write message";
100 if (!byte_stream_->WriteValue(value, err)) {
101 LOG(WARNING) <<
"Failed to write value";
108 MessageStreamFactory::MessageStreamFactory(SSL_CTX *ssl_context,
109 IncomingValueFactoryInterface &value_factory)
110 : ssl_context_(ssl_context), value_factory_(value_factory) {
111 ssl_created_ =
false;
115 MessageStreamFactory::~MessageStreamFactory() {
121 bool MessageStreamFactory::NewMessageStream(
int fd,
bool use_ssl, SSL *ssl, uint32_t max_message_size_bytes,
122 MessageStreamInterface **message_stream) {
125 ssl_ = SSL_new(ssl_context_);
131 LOG(ERROR) <<
"Failed to create new SSL object";
134 SSL_set_mode(ssl_, SSL_MODE_AUTO_RETRY);
135 if (SSL_set_fd(ssl_, fd) != 1) {
136 LOG(ERROR) <<
"Failed to associate SSL object with file descriptor";
140 if (SSL_accept(ssl_) != 1) {
141 LOG(ERROR) <<
"Failed to perform SSL handshake";
142 LOG(ERROR) <<
"The client may have attempted to use an SSL/TLS version below TLSv1.1";
149 LOG(INFO) <<
"Successfully performed SSL handshake";
150 *message_stream =
new MessageStream(max_message_size_bytes,
new SslByteStream(ssl));
153 new MessageStream(max_message_size_bytes,
new PlainByteStream(fd, value_factory_));
159 bool MessageStream::ReadHeader(uint32_t *message_size, uint32_t *value_size) {
161 if (!byte_stream_->Read(header,
sizeof(header))) {
165 if (header[0] !=
'F') {
166 LOG(WARNING) <<
"Received invalid magic value " << header[0];
170 memcpy(reinterpret_cast<char *>(message_size), header + 1,
sizeof(*message_size));
171 memcpy(reinterpret_cast<char *>(value_size), header + 5,
sizeof(*value_size));
172 *message_size = ntohl(*message_size);
173 *value_size = ntohl(*value_size);
178 bool MessageStream::WriteHeader(uint32_t message_size, uint32_t value_size) {
181 message_size = htonl(message_size);
182 value_size = htonl(value_size);
183 memcpy(header + 1, reinterpret_cast<char *>(&message_size),
sizeof(message_size));
184 memcpy(header + 5, reinterpret_cast<char *>(&value_size),
sizeof(value_size));
185 return byte_stream_->Write(header,
sizeof(header));