#include #include #include #include #include #include #include #include #include #include "connection.h" #include "connection_p.h" #include "json_p.h" #include "exceptions.h" #include "term.h" #include "cursor_p.h" #include "rapidjson-config.h" #include "rapidjson/rapidjson.h" #include "rapidjson/encodedstream.h" #include "rapidjson/document.h" namespace RethinkDB { using QueryType = Protocol::Query::QueryType; // constants const int debug_net = 0; const uint32_t version_magic = static_cast(Protocol::VersionDummy::Version::V0_4); const uint32_t json_magic = static_cast(Protocol::VersionDummy::Protocol::JSON); std::unique_ptr connect(std::string host, int port, std::string auth_key) { struct addrinfo hints; memset(&hints, 0, sizeof hints); hints.ai_family = AF_UNSPEC; hints.ai_socktype = SOCK_STREAM; char port_str[16]; snprintf(port_str, 16, "%d", port); struct addrinfo *servinfo; int ret = getaddrinfo(host.c_str(), port_str, &hints, &servinfo); if (ret) throw Error("getaddrinfo: %s\n", gai_strerror(ret)); struct addrinfo *p; Error error; int sockfd; for (p = servinfo; p != NULL; p = p->ai_next) { sockfd = socket(p->ai_family, p->ai_socktype, p->ai_protocol); if (sockfd == -1) { error = Error::from_errno("socket"); continue; } if (connect(sockfd, p->ai_addr, p->ai_addrlen) == -1) { ::close(sockfd); error = Error::from_errno("connect"); continue; } break; } if (p == NULL) { throw error; } freeaddrinfo(servinfo); std::unique_ptr conn_private(new ConnectionPrivate(sockfd)); WriteLock writer(conn_private.get()); { size_t size = auth_key.size(); char buf[12 + size]; memcpy(buf, &version_magic, 4); uint32_t n = size; memcpy(buf + 4, &n, 4); memcpy(buf + 8, auth_key.data(), size); memcpy(buf + 8 + size, &json_magic, 4); writer.send(buf, sizeof buf); } ReadLock reader(conn_private.get()); { const size_t max_response_length = 1024; char buf[max_response_length + 1]; size_t len = reader.recv_cstring(buf, max_response_length); if (len == max_response_length || strcmp(buf, "SUCCESS")) { buf[len] = 0; ::close(sockfd); throw Error("Server rejected connection with message: %s", buf); } } return std::unique_ptr(new Connection(conn_private.release())); } Connection::Connection(ConnectionPrivate *dd) : d(dd) { } Connection::~Connection() { // close(); if (d->guarded_sockfd >= 0) ::close(d->guarded_sockfd); } size_t ReadLock::recv_some(char* buf, size_t size, double wait) { if (wait != FOREVER) { while (true) { fd_set readfds; struct timeval tv; FD_ZERO(&readfds); FD_SET(conn->guarded_sockfd, &readfds); tv.tv_sec = (int)wait; tv.tv_usec = (int)((wait - (int)wait) / MICROSECOND); int rv = select(conn->guarded_sockfd + 1, &readfds, NULL, NULL, &tv); if (rv == -1) { throw Error::from_errno("select"); } else if (rv == 0) { throw TimeoutException(); } if (FD_ISSET(conn->guarded_sockfd, &readfds)) { break; } } } ssize_t numbytes = ::recv(conn->guarded_sockfd, buf, size, 0); if (numbytes <= 0) throw Error::from_errno("recv"); if (debug_net > 1) { fprintf(stderr, "<< %s\n", write_datum(std::string(buf, numbytes)).c_str()); } return numbytes; } void ReadLock::recv(char* buf, size_t size, double wait) { while (size) { size_t numbytes = recv_some(buf, size, wait); buf += numbytes; size -= numbytes; } } size_t ReadLock::recv_cstring(char* buf, size_t max_size){ size_t size = 0; for (; size < max_size; size++) { recv(buf, 1, FOREVER); if (*buf == 0) { break; } buf++; } return size; } void WriteLock::send(const char* buf, size_t size) { while (size) { ssize_t numbytes = ::write(conn->guarded_sockfd, buf, size); if (numbytes == -1) throw Error::from_errno("write"); if (debug_net > 1) { fprintf(stderr, ">> %s\n", write_datum(std::string(buf, numbytes)).c_str()); } buf += numbytes; size -= numbytes; } } void WriteLock::send(const std::string data) { send(data.data(), data.size()); } std::string ReadLock::recv(size_t size) { char buf[size]; recv(buf, size, FOREVER); return buf; } void Connection::close() { CacheLock guard(d.get()); for (auto& it : d->guarded_cache) { stop_query(it.first); } int ret = ::close(d->guarded_sockfd); if (ret == -1) { throw Error::from_errno("close"); } d->guarded_sockfd = -1; } Response ConnectionPrivate::wait_for_response(uint64_t token_want, double wait) { CacheLock guard(this); ConnectionPrivate::TokenCache& cache = guarded_cache[token_want]; while (true) { if (!cache.responses.empty()) { Response response(std::move(cache.responses.front())); cache.responses.pop(); if (cache.closed && cache.responses.empty()) { guarded_cache.erase(token_want); } return response; } if (cache.closed) { throw Error("Trying to read from a closed token"); } if (guarded_loop_active) { cache.cond.wait(guard.inner_lock); } else { break; } } ReadLock reader(this); return reader.read_loop(token_want, std::move(guard), wait); } Response ReadLock::read_loop(uint64_t token_want, CacheLock&& guard, double wait) { if (!guard.inner_lock) { guard.lock(); } if (conn->guarded_loop_active) { throw Error("Cannot run more than one read loop on the same connection"); } conn->guarded_loop_active = true; guard.unlock(); try { while (true) { char buf[12]; bzero(buf, sizeof(buf)); recv(buf, 12, wait); uint64_t token_got; memcpy(&token_got, buf, 8); uint32_t length; memcpy(&length, buf + 8, 4); std::unique_ptr bufmem(new char[length + 1]); char *buffer = bufmem.get(); bzero(buffer, length + 1); recv(buffer, length, wait); buffer[length] = '\0'; rapidjson::Document json; json.ParseInsitu(buffer); if (json.HasParseError()) { fprintf(stderr, "json parse error, code: %d, position: %d\n", (int)json.GetParseError(), (int)json.GetErrorOffset()); } else if (json.IsNull()) { fprintf(stderr, "null value, read: %s\n", buffer); } Datum datum = read_datum(json); if (debug_net > 0) { fprintf(stderr, "[%" PRIu64 "] << %s\n", token_got, write_datum(datum).c_str()); } Response response(std::move(datum)); if (token_got == token_want) { guard.lock(); if (response.type != Protocol::Response::ResponseType::SUCCESS_PARTIAL) { auto it = conn->guarded_cache.find(token_got); if (it != conn->guarded_cache.end()) { it->second.closed = true; it->second.cond.notify_all(); } conn->guarded_cache.erase(it); } conn->guarded_loop_active = false; for (auto& it : conn->guarded_cache) { it.second.cond.notify_all(); } return response; } else { guard.lock(); auto it = conn->guarded_cache.find(token_got); if (it == conn->guarded_cache.end()) { // drop the response } else if (!it->second.closed) { it->second.responses.emplace(std::move(response)); if (response.type != Protocol::Response::ResponseType::SUCCESS_PARTIAL) { it->second.closed = true; } } it->second.cond.notify_all(); guard.unlock(); } } } catch (const TimeoutException &e) { if (!guard.inner_lock){ guard.lock(); } conn->guarded_loop_active = false; throw e; } } void ConnectionPrivate::run_query(Query query, bool no_reply) { WriteLock writer(this); writer.send(query.serialize()); } Cursor Connection::start_query(Term *term, OptArgs&& opts) { bool no_reply = false; auto it = opts.find("noreply"); if (it != opts.end()) { no_reply = *(it->second.datum.get_boolean()); } uint64_t token = d->new_token(); { CacheLock guard(d.get()); d->guarded_cache[token]; } d->run_query(Query{QueryType::START, token, term->datum, std::move(opts)}); if (no_reply) { return Cursor(new CursorPrivate(token, this, Nil())); } Cursor cursor(new CursorPrivate(token, this)); Response response = d->wait_for_response(token, FOREVER); cursor.d->add_response(std::move(response)); return cursor; } void Connection::stop_query(uint64_t token) { const auto& it = d->guarded_cache.find(token); if (it != d->guarded_cache.end() && !it->second.closed) { d->run_query(Query{QueryType::STOP, token}, true); } } void Connection::continue_query(uint64_t token) { d->run_query(Query{QueryType::CONTINUE, token}, true); } Error Response::as_error() { std::string repr; if (result.size() == 1) { std::string* string = result[0].get_string(); if (string) { repr = *string; } else { repr = write_datum(result[0]); } } else { repr = write_datum(Datum(result)); } std::string err; using RT = Protocol::Response::ResponseType; using ET = Protocol::Response::ErrorType; switch (type) { case RT::SUCCESS_SEQUENCE: err = "unexpected response: SUCCESS_SEQUENCE"; break; case RT::SUCCESS_PARTIAL: err = "unexpected response: SUCCESS_PARTIAL"; break; case RT::SUCCESS_ATOM: err = "unexpected response: SUCCESS_ATOM"; break; case RT::WAIT_COMPLETE: err = "unexpected response: WAIT_COMPLETE"; break; case RT::SERVER_INFO: err = "unexpected response: SERVER_INFO"; break; case RT::CLIENT_ERROR: err = "ReqlDriverError"; break; case RT::COMPILE_ERROR: err = "ReqlCompileError"; break; case RT::RUNTIME_ERROR: switch (error_type) { case ET::INTERNAL: err = "ReqlInternalError"; break; case ET::RESOURCE_LIMIT: err = "ReqlResourceLimitError"; break; case ET::QUERY_LOGIC: err = "ReqlQueryLogicError"; break; case ET::NON_EXISTENCE: err = "ReqlNonExistenceError"; break; case ET::OP_FAILED: err = "ReqlOpFailedError"; break; case ET::OP_INDETERMINATE: err = "ReqlOpIndeterminateError"; break; case ET::USER: err = "ReqlUserError"; break; case ET::PERMISSION_ERROR: err = "ReqlPermissionError"; break; default: err = "ReqlRuntimeError"; break; } } throw Error("%s: %s", err.c_str(), repr.c_str()); } Protocol::Response::ResponseType response_type(double t) { int n = static_cast(t); using RT = Protocol::Response::ResponseType; switch (n) { case static_cast(RT::SUCCESS_ATOM): return RT::SUCCESS_ATOM; case static_cast(RT::SUCCESS_SEQUENCE): return RT::SUCCESS_SEQUENCE; case static_cast(RT::SUCCESS_PARTIAL): return RT::SUCCESS_PARTIAL; case static_cast(RT::WAIT_COMPLETE): return RT::WAIT_COMPLETE; case static_cast(RT::CLIENT_ERROR): return RT::CLIENT_ERROR; case static_cast(RT::COMPILE_ERROR): return RT::COMPILE_ERROR; case static_cast(RT::RUNTIME_ERROR): return RT::RUNTIME_ERROR; default: throw Error("Unknown response type"); } } Protocol::Response::ErrorType runtime_error_type(double t) { int n = static_cast(t); using ET = Protocol::Response::ErrorType; switch (n) { case static_cast(ET::INTERNAL): return ET::INTERNAL; case static_cast(ET::RESOURCE_LIMIT): return ET::RESOURCE_LIMIT; case static_cast(ET::QUERY_LOGIC): return ET::QUERY_LOGIC; case static_cast(ET::NON_EXISTENCE): return ET::NON_EXISTENCE; case static_cast(ET::OP_FAILED): return ET::OP_FAILED; case static_cast(ET::OP_INDETERMINATE): return ET::OP_INDETERMINATE; case static_cast(ET::USER): return ET::USER; default: throw Error("Unknown error type"); } } }