diff options
author | Lifang Zhang <lifang.zhang@nokia.com> | 2023-12-02 07:05:44 +0200 |
---|---|---|
committer | Matias Elo <matias.elo@nokia.com> | 2024-02-21 16:08:18 +0200 |
commit | b04851ad21929cc143fa0a6f8c04a191c85f9c6e (patch) | |
tree | 01e83ab02e5f92f45f7c531451ff1acddb168995 | |
parent | 10753512c8a4c04350092c5faaa014ad485427b9 (diff) |
linux-gen: ml: implement the new ML API
Added implementation of the new ML API.
Signed-off-by: Lifang Zhang <lifang.zhang@nokia.com>
Signed-off-by: Jere Leppänen <jere.leppanen@nokia.com>
Reviewed-by: Matias Elo <matias.elo@nokia.com>
Reviewed-by: Petri Savolainen <petri.savolainen@nokia.com>
Reviewed-by: Tuomas Taipale <tuomas.taipale@nokia.com>
25 files changed, 3566 insertions, 9 deletions
diff --git a/DEPENDENCIES b/DEPENDENCIES index 7dbe86489..5af0d4ac3 100644 --- a/DEPENDENCIES +++ b/DEPENDENCIES @@ -305,6 +305,35 @@ Prerequisites for building the OpenDataPlane (ODP) API which would be the first actual queue in case 5 regular combined queues were configured (zero-indexing). +3.7 Machine Learning API support (optional) + Use ML API for model inferencing. ML implementation uses ONNX Runtime library + (https://github.com/microsoft/onnxruntime). ODP has been tested with ONNX + Runtime version 1.16.3. + +3.7.1 Prebuilt onnxruntime download + Download a default CPU version onnxruntime-linux-x64-*.tgz and unzip it to + any folder. + + $ wget -P ~ https://github.com/microsoft/onnxruntime/releases/download/v<version>/onnxruntime-linux-x64-<version>.tgz + $ mkdir <onnxruntime path> + $ cd <onnxruntime path>/ + $ tar --strip=1 -zxvf ~/onnxruntime-linux-x64-<version>.tgz + +3.7.1 Build onnxruntime from source + $ git clone --recursive https://github.com/Microsoft/onnxruntime.git + $ cd onnxruntime + + # Configure + $ ./build.sh --config RelWithDebInfo --build_shared_lib --parallel + $ tools/ci_build/github/linux/copy_strip_binary.sh -r build/Linux/ -a onnxruntime -l libonnxruntime.so.1.14.0 -c RelWithDebInfo -s . -t <commit id> + $ cp -r build/Linux/onnxruntime/ <onnxruntime path> + +3.7.2 Build ODP with ML support + After installing onnxruntime and example dependencies, ODP can be configured to be + built with ML support by giving onnxruntime path with --with-ort-path. + + $ ../configure --with-ort-path=<onnxruntime path> + 4.0 Packages needed to build API tests CUnit test framework version 2.1-3 is required diff --git a/config/odp-linux-generic.conf b/config/odp-linux-generic.conf index 2d27752b2..93997ecb3 100644 --- a/config/odp-linux-generic.conf +++ b/config/odp-linux-generic.conf @@ -16,7 +16,7 @@ # Mandatory fields odp_implementation = "linux-generic" -config_file_version = "0.1.27" +config_file_version = "0.1.28" # System options system: { @@ -365,3 +365,34 @@ ipsec: { async_outbound = 0 } } + +ml: { + # Enable onnxruntime profiling, when enabled, a json file will be + # generated after inference. chrome://tracing/ can be used to check + # the profiling. Use 0 to disable and 1 to enable profiling. + enable_profiling = 0 + + # Choose onnxruntime execution mode, which can be "SEQUENTIAL" or + # "PARALLEL" + execution_mode = "SEQUENTIAL" + + # Set the number of threads used to parallelize the execution of the + # graph across nodes. A value of 0 means onnxruntime will pick a default. + inter_op_num_threads = 0 + + # Set the number of threads used to parallelize the execution within + # a node. A value of 0 means onnxruntime will pick a default. + intra_op_num_threads = 0 + + # Set graph optimization level. Valid values are: + # DISABLE_ALL: disables all optimizations + # ENABLE_BASIC: enables basic optimizations + # ENABLE_EXTENDED: enables basic and extended optimizations + # ENABLE_ALL: enables all available optimizations including layout optimization + graph_optimization_level = "ENABLE_ALL" + + # Serialize the optimized model to disk. When initializing a session + # with the same model, no need to apply optimization anymore, thus + # reducing model startup time. + optimized_model_filepath = "" +} diff --git a/platform/linux-generic/Makefile.am b/platform/linux-generic/Makefile.am index c2aec362e..11cdb4c64 100644 --- a/platform/linux-generic/Makefile.am +++ b/platform/linux-generic/Makefile.am @@ -13,6 +13,7 @@ AM_CPPFLAGS += -I$(top_srcdir)/platform/$(with_platform)/arch/default AM_CPPFLAGS += -I$(top_srcdir)/platform/$(with_platform)/arch/common AM_CPPFLAGS += $(OPENSSL_CPPFLAGS) +AM_CPPFLAGS += $(ORT_CPPFLAGS) AM_CFLAGS += $(AARCH64CRYPTO_CFLAGS) AM_CFLAGS += $(DPDK_CFLAGS) @@ -141,6 +142,7 @@ noinst_HEADERS = \ include/odp_event_validation_internal.h \ include/odp_fdserver_internal.h \ include/odp_forward_typedefs_internal.h \ + include/odp_ml_fp16.h \ include/odp_global_data.h \ include/odp_init_internal.h \ include/odp_ipsec_internal.h \ @@ -229,6 +231,8 @@ __LIB__libodp_linux_la_SOURCES = \ odp_ishmphy.c \ odp_ishmpool.c \ odp_libconfig.c \ + odp_ml_fp16.c \ + odp_ml_quantize.c \ odp_name_table.c \ odp_packet.c \ odp_packet_vector.c \ @@ -298,6 +302,15 @@ __LIB__libodp_linux_la_SOURCES += \ endif endif endif + +if WITH_ML +__LIB__libodp_linux_la_SOURCES += \ + odp_ml.c +else +__LIB__libodp_linux_la_SOURCES += \ + odp_ml_null.c +endif + if ODP_ABI_COMPAT __LIB__libodp_linux_la_SOURCES += \ odp_atomic_api.c \ @@ -473,6 +486,7 @@ __LIB__libodp_linux_la_LIBADD += $(PTHREAD_LIBS) __LIB__libodp_linux_la_LIBADD += $(TIMER_LIBS) __LIB__libodp_linux_la_LIBADD += $(LIBXDP_LIBS) __LIB__libodp_linux_la_LIBADD += $(IPSEC_MB_LIBS) +__LIB__libodp_linux_la_LIBADD += $(ORT_LIBS) if ODP_PKTIO_PCAP __LIB__libodp_linux_la_LIBADD += $(PCAP_LIBS) diff --git a/platform/linux-generic/include/odp/api/plat/event_inlines.h b/platform/linux-generic/include/odp/api/plat/event_inlines.h index d30c6acbb..990575166 100644 --- a/platform/linux-generic/include/odp/api/plat/event_inlines.h +++ b/platform/linux-generic/include/odp/api/plat/event_inlines.h @@ -99,6 +99,7 @@ _ODP_INLINE void *odp_event_user_area(odp_event_t event) switch (type) { case ODP_EVENT_BUFFER: + case ODP_EVENT_ML_COMPL: case ODP_EVENT_DMA_COMPL: return _odp_buffer_get((odp_buffer_t)event, void *, uarea_addr); case ODP_EVENT_PACKET: @@ -121,6 +122,7 @@ _ODP_INLINE void *odp_event_user_area_and_flag(odp_event_t event, int *flag) switch (type) { case ODP_EVENT_BUFFER: case ODP_EVENT_DMA_COMPL: + case ODP_EVENT_ML_COMPL: *flag = -1; return _odp_buffer_get((odp_buffer_t)event, void *, uarea_addr); case ODP_EVENT_PACKET: diff --git a/platform/linux-generic/include/odp_config_internal.h b/platform/linux-generic/include/odp_config_internal.h index 8fd8c4be7..89d89936c 100644 --- a/platform/linux-generic/include/odp_config_internal.h +++ b/platform/linux-generic/include/odp_config_internal.h @@ -199,6 +199,15 @@ extern "C" { /* Enable timer scan performance benchmark. This works with inline enabled. */ #define CONFIG_TIMER_PROFILE_INLINE 0 +/* Maximum number of ML models that can be created or loaded. */ +#define CONFIG_ML_MAX_MODELS 4 + +/* Maximum number of inputs for a ML model. */ +#define CONFIG_ML_MAX_INPUTS 4 + +/* Maximum number of outputs for a ML model. */ +#define CONFIG_ML_MAX_OUTPUTS 4 + #ifdef __cplusplus } #endif diff --git a/platform/linux-generic/include/odp_global_data.h b/platform/linux-generic/include/odp_global_data.h index f00e155de..2a87192df 100644 --- a/platform/linux-generic/include/odp_global_data.h +++ b/platform/linux-generic/include/odp_global_data.h @@ -80,6 +80,7 @@ typedef struct odp_global_data_ro_t { uint8_t ipsec; uint8_t stash; uint8_t traffic_mngr; + uint8_t ml; } disable; diff --git a/platform/linux-generic/include/odp_init_internal.h b/platform/linux-generic/include/odp_init_internal.h index 24e8346ad..ca5d68c87 100644 --- a/platform/linux-generic/include/odp_init_internal.h +++ b/platform/linux-generic/include/odp_init_internal.h @@ -105,6 +105,9 @@ int _odp_stash_term_global(void); int _odp_dma_init_global(void); int _odp_dma_term_global(void); +int _odp_ml_init_global(void); +int _odp_ml_term_global(void); + #ifdef __cplusplus } #endif diff --git a/platform/linux-generic/include/odp_ml_fp16.h b/platform/linux-generic/include/odp_ml_fp16.h new file mode 100644 index 000000000..5294a7c0b --- /dev/null +++ b/platform/linux-generic/include/odp_ml_fp16.h @@ -0,0 +1,21 @@ +/* SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2023 Nokia + */ + +#ifndef ODP_ML_FP16_H_ +#define ODP_ML_FP16_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +uint16_t _odp_float32_to_float16(float x); +float _odp_float16_to_float32(uint16_t f16); +uint16_t _odp_float32_to_bfloat16(float x); +float _odp_bfloat16_to_float32(uint16_t f16); + +#ifdef __cplusplus +} +#endif + +#endif /* ODP_ML_FP16_H_ */ diff --git a/platform/linux-generic/libodp-linux.pc.in b/platform/linux-generic/libodp-linux.pc.in index 05ba5b9d6..62589c1a3 100644 --- a/platform/linux-generic/libodp-linux.pc.in +++ b/platform/linux-generic/libodp-linux.pc.in @@ -8,5 +8,5 @@ Description: The ODP packet processing engine Version: @PKGCONFIG_VERSION@ Requires.private: libconfig@AARCH64CRYPTO_PKG@ Libs: -L${libdir} -l@ODP_LIB_NAME@ @ATOMIC_LIBS_NON_ABI_COMPAT@ -Libs.private: @OPENSSL_STATIC_LIBS@ @DPDK_LIBS@ @PCAP_LIBS@ @PTHREAD_LIBS@ @TIMER_LIBS@ @LIBXDP_LIBS@ -lpthread @ATOMIC_LIBS_ABI_COMPAT@ @IPSEC_MB_LIBS@ +Libs.private: @OPENSSL_STATIC_LIBS@ @DPDK_LIBS@ @PCAP_LIBS@ @PTHREAD_LIBS@ @TIMER_LIBS@ @LIBXDP_LIBS@ -lpthread @ATOMIC_LIBS_ABI_COMPAT@ @IPSEC_MB_LIBS@ @ORT_LIBS@ Cflags: -I${includedir} diff --git a/platform/linux-generic/m4/configure.m4 b/platform/linux-generic/m4/configure.m4 index 61b65540f..161bf48c5 100644 --- a/platform/linux-generic/m4/configure.m4 +++ b/platform/linux-generic/m4/configure.m4 @@ -31,10 +31,11 @@ m4_include([platform/linux-generic/m4/odp_pcapng.m4]) m4_include([platform/linux-generic/m4/odp_dpdk.m4]) m4_include([platform/linux-generic/m4/odp_wfe.m4]) m4_include([platform/linux-generic/m4/odp_xdp.m4]) +m4_include([platform/linux-generic/m4/odp_ml.m4]) ODP_EVENT_VALIDATION ODP_SCHEDULER -AS_VAR_APPEND([PLAT_DEP_LIBS], ["${ATOMIC_LIBS} ${AARCH64CRYPTO_LIBS} ${LIBCONFIG_LIBS} ${OPENSSL_LIBS} ${IPSEC_MB_LIBS} ${DPDK_LIBS_LT} ${LIBCLI_LIBS} ${LIBXDP_LIBS}"]) +AS_VAR_APPEND([PLAT_DEP_LIBS], ["${ATOMIC_LIBS} ${AARCH64CRYPTO_LIBS} ${LIBCONFIG_LIBS} ${OPENSSL_LIBS} ${IPSEC_MB_LIBS} ${DPDK_LIBS_LT} ${LIBCLI_LIBS} ${LIBXDP_LIBS} ${ORT_LIBS}"]) # Add text to the end of configure with platform specific settings. # Make sure it's aligned same as other lines in configure.ac. @@ -46,6 +47,7 @@ AS_VAR_APPEND([PLAT_CFG_TEXT], [" pcap: ${have_pcap} pcapng: ${have_pcapng} wfe_locks: ${use_wfe_locks} + ml_support: ${ml_support} default_config_path: ${default_config_path}"]) # Ignore Clang specific errors about fields with variable sized type not at the diff --git a/platform/linux-generic/m4/odp_libconfig.m4 b/platform/linux-generic/m4/odp_libconfig.m4 index a6d19f661..77095e0fe 100644 --- a/platform/linux-generic/m4/odp_libconfig.m4 +++ b/platform/linux-generic/m4/odp_libconfig.m4 @@ -3,7 +3,7 @@ ########################################################################## m4_define([_odp_config_version_generation], [0]) m4_define([_odp_config_version_major], [1]) -m4_define([_odp_config_version_minor], [27]) +m4_define([_odp_config_version_minor], [28]) m4_define([_odp_config_version], [_odp_config_version_generation._odp_config_version_major._odp_config_version_minor]) diff --git a/platform/linux-generic/m4/odp_ml.m4 b/platform/linux-generic/m4/odp_ml.m4 new file mode 100644 index 000000000..a7b9a4fd6 --- /dev/null +++ b/platform/linux-generic/m4/odp_ml.m4 @@ -0,0 +1,46 @@ +########################################################################## +# Onnxruntime library path and name +########################################################################## +# Optional configure parameter for a non-standard install prefix of onnxruntime +AC_ARG_WITH([ort-path], + [AS_HELP_STRING([--with-ort-path=DIR], + [path to onnxruntime libs and headers [default=system]])], + [ort_path_given=yes + ORT_CPPFLAGS="-I$withval/include" + ORT_LIBS="-L$withval/lib" + ORT_RPATH="-R$withval/lib"], + []) + +########################################################################## +# Save and set temporary compilation flags +########################################################################## +OLD_CPPFLAGS=$CPPFLAGS +OLD_LIBS=$LIBS +CPPFLAGS="$ORT_CPPFLAGS $CPPFLAGS" +LIBS="$ORT_LIBS $LIBS" + +######################################################################### +# If ort is available, enable ML API +######################################################################### +ml_support=no +AC_CHECK_HEADERS([onnxruntime_c_api.h], + [AC_CHECK_LIB(onnxruntime, OrtGetApiBase, [ml_support=yes], [], [])], + [AS_IF([test "x$ort_path_given" = "xyes"], + [AC_MSG_ERROR([ort not found at the specified path (--with-ort-path)])])]) + +AS_IF([test "x$ml_support" != "xno"], + [ORT_LIBS="$ORT_RPATH $ORT_LIBS -lonnxruntime -lm"], + [ORT_CPPFLAGS="" ORT_LIBS="-lm"]) + +AC_CONFIG_COMMANDS_PRE([dnl +AM_CONDITIONAL([WITH_ML], [test x$ml_support = xyes ]) +]) + +########################################################################## +# Restore old saved variables +########################################################################## +LIBS=$OLD_LIBS +CPPFLAGS=$OLD_CPPFLAGS + +AC_SUBST([ORT_CPPFLAGS]) +AC_SUBST([ORT_LIBS]) diff --git a/platform/linux-generic/odp_event.c b/platform/linux-generic/odp_event.c index 9ec4b4bfb..f3644f02b 100644 --- a/platform/linux-generic/odp_event.c +++ b/platform/linux-generic/odp_event.c @@ -12,6 +12,7 @@ #include <odp/api/packet.h> #include <odp/api/timer.h> #include <odp/api/pool.h> +#include <odp/api/ml.h> #include <odp_buffer_internal.h> #include <odp_ipsec_internal.h> @@ -69,6 +70,9 @@ static inline void event_free(odp_event_t event, _odp_ev_id_t id) case ODP_EVENT_DMA_COMPL: odp_dma_compl_free(odp_dma_compl_from_event(event)); break; + case ODP_EVENT_ML_COMPL: + odp_ml_compl_free(odp_ml_compl_from_event(event)); + break; default: _ODP_ABORT("Invalid event type: %d\n", odp_event_type(event)); } @@ -117,6 +121,8 @@ int odp_event_is_valid(odp_event_t event) /* Fall through */ case ODP_EVENT_DMA_COMPL: /* Fall through */ + case ODP_EVENT_ML_COMPL: + /* Fall through */ case ODP_EVENT_PACKET_TX_COMPL: break; default: diff --git a/platform/linux-generic/odp_init.c b/platform/linux-generic/odp_init.c index 05b693c94..795252df1 100644 --- a/platform/linux-generic/odp_init.c +++ b/platform/linux-generic/odp_init.c @@ -51,6 +51,7 @@ enum init_stage { IPSEC_SAD_INIT, IPSEC_INIT, DMA_INIT, + ML_INIT, ALL_INIT /* All init stages completed */ }; @@ -95,6 +96,7 @@ static void disable_features(odp_global_data_ro_t *global_ro, global_ro->disable.traffic_mngr = init_param->not_used.feat.tm; global_ro->disable.compress = init_param->not_used.feat.compress; + global_ro->disable.ml = init_param->not_used.feat.ml; } void odp_init_param_init(odp_init_t *param) @@ -145,6 +147,13 @@ static int term_global(enum init_stage stage) switch (stage) { case ALL_INIT: + case ML_INIT: + if (_odp_ml_term_global()) { + _ODP_ERR("ODP ML term failed.\n"); + rc = -1; + } + /* Fall through */ + case DMA_INIT: if (_odp_dma_term_global()) { _ODP_ERR("ODP DMA term failed.\n"); @@ -509,6 +518,12 @@ int odp_init_global(odp_instance_t *instance, } stage = DMA_INIT; + if (_odp_ml_init_global()) { + _ODP_ERR("ODP ML init failed.\n"); + goto init_failed; + } + stage = ML_INIT; + *instance = (odp_instance_t)odp_global_ro.main_pid; return 0; diff --git a/platform/linux-generic/odp_ml.c b/platform/linux-generic/odp_ml.c new file mode 100644 index 000000000..fda06e7cb --- /dev/null +++ b/platform/linux-generic/odp_ml.c @@ -0,0 +1,2633 @@ +/* SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2023 Nokia + */ + +#include <odp/api/ml.h> +#include <odp/api/queue.h> +#include <odp/api/plat/event_inline_types.h> + +#include <odp_global_data.h> +#include <odp_debug_internal.h> +#include <odp_init_internal.h> +#include <odp_pool_internal.h> +#include <odp_config_internal.h> +#include <odp_macros_internal.h> +#include <odp_libconfig_internal.h> + +#include <onnxruntime_c_api.h> + +#include <stdint.h> +#include <stdio.h> +#include <inttypes.h> +#include <string.h> + +#define ML_MAX_IO_SEGS UINT32_MAX +#define ML_MAX_COMPL_ID 32 +#define ML_MAX_CONFIG_STR_LEN 65 +#define ML_MAX_MODEL_SIZE (1024 * 1024 * 1024) +#define ML_MAX_MODELS_CREATED CONFIG_ML_MAX_MODELS +#define ML_MAX_MODELS_LOADED CONFIG_ML_MAX_MODELS + +/* Error codes */ +enum { + /* Feature not supported */ + ML_FEATURE_NOT_SUPPORTED = 1, + + /* Model is not created */ + ML_NOT_CREATED, + + /* Model was not loaded */ + ML_NOT_LOADED, + + /* Model has already loaded */ + ML_LOADED, + + /* Bad input */ + ML_BAD_INPUT, + + /* Fail from underlying library onnxruntime */ + ML_LIB_FAILED, + + /* Bad output */ + ML_BAD_OUTPUT, + + /* Bad handle */ + ML_BAD_HDL +}; + +typedef struct ort_run_opts_t { + int enable_profiling; + + ExecutionMode execution_mode; + + int inter_op_num_threads; + + int intra_op_num_threads; + + GraphOptimizationLevel graph_opt_level; + + char opt_model_filepath[ML_MAX_CONFIG_STR_LEN]; +} ort_run_opts_t; + +typedef struct ml_input_t { + /* Combined input start address */ + void *addr; + /* Data size in bytes */ + uint64_t size; +} ml_input_t; + +/* Onnxruntime model info */ +typedef struct ml_model_t { + /* Guards state, which must be accessed atomically */ + odp_ticketlock_t lock; + + enum { + ML_STATE_FREE = 0, /* Not allocated */ + ML_STATE_CREATED, /* Model is created */ + ML_STATE_LOADED, /* Model is loaded */ + ML_STATE_INFERENCING, /* Model is inferencing */ + } state; + + OrtSession *session; + OrtSessionOptions *session_opts; + uint32_t max_compl_id; + odp_atomic_u32_t compl_status[ML_MAX_COMPL_ID]; + + odp_ml_model_info_t info; + odp_ml_input_info_t input_info[CONFIG_ML_MAX_INPUTS]; + uint64_t input_sizes[CONFIG_ML_MAX_INPUTS]; + odp_ml_output_info_t output_info[CONFIG_ML_MAX_OUTPUTS]; + uint64_t output_sizes[CONFIG_ML_MAX_OUTPUTS]; + + struct { + void *user_ptr; + } result[ML_MAX_COMPL_ID]; +} ml_model_t; + +typedef struct ml_global_t { + odp_shm_t shm; + + odp_ml_capability_t capa; + odp_ml_config_t ml_config; + + odp_pool_param_t pool_param; + + const OrtApi *ort_api; + OrtEnv *env; + ort_run_opts_t ort_run_opts; + + ml_model_t models[ML_MAX_MODELS_CREATED]; + +} ml_global_t; + +static ml_global_t *_odp_ml_glb; + +static inline ml_model_t *ml_model_from_handle(odp_ml_model_t model) +{ + return (ml_model_t *)(uintptr_t)model; +} + +int odp_ml_capability(odp_ml_capability_t *capa) +{ + odp_pool_capability_t pool_capa; + + memset(capa, 0, sizeof(odp_ml_capability_t)); + + if (odp_global_ro.disable.ml) { + _ODP_PRINT("ML is disabled\n"); + return 0; + } + + capa->max_model_size = ML_MAX_MODEL_SIZE; + capa->max_models = ML_MAX_MODELS_CREATED; + capa->max_models_loaded = ML_MAX_MODELS_LOADED; + capa->max_compl_id = ML_MAX_COMPL_ID; + capa->max_inputs = CONFIG_ML_MAX_INPUTS; + capa->max_outputs = CONFIG_ML_MAX_OUTPUTS; + capa->max_segs_per_input = ML_MAX_IO_SEGS; + capa->max_segs_per_output = ML_MAX_IO_SEGS; + capa->min_input_align = 1; + capa->min_output_align = 1; + + capa->load.compl_mode_mask = ODP_ML_COMPL_MODE_SYNC | + ODP_ML_COMPL_MODE_POLL | + ODP_ML_COMPL_MODE_EVENT; + capa->load.compl_queue_plain = 1; + capa->load.compl_queue_sched = 1; + + capa->run.compl_mode_mask = ODP_ML_COMPL_MODE_SYNC | + ODP_ML_COMPL_MODE_POLL | + ODP_ML_COMPL_MODE_EVENT; + capa->run.compl_queue_plain = 1; + capa->run.compl_queue_sched = 1; + + if (odp_pool_capability(&pool_capa)) { + _ODP_ERR("Pool capability failed\n"); + return -1; + } + + capa->pool.max_pools = pool_capa.buf.max_pools; + capa->pool.max_num = pool_capa.buf.max_num; + capa->pool.max_uarea_size = pool_capa.buf.max_uarea_size; + capa->pool.uarea_persistence = pool_capa.buf.uarea_persistence; + capa->pool.max_cache_size = pool_capa.buf.max_cache_size; + capa->pool.min_cache_size = pool_capa.buf.min_cache_size; + + return 0; +} + +void odp_ml_config_init(odp_ml_config_t *config) +{ + memset(config, 0, sizeof(odp_ml_config_t)); + config->max_models_created = 1; + config->max_models_loaded = 1; +} + +int odp_ml_config(const odp_ml_config_t *config) +{ + if (!config) { + _ODP_ERR("Error: config must not be NULL\n"); + return -1; + } + + if (config->max_model_size == 0 || config->max_models_created == 0 || + config->max_models_loaded == 0) { + _ODP_ERR("Error: max_model_size, max_models_created and max_models_loaded" + " must be bigger than 0\n"); + return -1; + } + + if (config->max_models_loaded > config->max_models_created) { + _ODP_ERR("Error: max_models_loaded %d exceeds max_models_created %d\n", + config->max_models_loaded, config->max_models_created); + return -1; + } + + if (config->max_models_created > ML_MAX_MODELS_CREATED) { + _ODP_ERR("Error: max_models_created %d exceeds maximum number" + " of models that can be created in this driver %d\n", + config->max_models_created, ML_MAX_MODELS_CREATED); + return -1; + } + + if (config->max_models_loaded > ML_MAX_MODELS_LOADED) { + _ODP_ERR("Error: max_models_loaded %d exceeds maximum number" + " of models that can be loaded in this driver %d\n", + config->max_models_loaded, ML_MAX_MODELS_LOADED); + return -1; + } + + if (config->max_model_size > ML_MAX_MODEL_SIZE) { + _ODP_ERR("max_model_size %" PRIu64 " exceeds supported maximum model size %d\n", + config->max_model_size, ML_MAX_MODEL_SIZE); + return -1; + } + + _odp_ml_glb->ml_config = *config; + return 0; +} + +void odp_ml_model_param_init(odp_ml_model_param_t *param) +{ + memset(param, 0, sizeof(odp_ml_model_param_t)); +} + +static int check_ortstatus(OrtStatus * const status) +{ + if (status != NULL) { + const char *msg = _odp_ml_glb->ort_api->GetErrorMessage(status); + + _ODP_ERR("%s\n", msg); + _odp_ml_glb->ort_api->ReleaseStatus(status); + return -1; + } + + return 0; +} + +/* Get model input and output count */ +static int get_model_io_count(OrtSession *model, uint32_t *num_inputs, uint32_t *num_outputs) +{ + size_t num = 0; + OrtStatus *status = NULL; + const OrtApi *ort_api = _odp_ml_glb->ort_api; + + status = ort_api->SessionGetInputCount(model, &num); + if (check_ortstatus(status)) { + _ODP_ERR("Get model input count failed\n"); + return -1; + } + + *num_inputs = num; + _ODP_DBG("num_inputs: %u\n", *num_inputs); + + status = ort_api->SessionGetOutputCount(model, &num); + if (check_ortstatus(status)) { + _ODP_ERR("Get model output count failed\n"); + return -1; + } + + *num_outputs = num; + _ODP_DBG("num_outputs: %u\n", *num_outputs); + + return 0; +} + +static odp_ml_data_type_t onnx_dtype_to_odp_dtype(ONNXTensorElementDataType onnx_dtype) +{ + switch (onnx_dtype) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return ODP_ML_DATA_TYPE_FP32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + return ODP_ML_DATA_TYPE_UINT8; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + return ODP_ML_DATA_TYPE_INT8; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + return ODP_ML_DATA_TYPE_UINT16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + return ODP_ML_DATA_TYPE_INT16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return ODP_ML_DATA_TYPE_INT32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return ODP_ML_DATA_TYPE_UINT32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return ODP_ML_DATA_TYPE_INT64; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + return ODP_ML_DATA_TYPE_UINT64; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return ODP_ML_DATA_TYPE_FP16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + return ODP_ML_DATA_TYPE_BFP16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + return ODP_ML_DATA_TYPE_FP64; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + /* Fall through */ + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: + /* Fall through */ + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: + /* Fall through */ + default: + _ODP_ERR("onnx_dtype %d not supported by odp_ml\n", onnx_dtype); + return ODP_ML_DATA_TYPE_NONE; + } +} + +/* Get the size of given odp_ml_data_type_t in bytes */ +static uint32_t size_of_odp_ml_data_type(odp_ml_data_type_t data_type) +{ + switch (data_type) { + case ODP_ML_DATA_TYPE_NONE: + return 0; + + case ODP_ML_DATA_TYPE_INT8: + /* Fall through */ + case ODP_ML_DATA_TYPE_UINT8: + return 1; + + case ODP_ML_DATA_TYPE_INT16: + /* Fall through */ + case ODP_ML_DATA_TYPE_UINT16: + /* Fall through */ + case ODP_ML_DATA_TYPE_FP16: + /* Fall through */ + case ODP_ML_DATA_TYPE_BFP16: + return 2; + + case ODP_ML_DATA_TYPE_INT24: + /* Fall through */ + case ODP_ML_DATA_TYPE_UINT24: + return 3; + + case ODP_ML_DATA_TYPE_INT32: + /* Fall through */ + case ODP_ML_DATA_TYPE_UINT32: + /* Fall through */ + case ODP_ML_DATA_TYPE_FP32: + return 4; + + case ODP_ML_DATA_TYPE_INT64: + /* Fall through */ + case ODP_ML_DATA_TYPE_UINT64: + /* Fall through */ + case ODP_ML_DATA_TYPE_FP64: + return 8; + + default: + return 0; + } +} + +static int get_shape(int64_t dims[], odp_ml_shape_info_t *shape) +{ + uint32_t dyn_cnt = 0; + + for (uint32_t i = 0; i < shape->num_dim; i++) { + if (dims[i] == 0) { + _ODP_ERR("Dimension value: %" PRId64 " must be at least 1\n", dims[i]); + return -1; + } else if (dims[i] == -1) { /* Symbolic dimension */ + dyn_cnt++; + shape->dim[i] = ODP_ML_DIM_DYNAMIC; + shape->dim_min[i] = 0; /*unknown*/ + shape->dim_max[i] = 0; /*unknown*/ + } else if (dims[i] > 0 && dims[i] < UINT32_MAX) { + shape->dim[i] = dims[i]; + shape->dim_min[i] = dims[i]; + shape->dim_max[i] = dims[i]; + } else { + _ODP_ERR("Dimension value: %" PRId64 " invalid\n", dims[i]); + return -1; + } + } + + if (dyn_cnt == 0) { + shape->type = ODP_ML_SHAPE_STATIC; + } else if (dyn_cnt == 1) { + shape->type = ODP_ML_SHAPE_BATCH; + } else { + _ODP_ERR("Data shape type not supported by ODP\n"); + return -1; + } + + return 0; +} + +static inline void calculate_model_io_size(const odp_ml_shape_info_t *shape, uint64_t *size) +{ + /* Calculate the data size in bytes of this tensor, 0 for tensors with + * dynamic batch sizes */ + for (size_t i = 0; i < shape->num_dim; i++) { + /* Skip dynamic dimension size */ + if (shape->dim[i] == ODP_ML_DIM_DYNAMIC) { + *size = 0; + break; + } + (*size) *= shape->dim[i]; + } +} + +static int get_model_io_type_shape_size(OrtTypeInfo *type_info, odp_ml_shape_info_t *shape, + odp_ml_data_type_t *data_type, uint32_t *data_type_size, + uint64_t *size) +{ + ONNXTensorElementDataType tensor_type; + const OrtTensorTypeAndShapeInfo *tensor_info; + size_t num_dim = 0; + OrtStatus *status = NULL; + int64_t dims[ODP_ML_MAX_DIMS] = {0}; + const OrtApi *ort_api = _odp_ml_glb->ort_api; + + status = ort_api->CastTypeInfoToTensorInfo(type_info, &tensor_info); + if (check_ortstatus(status)) { + _ODP_ERR("CastTypeInfoToTensorInfo failed\n"); + return -1; + } + + status = ort_api->GetTensorElementType(tensor_info, &tensor_type); + if (check_ortstatus(status)) { + _ODP_ERR("GetTensorElementType failed\n"); + return -1; + } + + *data_type = onnx_dtype_to_odp_dtype(tensor_type); + if (*data_type == ODP_ML_DATA_TYPE_NONE) /* Type not supported by odp */ + return -1; + + status = ort_api->GetDimensionsCount(tensor_info, &num_dim); + if (check_ortstatus(status)) { + _ODP_ERR("GetDimensionsCount failed\n"); + return -1; + } + + if (num_dim > ODP_ML_MAX_DIMS) { + _ODP_ERR("Number of dimensions: %zu exceeds supported maximum number" + " of dimensions: %d\n", num_dim, ODP_ML_MAX_DIMS); + return -1; + } + shape->num_dim = num_dim; + + status = ort_api->GetDimensions(tensor_info, dims, num_dim); + if (check_ortstatus(status)) { + _ODP_ERR("GetDimensions failed\n"); + return -1; + } + + if (get_shape(dims, shape)) + return -1; + + *data_type_size = size_of_odp_ml_data_type(*data_type); + + *size = *data_type_size; + calculate_model_io_size(shape, size); + + return 0; +} + +/* Get model input and output info */ +static int get_model_io_info(OrtSession *session, ml_model_t *mdl, + const odp_ml_model_param_t *param) +{ + char *name; + OrtTypeInfo *type_info; + const odp_ml_data_format_t *data_format; + OrtStatus *status = NULL; + OrtAllocator *allocator = NULL; + const OrtApi *ort_api = _odp_ml_glb->ort_api; + odp_ml_input_info_t *input_info = mdl->input_info; + odp_ml_output_info_t *output_info = mdl->output_info; + + status = ort_api->GetAllocatorWithDefaultOptions(&allocator); + if (check_ortstatus(status)) { + _ODP_ERR("GetAllocatorWithDefaultOptions failed\n"); + return -1; + } + + /* Retrieve info about input array. */ + memset(input_info, 0, sizeof(mdl->input_info)); + for (uint32_t i = 0; i < mdl->info.num_inputs; i++) { + name = NULL; + status = ort_api->SessionGetInputName(session, i, allocator, &name); + if (check_ortstatus(status)) { + _ODP_ERR("Get %uth input name failed\n", i); + return -1; + } + + strncpy(input_info[i].name, name, ODP_ML_MODEL_IO_NAME_LEN - 1); + input_info[i].name[ODP_ML_MODEL_IO_NAME_LEN - 1] = 0; + + /* Free memory allocated by SessionGetInputName */ + status = ort_api->AllocatorFree(allocator, name); + if (check_ortstatus(status)) { + _ODP_ERR("AllocatorFree %uth input_name failed\n", i); + return -1; + } + + if (param->extra_info.num_inputs) { + data_format = ¶m->extra_info.input_format[i]; + + input_info[i].shape = data_format->shape; + input_info[i].data_type = data_format->data_type; + input_info[i].data_type_size = data_format->data_type_size; + + mdl->input_sizes[i] = input_info[i].data_type_size; + calculate_model_io_size(&data_format->shape, &mdl->input_sizes[i]); + continue; + } + + type_info = NULL; + status = ort_api->SessionGetInputTypeInfo(session, i, &type_info); + if (check_ortstatus(status)) { + _ODP_ERR("SessionGetInputTypeInfo failed\n"); + return -1; + } + + if (get_model_io_type_shape_size(type_info, &input_info[i].shape, + &input_info[i].data_type, + &input_info[i].data_type_size, + &mdl->input_sizes[i])) { + _ODP_ERR("get_model_io_type_shape_size() for input failed\n"); + ort_api->ReleaseTypeInfo(type_info); + return -1; + } + + ort_api->ReleaseTypeInfo(type_info); + } + + /* Retrieve info about output array. */ + memset(output_info, 0, sizeof(mdl->output_info)); + for (uint32_t i = 0; i < mdl->info.num_outputs; i++) { + name = NULL; + status = ort_api->SessionGetOutputName(session, i, allocator, &name); + if (check_ortstatus(status)) { + _ODP_ERR("Get %uth output name failed\n", i); + return -1; + } + + strncpy(output_info[i].name, name, ODP_ML_MODEL_IO_NAME_LEN - 1); + output_info[i].name[ODP_ML_MODEL_IO_NAME_LEN - 1] = 0; + + /* Free memory allocated by SessionGetOutputName */ + status = ort_api->AllocatorFree(allocator, name); + if (check_ortstatus(status)) { + _ODP_ERR("AllocatorFree %uth output_name failed\n", i); + return -1; + } + + if (param->extra_info.num_outputs) { + data_format = ¶m->extra_info.output_format[i]; + + output_info[i].shape = data_format->shape; + output_info[i].data_type = data_format->data_type; + output_info[i].data_type_size = data_format->data_type_size; + + mdl->output_sizes[i] = output_info[i].data_type_size; + calculate_model_io_size(&data_format->shape, &mdl->output_sizes[i]); + continue; + } + + type_info = NULL; + status = ort_api->SessionGetOutputTypeInfo(session, i, &type_info); + if (check_ortstatus(status)) { + _ODP_ERR("SessionGetOutputTypeInfo failed\n"); + return -1; + } + + if (get_model_io_type_shape_size(type_info, &output_info[i].shape, + &output_info[i].data_type, + &output_info[i].data_type_size, + &mdl->output_sizes[i])) { + _ODP_ERR("get_model_io_type_shape_size() for output failed\n"); + ort_api->ReleaseTypeInfo(type_info); + return -1; + } + + ort_api->ReleaseTypeInfo(type_info); + } + + return 0; +} + +static inline int check_model_io_num(const odp_ml_model_param_t *param, + uint32_t num_inputs, uint32_t num_outputs) +{ + /* Make sure the number of inputs/outputs not exceeding the supported + * model max inputs/outputs */ + if (num_inputs > CONFIG_ML_MAX_INPUTS) { + _ODP_ERR("The model's number of inputs %u exceeds the maximum " + "number of inputs supported in a model %u\n", + num_inputs, CONFIG_ML_MAX_INPUTS); + return -1; + } + + if (num_outputs > CONFIG_ML_MAX_OUTPUTS) { + _ODP_ERR("The model's number of outputs %u exceeds the maximum " + "number of outputs supported in a model %u\n", + num_outputs, CONFIG_ML_MAX_OUTPUTS); + + return -1; + } + + /* Make sure the numbers of inputs/outputs provided in the extra_info of + * param match the numbers defined in model metadata. */ + if (param->extra_info.num_inputs && + param->extra_info.num_inputs != num_inputs) { + _ODP_ERR("Provided param->extra_info.num_inputs %u does not match the" + " number of inputs defined in model metadata: %u\n", + param->extra_info.num_inputs, num_inputs); + return -1; + } + + if (param->extra_info.num_outputs && param->extra_info.num_outputs != num_outputs) { + _ODP_ERR("Provided param->extra_info.num_outputs %u does not match the" + " number of outputs defined in model metadata: %u\n", + param->extra_info.num_outputs, num_outputs); + return -1; + } + + if (param->extra_info.num_inputs && !param->extra_info.input_format) { + _ODP_ERR("num_inputs is provided but not input_format in param->extra_info\n"); + return -1; + } + + if (param->extra_info.num_outputs && !param->extra_info.output_format) { + _ODP_ERR("num_outputs is provided but not output_format in param->extra_info\n"); + return -1; + } + + return 0; +} + +static int create_ort_model(const odp_ml_model_param_t *param, OrtSession **session, + ml_model_t *mdl, OrtSessionOptions *session_opts) +{ + OrtStatus *status; + int64_t model_version; + uint32_t num_inputs = 0; + uint32_t num_outputs = 0; + OrtModelMetadata *metadata = {0}; + const OrtApi *ort_api = _odp_ml_glb->ort_api; + + status = ort_api->CreateSessionFromArray(_odp_ml_glb->env, + param->model, + param->size, + session_opts, + session); + if (check_ortstatus(status) || !(*session)) { + _ODP_ERR("CreateSessionFromArray failed\n"); + return -1; + } + + if (get_model_io_count(*session, &num_inputs, &num_outputs)) { + _ODP_ERR("get_model_io_count() failed\n"); + ort_api->ReleaseSession(*session); + return -1; + } + + if (check_model_io_num(param, num_inputs, num_outputs)) { + ort_api->ReleaseSession(*session); + return -1; + } + + mdl->max_compl_id = param->max_compl_id; + mdl->info.num_inputs = num_inputs; + mdl->info.num_outputs = num_outputs; + + /* Get metadata */ + status = ort_api->SessionGetModelMetadata(*session, &metadata); + if (check_ortstatus(status) || !metadata) { + _ODP_ERR("SessionGetModelMetadata failed\n"); + ort_api->ReleaseSession(*session); + return -1; + } + + /* Get model version */ + status = ort_api->ModelMetadataGetVersion(metadata, &model_version); + if (check_ortstatus(status)) { + _ODP_ERR("ModelMetadataGetVersion failed\n"); + ort_api->ReleaseModelMetadata(metadata); + ort_api->ReleaseSession(*session); + return -1; + } + mdl->info.model_version = model_version; + mdl->info.interface_version = 0; + + if (get_model_io_info(*session, mdl, param)) { + _ODP_ERR("get_model_io_info() failed\n"); + ort_api->ReleaseModelMetadata(metadata); + ort_api->ReleaseSession(*session); + return -1; + } + + ort_api->ReleaseModelMetadata(metadata); + return 0; +} + +static int set_ort_run_opts(const char *name, OrtSessionOptions *se_opts) +{ + OrtStatus *status; + ort_run_opts_t *opts = &_odp_ml_glb->ort_run_opts; + const OrtApi *ort_api = _odp_ml_glb->ort_api; + + if (opts->enable_profiling) { + status = ort_api->EnableProfiling(se_opts, name); + if (check_ortstatus(status)) { + _ODP_ERR("Enable profiling failed\n"); + return -1; + } + } + + status = ort_api->SetSessionExecutionMode(se_opts, opts->execution_mode); + if (check_ortstatus(status)) { + _ODP_ERR("SetSessionExecutionMode failed\n"); + return -1; + } + + if (opts->intra_op_num_threads) { + status = ort_api->SetIntraOpNumThreads(se_opts, opts->intra_op_num_threads); + if (check_ortstatus(status)) { + _ODP_ERR("SetIntraOpNumThreads failed\n"); + return -1; + } + } + + if (opts->inter_op_num_threads) { + status = ort_api->SetInterOpNumThreads(se_opts, opts->inter_op_num_threads); + if (check_ortstatus(status)) { + _ODP_ERR("SetInterOpNumThreads failed\n"); + return -1; + } + } + + status = ort_api->SetSessionGraphOptimizationLevel(se_opts, opts->graph_opt_level); + if (check_ortstatus(status)) { + _ODP_ERR("SetSessionGraphOptimizationLevel failed\n"); + return -1; + } + + /* Optimized model file path is not provided */ + if (opts->opt_model_filepath[0] == '\0') + return 0; + + status = ort_api->SetOptimizedModelFilePath(se_opts, opts->opt_model_filepath); + if (check_ortstatus(status)) { + _ODP_ERR("SetOptimizedModelFilePath failed\n"); + return -1; + } + + return 0; +} + +static inline void reset_mdl_info_sizes(ml_model_t *mdl) +{ + memset(&mdl->info, 0, sizeof(odp_ml_model_info_t)); + memset(mdl->input_info, 0, sizeof(mdl->input_info)); + memset(mdl->output_info, 0, sizeof(mdl->output_info)); + memset(mdl->input_sizes, 0, sizeof(mdl->input_sizes)); + memset(mdl->output_sizes, 0, sizeof(mdl->output_sizes)); +} + +static int check_io_shape(ml_model_t *mdl) +{ + odp_ml_shape_info_t *shape; + + for (uint32_t i = 0; i < mdl->info.num_inputs; i++) { + shape = &mdl->input_info[i].shape; + + if (shape->type == ODP_ML_SHAPE_NONE) { + _ODP_ERR("Undefined shape type for model input[%u]\n", i); + return -1; + } + + if (shape->type == ODP_ML_SHAPE_STATIC) + continue; + + /* shape->type == ODP_ML_SHAPE_BATCH */ + for (uint32_t j = 0; j < shape->num_dim; j++) { + if (shape->dim[j] == ODP_ML_DIM_DYNAMIC && !shape->dim_max[j]) { + _ODP_ERR("Missing dim_max[%u] for dynamic sized input[%u], please" + " provide via the extra_info of model param\n", j, i); + return -1; + } + } + } + + for (uint32_t i = 0; i < mdl->info.num_outputs; i++) { + if (mdl->output_info[i].shape.type == ODP_ML_SHAPE_NONE) { + _ODP_ERR("Undefined shape type for model output[%u]\n", i); + return -1; + } + } + + return 0; +} + +odp_ml_model_t odp_ml_model_create(const char *name, const odp_ml_model_param_t *param) +{ + OrtStatus *status; + odp_ml_model_info_t *info; + OrtSessionOptions *session_opts; + uint32_t i = 0; + ml_model_t *mdl = NULL; + OrtSession *session = NULL; + const OrtApi *ort_api = _odp_ml_glb->ort_api; + + if (odp_unlikely(odp_global_ro.disable.ml)) { + _ODP_ERR("ML is disabled\n"); + return ODP_ML_MODEL_INVALID; + } + + if (odp_unlikely(param->size > _odp_ml_glb->ml_config.max_model_size)) { + _ODP_ERR("Model size %" PRIu64 " exceeds maximum model size configured %" PRIu64 "\n", + param->size, _odp_ml_glb->ml_config.max_model_size); + return ODP_ML_MODEL_INVALID; + } + + if (odp_unlikely(!param->size || !param->model)) { + _ODP_ERR("Invalid model param: param->model: %p, param->size: %" PRIu64 "\n", + param->model, param->size); + return ODP_ML_MODEL_INVALID; + } + + if (odp_unlikely(param->max_compl_id > ML_MAX_COMPL_ID)) { + _ODP_ERR("param->max_compl_id: %u exceeds maximum completion id supported: %d\n", + param->max_compl_id, ML_MAX_COMPL_ID); + return ODP_ML_MODEL_INVALID; + } + + /* Find an emtpy slot to store the new model */ + for (i = 0; i < ML_MAX_MODELS_CREATED; i++) { + if (_odp_ml_glb->models[i].state) + continue; + + odp_ticketlock_lock(&_odp_ml_glb->models[i].lock); + + if (_odp_ml_glb->models[i].state) { + odp_ticketlock_unlock(&_odp_ml_glb->models[i].lock); + continue; + } + + mdl = &_odp_ml_glb->models[i]; + break; + } + + if (i == ML_MAX_MODELS_CREATED) { + _ODP_ERR("Maximum number of models has already been created!\n"); + return ODP_ML_MODEL_INVALID; + } + + /* Free model entry was found and is now locked */ + mdl->state = ML_STATE_CREATED; + + status = ort_api->CreateSessionOptions(&session_opts); + if (check_ortstatus(status) || !session_opts) { + _ODP_ERR("Error: CreateSessionOptions failed.\n"); + mdl->state = ML_STATE_FREE; + odp_ticketlock_unlock(&mdl->lock); + return ODP_ML_MODEL_INVALID; + } + + if (set_ort_run_opts(name, session_opts)) { + _odp_ml_glb->ort_api->ReleaseSessionOptions(session_opts); + mdl->state = ML_STATE_FREE; + odp_ticketlock_unlock(&mdl->lock); + return ODP_ML_MODEL_INVALID; + } + + /* Store model info */ + info = &mdl->info; + memset(info, 0, sizeof(odp_ml_model_info_t)); + + if (create_ort_model(param, &session, mdl, session_opts)) { + mdl->state = ML_STATE_FREE; + + /* Initialize info back to 0 when some fields have been filled + * while later failed */ + reset_mdl_info_sizes(mdl); + odp_ticketlock_unlock(&mdl->lock); + + _odp_ml_glb->ort_api->ReleaseSessionOptions(session_opts); + _ODP_ERR("create_ort_model() failed\n"); + return ODP_ML_MODEL_INVALID; + } + + if (check_io_shape(mdl)) { + mdl->state = ML_STATE_FREE; + reset_mdl_info_sizes(mdl); + odp_ticketlock_unlock(&mdl->lock); + + ort_api->ReleaseSession(session); + _odp_ml_glb->ort_api->ReleaseSessionOptions(session_opts); + return ODP_ML_MODEL_INVALID; + } + + mdl->session = session; + mdl->session_opts = session_opts; + info->index = i; + + if (name) { + strncpy(info->name, name, ODP_ML_MODEL_NAME_LEN - 1); + info->name[ODP_ML_MODEL_NAME_LEN - 1] = 0; + } + + mdl->max_compl_id = param->max_compl_id; + for (uint32_t j = 0; j < ML_MAX_COMPL_ID; j++) + odp_atomic_init_u32(&mdl->compl_status[j], 1); + + odp_ticketlock_unlock(&mdl->lock); + return (odp_ml_model_t)mdl; +} + +int odp_ml_model_destroy(odp_ml_model_t model) +{ + ml_model_t *mdl = ml_model_from_handle(model); + + if (model == ODP_ML_MODEL_INVALID) { + _ODP_ERR("Bad ML model handle\n"); + return -1; + } + + odp_ticketlock_lock(&mdl->lock); + + if (mdl->state != ML_STATE_CREATED) { + _ODP_ERR("Model not created\n"); + odp_ticketlock_unlock(&mdl->lock); + return -1; + } + + _odp_ml_glb->ort_api->ReleaseSessionOptions(mdl->session_opts); + _odp_ml_glb->ort_api->ReleaseSession(mdl->session); + mdl->state = ML_STATE_FREE; + mdl->session = NULL; + odp_ticketlock_unlock(&mdl->lock); + + return 0; +} + +int odp_ml_model_info(odp_ml_model_t model, odp_ml_model_info_t *info) +{ + ml_model_t *mdl = ml_model_from_handle(model); + + if (odp_unlikely(model == ODP_ML_MODEL_INVALID)) { + _ODP_ERR("Bad ML model handle\n"); + return -1; + } + + if (odp_unlikely(!info)) { + _ODP_ERR("info must not be NULL\n"); + return -1; + } + + odp_ticketlock_lock(&mdl->lock); + if (odp_unlikely(mdl->state == ML_STATE_FREE)) { + _ODP_ERR("Model not created\n"); + odp_ticketlock_unlock(&mdl->lock); + return -1; + } + + *info = mdl->info; + + odp_ticketlock_unlock(&mdl->lock); + return 0; +} + +uint32_t odp_ml_model_input_info(odp_ml_model_t model, odp_ml_input_info_t info[], uint32_t num) +{ + uint32_t num_model_inputs; + uint32_t num_written; + ml_model_t *mdl = ml_model_from_handle(model); + + if (odp_unlikely(model == ODP_ML_MODEL_INVALID)) { + _ODP_ERR("Bad ML model handle\n"); + return 0; + } + + odp_ticketlock_lock(&mdl->lock); + num_model_inputs = mdl->info.num_inputs; + num_written = num_model_inputs >= num ? num : num_model_inputs; + + if (num == 0) { + odp_ticketlock_unlock(&mdl->lock); + return num_model_inputs; + } + + for (uint32_t i = 0; i < num_written; i++) + info[i] = mdl->input_info[i]; + + odp_ticketlock_unlock(&mdl->lock); + return num_model_inputs; +} + +uint32_t odp_ml_model_output_info(odp_ml_model_t model, odp_ml_output_info_t info[], uint32_t num) +{ + uint32_t num_model_outputs; + uint32_t num_written; + ml_model_t *mdl = ml_model_from_handle(model); + + if (odp_unlikely(model == ODP_ML_MODEL_INVALID)) { + _ODP_ERR("Bad ML model handle\n"); + return 0; + } + + odp_ticketlock_lock(&mdl->lock); + num_model_outputs = mdl->info.num_outputs; + num_written = num_model_outputs >= num ? num : num_model_outputs; + + if (num == 0) { + odp_ticketlock_unlock(&mdl->lock); + return num_model_outputs; + } + + for (uint32_t i = 0; i < num_written; i++) + info[i] = mdl->output_info[i]; + + odp_ticketlock_unlock(&mdl->lock); + return num_model_outputs; +} + +odp_ml_model_t odp_ml_model_lookup(const char *name) +{ + uint32_t i; + ml_model_t *mdl; + + for (i = 0; i < ML_MAX_MODELS_CREATED; i++) { + mdl = &_odp_ml_glb->models[i]; + + odp_ticketlock_lock(&mdl->lock); + + if (mdl->state == ML_STATE_FREE) { + odp_ticketlock_unlock(&mdl->lock); + continue; + } + + if (!strcmp(mdl->info.name, name)) { + /* found it */ + odp_ticketlock_unlock(&mdl->lock); + return (odp_ml_model_t)mdl; + } + odp_ticketlock_unlock(&mdl->lock); + } + + return ODP_ML_MODEL_INVALID; +} + +uint64_t odp_ml_model_to_u64(odp_ml_model_t model) +{ + return _odp_pri(model); +} + +static const char *data_type_str(odp_ml_data_type_t data_type) +{ + switch (data_type) { + case ODP_ML_DATA_TYPE_INT8: + return "int8"; + case ODP_ML_DATA_TYPE_UINT8: + return "uint8"; + case ODP_ML_DATA_TYPE_UINT16: + return "uint16"; + case ODP_ML_DATA_TYPE_INT16: + return "int16"; + case ODP_ML_DATA_TYPE_INT32: + return "int32"; + case ODP_ML_DATA_TYPE_UINT32: + return "uint32"; + case ODP_ML_DATA_TYPE_INT64: + return "int64"; + case ODP_ML_DATA_TYPE_UINT64: + return "uint64"; + case ODP_ML_DATA_TYPE_FP16: + return "fp16"; + case ODP_ML_DATA_TYPE_FP32: + return "fp32"; + case ODP_ML_DATA_TYPE_BFP16: + return "bfp16"; + default: + return "unknown"; + } +} + +static const char *shape_type_str(odp_ml_shape_type_t shape_type) +{ + switch (shape_type) { + case ODP_ML_SHAPE_NONE: + return "none"; + case ODP_ML_SHAPE_STATIC: + return "static"; + case ODP_ML_SHAPE_BATCH: + return "batch"; + default: + return "Unknown"; + } +} + +static void print_shape(const odp_ml_shape_info_t *shape) +{ + /* Print shape */ + _ODP_PRINT("Shape: %s [", shape_type_str(shape->type)); + + for (uint32_t i = 0; i < shape->num_dim; i++) { + if (shape->dim[i] == ODP_ML_DIM_DYNAMIC) + _ODP_PRINT("Dyn"); + else + _ODP_PRINT("%" PRIu32, shape->dim[i]); + + if (i == (shape->num_dim - 1)) + _ODP_PRINT("]\n"); + else + _ODP_PRINT(", "); + } + + /* The number of dimensions for a scalar input is 0, in which case did not + * go into above for loop */ + if (shape->num_dim == 0) + _ODP_PRINT("]\n"); +} + +void odp_ml_model_print(odp_ml_model_t model) +{ + ml_model_t *mdl = ml_model_from_handle(model); + const odp_ml_model_info_t * const info = &mdl->info; + const odp_ml_input_info_t * const input_info = mdl->input_info; + const odp_ml_output_info_t * const output_info = mdl->output_info; + + if (odp_unlikely(model == ODP_ML_MODEL_INVALID)) { + _ODP_ERR("Bad ML model handle\n"); + return; + } + + odp_ticketlock_lock(&mdl->lock); + if (odp_unlikely(mdl->state == ML_STATE_FREE)) { + odp_ticketlock_unlock(&mdl->lock); + _ODP_ERR("Model not created\n"); + return; + } + + _ODP_PRINT("\nModel info\n"); + _ODP_PRINT("----------\n"); + _ODP_PRINT(" Model handle: 0x%" PRIx64 "\n", odp_ml_model_to_u64(model)); + _ODP_PRINT(" Name: %s\n", info->name); + _ODP_PRINT(" Model version: %" PRIu64 "\n", info->model_version); + _ODP_PRINT(" Model interface version: %" PRIu64 "\n", info->interface_version); + _ODP_PRINT(" Index: %u\n", info->index); + _ODP_PRINT(" Number of inputs: %u\n", info->num_inputs); + + for (uint32_t i = 0; i < info->num_inputs; i++) { + _ODP_PRINT(" Input[%u]: ", i); + _ODP_PRINT("Name: %s, ", input_info[i].name); + _ODP_PRINT("Data_type: %s, ", data_type_str(input_info[i].data_type)); + print_shape(&input_info[i].shape); + } + + _ODP_PRINT(" Number of outputs: %u\n", info->num_outputs); + for (uint32_t i = 0; i < info->num_outputs; i++) { + _ODP_PRINT(" Output[%u]: ", i); + _ODP_PRINT("Name: %s, ", output_info[i].name); + _ODP_PRINT("Data_type: %s, ", data_type_str(output_info[i].data_type)); + print_shape(&output_info[i].shape); + } + + odp_ticketlock_unlock(&mdl->lock); + + _ODP_PRINT("\n"); +} + +static inline void mode_print(odp_ml_compl_mode_t compl_mode_mask) +{ + if (compl_mode_mask & ODP_ML_COMPL_MODE_SYNC) + _ODP_PRINT(" syn"); + + if (compl_mode_mask & ODP_ML_COMPL_MODE_POLL) + _ODP_PRINT(" poll"); + + if (compl_mode_mask & ODP_ML_COMPL_MODE_EVENT) + _ODP_PRINT(" event"); +} + +void odp_ml_print(void) +{ + _ODP_PRINT("\nML info\n"); + _ODP_PRINT("-----------\n"); + _ODP_PRINT(" max_model_size: %u\n", ML_MAX_MODEL_SIZE); + _ODP_PRINT(" max_compl_id: %u\n", ML_MAX_COMPL_ID); + _ODP_PRINT(" max_models_created: %u\n", ML_MAX_MODELS_CREATED); + _ODP_PRINT(" max_models_loaded: %u\n", ML_MAX_MODELS_LOADED); + _ODP_PRINT(" model_max_inputs: %u\n", CONFIG_ML_MAX_INPUTS); + _ODP_PRINT(" model_max_outputs: %u\n", CONFIG_ML_MAX_OUTPUTS); + + _ODP_PRINT(" load:\n"); + _ODP_PRINT(" completion mode: "); + mode_print(_odp_ml_glb->capa.load.compl_mode_mask); + _ODP_PRINT(", plain queue: %c, schedule queue: %c\n", + _odp_ml_glb->capa.load.compl_queue_plain ? 'Y' : 'N', + _odp_ml_glb->capa.load.compl_queue_sched ? 'Y' : 'N'); + + _ODP_PRINT(" run:\n"); + _ODP_PRINT(" completion mode:"); + mode_print(_odp_ml_glb->capa.run.compl_mode_mask); + _ODP_PRINT(", plain queue: %c, schedule queue: %c\n", + _odp_ml_glb->capa.run.compl_queue_plain ? 'Y' : 'N', + _odp_ml_glb->capa.run.compl_queue_sched ? 'Y' : 'N'); + _ODP_PRINT("\n"); +} + +int odp_ml_model_extra_stat_info(odp_ml_model_t model, + odp_ml_extra_stat_info_t info[] ODP_UNUSED, + int num ODP_UNUSED) +{ + if (odp_unlikely(model == ODP_ML_MODEL_INVALID)) { + _ODP_ERR("Bad ML model handle\n"); + return -1; + } + + return 0; +} + +int odp_ml_model_extra_stats(odp_ml_model_t model, uint64_t stats[] ODP_UNUSED, int num ODP_UNUSED) +{ + if (odp_unlikely(model == ODP_ML_MODEL_INVALID)) { + _ODP_ERR("Bad ML model handle\n"); + return -1; + } + + return 0; +} + +void odp_ml_compl_pool_param_init(odp_ml_compl_pool_param_t *pool_param) +{ + if (odp_unlikely(!pool_param)) { + _ODP_ERR("Param 'pool_param' must not NULL\n"); + return; + } + + memset(pool_param, 0, sizeof(odp_ml_compl_pool_param_t)); + + pool_param->cache_size = _odp_ml_glb->pool_param.buf.cache_size; +} + +odp_pool_t odp_ml_compl_pool_create(const char *name, const odp_ml_compl_pool_param_t *pool_param) +{ + odp_pool_t pool; + odp_pool_param_t ml_pool_param; + uint32_t num = pool_param->num; + uint32_t uarea_size = pool_param->uarea_size; + uint32_t cache_size = pool_param->cache_size; + uint32_t buf_size = _ODP_MAX(sizeof(odp_ml_run_result_t), + sizeof(odp_ml_load_result_t)); + + if (num > _odp_ml_glb->capa.pool.max_num) { + _ODP_ERR("Too many ML completion events: %u\n", num); + return ODP_POOL_INVALID; + } + + if (uarea_size > _odp_ml_glb->capa.pool.max_uarea_size) { + _ODP_ERR("Bad uarea size: %u\n", uarea_size); + return ODP_POOL_INVALID; + } + + if (cache_size < _odp_ml_glb->capa.pool.min_cache_size || + cache_size > _odp_ml_glb->capa.pool.max_cache_size) { + _ODP_ERR("Bad cache size: %u\n", cache_size); + return ODP_POOL_INVALID; + } + + odp_pool_param_init(&ml_pool_param); + ml_pool_param.type = ODP_POOL_BUFFER; + ml_pool_param.uarea_init.init_fn = pool_param->uarea_init.init_fn; + ml_pool_param.uarea_init.args = pool_param->uarea_init.args; + ml_pool_param.buf.num = num; + ml_pool_param.buf.cache_size = cache_size; + ml_pool_param.buf.size = buf_size; + ml_pool_param.buf.uarea_size = uarea_size; + + pool = _odp_pool_create(name, &ml_pool_param, ODP_POOL_ML_COMPL); + + return pool; +} + +odp_ml_compl_t odp_ml_compl_alloc(odp_pool_t pool) +{ + odp_buffer_t buf; + odp_event_t ev; + odp_ml_run_result_t *result; + uint32_t buf_size = _ODP_MAX(sizeof(odp_ml_run_result_t), + sizeof(odp_ml_load_result_t)); + + buf = odp_buffer_alloc(pool); + + if (odp_unlikely(buf == ODP_BUFFER_INVALID)) + return ODP_ML_COMPL_INVALID; + + result = odp_buffer_addr(buf); + memset(result, 0, buf_size); + + ev = odp_buffer_to_event(buf); + _odp_event_type_set(ev, ODP_EVENT_ML_COMPL); + + return (odp_ml_compl_t)(uintptr_t)buf; +} + +void odp_ml_compl_free(odp_ml_compl_t ml_compl) +{ + odp_event_t ev; + odp_buffer_t buf = (odp_buffer_t)(uintptr_t)ml_compl; + + if (odp_unlikely(ml_compl == ODP_ML_COMPL_INVALID)) { + _ODP_ERR("Bad ML job completion handle\n"); + return; + } + + ev = odp_buffer_to_event(buf); + _odp_event_type_set(ev, ODP_EVENT_BUFFER); + + odp_buffer_free(buf); +} + +int odp_ml_compl_run_result(odp_ml_compl_t ml_compl, odp_ml_run_result_t *result) +{ + odp_event_subtype_t subtype; + odp_ml_run_result_t *run_result; + odp_buffer_t buf = (odp_buffer_t)(uintptr_t)ml_compl; + odp_event_t ev = odp_buffer_to_event(buf); + + if (odp_unlikely(ml_compl == ODP_ML_COMPL_INVALID)) { + _ODP_ERR("Given ML completion event is invalid\n"); + return -2; + } + + if (odp_event_types(ev, &subtype) != ODP_EVENT_ML_COMPL || + subtype != ODP_EVENT_ML_COMPL_RUN) { + _ODP_ERR("Given completion event has wrong event type or subtype\n"); + return -2; + } + + run_result = odp_buffer_addr(buf); + if (result) + *result = *run_result; + + return run_result->error_code ? -1 : 0; +} + +int odp_ml_compl_load_result(odp_ml_compl_t ml_compl, odp_ml_load_result_t *result) +{ + odp_event_subtype_t subtype; + odp_ml_load_result_t *load_result; + odp_buffer_t buf = (odp_buffer_t)(uintptr_t)ml_compl; + odp_event_t ev = odp_buffer_to_event(buf); + + if (odp_unlikely(ml_compl == ODP_ML_COMPL_INVALID)) { + _ODP_ERR("Given ML completion event is invalid\n"); + return -2; + } + + if (odp_event_types(ev, &subtype) != ODP_EVENT_ML_COMPL || + subtype != ODP_EVENT_ML_COMPL_LOAD) { + _ODP_ERR("Given completion event has wrong event type or subtype\n"); + return -2; + } + + load_result = odp_buffer_addr(buf); + if (result) + *result = *load_result; + + return load_result->error_code ? -1 : 0; +} + +void *odp_ml_compl_user_area(odp_ml_compl_t ml_compl) +{ + return odp_buffer_user_area((odp_buffer_t)(uintptr_t)ml_compl); +} + +odp_ml_compl_t odp_ml_compl_from_event(odp_event_t event) +{ + _ODP_ASSERT(_odp_event_hdr_field(event, int8_t, event_type) == ODP_EVENT_ML_COMPL); + + return (odp_ml_compl_t)(uintptr_t)event; +} + +odp_event_t odp_ml_compl_to_event(odp_ml_compl_t ml_compl) +{ + return (odp_event_t)(uintptr_t)ml_compl; +} + +uint64_t odp_ml_compl_to_u64(odp_ml_compl_t ml_compl) +{ + return (uint64_t)(uintptr_t)ml_compl; +} + +void odp_ml_compl_param_init(odp_ml_compl_param_t *compl_param) +{ + memset(compl_param, 0, sizeof(odp_ml_compl_param_t)); + + compl_param->queue = ODP_QUEUE_INVALID; + compl_param->event = ODP_EVENT_INVALID; +} + +int odp_ml_model_load(odp_ml_model_t model, odp_ml_load_result_t *result) +{ + odp_ml_load_result_t result_local; + int ret = -1; + ml_model_t *mdl = ml_model_from_handle(model); + + memset(&result_local, 0, sizeof(result_local)); + + if (odp_unlikely(model == ODP_ML_MODEL_INVALID)) { + _ODP_ERR("Bad ML model handle\n"); + result_local.error_code = ML_BAD_HDL; + goto load_fail; + } + + odp_ticketlock_lock(&mdl->lock); + if (odp_unlikely(mdl->state != ML_STATE_CREATED)) { + _ODP_ERR("Model has not been created yet or is already loaded\n"); + odp_ticketlock_unlock(&mdl->lock); + result_local.error_code = ML_NOT_CREATED; + goto load_fail; + } + + mdl->state = ML_STATE_LOADED; + odp_ticketlock_unlock(&mdl->lock); + ret = 0; + +load_fail: + if (result) + *result = result_local; + + return ret; +} + +static inline int check_compl_param(const odp_ml_compl_param_t *compl_param, + uint32_t max_compl_id, odp_bool_t is_load) +{ + odp_ml_config_t *config = &_odp_ml_glb->ml_config; + + switch (compl_param->mode) { + case ODP_ML_COMPL_MODE_POLL: + if (is_load && !(config->load_mode_mask & ODP_ML_COMPL_MODE_POLL)) { + _ODP_ERR("Poll mode loading/unloading is not configured\n"); + return -1; + } + + if (!is_load && !(config->run_mode_mask & ODP_ML_COMPL_MODE_POLL)) { + _ODP_ERR("Poll mode run is not configured\n"); + return -1; + } + + if (compl_param->compl_id > max_compl_id) { + _ODP_ERR("Bad compl_id: %u, exceeding model max completion id %u\n", + compl_param->compl_id, max_compl_id); + return -1; + } + break; + case ODP_ML_COMPL_MODE_EVENT: + if (is_load && !(config->load_mode_mask & ODP_ML_COMPL_MODE_EVENT)) { + _ODP_ERR("Event mode loading/unloading is not configured\n"); + return -1; + } + + if (!is_load && !(config->run_mode_mask & ODP_ML_COMPL_MODE_EVENT)) { + _ODP_ERR("Event mode run is not configured\n"); + return -1; + } + + if (compl_param->event == ODP_EVENT_INVALID || + compl_param->queue == ODP_QUEUE_INVALID) { + _ODP_ERR("Bad event or queue\n"); + return -1; + } + + if (odp_event_type(compl_param->event) != ODP_EVENT_ML_COMPL) { + _ODP_ERR("Bad completion event type\n"); + return -1; + } + break; + default: + /* Including ODP_ML_COMPL_MODE_SYNC, which is not supported by + * asynchrous functions (e.g. *_start()) either. + */ + _ODP_ERR("Invalid completion mode %u\n", compl_param->mode); + return -1; + } + + return 0; +} + +int odp_ml_model_load_start(odp_ml_model_t model, const odp_ml_compl_param_t *compl_param) +{ + int ret; + ml_model_t *mdl = ml_model_from_handle(model); + + if (odp_unlikely(model == ODP_ML_MODEL_INVALID)) { + _ODP_ERR("Bad model handle\n"); + return -1; + } + + if (odp_unlikely(check_compl_param(compl_param, mdl->max_compl_id, true))) + return -1; + + if (compl_param->mode == ODP_ML_COMPL_MODE_POLL) + odp_atomic_store_rel_u32(&mdl->compl_status[compl_param->compl_id], 0); + + ret = odp_ml_model_load(model, NULL); + + if (odp_unlikely(ret)) + return -1; + + /* Send a completion event to the given queue */ + if (compl_param->mode == ODP_ML_COMPL_MODE_EVENT) { + odp_ml_load_result_t *result; + odp_buffer_t buf = (odp_buffer_t)(uintptr_t)compl_param->event; + + _odp_buffer_subtype_set(buf, ODP_EVENT_ML_COMPL_LOAD); + + result = odp_buffer_addr(buf); + result->error_code = 0; + result->user_ptr = compl_param->user_ptr; + + if (odp_unlikely(odp_queue_enq(compl_param->queue, compl_param->event))) { + _ODP_ERR("Completion event enqueue failed %" PRIu64 "\n", + odp_queue_to_u64(compl_param->queue)); + if (odp_ml_model_unload(model, NULL)) + _ODP_ERR("Failed to unload model\n"); + return -1; + } + + return 0; + } + + mdl->result[compl_param->compl_id].user_ptr = compl_param->user_ptr; + odp_atomic_store_rel_u32(&mdl->compl_status[compl_param->compl_id], 1); + return 0; +} + +int odp_ml_model_load_status(odp_ml_model_t model, uint32_t compl_id, odp_ml_load_result_t *result) +{ + int ret; + ml_model_t *mdl = ml_model_from_handle(model); + + if (odp_unlikely(model == ODP_ML_MODEL_INVALID || compl_id > mdl->max_compl_id)) { + _ODP_ERR("Invalid model or compl_id: %u\n", compl_id); + return -2; + } + + ret = odp_atomic_load_acq_u32(&mdl->compl_status[compl_id]); + + if (ret && result) { + result->error_code = 0; + result->user_ptr = mdl->result[compl_id].user_ptr; + } + + return ret; +} + +int odp_ml_model_unload(odp_ml_model_t model, odp_ml_load_result_t *result) +{ + odp_ml_load_result_t result_local; + int ret = -1; + ml_model_t *mdl = ml_model_from_handle(model); + + memset(&result_local, 0, sizeof(result_local)); + + if (odp_unlikely(model == ODP_ML_MODEL_INVALID)) { + result_local.error_code = ML_BAD_HDL; + _ODP_ERR("Bad ML model handle\n"); + goto unload_fail; + } + + odp_ticketlock_lock(&mdl->lock); + /* mdl->state == ML_STATE_FREE, ML_STATE_CREATED, ML_STATE_INFERENCING */ + if (odp_unlikely(mdl->state != ML_STATE_LOADED)) { + _ODP_ERR("Model has not been created/loaded or inferencing has not finished yet\n"); + odp_ticketlock_unlock(&mdl->lock); + result_local.error_code = ML_NOT_LOADED; + goto unload_fail; + } + + mdl->state = ML_STATE_CREATED; + odp_ticketlock_unlock(&mdl->lock); + + ret = 0; + +unload_fail: + if (result) + *result = result_local; + + return ret; +} + +int odp_ml_model_unload_start(odp_ml_model_t model, const odp_ml_compl_param_t *compl_param) +{ + int ret; + ml_model_t *mdl = ml_model_from_handle(model); + + if (odp_unlikely(model == ODP_ML_MODEL_INVALID)) { + _ODP_ERR("Bad model handle\n"); + return -1; + } + + if (odp_unlikely(check_compl_param(compl_param, mdl->max_compl_id, true))) + return -1; + + if (compl_param->mode == ODP_ML_COMPL_MODE_POLL) + odp_atomic_store_rel_u32(&mdl->compl_status[compl_param->compl_id], 0); + + ret = odp_ml_model_unload(model, NULL); + + if (odp_unlikely(ret)) + return -1; + + /* Upon successful unloading, send a completion event to the given queue */ + if (compl_param->mode == ODP_ML_COMPL_MODE_EVENT) { + odp_ml_load_result_t *result; + odp_buffer_t buf = (odp_buffer_t)(uintptr_t)compl_param->event; + + _odp_buffer_subtype_set(buf, ODP_EVENT_ML_COMPL_LOAD); + + result = odp_buffer_addr(buf); + result->error_code = 0; + result->user_ptr = compl_param->user_ptr; + + if (odp_unlikely(odp_queue_enq(compl_param->queue, compl_param->event))) { + _ODP_ERR("Completion event enqueue failed %" PRIu64 "\n", + odp_queue_to_u64(compl_param->queue)); + return -1; + } + + return 0; + } + + mdl->result[compl_param->compl_id].user_ptr = compl_param->user_ptr; + odp_atomic_store_rel_u32(&mdl->compl_status[compl_param->compl_id], 1); + return 0; +} + +int odp_ml_model_unload_status(odp_ml_model_t model, uint32_t compl_id, + odp_ml_load_result_t *result) +{ + return odp_ml_model_load_status(model, compl_id, result); +} + +void odp_ml_run_param_init(odp_ml_run_param_t *param) +{ + memset(param, 0, sizeof(odp_ml_run_param_t)); +} + +static void ml_shape_to_int64(const odp_ml_shape_info_t *shape, uint32_t batch_size, int64_t *array) +{ + for (uint32_t i = 0; i < shape->num_dim; i++) { + /* Replace dynamic dimension size with provided batch_size */ + if (shape->dim[i] == ODP_ML_DIM_DYNAMIC) + array[i] = batch_size; + else + array[i] = shape->dim[i]; + } +} + +/* Get the number of elements in given shape */ +static inline uint64_t get_num_elem(uint32_t batch_size, const odp_ml_shape_info_t *shape) +{ + uint64_t num_elements = 1; + int64_t dim[ODP_ML_MAX_DIMS] = {0}; + + ml_shape_to_int64(shape, batch_size, dim); + + for (uint32_t i = 0; i < shape->num_dim; i++) + num_elements *= (uint64_t)dim[i]; + + return num_elements; +} + +static inline uint32_t dyn_io_size(const odp_ml_shape_info_t *shape, uint32_t data_type_size, + const odp_ml_run_param_t *param) +{ + uint32_t size; + + if (!param || !param->batch_size) { + _ODP_ERR("Parameter 'param' must not be NULL and batch_size must be " + "provided when a input/output has dynamic dimension size\n"); + return 0; + } + + size = get_num_elem(param->batch_size, shape); + size *= data_type_size; + + return size; +} + +static int verify_run_params(odp_ml_model_t model, const odp_ml_data_t *data, + const odp_ml_run_param_t *param) +{ + const ml_model_t *mdl = ml_model_from_handle(model); + + if (odp_unlikely(model == ODP_ML_MODEL_INVALID)) { + _ODP_ERR("Bad ML model handle\n"); + return -1; + } + + if (odp_unlikely(!data)) { + _ODP_ERR("Parameter 'data' must not be NULL\n"); + return -1; + } + + /* Make sure that the number of input data segments equals or bigger than + * the number of model inputs. */ + if (mdl->info.num_inputs > data->num_input_seg) { + _ODP_ERR("The num of input data segments %u must not less than " + "the number of model inputs %u\n", data->num_input_seg, + mdl->info.num_inputs); + return -1; + } + + if (mdl->info.num_outputs > data->num_output_seg) { + _ODP_ERR("The num of output data segments %u must not less than " + "the number of model outputs %u\n", data->num_output_seg, + mdl->info.num_outputs); + return -1; + } + + if (data->num_input_seg > mdl->info.num_inputs && + (_odp_ml_glb->capa.max_segs_per_input == 1)) { + _ODP_ERR("Segmented input data is not supported\n"); + return -1; + } + + if (data->num_output_seg > mdl->info.num_outputs && + (_odp_ml_glb->capa.max_segs_per_output == 1)) { + _ODP_ERR("Segmented output data is not supported"); + return -1; + } + + uint32_t size = 0; + uint32_t input_index = 0; + uint32_t seg_size_sum = 0; + odp_bool_t index_new = true; + uint32_t segs_per_input = 1; + + for (uint32_t i = 0; i < data->num_input_seg; i++) { + if (data->input_seg[i].addr == NULL) { + _ODP_ERR("data->input_seg[%u].addr must not NULL\n", i); + return -1; + }; + + if (index_new) { + if (input_index > mdl->info.num_inputs - 1) { + _ODP_ERR("Too much number of input segments given\n"); + return -1; + } + + /* Input with dynamic batch size */ + if (mdl->input_info[input_index].shape.type == ODP_ML_SHAPE_BATCH) + size = dyn_io_size(&mdl->input_info[input_index].shape, + mdl->input_info[input_index].data_type_size, + param); + else + size = mdl->input_sizes[input_index]; + + if (!size) { + _ODP_ERR("Size for %uth input is 0\n", input_index); + return -1; + } + } + + seg_size_sum += data->input_seg[i].size; + + if (seg_size_sum > size) { + _ODP_ERR("Sum of segment sizes %u exceeds %uth input data size %u\n", + seg_size_sum, input_index, size); + return -1; + } + + if (seg_size_sum == size) { + if (segs_per_input > _odp_ml_glb->capa.max_segs_per_input) { + _ODP_ERR("Number of segments %u for input[%u] exceeds maximum" + " number of data segments per model input %u\n", + segs_per_input, input_index, + _odp_ml_glb->capa.max_segs_per_input); + return -1; + } + input_index++; + index_new = true; + seg_size_sum = 0; + segs_per_input = 1; + } else { + segs_per_input++; + index_new = false; + } + } + + if (input_index != mdl->info.num_inputs) { + _ODP_ERR("Data is not provided for all model inputs\n"); + return -1; + } + + seg_size_sum = 0; + index_new = true; + uint32_t output_index = 0; + uint32_t segs_per_output = 1; + + for (uint32_t i = 0; i < data->num_output_seg; i++) { + if (data->output_seg[i].addr == NULL) { + _ODP_ERR("data->output_seg[%u].addr must not NULL\n", i); + return -1; + } + + if (index_new) { + if (output_index > mdl->info.num_outputs - 1) { + _ODP_ERR("Too much number of output segments given\n"); + return -1; + } + + /* Output with dynamic batch size */ + if (mdl->output_info[output_index].shape.type == ODP_ML_SHAPE_BATCH) + size = dyn_io_size(&mdl->output_info[output_index].shape, + mdl->output_info[output_index].data_type_size, + param); + else + size = mdl->output_sizes[output_index]; + + if (!size) { + _ODP_ERR("Size for %uth output is 0\n", output_index); + return -1; + } + } + + seg_size_sum += data->output_seg[i].size; + + if (seg_size_sum > size) { + _ODP_ERR("Sum of segment sizes %u exceeds %uth output data size %u\n", + seg_size_sum, output_index, size); + return -1; + } + + if (seg_size_sum >= size) { + if (segs_per_output > _odp_ml_glb->capa.max_segs_per_output) { + _ODP_ERR("Number of segments %u for output[%u] exceeds maximum" + " number of data segments per model output %u\n", + segs_per_output, output_index, + _odp_ml_glb->capa.max_segs_per_output); + return -1; + } + output_index++; + index_new = true; + seg_size_sum = 0; + segs_per_output = 1; + } else { + segs_per_output++; + index_new = false; + } + } + + if (output_index != mdl->info.num_outputs) { + _ODP_ERR("Not enough output_segs to hold all output data\n"); + return -1; + } + + return 0; +} + +static ONNXTensorElementDataType onnx_dtype_from_odp_dtype(odp_ml_data_type_t data_type) +{ + switch (data_type) { + case ODP_ML_DATA_TYPE_NONE: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + case ODP_ML_DATA_TYPE_INT8: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; + case ODP_ML_DATA_TYPE_UINT8: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + case ODP_ML_DATA_TYPE_INT16: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16; + case ODP_ML_DATA_TYPE_UINT16: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16; + case ODP_ML_DATA_TYPE_INT24: + /* Fall through*/ + case ODP_ML_DATA_TYPE_UINT24: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + case ODP_ML_DATA_TYPE_FP64: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE; + case ODP_ML_DATA_TYPE_INT32: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32; + case ODP_ML_DATA_TYPE_UINT32: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32; + case ODP_ML_DATA_TYPE_INT64: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + case ODP_ML_DATA_TYPE_UINT64: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64; + case ODP_ML_DATA_TYPE_FP16: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + case ODP_ML_DATA_TYPE_FP32: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + case ODP_ML_DATA_TYPE_BFP16: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; + default: + return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + } +} + +static int verify_tensor(const OrtValue *tensor, odp_ml_data_type_t expected_type, + const odp_ml_shape_info_t *expected_shape, uint32_t batch_size) +{ + OrtTensorTypeAndShapeInfo *tensor_info; + ONNXTensorElementDataType tensor_type; + size_t dim_count; + OrtStatus *status = NULL; + int64_t dims[ODP_ML_MAX_DIMS] = {0}; + int64_t shape_arr[ODP_ML_MAX_DIMS] = {0}; + const OrtApi *ort_api = _odp_ml_glb->ort_api; + + status = ort_api->GetTensorTypeAndShape(tensor, &tensor_info); + if (check_ortstatus(status)) { + _ODP_ERR("GetTensorTypeAndShape() failed\n"); + return -1; + } + + status = ort_api->GetTensorElementType(tensor_info, &tensor_type); + if (check_ortstatus(status)) { + ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + _ODP_ERR("GetTensorElementType() failed\n"); + return -1; + } + + if (onnx_dtype_to_odp_dtype(tensor_type) != expected_type) { + ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + _ODP_ERR("Tensor type does not match model type\n"); + return -1; + } + + status = ort_api->GetDimensionsCount(tensor_info, &dim_count); + if (check_ortstatus(status)) { + ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + _ODP_ERR("GetDimensionsCount() failed\n"); + return -1; + } + + if (dim_count != expected_shape->num_dim) { + ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + _ODP_ERR("Tensor dimension does not match shape_dim\n"); + return -1; + } + + status = ort_api->GetDimensions(tensor_info, dims, dim_count); + if (check_ortstatus(status)) { + ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + _ODP_ERR("GetDimensions() failed\n"); + return -1; + } + + ml_shape_to_int64(expected_shape, batch_size, shape_arr); + + for (uint32_t i = 0; i < dim_count; i++) { + if (dims[i] != shape_arr[i]) { + ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + _ODP_ERR("Shape[%u]: %" PRIu64 " does not match expected: %" PRIu64 "\n", + i, dims[i], shape_arr[i]); + return -1; + } + } + + ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + return 0; +} + +static int input_data_to_tensor(const odp_ml_input_info_t *input_info, uint32_t num_seg, + const odp_ml_data_seg_t *input_seg, uint32_t *seg_idx, + uint32_t batch_size, OrtValue **input_tensor) +{ + int is_tensor; + uint64_t input_size; + OrtAllocator *allocator; + void *data = NULL; + OrtStatus *status = NULL; + int64_t shape[ODP_ML_MAX_DIMS] = {0}; + const OrtApi *ort_api = _odp_ml_glb->ort_api; + ONNXTensorElementDataType onnx_dtype = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + + ml_shape_to_int64(&input_info->shape, batch_size, shape); + + onnx_dtype = onnx_dtype_from_odp_dtype(input_info->data_type); + _ODP_ASSERT(onnx_dtype != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + + status = ort_api->GetAllocatorWithDefaultOptions(&allocator); + if (check_ortstatus(status)) { + _ODP_ERR("GetAllocatorWithDefaultOptions() failed\n"); + return -1; + } + + status = ort_api->CreateTensorAsOrtValue(allocator, + shape, + input_info->shape.num_dim, + onnx_dtype, + input_tensor); + if (check_ortstatus(status) || !input_tensor[0]) { + _ODP_ERR("CreateTensorWithDataAsOrtValue() failed\n"); + return -1; + } + + input_size = input_info->data_type_size * get_num_elem(batch_size, &input_info->shape); + + status = ort_api->GetTensorMutableData(input_tensor[0], &data); + if (check_ortstatus(status) || !data) { + _ODP_ERR("GetTensorMutableData() failed\n"); + return -1; + } + + for (uint64_t i = 0; i < input_size; ) { + if (*seg_idx >= num_seg) { + _ODP_ERR("Insufficient input data\n"); + return -1; + } + + uint64_t seg_size = input_seg[*seg_idx].size; + + if (i + seg_size > input_size) { + _ODP_ERR("Excess input data in segment %" PRIu32 "\n", *seg_idx); + return -1; + } + + memcpy((uint8_t *)data + i, input_seg[(*seg_idx)++].addr, seg_size); + i += seg_size; + } + + if (!ODP_DEBUG) + return 0; + + status = ort_api->IsTensor(input_tensor[0], &is_tensor); + if (check_ortstatus(status) || !is_tensor) { + _ODP_ERR("input_tensor IsTensor failed\n"); + return -1; + } + + /* Make sure tensor shape matches input_shape */ + if (verify_tensor(input_tensor[0], input_info->data_type, + &input_info->shape, batch_size)) { + _ODP_ERR("Verify input_tensor failed\n"); + return -1; + } + + return 0; +} + +static int verify_output_tensor(OrtValue *output_tensor, odp_ml_data_type_t expected_type, + const odp_ml_shape_info_t *expected_shape, uint32_t batch_size) +{ + int is_tensor = 0; + const OrtApi *ort_api = _odp_ml_glb->ort_api; + OrtStatus *status = ort_api->IsTensor(output_tensor, &is_tensor); + + if (check_ortstatus(status) || !is_tensor) { + _ODP_ERR("output_tensor IsTensor failed\n"); + return -1; + } + + /* Make sure tensor shape matches output_shape */ + if (verify_tensor(output_tensor, expected_type, expected_shape, batch_size)) { + _ODP_ERR("Verify output_tensor failed\n"); + return -1; + } + + return 0; +} + +static int get_tensor_data_size(OrtValue *tensor, uint32_t *size, uint32_t data_type_size) +{ + size_t num_elem; + OrtStatus *status; + OrtTensorTypeAndShapeInfo *tensor_info; + const OrtApi *ort_api = _odp_ml_glb->ort_api; + + status = ort_api->GetTensorTypeAndShape(tensor, &tensor_info); + if (check_ortstatus(status)) { + _ODP_ERR("GetTensorTypeAndShape() failed\n"); + return -1; + } + + status = ort_api->GetTensorShapeElementCount(tensor_info, &num_elem); + if (check_ortstatus(status)) { + ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + _ODP_ERR("GetTensorShapeElementCount() failed\n"); + return -1; + } + *size = data_type_size * num_elem; + + ort_api->ReleaseTensorTypeAndShapeInfo(tensor_info); + return 0; +} + +static int check_output_size(odp_bool_t is_segmented, uint32_t output_idx, uint32_t seg_idx, + uint64_t out_tensor_data_size, const odp_ml_data_t data[]) +{ + uint64_t output_size = 0; + + /* Output is not segmented */ + if (!is_segmented) { + /* Make sure tensor data size does not exceed size allocated for + * data->output_seg[seg_idx].addr */ + if (out_tensor_data_size > data->output_seg[seg_idx].size) { + _ODP_ERR("Malloc at least %" PRIu64 " bytes for %dth output tensor\n", + out_tensor_data_size, output_idx); + return -1; + } + + return 0; + } + + /* Output is segmented, first calculate total size for one tensor */ + for (; seg_idx < data->num_output_seg; seg_idx++) { + output_size += data->output_seg[seg_idx].size; + if (output_size >= out_tensor_data_size) + break; + } + + if (0 == output_size) { + _ODP_ERR("No output data segments for %uth output tensor\n", output_idx); + return -1; + } + + if (out_tensor_data_size > output_size) { + _ODP_ERR("Output segments (%" PRIu64 " bytes in total) for %uth output" + " is expected to be at least %" PRIu64 " bytes\n", + output_size, output_idx, out_tensor_data_size); + return -1; + } + + return 0; +} + +static int output_tensors_to_data(OrtValue **output_tensors, + uint32_t model_num_outputs, + const odp_ml_run_param_t *param, + const odp_ml_output_info_t *output_info, + const odp_ml_data_t *data, + odp_ml_run_result_t *result_local) +{ + uint32_t seg_idx; + uint64_t seg_size; + uint64_t cpy_size; + uint64_t left_size; + uint64_t output_val_offset; + uint32_t out_tensor_data_size; + void *output_val = NULL; /* Pointer to store one raw output value */ + OrtStatus *status = NULL; + uint32_t batch_size = (param && param->batch_size) ? param->batch_size : 0; + const OrtApi *ort_api = _odp_ml_glb->ort_api; + odp_bool_t is_segmented = (data->num_output_seg != model_num_outputs); + + seg_idx = 0; + for (uint32_t i = 0; i < model_num_outputs; i++) { + if (ODP_DEBUG && + verify_output_tensor(output_tensors[i], output_info[i].data_type, + &output_info[i].shape, batch_size)){ + result_local->error_code = ML_BAD_OUTPUT; + return -1; + } + + /* Get tensor data size */ + if (get_tensor_data_size(output_tensors[i], &out_tensor_data_size, + output_info[i].data_type_size)) { + result_local->error_code = ML_LIB_FAILED; + return -1; + } + + /* When output_tensor is an empty tensor [], skip getting data */ + if (out_tensor_data_size == 0) + continue; + + if (ODP_DEBUG && check_output_size(is_segmented, i, seg_idx, + out_tensor_data_size, data)) { + result_local->error_code = ML_BAD_OUTPUT; + return -1; + } + + /* Following assumes param and data->output_seg are valid */ + /* Get tensor data */ + output_val = NULL; + status = ort_api->GetTensorMutableData(output_tensors[i], &output_val); + if (check_ortstatus(status) || !output_val) { + result_local->error_code = ML_LIB_FAILED; + return -1; + } + + /* Output is not segmented */ + if (!is_segmented) { + /* Store output data to data->output_seg[i].addr */ + memcpy(data->output_seg[i].addr, output_val, out_tensor_data_size); + seg_idx++; + continue; + } + + /* Output is segmented */ + output_val_offset = 0; + left_size = out_tensor_data_size; + for (; seg_idx < data->num_output_seg; seg_idx++) { + seg_size = data->output_seg[seg_idx].size; + cpy_size = left_size > seg_size ? seg_size : left_size; + memcpy(data->output_seg[seg_idx].addr, + ((char *)output_val) + output_val_offset, cpy_size); + + output_val_offset += cpy_size; + left_size = out_tensor_data_size - output_val_offset; + + if (!left_size) { + seg_idx++; + break; + } + } + } + + return 0; +} + +int odp_ml_run(odp_ml_model_t model, const odp_ml_data_t *data, const odp_ml_run_param_t *param) +{ + odp_ml_run_result_t result_local; + + int retval = -1; /* Return value of this function */ + int ret = 0; + OrtStatus *status = NULL; + uint32_t batch_size = 0; + + OrtValue *input_tensor[CONFIG_ML_MAX_INPUTS] = {0}; + OrtValue *output_tensors[CONFIG_ML_MAX_OUTPUTS] = {0}; + const char *input_names[CONFIG_ML_MAX_INPUTS] = {0}; + const char *output_names[CONFIG_ML_MAX_OUTPUTS] = {0}; + + const OrtApi *ort_api = _odp_ml_glb->ort_api; + ml_model_t *mdl = ml_model_from_handle(model); + const odp_ml_model_info_t *ml_info = &mdl->info; + const odp_ml_input_info_t *input_info = mdl->input_info; + const odp_ml_output_info_t *output_info = mdl->output_info; + OrtSession *session = mdl->session; + + odp_ticketlock_lock(&mdl->lock); + if (odp_unlikely(mdl->state == ML_STATE_INFERENCING)) { + odp_ticketlock_unlock(&mdl->lock); + return 0; + } + if (odp_unlikely(mdl->state != ML_STATE_LOADED)) { + _ODP_ERR("Wrong model state: not created or not loaded\n"); + odp_ticketlock_unlock(&mdl->lock); + return -1; + } + mdl->state = ML_STATE_INFERENCING; + odp_ticketlock_unlock(&mdl->lock); + + memset(&result_local, 0, sizeof(result_local)); + + if (ODP_DEBUG && verify_run_params(model, data, param)) { + result_local.error_code = ML_BAD_INPUT; + goto init_fail; + } + + if (param && param->batch_size) + batch_size = param->batch_size; + + uint32_t seg_idx = 0; + + /* Transfer input data to tensor */ + for (uint32_t i = 0; i < ml_info->num_inputs; i++) { + ret = input_data_to_tensor(&input_info[i], + data->num_input_seg, + data->input_seg, + &seg_idx, + batch_size, + &input_tensor[i]); + if (ret) { + _ODP_ERR("%uth input data to tensor failed\n", i); + result_local.error_code = ML_LIB_FAILED; + goto release_input_tensors; + } + + _ODP_DBG("input_tensor[%u]: %p\n", i, input_tensor[i]); + + /* Model input names */ + input_names[i] = input_info[i].name; + } + + if (seg_idx < data->num_input_seg) { + _ODP_ERR("Excess input segments\n"); + ret = -1; + } + + for (uint32_t i = 0; i < ml_info->num_outputs; i++) + output_names[i] = output_info[i].name; + + /* Run inference */ + status = ort_api->Run(session, + NULL, + (const char * const *)input_names, + (const OrtValue * const*)input_tensor, + ml_info->num_inputs, + (const char * const *)output_names, + ml_info->num_outputs, + output_tensors); + + if (check_ortstatus(status)) { + _ODP_ERR("Run inference failed\n"); + result_local.error_code = ML_LIB_FAILED; + goto release_all_tensors; + } + + /* Verify output tensors and store them to output */ + if (output_tensors_to_data(output_tensors, ml_info->num_outputs, param, + output_info, data, &result_local)) { + _ODP_ERR("Output tensors to data failed\n"); + goto release_all_tensors; + } + + retval = 1; + +release_all_tensors: + for (uint32_t i = 0; i < ml_info->num_outputs; i++) + ort_api->ReleaseValue(output_tensors[i]); + +release_input_tensors: + for (uint32_t i = 0; i < ml_info->num_inputs; i++) + ort_api->ReleaseValue(input_tensor[i]); + +init_fail: + if (param && param->result) + *param->result = result_local; + + odp_ticketlock_lock(&mdl->lock); + mdl->state = ML_STATE_LOADED; + odp_ticketlock_unlock(&mdl->lock); + + return retval; +} + +int odp_ml_run_multi(odp_ml_model_t model, const odp_ml_data_t data[], + const odp_ml_run_param_t param[], int num) +{ + int i; + int ret; + + if (odp_unlikely(num < 1)) { + _ODP_ERR("Bad number of runs\n"); + return -1; + } + + for (i = 0; i < num; i++) { + if (param) + ret = odp_ml_run(model, &data[i], ¶m[i]); + else + ret = odp_ml_run(model, &data[i], NULL); + + if (odp_unlikely(ret != 1)) + break; + } + + if (odp_unlikely(i == 0)) + return ret; + + return i; +} + +int odp_ml_run_start(odp_ml_model_t model, const odp_ml_data_t *data, + const odp_ml_compl_param_t *compl_param, + const odp_ml_run_param_t *run_param) +{ + int ret; + ml_model_t *mdl = ml_model_from_handle(model); + + if (odp_unlikely(model == ODP_ML_MODEL_INVALID)) { + _ODP_ERR("Bad model handle\n"); + return -1; + } + + if (odp_unlikely(!compl_param)) { + _ODP_ERR("Completion parameter is NULL\n"); + return -1; + } + + /* Check completion mode */ + if (odp_unlikely(check_compl_param(compl_param, mdl->max_compl_id, false))) { + _ODP_ERR("Bad ML job completion parameter\n"); + return -1; + } + + if (compl_param->mode == ODP_ML_COMPL_MODE_POLL) + odp_atomic_store_rel_u32(&mdl->compl_status[compl_param->compl_id], 0); + + ret = odp_ml_run(model, data, run_param); + + if (odp_unlikely(ret < 1)) + return ret; + + /* Send a completion event to the given queue */ + if (compl_param->mode == ODP_ML_COMPL_MODE_EVENT) { + odp_ml_run_result_t *result; + odp_buffer_t buf = (odp_buffer_t)(uintptr_t)compl_param->event; + + _odp_buffer_subtype_set(buf, ODP_EVENT_ML_COMPL_RUN); + + result = odp_buffer_addr(buf); + result->error_code = 0; + result->user_ptr = compl_param->user_ptr; + + if (odp_unlikely(odp_queue_enq(compl_param->queue, compl_param->event))) { + _ODP_ERR("Completion event enqueue failed %" PRIu64 "\n", + odp_queue_to_u64(compl_param->queue)); + return -1; + } + + return 1; + } + + /* compl_param->mode == ODP_ML_COMPL_MODE_POLL */ + mdl->result[compl_param->compl_id].user_ptr = compl_param->user_ptr; + odp_atomic_store_rel_u32(&mdl->compl_status[compl_param->compl_id], 1); + + return 1; +} + +int odp_ml_run_start_multi(odp_ml_model_t model, const odp_ml_data_t data[], + const odp_ml_compl_param_t compl_param[], + const odp_ml_run_param_t run_param[], int num) +{ + int i; + int ret = 0; + + if (odp_unlikely(num < 1)) { + _ODP_ERR("Bad number of runs\n"); + return -1; + } + + for (i = 0; i < num; i++) { + if (run_param) + ret = odp_ml_run_start(model, &data[i], &compl_param[i], &run_param[i]); + else + ret = odp_ml_run_start(model, &data[i], &compl_param[i], NULL); + + if (odp_unlikely(ret != 1)) + break; + } + + if (odp_unlikely(i == 0)) + return ret; + + return i; +} + +int odp_ml_run_status(odp_ml_model_t model, uint32_t compl_id, odp_ml_run_result_t *result) +{ + int ret; + ml_model_t *mdl = ml_model_from_handle(model); + + if (odp_unlikely(model == ODP_ML_MODEL_INVALID || + compl_id > mdl->max_compl_id)) { + _ODP_ERR("Invalid model handle or completion id: %u\n", compl_id); + return -2; + } + + ret = odp_atomic_load_acq_u32(&mdl->compl_status[compl_id]); + + if (result) { + result->error_code = 0; + result->user_ptr = mdl->result[compl_id].user_ptr; + } + + return ret; +} + +static int opt_level_from_str(const char *level_str, GraphOptimizationLevel *level) +{ + if (strcmp(level_str, "DISABLE_ALL") == 0) + *level = ORT_DISABLE_ALL; + else if (strcmp(level_str, "ENABLE_BASIC") == 0) + *level = ORT_ENABLE_BASIC; + else if (strcmp(level_str, "ENABLE_EXTENDED") == 0) + *level = ORT_ENABLE_EXTENDED; + else if (strcmp(level_str, "ENABLE_ALL") == 0) + *level = ORT_ENABLE_ALL; + else + return -1; + + return 0; +} + +static int execution_mode_from_str(const char *mode_str, ExecutionMode *mode) +{ + if (strcmp(mode_str, "SEQUENTIAL") == 0) + *mode = ORT_SEQUENTIAL; + else if (strcmp(mode_str, "PARALLEL") == 0) + *mode = ORT_PARALLEL; + else + return -1; + + return 0; +} + +static int read_config_file(ort_run_opts_t *opts) +{ + const char *conf_str; + char mode_str[ML_MAX_CONFIG_STR_LEN]; + char opt_level_str[ML_MAX_CONFIG_STR_LEN]; + + _ODP_PRINT("ML config:\n"); + + conf_str = "ml.enable_profiling"; + if (!_odp_libconfig_lookup_int(conf_str, &opts->enable_profiling)) { + _ODP_ERR("Config option '%s' not found.\n", conf_str); + return -1; + } + _ODP_PRINT(" %s: %i\n", conf_str, opts->enable_profiling); + + conf_str = "ml.execution_mode"; + if (_odp_libconfig_lookup_str(conf_str, mode_str, ML_MAX_CONFIG_STR_LEN) < 0) { + _ODP_ERR("Config option '%s' not found.\n", conf_str); + return -1; + } + + if (execution_mode_from_str(mode_str, &opts->execution_mode)) { + _ODP_ERR("Unsupported execution mode: %s\n", mode_str); + return -1; + } + _ODP_PRINT(" %s: %s\n", conf_str, mode_str); + + conf_str = "ml.inter_op_num_threads"; + if (!_odp_libconfig_lookup_int(conf_str, &opts->inter_op_num_threads)) { + _ODP_ERR("Config option '%s' not found.\n", conf_str); + return -1; + } + _ODP_PRINT(" %s: %i\n", conf_str, opts->inter_op_num_threads); + + conf_str = "ml.intra_op_num_threads"; + if (!_odp_libconfig_lookup_int(conf_str, &opts->intra_op_num_threads)) { + _ODP_ERR("Config option '%s' not found.\n", conf_str); + return -1; + } + _ODP_PRINT(" %s: %i\n", conf_str, opts->intra_op_num_threads); + + conf_str = "ml.graph_optimization_level"; + if (_odp_libconfig_lookup_str(conf_str, opt_level_str, + ML_MAX_CONFIG_STR_LEN) < 0) { + _ODP_ERR("Config option '%s' not found.\n", conf_str); + return -1; + } + + if (opt_level_from_str(opt_level_str, &opts->graph_opt_level)) { + _ODP_ERR("Graph optimize level %s not supported\n", opt_level_str); + return -1; + } + _ODP_PRINT(" %s: %s\n", conf_str, opt_level_str); + + conf_str = "ml.optimized_model_filepath"; + if (_odp_libconfig_lookup_str(conf_str, opts->opt_model_filepath, + ML_MAX_CONFIG_STR_LEN) < 0) { + _ODP_ERR("Config option '%s' not found.\n", conf_str); + return -1; + } + _ODP_PRINT(" %s: %s\n", conf_str, opts->opt_model_filepath); + + return 0; +} + +int _odp_ml_init_global(void) +{ + int i; + OrtEnv *env; + odp_shm_t shm; + OrtStatus *status; + const OrtApi *ort_api; + + if (odp_global_ro.disable.ml) { + _ODP_ERR("ML is disabled\n"); + return 0; + } + + shm = odp_shm_reserve("_odp_ml_global", sizeof(ml_global_t), ODP_CACHE_LINE_SIZE, 0); + _odp_ml_glb = odp_shm_addr(shm); + + if (_odp_ml_glb == NULL) { + _ODP_ERR("SHM reserve failed for odp_ml\n"); + return -1; + } + + memset(_odp_ml_glb, 0, sizeof(ml_global_t)); + _odp_ml_glb->shm = shm; + + if (odp_ml_capability(&_odp_ml_glb->capa)) { + _ODP_ERR("ML capability failed\n"); + return -1; + } + + odp_pool_param_init(&_odp_ml_glb->pool_param); + + if (read_config_file(&_odp_ml_glb->ort_run_opts)) + return -1; + + ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + if (!ort_api) { + _ODP_ERR("Failed to init ONNX Runtime engine.\n"); + return -1; + } + _odp_ml_glb->ort_api = ort_api; + + status = ort_api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "Default", &env); + if (check_ortstatus(status) || !env) { + _ODP_ERR("ort_api->CreateEnv() failed.\n"); + return -1; + } + _odp_ml_glb->env = env; + + for (i = 0; i < ML_MAX_MODELS_CREATED; i++) + odp_ticketlock_init(&_odp_ml_glb->models[i].lock); + + return 0; +} + +int _odp_ml_term_global(void) +{ + if (odp_global_ro.disable.ml) + return 0; + + if (_odp_ml_glb == NULL) + return 0; + + if (_odp_ml_glb->env) + _odp_ml_glb->ort_api->ReleaseEnv(_odp_ml_glb->env); + + if (odp_shm_free(_odp_ml_glb->shm)) { + _ODP_ERR("Shm free failed for odp_ml\n"); + return -1; + } + + return 0; +} diff --git a/platform/linux-generic/odp_ml_fp16.c b/platform/linux-generic/odp_ml_fp16.c new file mode 100644 index 000000000..f135f8b5a --- /dev/null +++ b/platform/linux-generic/odp_ml_fp16.c @@ -0,0 +1,425 @@ +/* SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2022-2023 Marvell. + * Copyright (c) 2023 Nokia + * + * Based on + * - dpdk/lib/mldev/mldev_utils_scalar.h + * - dpdk/lib/mldev/mldev_utils_scalar.c + * - dpdk/lib/mldev/mldev_utils_scalar_bfloat16.c + */ + +#include <errno.h> +#include <stdint.h> + +#include <odp_ml_fp16.h> + +#ifndef BIT +#define BIT(nr) (1UL << (nr)) +#endif + +#ifndef BITS_PER_LONG +#define BITS_PER_LONG (__SIZEOF_LONG__ * 8) +#endif + +#ifndef GENMASK_U32 +#define GENMASK_U32(h, l) (((~0UL) << (l)) & (~0UL >> (BITS_PER_LONG - 1 - (h)))) +#endif + +/* float32: bit index of MSB & LSB of sign, exponent and mantissa */ +#define FP32_LSB_M 0 +#define FP32_MSB_M 22 +#define FP32_LSB_E 23 +#define FP32_MSB_E 30 +#define FP32_LSB_S 31 +#define FP32_MSB_S 31 + +/* float32: bitmask for sign, exponent and mantissa */ +#define FP32_MASK_S GENMASK_U32(FP32_MSB_S, FP32_LSB_S) +#define FP32_MASK_E GENMASK_U32(FP32_MSB_E, FP32_LSB_E) +#define FP32_MASK_M GENMASK_U32(FP32_MSB_M, FP32_LSB_M) + +/* float16: bit index of MSB & LSB of sign, exponent and mantissa */ +#define FP16_LSB_M 0 +#define FP16_MSB_M 9 +#define FP16_LSB_E 10 +#define FP16_MSB_E 14 +#define FP16_LSB_S 15 +#define FP16_MSB_S 15 + +/* float16: bitmask for sign, exponent and mantissa */ +#define FP16_MASK_S GENMASK_U32(FP16_MSB_S, FP16_LSB_S) +#define FP16_MASK_E GENMASK_U32(FP16_MSB_E, FP16_LSB_E) +#define FP16_MASK_M GENMASK_U32(FP16_MSB_M, FP16_LSB_M) + +/* bfloat16: bit index of MSB & LSB of sign, exponent and mantissa */ +#define BF16_LSB_M 0 +#define BF16_MSB_M 6 +#define BF16_LSB_E 7 +#define BF16_MSB_E 14 +#define BF16_LSB_S 15 +#define BF16_MSB_S 15 + +/* bfloat16: bitmask for sign, exponent and mantissa */ +#define BF16_MASK_S GENMASK_U32(BF16_MSB_S, BF16_LSB_S) +#define BF16_MASK_E GENMASK_U32(BF16_MSB_E, BF16_LSB_E) +#define BF16_MASK_M GENMASK_U32(BF16_MSB_M, BF16_LSB_M) + +/* Exponent bias */ +#define FP32_BIAS_E 127 +#define FP16_BIAS_E 15 +#define BF16_BIAS_E 127 + +#define FP32_PACK(sign, exponent, mantissa) \ + (((sign) << FP32_LSB_S) | ((exponent) << FP32_LSB_E) | (mantissa)) + +#define FP16_PACK(sign, exponent, mantissa) \ + (((sign) << FP16_LSB_S) | ((exponent) << FP16_LSB_E) | (mantissa)) + +#define BF16_PACK(sign, exponent, mantissa) \ + (((sign) << BF16_LSB_S) | ((exponent) << BF16_LSB_E) | (mantissa)) + +/* Represent float32 as float and uint32_t */ +union float32 { + float f; + uint32_t u; +}; + +/* Convert a single precision floating point number (float32) into a half precision + * floating point number (float16) using round to nearest rounding mode. + */ +static uint16_t +__float32_to_float16_scalar_rtn(float x) +{ + union float32 f32; /* float32 input */ + uint32_t f32_s; /* float32 sign */ + uint32_t f32_e; /* float32 exponent */ + uint32_t f32_m; /* float32 mantissa */ + uint16_t f16_s; /* float16 sign */ + uint16_t f16_e; /* float16 exponent */ + uint16_t f16_m; /* float16 mantissa */ + uint32_t tbits; /* number of truncated bits */ + uint32_t tmsb; /* MSB position of truncated bits */ + uint32_t m_32; /* temporary float32 mantissa */ + uint16_t m_16; /* temporary float16 mantissa */ + uint16_t u16; /* float16 output */ + int be_16; /* float16 biased exponent, signed */ + + f32.f = x; + f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S; + f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E; + f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M; + + f16_s = f32_s; + f16_e = 0; + f16_m = 0; + + switch (f32_e) { + case (0): /* float32: zero or subnormal number */ + f16_e = 0; + f16_m = 0; /* convert to zero */ + break; + case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */ + f16_e = FP16_MASK_E >> FP16_LSB_E; + if (f32_m == 0) { /* infinity */ + f16_m = 0; + } else { /* nan, propagate mantissa and set MSB of mantissa to 1 */ + f16_m = f32_m >> (FP32_MSB_M - FP16_MSB_M); + f16_m |= BIT(FP16_MSB_M); + } + break; + default: /* float32: normal number */ + /* compute biased exponent for float16 */ + be_16 = (int)f32_e - FP32_BIAS_E + FP16_BIAS_E; + + /* overflow, be_16 = [31-INF], set to infinity */ + if (be_16 >= (int)(FP16_MASK_E >> FP16_LSB_E)) { + f16_e = FP16_MASK_E >> FP16_LSB_E; + f16_m = 0; + } else if ((be_16 >= 1) && (be_16 < (int)(FP16_MASK_E >> FP16_LSB_E))) { + /* normal float16, be_16 = [1:30]*/ + f16_e = be_16; + m_16 = f32_m >> (FP32_LSB_E - FP16_LSB_E); + tmsb = FP32_MSB_M - FP16_MSB_M - 1; + if ((f32_m & GENMASK_U32(tmsb, 0)) > BIT(tmsb)) { + /* round: non-zero truncated bits except MSB */ + m_16++; + + /* overflow into exponent */ + if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1) + f16_e++; + } else if ((f32_m & GENMASK_U32(tmsb, 0)) == BIT(tmsb)) { + /* round: MSB of truncated bits and LSB of m_16 is set */ + if ((m_16 & 0x1) == 0x1) { + m_16++; + + /* overflow into exponent */ + if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1) + f16_e++; + } + } + f16_m = m_16 & FP16_MASK_M; + } else if ((be_16 >= -(int)(FP16_MSB_M)) && (be_16 < 1)) { + /* underflow: zero / subnormal, be_16 = [-9:0] */ + f16_e = 0; + + /* add implicit leading zero */ + m_32 = f32_m | BIT(FP32_LSB_E); + tbits = FP32_LSB_E - FP16_LSB_E - be_16 + 1; + m_16 = m_32 >> tbits; + + /* if non-leading truncated bits are set */ + if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) { + m_16++; + + /* overflow into exponent */ + if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1) + f16_e++; + } else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) { + /* if leading truncated bit is set */ + if ((m_16 & 0x1) == 0x1) { + m_16++; + + /* overflow into exponent */ + if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1) + f16_e++; + } + } + f16_m = m_16 & FP16_MASK_M; + } else if (be_16 == -(int)(FP16_MSB_M + 1)) { + /* underflow: zero, be_16 = [-10] */ + f16_e = 0; + if (f32_m != 0) + f16_m = 1; + else + f16_m = 0; + } else { + /* underflow: zero, be_16 = [-INF:-11] */ + f16_e = 0; + f16_m = 0; + } + + break; + } + + u16 = FP16_PACK(f16_s, f16_e, f16_m); + + return u16; +} + +/* Convert a half precision floating point number (float16) into a single precision + * floating point number (float32). + */ +static float +__float16_to_float32_scalar_rtx(uint16_t f16) +{ + union float32 f32; /* float32 output */ + uint16_t f16_s; /* float16 sign */ + uint16_t f16_e; /* float16 exponent */ + uint16_t f16_m; /* float16 mantissa */ + uint32_t f32_s; /* float32 sign */ + uint32_t f32_e; /* float32 exponent */ + uint32_t f32_m; /* float32 mantissa*/ + uint8_t shift; /* number of bits to be shifted */ + uint32_t clz; /* count of leading zeroes */ + int e_16; /* float16 exponent unbiased */ + + f16_s = (f16 & FP16_MASK_S) >> FP16_LSB_S; + f16_e = (f16 & FP16_MASK_E) >> FP16_LSB_E; + f16_m = (f16 & FP16_MASK_M) >> FP16_LSB_M; + + f32_s = f16_s; + switch (f16_e) { + case (FP16_MASK_E >> FP16_LSB_E): /* float16: infinity or nan */ + f32_e = FP32_MASK_E >> FP32_LSB_E; + if (f16_m == 0x0) { /* infinity */ + f32_m = f16_m; + } else { /* nan, propagate mantissa, set MSB of mantissa to 1 */ + f32_m = f16_m; + shift = FP32_MSB_M - FP16_MSB_M; + f32_m = (f32_m << shift) & FP32_MASK_M; + f32_m |= BIT(FP32_MSB_M); + } + break; + case 0: /* float16: zero or sub-normal */ + f32_m = f16_m; + if (f16_m == 0) { /* zero signed */ + f32_e = 0; + } else { /* subnormal numbers */ + clz = __builtin_clz((uint32_t)f16_m) - sizeof(uint32_t) * 8 + FP16_LSB_E; + e_16 = (int)f16_e - clz; + f32_e = FP32_BIAS_E + e_16 - FP16_BIAS_E; + + shift = clz + (FP32_MSB_M - FP16_MSB_M) + 1; + f32_m = (f32_m << shift) & FP32_MASK_M; + } + break; + default: /* normal numbers */ + f32_m = f16_m; + e_16 = (int)f16_e; + f32_e = FP32_BIAS_E + e_16 - FP16_BIAS_E; + + shift = (FP32_MSB_M - FP16_MSB_M); + f32_m = (f32_m << shift) & FP32_MASK_M; + } + + f32.u = FP32_PACK(f32_s, f32_e, f32_m); + + return f32.f; +} + +/* Convert a single precision floating point number (float32) into a + * brain float number (bfloat16) using round to nearest rounding mode. + */ +static uint16_t +__float32_to_bfloat16_scalar_rtn(float x) +{ + union float32 f32; /* float32 input */ + uint32_t f32_s; /* float32 sign */ + uint32_t f32_e; /* float32 exponent */ + uint32_t f32_m; /* float32 mantissa */ + uint16_t b16_s; /* float16 sign */ + uint16_t b16_e; /* float16 exponent */ + uint16_t b16_m; /* float16 mantissa */ + uint32_t tbits; /* number of truncated bits */ + uint16_t u16; /* float16 output */ + + f32.f = x; + f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S; + f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E; + f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M; + + b16_s = f32_s; + b16_e = 0; + b16_m = 0; + + switch (f32_e) { + case (0): /* float32: zero or subnormal number */ + b16_e = 0; + if (f32_m == 0) /* zero */ + b16_m = 0; + else /* subnormal float32 number, normal bfloat16 */ + goto bf16_normal; + break; + case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */ + b16_e = BF16_MASK_E >> BF16_LSB_E; + if (f32_m == 0) { /* infinity */ + b16_m = 0; + } else { /* nan, propagate mantissa and set MSB of mantissa to 1 */ + b16_m = f32_m >> (FP32_MSB_M - BF16_MSB_M); + b16_m |= BIT(BF16_MSB_M); + } + break; + default: /* float32: normal number, normal bfloat16 */ + goto bf16_normal; + } + + goto bf16_pack; + +bf16_normal: + b16_e = f32_e; + tbits = FP32_MSB_M - BF16_MSB_M; + b16_m = f32_m >> tbits; + + /* if non-leading truncated bits are set */ + if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) { + b16_m++; + + /* if overflow into exponent */ + if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1) + b16_e++; + } else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) { + /* if only leading truncated bit is set */ + if ((b16_m & 0x1) == 0x1) { + b16_m++; + + /* if overflow into exponent */ + if (((b16_m & BF16_MASK_E) >> BF16_LSB_E) == 0x1) + b16_e++; + } + } + b16_m = b16_m & BF16_MASK_M; + +bf16_pack: + u16 = BF16_PACK(b16_s, b16_e, b16_m); + + return u16; +} + +/* Convert a brain float number (bfloat16) into a + * single precision floating point number (float32). + */ +static float +__bfloat16_to_float32_scalar_rtx(uint16_t f16) +{ + union float32 f32; /* float32 output */ + uint16_t b16_s; /* float16 sign */ + uint16_t b16_e; /* float16 exponent */ + uint16_t b16_m; /* float16 mantissa */ + uint32_t f32_s; /* float32 sign */ + uint32_t f32_e; /* float32 exponent */ + uint32_t f32_m; /* float32 mantissa*/ + uint8_t shift; /* number of bits to be shifted */ + + b16_s = (f16 & BF16_MASK_S) >> BF16_LSB_S; + b16_e = (f16 & BF16_MASK_E) >> BF16_LSB_E; + b16_m = (f16 & BF16_MASK_M) >> BF16_LSB_M; + + f32_s = b16_s; + switch (b16_e) { + case (BF16_MASK_E >> BF16_LSB_E): /* bfloat16: infinity or nan */ + f32_e = FP32_MASK_E >> FP32_LSB_E; + if (b16_m == 0x0) { /* infinity */ + f32_m = 0; + } else { /* nan, propagate mantissa, set MSB of mantissa to 1 */ + f32_m = b16_m; + shift = FP32_MSB_M - BF16_MSB_M; + f32_m = (f32_m << shift) & FP32_MASK_M; + f32_m |= BIT(FP32_MSB_M); + } + break; + case 0: /* bfloat16: zero or subnormal */ + f32_m = b16_m; + if (b16_m == 0) { /* zero signed */ + f32_e = 0; + } else { /* subnormal numbers */ + goto fp32_normal; + } + break; + default: /* bfloat16: normal number */ + goto fp32_normal; + } + + goto fp32_pack; + +fp32_normal: + f32_m = b16_m; + f32_e = FP32_BIAS_E + b16_e - BF16_BIAS_E; + + shift = (FP32_MSB_M - BF16_MSB_M); + f32_m = (f32_m << shift) & FP32_MASK_M; + +fp32_pack: + f32.u = FP32_PACK(f32_s, f32_e, f32_m); + + return f32.f; +} + +uint16_t _odp_float32_to_float16(float x) +{ + return __float32_to_float16_scalar_rtn(x); +} + +float _odp_float16_to_float32(uint16_t f16) +{ + return __float16_to_float32_scalar_rtx(f16); +} + +uint16_t _odp_float32_to_bfloat16(float x) +{ + return __float32_to_bfloat16_scalar_rtn(x); +} + +float _odp_bfloat16_to_float32(uint16_t f16) +{ + return __bfloat16_to_float32_scalar_rtx(f16); +} diff --git a/platform/linux-generic/odp_ml_null.c b/platform/linux-generic/odp_ml_null.c new file mode 100644 index 000000000..faf431997 --- /dev/null +++ b/platform/linux-generic/odp_ml_null.c @@ -0,0 +1,229 @@ +/* SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2023 Nokia + */ + +#include <odp/api/ml.h> +#include <odp_init_internal.h> + +#include <string.h> + +/* Dummy ML API implementation, no capability and just return error for + * other functions. + */ +int _odp_ml_init_global(void) +{ + return 0; +} + +int _odp_ml_term_global(void) +{ + return 0; +} + +int odp_ml_capability(odp_ml_capability_t *capa) +{ + memset(capa, 0, sizeof(odp_ml_capability_t)); + return 0; +} + +void odp_ml_config_init(odp_ml_config_t *config ODP_UNUSED) +{ +} + +int odp_ml_config(const odp_ml_config_t *config ODP_UNUSED) +{ + return -1; +} + +void odp_ml_model_param_init(odp_ml_model_param_t *param ODP_UNUSED) +{ +} + +odp_ml_model_t odp_ml_model_create(const char *name ODP_UNUSED, + const odp_ml_model_param_t *param ODP_UNUSED) +{ + return ODP_ML_MODEL_INVALID; +} + +int odp_ml_model_destroy(odp_ml_model_t model ODP_UNUSED) +{ + return -1; +} + +int odp_ml_model_info(odp_ml_model_t model ODP_UNUSED, odp_ml_model_info_t *info ODP_UNUSED) +{ + return -1; +} + +uint32_t odp_ml_model_input_info(odp_ml_model_t model ODP_UNUSED, + odp_ml_input_info_t info[] ODP_UNUSED, + uint32_t num ODP_UNUSED) +{ + return 0; +} + +uint32_t odp_ml_model_output_info(odp_ml_model_t model ODP_UNUSED, + odp_ml_output_info_t info[] ODP_UNUSED, + uint32_t num ODP_UNUSED) +{ + return 0; +} + +odp_ml_model_t odp_ml_model_lookup(const char *name ODP_UNUSED) +{ + return ODP_ML_MODEL_INVALID; +} + +uint64_t odp_ml_model_to_u64(odp_ml_model_t model ODP_UNUSED) +{ + return 0; +} + +void odp_ml_model_print(odp_ml_model_t model ODP_UNUSED) +{ +} + +void odp_ml_print(void) +{ +} + +void odp_ml_compl_pool_param_init(odp_ml_compl_pool_param_t *pool_param) +{ + memset(pool_param, 0, sizeof(odp_ml_compl_pool_param_t)); +} + +odp_pool_t odp_ml_compl_pool_create(const char *name ODP_UNUSED, + const odp_ml_compl_pool_param_t *pool_param ODP_UNUSED) +{ + return ODP_POOL_INVALID; +} + +odp_ml_compl_t odp_ml_compl_alloc(odp_pool_t pool ODP_UNUSED) +{ + return ODP_ML_COMPL_INVALID; +} + +void odp_ml_compl_free(odp_ml_compl_t ml_compl ODP_UNUSED) +{ +} + +int odp_ml_compl_run_result(odp_ml_compl_t ml_compl ODP_UNUSED, + odp_ml_run_result_t *result ODP_UNUSED) +{ + return -1; +} + +int odp_ml_compl_load_result(odp_ml_compl_t ml_compl ODP_UNUSED, + odp_ml_load_result_t *result ODP_UNUSED) +{ + return -1; +} + +void *odp_ml_compl_user_area(odp_ml_compl_t ml_compl ODP_UNUSED) +{ + return NULL; +} + +odp_ml_compl_t odp_ml_compl_from_event(odp_event_t event ODP_UNUSED) +{ + return ODP_ML_COMPL_INVALID; +} + +odp_event_t odp_ml_compl_to_event(odp_ml_compl_t ml_compl ODP_UNUSED) +{ + return ODP_EVENT_INVALID; +} + +uint64_t odp_ml_compl_to_u64(odp_ml_compl_t ml_compl ODP_UNUSED) +{ + return 0; +} + +void odp_ml_compl_param_init(odp_ml_compl_param_t *compl_param ODP_UNUSED) +{ +} + +int odp_ml_model_load(odp_ml_model_t model ODP_UNUSED, odp_ml_load_result_t *result ODP_UNUSED) +{ + return -1; +} + +int odp_ml_model_load_start(odp_ml_model_t model ODP_UNUSED, + const odp_ml_compl_param_t *compl_param ODP_UNUSED) +{ + return -1; +} + +int odp_ml_model_load_status(odp_ml_model_t model ODP_UNUSED, uint32_t compl_id ODP_UNUSED, + odp_ml_load_result_t *result ODP_UNUSED) +{ + return -1; +} + +int odp_ml_model_unload(odp_ml_model_t model ODP_UNUSED, odp_ml_load_result_t *result ODP_UNUSED) +{ + return -1; +} + +int odp_ml_model_unload_start(odp_ml_model_t model ODP_UNUSED, + const odp_ml_compl_param_t *compl_param ODP_UNUSED) +{ + return -1; +} + +int odp_ml_model_unload_status(odp_ml_model_t model ODP_UNUSED, uint32_t compl_id ODP_UNUSED, + odp_ml_load_result_t *result ODP_UNUSED) +{ + return -1; +} + +void odp_ml_run_param_init(odp_ml_run_param_t *param ODP_UNUSED) +{ +} + +int odp_ml_run(odp_ml_model_t model ODP_UNUSED, const odp_ml_data_t *data ODP_UNUSED, + const odp_ml_run_param_t *param ODP_UNUSED) +{ + return -1; +} + +int odp_ml_run_multi(odp_ml_model_t model ODP_UNUSED, const odp_ml_data_t data[] ODP_UNUSED, + const odp_ml_run_param_t param[] ODP_UNUSED, int num ODP_UNUSED) +{ + return -1; +} + +int odp_ml_run_start(odp_ml_model_t model ODP_UNUSED, const odp_ml_data_t *data ODP_UNUSED, + const odp_ml_compl_param_t *compl_param ODP_UNUSED, + const odp_ml_run_param_t *run_param ODP_UNUSED) +{ + return -1; +} + +int odp_ml_run_start_multi(odp_ml_model_t model ODP_UNUSED, + const odp_ml_data_t data[] ODP_UNUSED, + const odp_ml_compl_param_t compl_param[] ODP_UNUSED, + const odp_ml_run_param_t run_param[] ODP_UNUSED, + int num ODP_UNUSED) +{ + return -1; +} + +int odp_ml_run_status(odp_ml_model_t model ODP_UNUSED, uint32_t compl_id ODP_UNUSED, + odp_ml_run_result_t *result ODP_UNUSED) +{ + return -1; +} + +int odp_ml_model_extra_stat_info(odp_ml_model_t model ODP_UNUSED, + odp_ml_extra_stat_info_t info[] ODP_UNUSED, + int num ODP_UNUSED) +{ + return -1; +} + +int odp_ml_model_extra_stats(odp_ml_model_t model ODP_UNUSED, + uint64_t stats[] ODP_UNUSED, int num ODP_UNUSED) +{ + return -1; +} diff --git a/platform/linux-generic/odp_ml_quantize.c b/platform/linux-generic/odp_ml_quantize.c new file mode 100644 index 000000000..0678f15ef --- /dev/null +++ b/platform/linux-generic/odp_ml_quantize.c @@ -0,0 +1,79 @@ +/* SPDX-License-Identifier: BSD-3-Clause + * Copyright (c) 2023 Nokia + */ + +#include <stdint.h> +#include <math.h> + +#include <odp/api/ml_quantize.h> + +#include <odp_macros_internal.h> +#include <odp_debug_internal.h> +#include <odp_ml_fp16.h> + +void odp_ml_fp32_to_uint8(uint8_t *u8, const float *fp32, uint32_t num, float scale, + uint8_t zerop) +{ + float fval; + + _ODP_ASSERT(scale != 0); + + for (uint32_t i = 0; i < num; i++) { + /* Range mapping: map real values to signed integer */ + fval = nearbyintf(fp32[i] / scale) + (float)zerop; + + /* clip */ + fval = _ODP_MAX(fval, 0.f); + fval = _ODP_MIN(fval, 255.f); + u8[i] = (uint8_t)(int32_t)fval; + } +} + +void odp_ml_fp32_from_uint8(float *fp32, const uint8_t *u8, uint32_t num, float scale, + uint8_t zerop) +{ + for (uint32_t i = 0; i < num; i++) + fp32[i] = (float)(u8[i] - zerop) * scale; +} + +void odp_ml_fp32_to_int8(int8_t *i8, const float *fp32, uint32_t num, float scale, int8_t zerop) +{ + float fval; + + _ODP_ASSERT(scale != 0); + + for (uint32_t i = 0; i < num; i++) { + /* Range mapping: map real values to signed integer */ + fval = nearbyintf(fp32[i] / scale) + (float)zerop; + + /* NOTE: Clamps signed quantization values to [-127,127] instead of [-128,127]. + * This is to ensure that symmetric quantization results in a zero + * point of exactly 0 for signed 8 bit ints. + */ + fval = _ODP_MAX(fval, -127.f); + fval = _ODP_MIN(fval, 127.f); + i8[i] = (int8_t)(int32_t)fval; + } +} + +void odp_ml_fp32_from_int8(float *fp32, const int8_t *i8, uint32_t num, float scale, int8_t zerop) +{ + for (uint32_t i = 0; i < num; i++) + fp32[i] = (float)(i8[i] - zerop) * scale; +} + +void odp_ml_fp32_to_fp16(uint16_t *fp16, const float *fp32, uint32_t num) +{ + uint32_t i; + + for (i = 0; i < num; i++) + fp16[i] = _odp_float32_to_float16(fp32[i]); +} + +void odp_ml_fp32_from_fp16(float *fp32, const uint16_t *fp16, uint32_t num) +{ + uint32_t i; + + for (i = 0; i < num; i++) + fp32[i] = _odp_float16_to_float32(fp16[i]); +} diff --git a/platform/linux-generic/odp_pool.c b/platform/linux-generic/odp_pool.c index 1e9767821..d3fde70f6 100644 --- a/platform/linux-generic/odp_pool.c +++ b/platform/linux-generic/odp_pool.c @@ -1257,6 +1257,10 @@ int odp_pool_info(odp_pool_t pool_hdl, odp_pool_info_t *info) info->dma_pool_param.uarea_size = pool->params.buf.uarea_size; info->dma_pool_param.cache_size = pool->params.buf.cache_size; + } else if (pool->type_2 == ODP_POOL_ML_COMPL) { + info->ml_pool_param.num = pool->params.buf.num; + info->ml_pool_param.uarea_size = pool->params.buf.uarea_size; + info->ml_pool_param.cache_size = pool->params.buf.cache_size; } else { info->params = pool->params; } @@ -1559,6 +1563,8 @@ static const char *get_long_type_str(odp_pool_type_t type) return "vector"; case ODP_POOL_DMA_COMPL: return "dma completion"; + case ODP_POOL_ML_COMPL: + return "ml completion"; default: return "unknown"; } @@ -1577,6 +1583,8 @@ static const char *get_short_type_str(odp_pool_type_t type) return "V"; case ODP_POOL_DMA_COMPL: return "D"; + case ODP_POOL_ML_COMPL: + return "M"; default: return "-"; } @@ -1875,6 +1883,7 @@ int odp_pool_ext_capability(odp_pool_type_t type, odp_pool_ext_capability_t *cap case ODP_POOL_TIMEOUT: case ODP_POOL_VECTOR: case ODP_POOL_DMA_COMPL: + case ODP_POOL_ML_COMPL: memset(capa, 0, sizeof(odp_pool_ext_capability_t)); return 0; default: diff --git a/platform/linux-generic/odp_system_info.c b/platform/linux-generic/odp_system_info.c index bea77fb23..a2593b531 100644 --- a/platform/linux-generic/odp_system_info.c +++ b/platform/linux-generic/odp_system_info.c @@ -627,5 +627,8 @@ void odp_sys_config_print(void) _ODP_PRINT("CONFIG_IPSEC_MAX_NUM_SA: %i\n", CONFIG_IPSEC_MAX_NUM_SA); _ODP_PRINT("CONFIG_TIMER_128BIT_ATOMICS: %i\n", CONFIG_TIMER_128BIT_ATOMICS); _ODP_PRINT("CONFIG_TIMER_PROFILE_INLINE: %i\n", CONFIG_TIMER_PROFILE_INLINE); + _ODP_PRINT("CONFIG_ML_MAX_MODELS: %i\n", CONFIG_ML_MAX_MODELS); + _ODP_PRINT("CONFIG_ML_MAX_INPUTS: %i\n", CONFIG_ML_MAX_INPUTS); + _ODP_PRINT("CONFIG_ML_MAX_OUTPUTS: %i\n", CONFIG_ML_MAX_OUTPUTS); _ODP_PRINT("\n"); } diff --git a/platform/linux-generic/test/inline-timer.conf b/platform/linux-generic/test/inline-timer.conf index d645bef3c..fa3b6982f 100644 --- a/platform/linux-generic/test/inline-timer.conf +++ b/platform/linux-generic/test/inline-timer.conf @@ -1,6 +1,6 @@ # Mandatory fields odp_implementation = "linux-generic" -config_file_version = "0.1.27" +config_file_version = "0.1.28" timer: { # Enable inline timer implementation diff --git a/platform/linux-generic/test/packet_align.conf b/platform/linux-generic/test/packet_align.conf index 427674bb2..fb1418348 100644 --- a/platform/linux-generic/test/packet_align.conf +++ b/platform/linux-generic/test/packet_align.conf @@ -1,6 +1,6 @@ # Mandatory fields odp_implementation = "linux-generic" -config_file_version = "0.1.27" +config_file_version = "0.1.28" pool: { pkt: { diff --git a/platform/linux-generic/test/process-mode.conf b/platform/linux-generic/test/process-mode.conf index 5bfcb9f2f..f4c6f7952 100644 --- a/platform/linux-generic/test/process-mode.conf +++ b/platform/linux-generic/test/process-mode.conf @@ -1,6 +1,6 @@ # Mandatory fields odp_implementation = "linux-generic" -config_file_version = "0.1.27" +config_file_version = "0.1.28" # Shared memory options shm: { diff --git a/platform/linux-generic/test/sched-basic.conf b/platform/linux-generic/test/sched-basic.conf index 1a401298e..8a6d0ac98 100644 --- a/platform/linux-generic/test/sched-basic.conf +++ b/platform/linux-generic/test/sched-basic.conf @@ -1,6 +1,6 @@ # Mandatory fields odp_implementation = "linux-generic" -config_file_version = "0.1.27" +config_file_version = "0.1.28" # Test scheduler with an odd spread value and without dynamic load balance sched_basic: { diff --git a/platform/linux-generic/test/stash-custom.conf b/platform/linux-generic/test/stash-custom.conf index b96c1cf45..6a2496303 100644 --- a/platform/linux-generic/test/stash-custom.conf +++ b/platform/linux-generic/test/stash-custom.conf @@ -1,6 +1,6 @@ # Mandatory fields odp_implementation = "linux-generic" -config_file_version = "0.1.27" +config_file_version = "0.1.28" # Test overflow safe stash variant stash: { |