Added more safety checks to del_virtual_socket(), new zts_shutdown() implementation

This commit is contained in:
Joseph Henry
2017-09-08 11:43:41 -07:00
parent 4fd2db7dd6
commit 05fec81757
8 changed files with 152 additions and 131 deletions

View File

@@ -575,8 +575,26 @@ namespace ZeroTier {
return err;
}
// Shuts down some aspect of a connection (Read/Write)
int VirtualTap::Shutdown(VirtualSocket *vs, int how)
{
int err = 0;
#if defined(STACK_PICO)
if(picostack) {
err = picostack->pico_Shutdown(vs, how);
}
return err;
#endif
#if defined(STACK_LWIP)
if(lwipstack) {
err = lwipstack->lwip_Shutdown(vs, how);
}
return err;
#endif
}
void VirtualTap::Housekeeping()
{/*
{
Mutex::Lock _l(_tcpconns_m);
std::time_t current_ts = std::time(nullptr);
if(current_ts > last_housekeeping_ts + ZT_HOUSEKEEPING_INTERVAL) {
@@ -635,11 +653,8 @@ namespace ZeroTier {
}
// TODO: Clean up VirtualSocket objects
last_housekeeping_ts = std::time(nullptr);
}
*/
}
/****************************************************************************/

View File

@@ -296,6 +296,11 @@ namespace ZeroTier {
*/
int Close(VirtualSocket *vs);
/*
* Shuts down some aspect of a VirtualSocket
*/
int Shutdown(VirtualSocket *vs, int how);
/*
* Disposes of previously-closed VirtualSockets
*/

View File

@@ -905,13 +905,14 @@ EPERM Firewall rules forbid VirtualSocket.
*/
int zts_setsockopt(ZT_SETSOCKOPT_SIG)
{
int err = errno = 0;
#if defined(STACK_PICO)
// TODO: Move stack-specific logic into stack driver section
//DEBUG_INFO("fd=%d", fd);
int err = errno = 0;
if(fd < 0 || fd >= ZT_MAX_SOCKETS) {
errno = EBADF;
return -1;
}
#if defined(STACK_PICO)
// Disable Nagle's algorithm
struct pico_socket *p = NULL;
err = zts_get_pico_socket(fd, &p);
@@ -1074,12 +1075,12 @@ int zts_close(ZT_CLOSE_SIG)
errno = EBADF;
return -1;
}
del_virtual_socket(fd);
if(vs->tap) {
vs->tap->Close(vs);
}
delete vs;
vs = NULL;
del_virtual_socket(fd);
return err;
/*
@@ -1216,7 +1217,7 @@ Linux:
ssize_t zts_sendto(ZT_SENDTO_SIG)
{
DEBUG_TRANS("fd=%d", fd);
//DEBUG_TRANS("fd=%d", fd);
int err = errno = 0;
if(fd < 0 || fd >= ZT_MAX_SOCKETS) {
errno = EBADF;
@@ -1293,7 +1294,6 @@ ssize_t zts_sendto(ZT_SENDTO_SIG)
err = -1;
errno = EINVAL;
}
//err = sendto(fd, buf, len, flags, addr, addrlen);
}
return err;
}
@@ -1473,6 +1473,10 @@ ssize_t zts_recv(ZT_RECV_SIG)
{
DEBUG_TRANS("fd=%d", fd);
int err = errno = 0;
if(fd < 0 || fd >= ZT_MAX_SOCKETS) {
errno = EBADF;
return -1;
}
ZeroTier::VirtualSocket *vs = get_virtual_socket(fd);
if(!vs) {
DEBUG_ERROR("invalid vs for fd=%d", fd);
@@ -1544,18 +1548,18 @@ ssize_t zts_recv(ZT_RECV_SIG)
allows either error to be returned for this case, and does
not require these constants to have the same value, so a
portable application should check for both possibilities.
[ ] [EBADF] The argument sockfd is an invalid descriptor.
[--] [EBADF] The argument sockfd is an invalid descriptor.
[ ] [ECONNREFUSED] A remote host refused to allow the network connection
(typically because it is not running the requested service).
[ ] [EFAULT] The receive buffer pointer(s) point outside the process's
address space.
[ ] [EINTR] The receive was interrupted by delivery of a signal before any
data were available; see signal(7).
[ ] [EINVAL] Invalid argument passed.
[--] [EINVAL] Invalid argument passed.
[ ] [ENOMEM] Could not allocate memory for recvmsg().
[ ] [ENOTCONN] The socket is associated with a connection-oriented protocol
and has not been connected (see connect(2) and accept(2)).
[ ] [ENOTSOCK] The argument sockfd does not refer to a socket.
[NA] [ENOTSOCK] The argument sockfd does not refer to a socket.
ZT_RECVFROM_SIG int fd, void *buf, size_t len, int flags, struct sockaddr *addr, socklen_t *addrlen
*/
@@ -1652,93 +1656,46 @@ int zts_write(ZT_WRITE_SIG) {
return write(fd, buf, len);
}
/*
Linux:
[--] [EBADF] The socket argument is not a valid file descriptor.
[--] [EINVAL] The how argument is invalid.
[--] [ENOTCONN] The socket is not connected.
[NA] [ENOTSOCK] The socket argument does not refer to a socket.
[NA] [ENOBUFS] Insufficient resources were available in the system to perform the operation.
ZT_SHUTDOWN_SIG int fd, int how
*/
int zts_shutdown(ZT_SHUTDOWN_SIG)
{
/*
int err = errno = 0;
#if defined(STACK_PICO)
DEBUG_INFO("fd = %d", fd);
int mode = 0;
if(how == SHUT_RD) mode = PICO_SHUT_RD;
if(how == SHUT_WR) mode = PICO_SHUT_WR;
if(how == SHUT_RDWR) mode = PICO_SHUT_RDWR;
if(fd < 0) {
if(fd < 0 || fd >= ZT_MAX_SOCKETS) {
errno = EBADF;
err = -1;
return -1;
}
else
{
if(!ZeroTier::zt1Service) {
DEBUG_ERROR("cannot shutdown socket. service not started. call zts_start(path) first");
errno = EBADF;
err = -1;
}
else
{
ZeroTier::_multiplexer_lock.lock();
// First, look for for unassigned VirtualSockets
ZeroTier::VirtualSocket *vs = ZeroTier::unmap[fd];
// Since we found an unassigned VirtualSocket, we don't need to consult the stack or tap
// during closure - it isn't yet stitched into the clockwork
if(vs) // unassigned
{
DEBUG_ERROR("unassigned shutdown");
// PICO_SHUT_RD
// PICO_SHUT_WR
// PICO_SHUT_RDWR
if((err = pico_socket_shutdown(vs->picosock, mode)) < 0)
DEBUG_ERROR("error calling pico_socket_shutdown()");
DEBUG_ERROR("vs=%p", vs);
delete vs;
vs = NULL;
ZeroTier::unmap.erase(fd);
// FIXME: Is deleting this correct behaviour?
}
else // assigned
{
std::pair<ZeroTier::VirtualSocket*, ZeroTier::VirtualTap*> *p = ZeroTier::fdmap[fd];
if(!p)
{
DEBUG_ERROR("unable to locate VirtualSocket pair.");
errno = EBADF;
err = -1;
}
else // found everything, begin closure
{
vs = p->first;
int f_err, blocking = 1;
if ((f_err = fcntl(fd, F_GETFL, 0)) < 0) {
DEBUG_ERROR("fcntl error, err = %s, errno = %d", f_err, errno);
err = -1;
}
else {
blocking = !(f_err & O_NONBLOCK);
}
if(blocking) {
DEBUG_INFO("blocking, waiting for write operations before shutdown...");
for(int i=0; i<ZT_SDK_CLTIME; i++) {
if(vs->TXbuf->count() == 0)
break;
nanosleep((const struct timespec[]){{0, (ZT_API_CHECK_INTERVAL * 1000000)}}, NULL);
}
}
if((err = pico_socket_shutdown(vs->picosock, mode)) < 0)
DEBUG_ERROR("error calling pico_socket_shutdown()");
}
}
ZeroTier::_multiplexer_lock.unlock();
}
if(how != SHUT_RD && how != SHUT_WR && how != SHUT_RDWR) {
errno = EINVAL;
return -1;
}
ZeroTier::VirtualSocket *vs = get_virtual_socket(fd);
if(!vs) {
DEBUG_ERROR("invalid vs for fd=%d", fd);
errno = EBADF;
return -1;
}
if(vs->state != ZT_SOCK_STATE_CONNECTED || vs->socket_type != SOCK_STREAM) {
DEBUG_ERROR("the socket is either not in a connected state, or isn't connection-based, fd=%d", fd);
errno = ENOTCONN;
return -1;
}
if(vs->tap) {
err = vs->tap->Shutdown(vs, how);
}
return err;
#endif
*/
return 0;
}
int zts_add_dns_nameserver(struct sockaddr *addr)
@@ -2081,7 +2038,7 @@ bool can_provision_new_socket()
int zts_nsockets()
{
ZeroTier::_multiplexer_lock.unlock();
ZeroTier::_multiplexer_lock.lock();
int num = ZeroTier::unmap.size() + ZeroTier::fdmap.size();
ZeroTier::_multiplexer_lock.unlock();
return num;
@@ -2228,52 +2185,62 @@ ZeroTier::VirtualSocket *get_virtual_socket(int fd)
void del_virtual_socket(int fd)
{
ZeroTier::_multiplexer_lock.lock();
std::map<int, ZeroTier::VirtualSocket*>::iterator fd_iter = ZeroTier::unmap.find(fd);
if(fd_iter != ZeroTier::unmap.end()) {
ZeroTier::unmap.erase(fd_iter);
std::map<int, ZeroTier::VirtualSocket*>::iterator un_iter = ZeroTier::unmap.find(fd);
if(un_iter != ZeroTier::unmap.end()) {
ZeroTier::unmap.erase(un_iter);
}
//ZeroTier::unmap.erase(fd);
std::map<int, std::pair<ZeroTier::VirtualSocket*,ZeroTier::VirtualTap*>*>::iterator un_iter = ZeroTier::fdmap.find(fd);
if(un_iter != ZeroTier::fdmap.end()) {
ZeroTier::fdmap.erase(un_iter);
std::map<int, std::pair<ZeroTier::VirtualSocket*,ZeroTier::VirtualTap*>*>::iterator fd_iter = ZeroTier::fdmap.find(fd);
if(fd_iter != ZeroTier::fdmap.end()) {
ZeroTier::fdmap.erase(fd_iter);
}
//ZeroTier::fdmap.erase(fd);
ZeroTier::_multiplexer_lock.unlock();
}
void add_unassigned_virtual_socket(int fd, ZeroTier::VirtualSocket *vs)
{
ZeroTier::_multiplexer_lock.lock();
ZeroTier::unmap[fd] = vs;
std::map<int, ZeroTier::VirtualSocket*>::iterator un_iter = ZeroTier::unmap.find(fd);
if(un_iter == ZeroTier::unmap.end()) {
ZeroTier::unmap[fd] = vs;
}
else {
DEBUG_ERROR("fd=%d already contained in <fd:vs> map", fd);
handle_general_failure();
}
ZeroTier::_multiplexer_lock.unlock();
}
void del_unassigned_virtual_socket(int fd)
{
ZeroTier::_multiplexer_lock.lock();
std::map<int, ZeroTier::VirtualSocket*>::iterator iter = ZeroTier::unmap.find(fd);
if(iter != ZeroTier::unmap.end()) {
ZeroTier::unmap.erase(iter);
}
//ZeroTier::unmap.erase(fd);
std::map<int, ZeroTier::VirtualSocket*>::iterator un_iter = ZeroTier::unmap.find(fd);
if(un_iter != ZeroTier::unmap.end()) {
ZeroTier::unmap.erase(un_iter);
}
ZeroTier::_multiplexer_lock.unlock();
}
void add_assigned_virtual_socket(ZeroTier::VirtualTap *tap, ZeroTier::VirtualSocket *vs, int fd)
{
ZeroTier::_multiplexer_lock.lock();
ZeroTier::fdmap[fd] = new std::pair<ZeroTier::VirtualSocket*,ZeroTier::VirtualTap*>(vs, tap);
std::map<int, std::pair<ZeroTier::VirtualSocket*,ZeroTier::VirtualTap*>*>::iterator fd_iter = ZeroTier::fdmap.find(fd);
if(fd_iter == ZeroTier::fdmap.end()) {
ZeroTier::fdmap[fd] = new std::pair<ZeroTier::VirtualSocket*,ZeroTier::VirtualTap*>(vs, tap);
}
else {
DEBUG_ERROR("fd=%d already contained in <fd,<vs,vt>> map", fd);
handle_general_failure();
}
ZeroTier::_multiplexer_lock.unlock();
}
void del_assigned_virtual_socket(ZeroTier::VirtualTap *tap, ZeroTier::VirtualSocket *vs, int fd)
{
ZeroTier::_multiplexer_lock.lock();
std::map<int, std::pair<ZeroTier::VirtualSocket*,ZeroTier::VirtualTap*>*>::iterator iter = ZeroTier::fdmap.find(fd);
if(iter != ZeroTier::fdmap.end()) {
ZeroTier::fdmap.erase(iter);
}
//ZeroTier::fdmap.erase(fd);
std::map<int, std::pair<ZeroTier::VirtualSocket*,ZeroTier::VirtualTap*>*>::iterator fd_iter = ZeroTier::fdmap.find(fd);
if(fd_iter != ZeroTier::fdmap.end()) {
ZeroTier::fdmap.erase(fd_iter);
}
ZeroTier::_multiplexer_lock.unlock();
}

View File

@@ -638,6 +638,25 @@ namespace ZeroTier
return err;
}
int lwIP::lwip_Shutdown(VirtualSocket *vs, int how)
{
int err=0, shut_rx=0, shut_tx=0;
if(how == SHUT_RD) {
shut_rx = 1;
}
if(how == SHUT_WR) {
shut_tx = 1;
}
if(how == SHUT_RDWR) {
shut_rx = 1;
shut_tx = 1;
}
if((err = tcp_shutdown((tcp_pcb*)(vs->pcb), shut_rx, shut_tx) < 0)) {
DEBUG_ERROR("error while shutting down socket, fd=%d", vs->app_fd);
}
return err;
}
/****************************************************************************/
/* Callbacks from lwIP stack */
/****************************************************************************/
@@ -716,15 +735,7 @@ namespace ZeroTier
return ERR_OK;
}
/*
NSLWIP network_stack_lwip
NSPICO network_stack_pico
NSRXBF network_stack_pico guarded frame buffer RX
ZTVIRT zt_virtual_wire
APPFDS app_fd
VSRXBF app_fd TX buf
VSTXBF app_fd RX buf
*/
// callback from stack to notify driver of the successful acceptance of a connection
err_t lwIP::lwip_cb_accept(void *arg, struct tcp_pcb *newPCB, err_t err)
{
//DEBUG_INFO();

View File

@@ -68,7 +68,6 @@ struct netif;
#define LWIP_PBUF_ALLOC_SIG pbuf_layer layer, u16_t length, pbuf_type type
#define LWIP_HTONS_SIG u16_t x
#define LWIP_NTOHS_SIG u16_t x
#define LWIP_UDP_NEW_SIG void
#define LWIP_UDP_CONNECT_SIG struct udp_pcb * pcb, const ip_addr_t * ipaddr, u16_t port
#define LWIP_UDP_SEND_SIG struct udp_pcb * pcb, struct pbuf * p
@@ -77,7 +76,6 @@ struct netif;
#define LWIP_UDP_RECVED_SIG struct udp_pcb * pcb, u16_t len
#define LWIP_UDP_BIND_SIG struct udp_pcb * pcb, const ip_addr_t * ipaddr, u16_t port
#define LWIP_UDP_REMOVE_SIG struct udp_pcb *pcb
#define LWIP_TCP_WRITE_SIG struct tcp_pcb *pcb, const void *arg, u16_t len, u8_t apiflags
#define LWIP_TCP_SENT_SIG struct tcp_pcb * pcb, err_t (* sent)(void * arg, struct tcp_pcb * tpcb, u16_t len)
#define LWIP_TCP_NEW_SIG void
@@ -96,15 +94,13 @@ struct netif;
#define LWIP_TCP_LISTEN_WITH_BACKLOG_SIG struct tcp_pcb * pcb, u8_t backlog
#define LWIP_TCP_BIND_SIG struct tcp_pcb * pcb, const ip_addr_t * ipaddr, u16_t port
#define LWIP_TCP_INPUT_SIG struct pbuf *p, struct netif *inp
#define LWIP_ETHERNET_INPUT_SIG struct pbuf *p, struct netif *netif
#define LWIP_IP_INPUT_SIG struct pbuf *p, struct netif *inp
#define LWIP_NETIF_SET_DEFAULT_SIG struct netif *netif
#define LWIP_NETIF_SET_UP_SIG struct netif *netif
#define LWIP_NETIF_POLL_SIG struct netif *netif
#define NETIF_SET_STATUS_CALLBACK struct netif *netif, netif_status_callback_fn status_callback
#define LWIP_TCP_SHUTDOWN_SIG struct tcp_pcb *pcb, int shut_rx, int shut_tx
#if defined(LIBZT_IPV4)
extern "C" err_t etharp_output(LWIP_ETHARP_OUTPUT_SIG);
@@ -160,11 +156,9 @@ extern "C" u16_t lwip_htons(LWIP_HTONS_SIG);
extern "C" u16_t lwip_ntohs(LWIP_NTOHS_SIG);
extern "C" void tcp_input(LWIP_TCP_INPUT_SIG);
extern "C" err_t ip_input(LWIP_IP_INPUT_SIG);
extern "C" err_t tcp_shutdown(LWIP_TCP_SHUTDOWN_SIG);
//extern "C" void netif_set_status_callback(NETIF_SET_STATUS_CALLBACK);
namespace ZeroTier {
class VirtualTap;
@@ -239,6 +233,10 @@ namespace ZeroTier {
*/
int lwip_Close(VirtualSocket *vs);
/*
* Shuts down some aspect of a VirtualSocket - Called from VirtualTap
*/
int lwip_Shutdown(VirtualSocket *vs, int how);
// --- Callbacks from network stack ---

View File

@@ -1014,6 +1014,24 @@ namespace ZeroTier {
return err;
}
int picoTCP::pico_Shutdown(VirtualSocket *vs, int how)
{
int err = 0, mode = 0;
if(how == SHUT_RD) {
mode = PICO_SHUT_RD;
}
if(how == SHUT_WR) {
mode = PICO_SHUT_WR;
}
if(how == SHUT_RDWR) {
mode = PICO_SHUT_RDWR;
}
if((err = pico_socket_shutdown(vs->picosock, mode)) < 0) {
DEBUG_ERROR("error while shutting down socket, fd=%d, pico_err=%d, %s", vs->app_fd, pico_err, beautify_pico_error(pico_err));
}
return err;
}
int picoTCP::map_pico_err_to_errno(int err)
{
if(err == PICO_ERR_NOERR) { errno = 0; return 0; } //

View File

@@ -192,6 +192,11 @@ namespace ZeroTier
*/
int pico_Close(VirtualSocket *vs);
/*
* Shuts down some aspect of a VirtualSocket - Called from VirtualTap
*/
int pico_Shutdown(VirtualSocket *vs, int how);
/*
* Converts a pico_err to its most closely-related errno, and sets errno
*/

View File

@@ -1847,6 +1847,7 @@ int random_api_test()
// PASSED implies we didn't segfault or hang anywhere
// variables which will be populated with random values
/*
int socket_family;
int socket_type;
int protocol;
@@ -1858,6 +1859,7 @@ int random_api_test()
struct sockaddr_storage;
struct sockaddr_in addr;
struct sockaddr_in6 addr6;
*/
/*
int num_operations = 100;