diff --git a/include/ZeroTierSockets.h b/include/ZeroTierSockets.h index 0710896..0af96a8 100644 --- a/include/ZeroTierSockets.h +++ b/include/ZeroTierSockets.h @@ -2623,6 +2623,24 @@ ZTS_API int ZTCALL zts_udp_client(const char* remote_ipstr); */ ZTS_API int ZTCALL zts_set_no_delay(int fd, int enabled); +/** + * @brief Get the last error for the given socket + * + * @param fd Socket file descriptor + * @return Error number defined in `zts_errno_t`. `ZTS_ERR_SERVICE` if the node + * experiences a problem, `ZTS_ERR_ARG` if invalid argument. Sets `zts_errno` + */ +ZTS_API int ZTCALL zts_get_last_socket_error(int fd); + +/** + * @brief Return amount of data available to read from socket + * + * @param fd Socket file descriptor + * @return Number of bytes available to read. `ZTS_ERR_SERVICE` if the node + * experiences a problem, `ZTS_ERR_ARG` if invalid argument. Sets `zts_errno` + */ +ZTS_API size_t ZTCALL zts_get_data_available(int fd); + /** * @brief Return whether `TCP_NODELAY` is enabled * diff --git a/pkg/nuget/ZeroTier.Sockets/ZeroTier.Sockets.x64.nuspec b/pkg/nuget/ZeroTier.Sockets/ZeroTier.Sockets.x64.nuspec index 5502015..f8c5106 100644 --- a/pkg/nuget/ZeroTier.Sockets/ZeroTier.Sockets.x64.nuspec +++ b/pkg/nuget/ZeroTier.Sockets/ZeroTier.Sockets.x64.nuspec @@ -9,7 +9,7 @@ LICENSE.txt icon.png false - Namespace adjustments, additions to Socket API, memory leak fixes. + Add NetworkStream. Fix errno bug Encrypted P2P SD-WAN networking layer (Managed C# API) [x64] Encrypted P2P SD-WAN networking layer (Managed C# API) [x64] Copyright 2021 ZeroTier, Inc. @@ -26,7 +26,7 @@ - + diff --git a/src/Sockets.cpp b/src/Sockets.cpp index 9b2bb2e..0049ba4 100644 --- a/src/Sockets.cpp +++ b/src/Sockets.cpp @@ -606,6 +606,30 @@ int zts_udp_client(const char* remote_ipstr) return fd; } +int zts_get_last_socket_error(int fd) +{ + int optval = 0; + zts_socklen_t optlen = sizeof(optval); + int err = ZTS_ERR_OK; + if ((err = zts_bsd_getsockopt(fd, ZTS_SOL_SOCKET, ZTS_SO_ERROR, &optval, &optlen)) < 0) { + return err; + } + return optval; +} + +size_t zts_get_data_available(int fd) +{ + if (! transport_ok()) { + return ZTS_ERR_SERVICE; + } + int err = ZTS_ERR_OK; + size_t bytes_available = 0; + if ((err = zts_bsd_ioctl(fd, ZTS_FIONREAD, &bytes_available)) < 0) { + return err; + } + return bytes_available; +} + int zts_set_no_delay(int fd, int enabled) { if (! transport_ok()) { diff --git a/src/bindings/csharp/CSharpSockets.cxx b/src/bindings/csharp/CSharpSockets.cxx index 5ba3765..2baf66b 100644 --- a/src/bindings/csharp/CSharpSockets.cxx +++ b/src/bindings/csharp/CSharpSockets.cxx @@ -398,6 +398,11 @@ SWIGINTERN void SWIG_CSharpException(int code, const char* msg) extern "C" { #endif +SWIGEXPORT int SWIGSTDCALL CSharp_zts_errno_get() +{ + return zts_errno; +} + #ifndef ZTS_DISABLE_CENTRAL_API SWIGEXPORT int SWIGSTDCALL CSharp_zts_central_set_access_mode(char jarg1) @@ -1574,6 +1579,11 @@ SWIGEXPORT int SWIGSTDCALL CSharp_zts_bsd_shutdown(int jarg1, int jarg2) return jresult; } +SWIGEXPORT size_t SWIGSTDCALL CSharp_zts_get_data_available(int fd) +{ + return zts_get_data_available(fd); +} + SWIGEXPORT int SWIGSTDCALL CSharp_zts_set_no_delay(int jarg1, int jarg2) { int jresult; diff --git a/src/bindings/csharp/NetworkStream.cs b/src/bindings/csharp/NetworkStream.cs new file mode 100644 index 0000000..2b87e6c --- /dev/null +++ b/src/bindings/csharp/NetworkStream.cs @@ -0,0 +1,372 @@ +/* + * Copyright (c)2013-2021 ZeroTier, Inc. + * + * Use of this software is governed by the Business Source License included + * in the LICENSE.TXT file in the project's root directory. + * + * Change Date: 2026-01-01 + * + * On the date above, in accordance with the Business Source License, use + * of this software will be governed by version 2.0 of the Apache License. + */ +/****/ + +using System; +using System.Threading; +using System.IO; +using System.Runtime.InteropServices; +using System.Net.Sockets; + +using ZeroTier; + +namespace ZeroTier.Sockets +{ + public class NetworkStream : Stream { + private ZeroTier.Sockets.Socket _streamSocket; + + private bool _isReadable; + private bool _isWriteable; + private bool _ownsSocket; + private volatile bool _isDisposed = false; + + internal NetworkStream() + { + _ownsSocket = true; + } + + public NetworkStream(ZeroTier.Sockets.Socket socket) + { + if (socket == null) { + throw new ArgumentNullException("socket"); + } + InitNetworkStream(socket, FileAccess.ReadWrite); + } + + public NetworkStream(ZeroTier.Sockets.Socket socket, bool ownsSocket) + { + if (socket == null) { + throw new ArgumentNullException("socket"); + } + InitNetworkStream(socket, FileAccess.ReadWrite); + _ownsSocket = ownsSocket; + } + + public NetworkStream(ZeroTier.Sockets.Socket socket, FileAccess accessMode) + { + if (socket == null) { + throw new ArgumentNullException("socket"); + } + InitNetworkStream(socket, accessMode); + } + + public NetworkStream(ZeroTier.Sockets.Socket socket, FileAccess accessMode, bool ownsSocket) + { + if (socket == null) { + throw new ArgumentNullException("socket"); + } + InitNetworkStream(socket, accessMode); + _ownsSocket = ownsSocket; + } + + internal NetworkStream(NetworkStream networkStream, bool ownsSocket) + { + ZeroTier.Sockets.Socket socket = networkStream.Socket; + if (socket == null) { + throw new ArgumentNullException("networkStream"); + } + InitNetworkStream(socket, FileAccess.ReadWrite); + _ownsSocket = ownsSocket; + } + + protected ZeroTier.Sockets.Socket Socket + { + get { + return _streamSocket; + } + } + + internal void ConvertToNotSocketOwner() + { + _ownsSocket = false; + GC.SuppressFinalize(this); + } + + public override int ReadTimeout + { + get { + return _streamSocket.ReceiveTimeout; + } + set { + if (value <= 0) { + throw new ArgumentOutOfRangeException("Timeout value must be greater than zero"); + } + _streamSocket.ReceiveTimeout = value; + } + } + + public override int WriteTimeout + { + get { + return _streamSocket.SendTimeout; + } + set { + if (value <= 0) { + throw new ArgumentOutOfRangeException("Timeout value must be greater than zero"); + } + _streamSocket.SendTimeout = value; + } + } + + protected bool Readable + { + get { + return _isReadable; + } + set { + _isReadable = value; + } + } + + protected bool Writeable + { + get { + return _isWriteable; + } + set { + _isWriteable = value; + } + } + + public override bool CanRead + { + get { + return _isReadable; + } + } + + public override bool CanSeek + { + get { + return false; + } + } + + public override bool CanWrite + { + get { + return _isWriteable; + } + } + + public override bool CanTimeout + { + get { + return true; + } + } + + public virtual bool DataAvailable + { + get { + if (_streamSocket == null) { + throw new IOException("ZeroTier socket is null"); + } + if (_isDisposed) { + throw new ObjectDisposedException("ZeroTier.Sockets.Socket"); + } + return _streamSocket.Available != 0; + } + } + + internal void InitNetworkStream(ZeroTier.Sockets.Socket socket, FileAccess accessMode) + { + if (! socket.Connected) { + throw new IOException("ZeroTier socket must be connected"); + } + if (! socket.Blocking) { + throw new IOException("ZeroTier socket must be in blocking mode"); + } + if (socket.SocketType != SocketType.Stream) { + throw new IOException("ZeroTier socket must by stream type"); + } + + _streamSocket = socket; + + switch (accessMode) { + case FileAccess.Write: + _isWriteable = true; + break; + case FileAccess.Read: + _isReadable = true; + break; + case FileAccess.ReadWrite: + default: + _isReadable = true; + _isWriteable = true; + break; + } + } + + public override int Read([In, Out] byte[] buffer, int offset, int size) + { + bool canRead = CanRead; + if (_isDisposed) { + throw new ObjectDisposedException("ZeroTier.Sockets.Socket"); + } + if (! canRead) { + throw new InvalidOperationException("Cannot read from ZeroTier socket"); + } + + if (buffer == null) { + throw new ArgumentNullException("buffer"); + } + if (offset < 0 || offset > buffer.Length) { + throw new ArgumentOutOfRangeException("offset"); + } + if (size < 0 || size > buffer.Length - offset) { + throw new ArgumentOutOfRangeException("size"); + } + + if (_streamSocket == null) { + throw new IOException("ZeroTier socket is null"); + } + + try { + int bytesTransferred = _streamSocket.Receive(buffer, offset, size, 0); + return bytesTransferred; + } + catch (Exception exception) { + throw new IOException("Cannot read from ZeroTier socket", exception); + } + } + + public override void Write(byte[] buffer, int offset, int size) + { + bool canWrite = CanWrite; + if (_isDisposed) { + throw new ObjectDisposedException("ZeroTier.Sockets.Socket"); + } + if (! canWrite) { + throw new InvalidOperationException("Cannot write to ZeroTier socket"); + } + if (buffer == null) { + throw new ArgumentNullException("buffer"); + } + if (offset < 0 || offset > buffer.Length) { + throw new ArgumentOutOfRangeException("offset"); + } + if (size < 0 || size > buffer.Length - offset) { + throw new ArgumentOutOfRangeException("size"); + } + if (_streamSocket == null) { + throw new IOException("ZeroTier socket is null"); + } + + try { + _streamSocket.Send(buffer, offset, size, SocketFlags.None); + } + catch (Exception exception) { + throw new IOException("Cannot write to ZeroTier socket", exception); + } + } + + internal bool Poll(int microSeconds, SelectMode mode) + { + if (_streamSocket == null) { + throw new IOException("ZeroTier socket is null"); + } + if (_isDisposed) { + throw new ObjectDisposedException("ZeroTier.Sockets.Socket"); + } + return _streamSocket.Poll(microSeconds, mode); + } + + internal bool PollRead() + { + if (_streamSocket == null) { + return false; + } + if (_isDisposed) { + return false; + } + return _streamSocket.Poll(0, SelectMode.SelectRead); + } + + public override void Flush() + { + // Not applicable + } + + public override void SetLength(long value) + { + throw new NotSupportedException("Not supported"); + } + + public override long Length + { + get { + throw new NotSupportedException("Not supported"); + } + } + + public override long Position + { + get { + throw new NotSupportedException("Not supported"); + } + + set { + throw new NotSupportedException("Not supported"); + } + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException("Not supported"); + } + + public void Close(int timeout) + { + if (timeout < 0) { + throw new ArgumentOutOfRangeException("Timeout value must be greater than zero"); + } + _streamSocket.Close(timeout); + } + + internal bool Connected + { + get { + if (! _isDisposed && _streamSocket != null && _streamSocket.Connected) { + return true; + } + else { + return false; + } + } + } + + protected override void Dispose(bool disposing) + { + bool cleanedUp = _isDisposed; + _isDisposed = true; + if (! cleanedUp && disposing) { + if (_streamSocket != null) { + _isWriteable = false; + _isReadable = false; + if (_ownsSocket) { + if (_streamSocket != null) { + _streamSocket.Shutdown(SocketShutdown.Both); + _streamSocket.Close(); + } + } + } + } + base.Dispose(disposing); + } + + ~NetworkStream() + { + Dispose(false); + } + } +} diff --git a/src/bindings/csharp/Socket.cs b/src/bindings/csharp/Socket.cs index ea6d49e..89a4d6a 100644 --- a/src/bindings/csharp/Socket.cs +++ b/src/bindings/csharp/Socket.cs @@ -230,6 +230,12 @@ namespace ZeroTier.Sockets public void Close() { + Close(0); + } + + public void Close(int timeout) + { + // TODO: Timeout needs to be implemented if (_isClosed) { throw new ObjectDisposedException("Socket has already been closed"); } @@ -291,21 +297,10 @@ namespace ZeroTier.Sockets public Int32 Send(Byte[] buffer) { - if (_isClosed) { - throw new ObjectDisposedException("Socket has been closed"); - } - if (_fd < 0) { - throw new ZeroTier.Sockets.SocketException((int)ZeroTier.Constants.ERR_SOCKET); - } - if (buffer == null) { - throw new ArgumentNullException("buffer"); - } - int flags = 0; - IntPtr bufferPtr = Marshal.UnsafeAddrOfPinnedArrayElement(buffer, 0); - return zts_bsd_send(_fd, bufferPtr, (uint)Buffer.ByteLength(buffer), (int)flags); + return Send(buffer, 0, buffer != null ? buffer.Length : 0, SocketFlags.None); } - public Int32 Receive(Byte[] buffer) + public Int32 Send(Byte[] buffer, int offset, int size, SocketFlags socketFlags) { if (_isClosed) { throw new ObjectDisposedException("Socket has been closed"); @@ -316,9 +311,49 @@ namespace ZeroTier.Sockets if (buffer == null) { throw new ArgumentNullException("buffer"); } + if (size < 0 || size > buffer.Length - offset) { + throw new ArgumentOutOfRangeException("size"); + } + if (offset < 0 || offset > buffer.Length) { + throw new ArgumentOutOfRangeException("offset"); + } int flags = 0; IntPtr bufferPtr = Marshal.UnsafeAddrOfPinnedArrayElement(buffer, 0); - return zts_bsd_recv(_fd, bufferPtr, (uint)Buffer.ByteLength(buffer), (int)flags); + return zts_bsd_send(_fd, bufferPtr + offset, (uint)Buffer.ByteLength(buffer), (int)flags); + } + + public int Available + { + get { + return zts_get_data_available(_fd); + } + } + + public Int32 Receive(Byte[] buffer) + { + return Receive(buffer, 0, buffer != null ? buffer.Length : 0, SocketFlags.None); + } + + public Int32 Receive(byte[] buffer, int offset, int size, SocketFlags socketFlags) + { + if (_isClosed) { + throw new ObjectDisposedException("Socket has been closed"); + } + if (_fd < 0) { + throw new ZeroTier.Sockets.SocketException((int)ZeroTier.Constants.ERR_SOCKET); + } + if (buffer == null) { + throw new ArgumentNullException("buffer"); + } + if (size < 0 || size > buffer.Length - offset) { + throw new ArgumentOutOfRangeException("size"); + } + if (offset < 0 || offset > buffer.Length) { + throw new ArgumentOutOfRangeException("offset"); + } + int flags = 0; + IntPtr bufferPtr = Marshal.UnsafeAddrOfPinnedArrayElement(buffer, 0); + return zts_bsd_recv(_fd, bufferPtr + offset, (uint)Buffer.ByteLength(buffer), (int)flags); } public int ReceiveTimeout @@ -579,6 +614,9 @@ namespace ZeroTier.Sockets [DllImport("libzt", EntryPoint = "CSharp_zts_bsd_shutdown")] static extern int zts_bsd_shutdown(int arg1, int arg2); + [DllImport("libzt", EntryPoint = "CSharp_zts_get_data_available")] + static extern int zts_get_data_available(int fd); + [DllImport("libzt", EntryPoint = "CSharp_zts_set_no_delay")] static extern int zts_set_no_delay(int fd, int enabled); diff --git a/test/selftest.c b/test/selftest.c index f4a634e..ce9eafd 100644 --- a/test/selftest.c +++ b/test/selftest.c @@ -1146,6 +1146,15 @@ void test_server_socket_usage(uint16_t port4, uint16_t port6) // Read message memset(dstbuf, 0, buflen); + + // Test zts_get_data_available + while (1) { + int av = zts_get_data_available(acc4); + zts_util_delay(50); + if (av > 0) { + break; + } + } bytes_read = zts_bsd_read(acc4, dstbuf, buflen); DEBUG_INFO("server4: read (%d) bytes", bytes_read); assert(bytes_read == msglen && zts_errno == 0);