SourceXtractorPlusPlus 1.0.3
SourceXtractor++, the next generation SExtractor
Loading...
Searching...
No Matches
OnnxModel.cpp
Go to the documentation of this file.
1/*
2 * OnnxModel.cpp
3 *
4 * Created on: Feb 16, 2021
5 * Author: mschefer
6 */
7
8#include "ElementsKernel/Exception.h"
9#include "ElementsKernel/Logging.h"
10#include "AlexandriaKernel/memory_tools.h"
11
14
15namespace SourceXtractor {
16
18 m_model_path = model_path;
19
21 auto allocator = Ort::AllocatorWithDefaultOptions();
22
23 onnx_logger.info() << "Loading ONNX model " << model_path;
24 m_session = Euclid::make_unique<Ort::Session>(ORT_ENV, model_path.c_str(), Ort::SessionOptions{nullptr});
25
26 if (m_session->GetOutputCount() != 1) {
27 throw Elements::Exception() << "Only ONNX models with a single output tensor are supported";
28 }
29
30 for (size_t i=0; i<m_session->GetInputCount(); i++) {
31 auto input_type = m_session->GetInputTypeInfo(i);
32
33 m_input_names.emplace_back(m_session->GetInputNameAllocated(i, allocator).get());
34 m_input_shapes.emplace_back(input_type.GetTensorTypeAndShapeInfo().GetShape());
35 m_input_types.emplace_back(input_type.GetTensorTypeAndShapeInfo().GetElementType());
36 }
37
38 m_output_name = std::string(m_session->GetOutputNameAllocated(0, allocator).get());
39 m_domain_name = std::string(m_session->GetModelMetadata().GetDomainAllocated(allocator).get());
40 m_graph_name = std::string(m_session->GetModelMetadata().GetGraphNameAllocated(allocator).get());
41
42 auto output_type = m_session->GetOutputTypeInfo(0);
43
44 m_output_shape = output_type.GetTensorTypeAndShapeInfo().GetShape();
45 m_output_type = output_type.GetTensorTypeAndShapeInfo().GetElementType();
46
47// onnx_logger.info() << "ONNX model with input of " << formatShape(m_input_shapes[0]);
48// onnx_logger.info() << "ONNX model with output of " << formatShape(m_output_shape);
49}
50
51}
T c_str(T... args)
static Logging getLogger(const std::string &name="")
std::vector< ONNXTensorElementDataType > m_input_types
Input type.
Definition OnnxModel.h:157
std::unique_ptr< Ort::Session > m_session
Session, one per model. In theory, it is thread-safe.
Definition OnnxModel.h:162
std::string m_output_name
Output tensor name.
Definition OnnxModel.h:156
ONNXTensorElementDataType m_output_type
Output type.
Definition OnnxModel.h:158
std::vector< std::string > m_input_names
Input tensor name.
Definition OnnxModel.h:155
OnnxModel(const std::string &model_path)
Definition OnnxModel.cpp:17
std::vector< std::int64_t > m_output_shape
Output tensor shape.
Definition OnnxModel.h:160
std::string m_graph_name
graph name
Definition OnnxModel.h:154
std::string m_domain_name
domain name
Definition OnnxModel.h:153
std::string m_model_path
Path to the ONNX model.
Definition OnnxModel.h:161
std::vector< std::vector< std::int64_t > > m_input_shapes
Input tensor shape.
Definition OnnxModel.h:159
std::unique_ptr< T > make_unique(Args &&... args)
Elements::Logging onnx_logger
Logger for the ONNX plugin.
Ort::Env ORT_ENV