diff options
Diffstat (limited to 'samples/net/http_server/src/ssl_utils.c')
-rw-r--r-- | samples/net/http_server/src/ssl_utils.c | 295 |
1 files changed, 295 insertions, 0 deletions
diff --git a/samples/net/http_server/src/ssl_utils.c b/samples/net/http_server/src/ssl_utils.c new file mode 100644 index 000000000..ac75aadd0 --- /dev/null +++ b/samples/net/http_server/src/ssl_utils.c @@ -0,0 +1,295 @@ +/* + * Copyright (c) 2017 Intel Corporation + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include <zephyr.h> +#include <net/net_core.h> +#include <net/net_context.h> +#include <net/nbuf.h> +#include <net/net_if.h> +#include <string.h> +#include <errno.h> +#include <misc/printk.h> + +#if !defined(CONFIG_MBEDTLS_CFG_FILE) +#include "mbedtls/config.h" +#else +#include CONFIG_MBEDTLS_CFG_FILE +#endif + +#include "mbedtls/ssl.h" + +#include "config.h" +#include "ssl_utils.h" + +#define RX_FIFO_DEPTH 4 + +K_MEM_POOL_DEFINE(rx_pkts, 4, 64, RX_FIFO_DEPTH, 4); + +static void ssl_received(struct net_context *context, + struct net_buf *buf, int status, void *user_data) +{ + struct ssl_context *ctx = user_data; + struct rx_fifo_block *rx_data = NULL; + struct k_mem_block block; + + ARG_UNUSED(context); + ARG_UNUSED(status); + + if (!net_nbuf_appdatalen(buf)) { + net_nbuf_unref(buf); + return; + } + + k_mem_pool_alloc(&rx_pkts, &block, + sizeof(struct rx_fifo_block), K_FOREVER); + rx_data = block.data; + rx_data->buf = buf; + + /* For freeing memory later */ + memcpy(&rx_data->block, &block, sizeof(struct k_mem_block)); + k_fifo_put(&ctx->rx_fifo, (void *)rx_data); +} + +static inline void ssl_sent(struct net_context *context, + int status, void *token, void *user_data) +{ + struct ssl_context *ctx = user_data; + + k_sem_give(&ctx->tx_sem); +} + +int ssl_tx(void *context, const unsigned char *buf, size_t size) +{ + struct ssl_context *ctx = context; + struct net_context *net_ctx; + struct net_buf *send_buf; + + int rc, len; + + net_ctx = ctx->net_ctx; + + send_buf = net_nbuf_get_tx(net_ctx, K_NO_WAIT); + if (!send_buf) { + return MBEDTLS_ERR_SSL_ALLOC_FAILED; + } + + rc = net_nbuf_append(send_buf, size, (uint8_t *) buf, K_FOREVER); + if (!rc) { + net_nbuf_unref(send_buf); + return MBEDTLS_ERR_SSL_INTERNAL_ERROR; + } + + len = net_buf_frags_len(send_buf); + + rc = net_context_send(send_buf, ssl_sent, K_NO_WAIT, NULL, ctx); + + if (rc < 0) { + net_nbuf_unref(send_buf); + return MBEDTLS_ERR_SSL_INTERNAL_ERROR; + } + + k_sem_take(&ctx->tx_sem, K_FOREVER); + return len; +} + +int ssl_rx(void *context, unsigned char *buf, size_t size) +{ + struct ssl_context *ctx = context; + uint16_t read_bytes; + struct rx_fifo_block *rx_data; + uint8_t *ptr; + int pos; + int len; + int rc = 0; + + if (ctx->frag == NULL) { + rx_data = k_fifo_get(&ctx->rx_fifo, K_FOREVER); + ctx->rx_nbuf = rx_data->buf; + k_mem_pool_free(&rx_data->block); + + read_bytes = net_nbuf_appdatalen(ctx->rx_nbuf); + + ctx->remaining = read_bytes; + ctx->frag = ctx->rx_nbuf->frags; + ptr = net_nbuf_appdata(ctx->rx_nbuf); + + len = ptr - ctx->frag->data; + net_buf_pull(ctx->frag, len); + } else { + read_bytes = ctx->remaining; + ptr = ctx->frag->data; + } + + len = ctx->frag->len; + pos = 0; + if (read_bytes > size) { + while (ctx->frag) { + read_bytes = len < (size - pos) ? len : (size - pos); + memcpy(buf + pos, ptr, read_bytes); + pos += read_bytes; + if (pos < size) { + ctx->frag = ctx->frag->frags; + ptr = ctx->frag->data; + len = ctx->frag->len; + } else { + if (read_bytes == len) { + ctx->frag = ctx->frag->frags; + } else { + net_buf_pull(ctx->frag, read_bytes); + } + + ctx->remaining -= size; + return size; + } + } + } else { + while (ctx->frag) { + memcpy(buf + pos, ptr, len); + pos += len; + ctx->frag = ctx->frag->frags; + if (!ctx->frag) { + break; + } + + ptr = ctx->frag->data; + len = ctx->frag->len; + } + + net_nbuf_unref(ctx->rx_nbuf); + ctx->rx_nbuf = NULL; + ctx->frag = NULL; + ctx->remaining = 0; + + if (read_bytes != pos) { + return -EIO; + } + + rc = read_bytes; + } + + return rc; +} + +static void ssl_accepted(struct net_context *context, + struct sockaddr *addr, + socklen_t addrlen, int error, void *user_data) +{ + int ret; + struct ssl_context *ctx = user_data; + + ctx->net_ctx = context; + ret = net_context_recv(context, ssl_received, 0, user_data); + if (ret < 0) { + printk("Cannot receive TCP packet (family %d)", + net_context_get_family(context)); + } + +} + +#if defined(CONFIG_NET_IPV6) +int ssl_init(struct ssl_context *ctx, void *addr) +{ + struct net_context *tcp_ctx = { 0 }; + struct sockaddr_in6 my_addr = { 0 }; + struct in6_addr *server_addr = addr; + int rc; + + k_sem_init(&ctx->tx_sem, 0, UINT_MAX); + k_fifo_init(&ctx->rx_fifo); + + my_mcast_addr.sin6_family = AF_INET6; + + net_ipaddr_copy(&my_addr.sin6_addr, server_addr); + my_addr.sin6_family = AF_INET6; + my_addr.sin6_port = htons(SERVER_PORT); + + rc = net_context_get(AF_INET6, SOCK_STREAM, IPPROTO_TCP, &tcp_ctx); + if (rc < 0) { + printk("Cannot get network context for IPv6 TCP (%d)", rc); + return -EIO; + } + + rc = net_context_bind(tcp_ctx, (struct sockaddr *)&my_addr, + sizeof(struct sockaddr_in6)); + if (rc < 0) { + printk("Cannot bind IPv6 TCP port %d (%d)", SERVER_PORT, rc); + goto error; + } + + ctx->rx_nbuf = NULL; + ctx->remaining = 0; + ctx->net_ctx = tcp_ctx; + + rc = net_context_listen(ctx->net_ctx, 0); + if (rc < 0) { + printk("Cannot listen IPv6 TCP (%d)", rc); + return -EIO; + } + + rc = net_context_accept(ctx->net_ctx, ssl_accepted, 0, ctx); + if (rc < 0) { + printk("Cannot accept IPv4 (%d)", rc); + return -EIO; + } + + return 0; + +error: + net_context_put(tcp_ctx); + return -EINVAL; +} + +#else +int ssl_init(struct ssl_context *ctx, void *addr) +{ + struct net_context *tcp_ctx = { 0 }; + struct sockaddr_in my_addr4 = { 0 }; + struct in_addr *server_addr = addr; + int rc; + + k_sem_init(&ctx->tx_sem, 0, UINT_MAX); + k_fifo_init(&ctx->rx_fifo); + + net_ipaddr_copy(&my_addr4.sin_addr, server_addr); + my_addr4.sin_family = AF_INET; + my_addr4.sin_port = htons(SERVER_PORT); + + rc = net_context_get(AF_INET, SOCK_STREAM, IPPROTO_TCP, &tcp_ctx); + if (rc < 0) { + printk("Cannot get network context for IPv4 TCP (%d)", rc); + return -EIO; + } + + rc = net_context_bind(tcp_ctx, (struct sockaddr *)&my_addr4, + sizeof(struct sockaddr_in)); + if (rc < 0) { + printk("Cannot bind IPv4 TCP port %d (%d)", SERVER_PORT, rc); + goto error; + } + + ctx->rx_nbuf = NULL; + ctx->remaining = 0; + ctx->net_ctx = tcp_ctx; + + rc = net_context_listen(ctx->net_ctx, 0); + if (rc < 0) { + printk("Cannot listen IPv4 (%d)", rc); + return -EIO; + } + + rc = net_context_accept(ctx->net_ctx, ssl_accepted, 0, ctx); + if (rc < 0) { + printk("Cannot accept IPv4 (%d)", rc); + return -EIO; + } + + return 0; + +error: + net_context_put(tcp_ctx); + return -EINVAL; +} +#endif |