aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorThorsten von Eicken <tve@voneicken.com>2020-04-02 10:01:16 -0700
committerDamien George <damien@micropython.org>2021-02-17 11:50:54 +1100
commit2c1299b0071c2c528cc01e3cde9eb22743820176 (patch)
tree0679eb4daf9522f30cec65b3d7bce494029482b9
parent2eed9780ba7074de9e464a2bc771ad14f0332a6c (diff)
extmod/modussl: Fix ussl read/recv/send/write errors when non-blocking.
Also fix related problems with socket on esp32, improve docs for wrap_socket, and add more tests.
-rw-r--r--docs/library/ussl.rst22
-rw-r--r--extmod/modussl_axtls.c31
-rw-r--r--extmod/modussl_mbedtls.c3
-rw-r--r--ports/esp32/modsocket.c8
-rw-r--r--tests/net_hosted/accept_timeout.py6
-rw-r--r--tests/net_hosted/connect_nonblock_xfer.py147
-rw-r--r--tests/net_inet/ssl_errors.py51
-rw-r--r--tests/net_inet/test_tls_nonblock.py116
-rw-r--r--tests/net_inet/test_tls_sites.py6
-rw-r--r--tests/net_inet/test_tls_sites.py.exp3
10 files changed, 373 insertions, 20 deletions
diff --git a/docs/library/ussl.rst b/docs/library/ussl.rst
index ffe146331..14e3f3ad1 100644
--- a/docs/library/ussl.rst
+++ b/docs/library/ussl.rst
@@ -13,16 +13,23 @@ facilities for network sockets, both client-side and server-side.
Functions
---------
-.. function:: ussl.wrap_socket(sock, server_side=False, keyfile=None, certfile=None, cert_reqs=CERT_NONE, ca_certs=None)
-
+.. function:: ussl.wrap_socket(sock, server_side=False, keyfile=None, certfile=None, cert_reqs=CERT_NONE, ca_certs=None, do_handshake=True)
Takes a `stream` *sock* (usually usocket.socket instance of ``SOCK_STREAM`` type),
and returns an instance of ssl.SSLSocket, which wraps the underlying stream in
an SSL context. Returned object has the usual `stream` interface methods like
- ``read()``, ``write()``, etc. In MicroPython, the returned object does not expose
- socket interface and methods like ``recv()``, ``send()``. In particular, a
- server-side SSL socket should be created from a normal socket returned from
+ ``read()``, ``write()``, etc.
+ A server-side SSL socket should be created from a normal socket returned from
:meth:`~usocket.socket.accept()` on a non-SSL listening server socket.
+ - *do_handshake* determines whether the handshake is done as part of the ``wrap_socket``
+ or whether it is deferred to be done as part of the initial reads or writes
+ (there is no ``do_handshake`` method as in CPython).
+ For blocking sockets doing the handshake immediately is standard. For non-blocking
+ sockets (i.e. when the *sock* passed into ``wrap_socket`` is in non-blocking mode)
+ the handshake should generally be deferred because otherwise ``wrap_socket`` blocks
+ until it completes. Note that in AXTLS the handshake can be deferred until the first
+ read or write but it then blocks until completion.
+
Depending on the underlying module implementation in a particular
:term:`MicroPython port`, some or all keyword arguments above may be not supported.
@@ -31,6 +38,11 @@ Functions
Some implementations of ``ussl`` module do NOT validate server certificates,
which makes an SSL connection established prone to man-in-the-middle attacks.
+ CPython's ``wrap_socket`` returns an ``SSLSocket`` object which has methods typical
+ for sockets, such as ``send``, ``recv``, etc. MicroPython's ``wrap_socket``
+ returns an object more similar to CPython's ``SSLObject`` which does not have
+ these socket methods.
+
Exceptions
----------
diff --git a/extmod/modussl_axtls.c b/extmod/modussl_axtls.c
index da5941a55..9d5934206 100644
--- a/extmod/modussl_axtls.c
+++ b/extmod/modussl_axtls.c
@@ -167,10 +167,15 @@ STATIC mp_obj_ssl_socket_t *ussl_socket_new(mp_obj_t sock, struct ssl_args *args
o->ssl_sock = ssl_client_new(o->ssl_ctx, (long)sock, NULL, 0, ext);
if (args->do_handshake.u_bool) {
- int res = ssl_handshake_status(o->ssl_sock);
-
- if (res != SSL_OK) {
- ussl_raise_error(res);
+ int r = ssl_handshake_status(o->ssl_sock);
+
+ if (r != SSL_OK) {
+ if (r == SSL_CLOSE_NOTIFY) { // EOF
+ r = MP_ENOTCONN;
+ } else if (r == SSL_EAGAIN) {
+ r = MP_EAGAIN;
+ }
+ ussl_raise_error(r);
}
}
@@ -242,8 +247,24 @@ STATIC mp_uint_t ussl_socket_write(mp_obj_t o_in, const void *buf, mp_uint_t siz
return MP_STREAM_ERROR;
}
- mp_int_t r = ssl_write(o->ssl_sock, buf, size);
+ mp_int_t r;
+eagain:
+ r = ssl_write(o->ssl_sock, buf, size);
+ if (r == 0) {
+ // see comment in ussl_socket_read above
+ if (o->blocking) {
+ goto eagain;
+ } else {
+ r = SSL_EAGAIN;
+ }
+ }
if (r < 0) {
+ if (r == SSL_CLOSE_NOTIFY || r == SSL_ERROR_CONN_LOST) {
+ return 0; // EOF
+ }
+ if (r == SSL_EAGAIN) {
+ r = MP_EAGAIN;
+ }
*errcode = r;
return MP_STREAM_ERROR;
}
diff --git a/extmod/modussl_mbedtls.c b/extmod/modussl_mbedtls.c
index 1677dc6e1..277af37c7 100644
--- a/extmod/modussl_mbedtls.c
+++ b/extmod/modussl_mbedtls.c
@@ -133,6 +133,7 @@ STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
}
}
+// _mbedtls_ssl_recv is called by mbedtls to receive bytes from the underlying socket
STATIC int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
mp_obj_t sock = *(mp_obj_t *)ctx;
@@ -171,7 +172,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
mbedtls_pk_init(&o->pkey);
mbedtls_ctr_drbg_init(&o->ctr_drbg);
#ifdef MBEDTLS_DEBUG_C
- // Debug level (0-4)
+ // Debug level (0-4) 1=warning, 2=info, 3=debug, 4=verbose
mbedtls_debug_set_threshold(0);
#endif
diff --git a/ports/esp32/modsocket.c b/ports/esp32/modsocket.c
index 61761d819..5135e3163 100644
--- a/ports/esp32/modsocket.c
+++ b/ports/esp32/modsocket.c
@@ -558,7 +558,8 @@ int _socket_send(socket_obj_t *sock, const char *data, size_t datalen) {
MP_THREAD_GIL_EXIT();
int r = lwip_write(sock->fd, data + sentlen, datalen - sentlen);
MP_THREAD_GIL_ENTER();
- if (r < 0 && errno != EWOULDBLOCK) {
+ // lwip returns EINPROGRESS when trying to send right after a non-blocking connect
+ if (r < 0 && errno != EWOULDBLOCK && errno != EINPROGRESS) {
mp_raise_OSError(errno);
}
if (r > 0) {
@@ -567,7 +568,7 @@ int _socket_send(socket_obj_t *sock, const char *data, size_t datalen) {
check_for_exceptions();
}
if (sentlen == 0) {
- mp_raise_OSError(MP_ETIMEDOUT);
+ mp_raise_OSError(sock->retries == 0 ? MP_EWOULDBLOCK : MP_ETIMEDOUT);
}
return sentlen;
}
@@ -650,7 +651,8 @@ STATIC mp_uint_t socket_stream_write(mp_obj_t self_in, const void *buf, mp_uint_
if (r > 0) {
return r;
}
- if (r < 0 && errno != EWOULDBLOCK) {
+ // lwip returns MP_EINPROGRESS when trying to write right after a non-blocking connect
+ if (r < 0 && errno != EWOULDBLOCK && errno != EINPROGRESS) {
*errcode = errno;
return MP_STREAM_ERROR;
}
diff --git a/tests/net_hosted/accept_timeout.py b/tests/net_hosted/accept_timeout.py
index ff989110a..5f528d557 100644
--- a/tests/net_hosted/accept_timeout.py
+++ b/tests/net_hosted/accept_timeout.py
@@ -1,9 +1,9 @@
# test that socket.accept() on a socket with timeout raises ETIMEDOUT
try:
- import usocket as socket
+ import uerrno as errno, usocket as socket
except:
- import socket
+ import errno, socket
try:
socket.socket.settimeout
@@ -18,5 +18,5 @@ s.listen(1)
try:
s.accept()
except OSError as er:
- print(er.args[0] in (110, "timed out")) # 110 is ETIMEDOUT; CPython uses a string
+ print(er.args[0] in (errno.ETIMEDOUT, "timed out")) # CPython uses a string instead of errno
s.close()
diff --git a/tests/net_hosted/connect_nonblock_xfer.py b/tests/net_hosted/connect_nonblock_xfer.py
new file mode 100644
index 000000000..feb648ea0
--- /dev/null
+++ b/tests/net_hosted/connect_nonblock_xfer.py
@@ -0,0 +1,147 @@
+# test that socket.connect() on a non-blocking socket raises EINPROGRESS
+# and that an immediate write/send/read/recv does the right thing
+
+try:
+ import sys, time
+ import uerrno as errno, usocket as socket, ussl as ssl
+except:
+ import socket, errno, ssl
+isMP = sys.implementation.name == "micropython"
+
+
+def dp(e):
+ # uncomment next line for development and testing, to print the actual exceptions
+ # print(repr(e))
+ pass
+
+
+# do_connect establishes the socket and wraps it if tls is True.
+# If handshake is true, the initial connect (and TLS handshake) is
+# allowed to be performed before returning.
+def do_connect(peer_addr, tls, handshake):
+ s = socket.socket()
+ s.setblocking(False)
+ try:
+ # print("Connecting to", peer_addr)
+ s.connect(peer_addr)
+ except OSError as er:
+ print("connect:", er.args[0] == errno.EINPROGRESS)
+ if er.args[0] != errno.EINPROGRESS:
+ print(" got", er.args[0])
+ # wrap with ssl/tls if desired
+ if tls:
+ try:
+ if sys.implementation.name == "micropython":
+ s = ssl.wrap_socket(s, do_handshake=handshake)
+ else:
+ s = ssl.wrap_socket(s, do_handshake_on_connect=handshake)
+ print("wrap: True")
+ except Exception as e:
+ dp(e)
+ print("wrap:", e)
+ elif handshake:
+ # just sleep a little bit, this allows any connect() errors to happen
+ time.sleep(0.2)
+ return s
+
+
+# test runs the test against a specific peer address.
+def test(peer_addr, tls=False, handshake=False):
+ # MicroPython plain sockets have read/write, but CPython's don't
+ # MicroPython TLS sockets and CPython's have read/write
+ # hasRW captures this wonderful state of affairs
+ hasRW = isMP or tls
+
+ # MicroPython plain sockets and CPython's have send/recv
+ # MicroPython TLS sockets don't have send/recv, but CPython's do
+ # hasSR captures this wonderful state of affairs
+ hasSR = not (isMP and tls)
+
+ # connect + send
+ if hasSR:
+ s = do_connect(peer_addr, tls, handshake)
+ # send -> 4 or EAGAIN
+ try:
+ ret = s.send(b"1234")
+ print("send:", handshake and ret == 4)
+ except OSError as er:
+ #
+ dp(er)
+ print("send:", er.args[0] in (errno.EAGAIN, errno.EINPROGRESS))
+ s.close()
+ else: # fake it...
+ print("connect:", True)
+ if tls:
+ print("wrap:", True)
+ print("send:", True)
+
+ # connect + write
+ if hasRW:
+ s = do_connect(peer_addr, tls, handshake)
+ # write -> None
+ try:
+ ret = s.write(b"1234")
+ print("write:", ret in (4, None)) # SSL may accept 4 into buffer
+ except OSError as er:
+ dp(er)
+ print("write:", False) # should not raise
+ except ValueError as er: # CPython
+ dp(er)
+ print("write:", er.args[0] == "Write on closed or unwrapped SSL socket.")
+ s.close()
+ else: # fake it...
+ print("connect:", True)
+ if tls:
+ print("wrap:", True)
+ print("write:", True)
+
+ if hasSR:
+ # connect + recv
+ s = do_connect(peer_addr, tls, handshake)
+ # recv -> EAGAIN
+ try:
+ print("recv:", s.recv(10))
+ except OSError as er:
+ dp(er)
+ print("recv:", er.args[0] == errno.EAGAIN)
+ s.close()
+ else: # fake it...
+ print("connect:", True)
+ if tls:
+ print("wrap:", True)
+ print("recv:", True)
+
+ # connect + read
+ if hasRW:
+ s = do_connect(peer_addr, tls, handshake)
+ # read -> None
+ try:
+ ret = s.read(10)
+ print("read:", ret is None)
+ except OSError as er:
+ dp(er)
+ print("read:", False) # should not raise
+ except ValueError as er: # CPython
+ dp(er)
+ print("read:", er.args[0] == "Read on closed or unwrapped SSL socket.")
+ s.close()
+ else: # fake it...
+ print("connect:", True)
+ if tls:
+ print("wrap:", True)
+ print("read:", True)
+
+
+if __name__ == "__main__":
+ # these tests use a non-existent test IP address, this way the connect takes forever and
+ # we can see EAGAIN/None (https://tools.ietf.org/html/rfc5737)
+ print("--- Plain sockets to nowhere ---")
+ test(socket.getaddrinfo("192.0.2.1", 80)[0][-1], False, False)
+ print("--- SSL sockets to nowhere ---")
+ # this test fails with AXTLS because do_handshake=False blocks on first read/write and
+ # there it times out until the connect is aborted
+ test(socket.getaddrinfo("192.0.2.1", 443)[0][-1], True, False)
+ print("--- Plain sockets ---")
+ test(socket.getaddrinfo("micropython.org", 80)[0][-1], False, True)
+ print("--- SSL sockets ---")
+ test(socket.getaddrinfo("micropython.org", 443)[0][-1], True, True)
diff --git a/tests/net_inet/ssl_errors.py b/tests/net_inet/ssl_errors.py
new file mode 100644
index 000000000..fd281b1c4
--- /dev/null
+++ b/tests/net_inet/ssl_errors.py
@@ -0,0 +1,51 @@
+# test that socket.connect() on a non-blocking socket raises EINPROGRESS
+# and that an immediate write/send/read/recv does the right thing
+
+import sys
+
+try:
+ import uerrno as errno, usocket as socket, ussl as ssl
+except:
+ import errno, socket, ssl
+
+
+def test(addr, hostname, block=True):
+ print("---", hostname or addr)
+ s = socket.socket()
+ s.setblocking(block)
+ try:
+ s.connect(addr)
+ print("connected")
+ except OSError as e:
+ if e.args[0] != errno.EINPROGRESS:
+ raise
+ print("EINPROGRESS")
+
+ try:
+ if sys.implementation.name == "micropython":
+ s = ssl.wrap_socket(s, do_handshake=block)
+ else:
+ s = ssl.wrap_socket(s, do_handshake_on_connect=block)
+ print("wrap: True")
+ except OSError:
+ print("wrap: error")
+
+ if not block:
+ try:
+ while s.write(b"0") is None:
+ pass
+ except (ValueError, OSError): # CPython raises ValueError, MicroPython raises OSError
+ print("write: error")
+ s.close()
+
+
+if __name__ == "__main__":
+ # connect to plain HTTP port, oops!
+ addr = socket.getaddrinfo("micropython.org", 80)[0][-1]
+ test(addr, None)
+ # connect to plain HTTP port, oops!
+ addr = socket.getaddrinfo("micropython.org", 80)[0][-1]
+ test(addr, None, False)
+ # connect to server with self-signed cert, oops!
+ addr = socket.getaddrinfo("test.mosquitto.org", 8883)[0][-1]
+ test(addr, "test.mosquitto.org")
diff --git a/tests/net_inet/test_tls_nonblock.py b/tests/net_inet/test_tls_nonblock.py
new file mode 100644
index 000000000..c27ead3d5
--- /dev/null
+++ b/tests/net_inet/test_tls_nonblock.py
@@ -0,0 +1,116 @@
+try:
+ import usocket as socket, ussl as ssl, uerrno as errno, sys
+except:
+ import socket, ssl, errno, sys, time, select
+
+
+def test_one(site, opts):
+ ai = socket.getaddrinfo(site, 443)
+ addr = ai[0][-1]
+ print(addr)
+
+ # Connect the raw socket
+ s = socket.socket()
+ s.setblocking(False)
+ try:
+ s.connect(addr)
+ raise OSError(-1, "connect blocks")
+ except OSError as e:
+ if e.args[0] != errno.EINPROGRESS:
+ raise
+
+ if sys.implementation.name != "micropython":
+ # in CPython we have to wait, otherwise wrap_socket is not happy
+ select.select([], [s], [])
+
+ try:
+ # Wrap with SSL
+ try:
+ if sys.implementation.name == "micropython":
+ s = ssl.wrap_socket(s, do_handshake=False)
+ else:
+ s = ssl.wrap_socket(s, do_handshake_on_connect=False)
+ except OSError as e:
+ if e.args[0] != errno.EINPROGRESS:
+ raise
+ print("wrapped")
+
+ # CPython needs to be told to do the handshake
+ if sys.implementation.name != "micropython":
+ while True:
+ try:
+ s.do_handshake()
+ break
+ except ssl.SSLError as err:
+ if err.args[0] == ssl.SSL_ERROR_WANT_READ:
+ select.select([s], [], [])
+ elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
+ select.select([], [s], [])
+ else:
+ raise
+ time.sleep(0.1)
+ # print("shook hands")
+
+ # Write HTTP request
+ out = b"GET / HTTP/1.0\r\nHost: %s\r\n\r\n" % bytes(site, "latin")
+ while len(out) > 0:
+ n = s.write(out)
+ if n is None:
+ continue
+ if n > 0:
+ out = out[n:]
+ elif n == 0:
+ raise OSError(-1, "unexpected EOF in write")
+ print("wrote")
+
+ # Read response
+ resp = b""
+ while True:
+ try:
+ b = s.read(128)
+ except OSError as err:
+ if err.args[0] == 2: # 2=ssl.SSL_ERROR_WANT_READ:
+ continue
+ raise
+ if b is None:
+ continue
+ if len(b) > 0:
+ if len(resp) < 1024:
+ resp += b
+ elif len(b) == 0:
+ break
+ print("read")
+
+ if resp[:7] != b"HTTP/1.":
+ raise ValueError("response doesn't start with HTTP/1.")
+ # print(resp)
+
+ finally:
+ s.close()
+
+
+SITES = [
+ "google.com",
+ {"host": "www.google.com"},
+ "micropython.org",
+ "pypi.org",
+ "api.telegram.org",
+ {"host": "api.pushbullet.com", "sni": True},
+]
+
+
+def main():
+ for site in SITES:
+ opts = {}
+ if isinstance(site, dict):
+ opts = site
+ site = opts["host"]
+ try:
+ test_one(site, opts)
+ print(site, "ok")
+ except Exception as e:
+ print(site, "error")
+ print("DONE")
+
+
+main()
diff --git a/tests/net_inet/test_tls_sites.py b/tests/net_inet/test_tls_sites.py
index d2cb928c8..3f945efb8 100644
--- a/tests/net_inet/test_tls_sites.py
+++ b/tests/net_inet/test_tls_sites.py
@@ -27,6 +27,8 @@ def test_one(site, opts):
s.write(b"GET / HTTP/1.0\r\nHost: %s\r\n\r\n" % bytes(site, "latin"))
resp = s.read(4096)
+ if resp[:7] != b"HTTP/1.":
+ raise ValueError("response doesn't start with HTTP/1.")
# print(resp)
finally:
@@ -36,10 +38,10 @@ def test_one(site, opts):
SITES = [
"google.com",
"www.google.com",
+ "micropython.org",
+ "pypi.org",
"api.telegram.org",
{"host": "api.pushbullet.com", "sni": True},
- # "w9rybpfril.execute-api.ap-southeast-2.amazonaws.com",
- {"host": "w9rybpfril.execute-api.ap-southeast-2.amazonaws.com", "sni": True},
]
diff --git a/tests/net_inet/test_tls_sites.py.exp b/tests/net_inet/test_tls_sites.py.exp
index 2f3c113d2..bc4a8dbd1 100644
--- a/tests/net_inet/test_tls_sites.py.exp
+++ b/tests/net_inet/test_tls_sites.py.exp
@@ -1,5 +1,6 @@
google.com ok
www.google.com ok
+micropython.org ok
+pypi.org ok
api.telegram.org ok
api.pushbullet.com ok
-w9rybpfril.execute-api.ap-southeast-2.amazonaws.com ok