How to (repeatedly) read from .NET SslStream with a timeout?

前端 未结 2 405
悲&欢浪女
悲&欢浪女 2021-02-05 12:15

I just need to read up to N bytes from a SslStream but if no byte has been received before a timeout, cancel, while leaving the stream in a valid state

相关标签:
2条回答
  • 2021-02-05 12:25

    I also encountered this problem with an SslStream returning five bytes of garbage data on the read after a timeout, and I separately came up a solution that is similar to OP's Update #3.

    I created a wrapper class which wraps the Tcp NetworkStream object as it is passed into the SslStream constructor. The wrapper class passes all calls onto to the underlying NetworkStream except that the Read() method includes an extra try...catch to suppress the Timeout exception and return 0 bytes instead.

    SslStream works correctly in this instance, including raising the appropriate IOException if the socket is closed. Note that our Stream returning 0 from a Read() is different from a TcpClient or Socket returning 0 from a Read() (which typically means a socket disconnect).

    class SocketTimeoutSuppressedStream : Stream
    {
        NetworkStream mStream;
    
        public SocketTimeoutSuppressedStream(NetworkStream pStream)
        {
            mStream = pStream;
        }
    
        public override int Read(byte[] buffer, int offset, int count)
        {
            try
            {
                return mStream.Read(buffer, offset, count);
            }
            catch (IOException lException)
            {
                SocketException lInnerException = lException.InnerException as SocketException;
                if (lInnerException != null && lInnerException.SocketErrorCode == SocketError.TimedOut)
                {
                    // Normally, a simple TimeOut on the read will cause SslStream to flip its lid
                    // However, if we suppress the IOException and just return 0 bytes read, this is ok.
                    // Note that this is not a "Socket.Read() returning 0 means the socket closed",
                    // this is a "Stream.Read() returning 0 means that no data is available"
                    return 0;
                }
                throw;
            }
        }
    
    
        public override bool CanRead => mStream.CanRead;
        public override bool CanSeek => mStream.CanSeek;
        public override bool CanTimeout => mStream.CanTimeout;
        public override bool CanWrite => mStream.CanWrite;
        public virtual bool DataAvailable => mStream.DataAvailable;
        public override long Length => mStream.Length;
        public override IAsyncResult BeginRead(byte[] buffer, int offset, int size, AsyncCallback callback, object state) => mStream.BeginRead(buffer, offset, size, callback, state);
        public override IAsyncResult BeginWrite(byte[] buffer, int offset, int size, AsyncCallback callback, object state) => mStream.BeginWrite(buffer, offset, size, callback, state);
        public void Close(int timeout) => mStream.Close(timeout);
        public override int EndRead(IAsyncResult asyncResult) => mStream.EndRead(asyncResult);
        public override void EndWrite(IAsyncResult asyncResult) => mStream.EndWrite(asyncResult);
        public override void Flush() => mStream.Flush();
        public override Task FlushAsync(CancellationToken cancellationToken) => mStream.FlushAsync(cancellationToken);
        public override long Seek(long offset, SeekOrigin origin) => mStream.Seek(offset, origin);
        public override void SetLength(long value) => mStream.SetLength(value);
        public override void Write(byte[] buffer, int offset, int count) => mStream.Write(buffer, offset, count);
    
        public override long Position
        {
            get { return mStream.Position; }
            set { mStream.Position = value; }
        }
    
        public override int ReadTimeout
        {
            get { return mStream.ReadTimeout; }
            set { mStream.ReadTimeout = value; }
        }
    
        public override int WriteTimeout
        {
            get { return mStream.WriteTimeout; }
            set { mStream.WriteTimeout = value; }
        }
    }
    

    This can then be used by wrapping the TcpClient NetworkStream object before it's passed to the SslStream, as follows:

    NetworkStream lTcpStream = lTcpClient.GetStream();
    SocketTimeoutSuppressedStream lSuppressedStream = new SocketTimeoutSuppressedStream(lTcpStream);
    using (lSslStream = new SslStream(lSuppressedStream, true, ServerCertificateValidation, SelectLocalCertificate, EncryptionPolicy.RequireEncryption))
    

    The problem comes down to SslStream corrupting its internal state on any exception from the underlying stream, even a harmless timeout. Oddly, the five (or so) bytes of data that the next read() returns are actually the start of the TLS encrypted payload data from the wire.

    Hope this helps

    0 讨论(0)
  • 2021-02-05 12:27

    You can certainly make approach #1 work. You simply need to keep track of the Task and continue waiting without calling ReadAsync again. So, very roughly:

    private Task readTask;     // class level variable
    ...
      if (readTask == null) readTask = stream->ReadAsync(buffer, 0, buffer->Length);
      if (task->Wait(timeout_ms)) {
         try {
             count = task->Result;
             ...
         }
         finally {
             task = null;
         }
      }
    

    Needs to be fleshed-out a bit so the caller can see that the read isn't completed yet but the snippet is too small to give concrete advice.

    0 讨论(0)
提交回复
热议问题