SourceXtractorPlusPlus
1.0.3
SourceXtractor++, the next generation SExtractor
Loading...
Searching...
No Matches
SEImplementation
SEImplementation
Common
OnnxModel.h
Go to the documentation of this file.
1
/*
2
* OnnxModel.h
3
*
4
* Created on: Feb 16, 2021
5
* Author: mschefer
6
*/
7
8
#ifndef _SEIMPLEMENTATION_COMMON_ONNXMODEL_H_
9
#define _SEIMPLEMENTATION_COMMON_ONNXMODEL_H_
10
11
#include <cstdint>
12
#include <map>
13
#include <memory>
14
#include <vector>
15
#include <list>
16
#include <iostream>
17
#include <numeric>
18
19
#include <onnxruntime_cxx_api.h>
20
21
namespace
SourceXtractor
{
22
23
class
OnnxModel
{
24
public
:
25
26
explicit
OnnxModel
(
const
std::string
& model_path);
27
28
template
<
typename
T,
typename
U>
29
void
run
(
std::vector<T>
& input_data,
std::vector<U>
& output_data)
const
{
30
Ort::RunOptions run_options;
31
auto
mem_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
32
33
// Allocate memory
34
std::vector<int64_t>
input_shape(
m_input_shapes
[0].
begin
(),
m_input_shapes
[0].
end
());
35
input_shape[0] = 1;
36
size_t
input_size =
std::accumulate
(input_shape.
begin
(), input_shape.
end
(), 1u,
std::multiplies<size_t>
());
37
38
std::vector<int64_t>
output_shape(
m_output_shape
.begin(),
m_output_shape
.end());
39
output_shape[0] = 1;
40
size_t
output_size =
std::accumulate
(output_shape.
begin
(), output_shape.
end
(), 1u,
std::multiplies<size_t>
());
41
42
// Check input and output size are OK
43
if
(input_data.
size
() < input_size || output_data.
size
() < output_size) {
44
throw
Elements::Exception
() <<
"OnnxModel: Insufficient buffer size "
;
45
}
46
47
// Setup input/output tensors
48
auto
input_tensor = Ort::Value::CreateTensor<T>(
49
mem_info, input_data.
data
(), input_data.
size
(), input_shape.
data
(), input_shape.
size
());
50
auto
output_tensor = Ort::Value::CreateTensor<U>(
51
mem_info, output_data.
data
(), output_data.
size
(), output_shape.
data
(), output_shape.
size
());
52
53
// Run the model
54
const
char
*input_name =
m_input_names
[0].c_str();
55
const
char
*output_name =
m_output_name
.c_str();
56
57
m_session
->Run(run_options, &input_name, &input_tensor, 1, &output_name, &output_tensor, 1);
58
}
59
60
template
<
typename
T,
typename
U>
61
void
runMultiInput
(
std::map
<
std::string
,
std::vector<T>
>& input_data,
std::vector<U>
& output_data)
const
{
62
Ort::RunOptions run_options;
63
auto
mem_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
64
65
std::vector<const char *>
input_names;
66
std::vector<Ort::Value>
input_tensors;
67
68
int
inputs_nb =
m_input_names
.size();
69
for
(
int
i=0; i<inputs_nb; i++) {
70
input_names.
emplace_back
(
m_input_names
[i].c_str());
71
72
// Allocate memory
73
std::vector<int64_t>
input_shape(
m_input_shapes
[i].
begin
(),
m_input_shapes
[i].
end
());
74
input_shape[0] = 1;
75
size_t
input_size =
std::accumulate
(input_shape.
begin
(), input_shape.
end
(), 1u,
std::multiplies<size_t>
());
76
77
// Check input size is OK
78
if
(input_data[
m_input_names
[i]].size() < input_size) {
79
throw
Elements::Exception
() <<
"OnnxModel: Insufficient buffer size "
;
80
}
81
82
input_tensors.
emplace_back
(Ort::Value::CreateTensor<T>(
83
mem_info, input_data[
m_input_names
[i]].data(), input_data[
m_input_names
[i]].size(),
84
input_shape.
data
(), input_shape.
size
()));
85
}
86
87
// Output name and shape
88
const
char
*output_name =
m_output_name
.c_str();
89
std::vector<int64_t>
output_shape(
m_output_shape
.begin(),
m_output_shape
.end());
90
output_shape[0] = 1;
91
92
// Setup output tensor
93
size_t
output_size =
std::accumulate
(output_shape.
begin
(), output_shape.
end
(), 1u,
std::multiplies<size_t>
());
94
95
// Check output and output size are OK
96
if
(output_data.
size
() < output_size) {
97
throw
Elements::Exception
() <<
"OnnxModel: Insufficient buffer size "
;
98
}
99
100
auto
output_tensor = Ort::Value::CreateTensor<U>(
101
mem_info, output_data.
data
(), output_data.
size
(), output_shape.
data
(), output_shape.
size
());
102
103
// Run the model
104
m_session
->Run(run_options, &input_names[0], &input_tensors[0], inputs_nb, &output_name, &output_tensor, 1);
105
}
106
107
108
ONNXTensorElementDataType
getInputType
()
const
{
109
return
m_input_types
[0];
110
}
111
112
ONNXTensorElementDataType
getOutputType
()
const
{
113
return
m_output_type
;
114
}
115
116
const
std::vector<std::int64_t>
&
getInputShape
()
const
{
117
return
m_input_shapes
[0];
118
}
119
120
const
std::vector<std::int64_t>
&
getOutputShape
()
const
{
121
return
m_output_shape
;
122
}
123
124
std::string
getDomain
()
const
{
125
return
m_domain_name
;
126
}
127
128
std::string
getGraphName
()
const
{
129
return
m_graph_name
;
130
}
131
132
std::string
getInputName
()
const
{
133
return
m_input_names
[0];
134
}
135
136
std::string
getOutputName
()
const
{
137
return
m_output_name
;
138
}
139
140
std::string
getModelPath
()
const
{
141
return
m_model_path
;
142
}
143
144
size_t
getInputNb
()
const
{
145
return
m_input_names
.size();
146
}
147
148
size_t
getOutputNb
()
const
{
149
return
1U;
150
}
151
152
private
:
153
std::string
m_domain_name
;
154
std::string
m_graph_name
;
155
std::vector<std::string>
m_input_names
;
156
std::string
m_output_name
;
157
std::vector<ONNXTensorElementDataType>
m_input_types
;
158
ONNXTensorElementDataType
m_output_type
;
159
std::vector<std::vector<std::int64_t>
>
m_input_shapes
;
160
std::vector<std::int64_t>
m_output_shape
;
161
std::string
m_model_path
;
162
std::unique_ptr<Ort::Session>
m_session
;
163
};
164
165
}
166
167
168
#endif
/* _SEIMPLEMENTATION_COMMON_ONNXMODEL_H_ */
std::accumulate
T accumulate(T... args)
std::string
std::begin
T begin(T... args)
Elements::Exception
SourceXtractor::OnnxModel::run
void run(std::vector< T > &input_data, std::vector< U > &output_data) const
Definition
OnnxModel.h:29
SourceXtractor::OnnxModel::getInputType
ONNXTensorElementDataType getInputType() const
Definition
OnnxModel.h:108
SourceXtractor::OnnxModel::getOutputType
ONNXTensorElementDataType getOutputType() const
Definition
OnnxModel.h:112
SourceXtractor::OnnxModel::m_input_types
std::vector< ONNXTensorElementDataType > m_input_types
Input type.
Definition
OnnxModel.h:157
SourceXtractor::OnnxModel::m_session
std::unique_ptr< Ort::Session > m_session
Session, one per model. In theory, it is thread-safe.
Definition
OnnxModel.h:162
SourceXtractor::OnnxModel::getGraphName
std::string getGraphName() const
Definition
OnnxModel.h:128
SourceXtractor::OnnxModel::getDomain
std::string getDomain() const
Definition
OnnxModel.h:124
SourceXtractor::OnnxModel::m_output_name
std::string m_output_name
Output tensor name.
Definition
OnnxModel.h:156
SourceXtractor::OnnxModel::getOutputNb
size_t getOutputNb() const
Definition
OnnxModel.h:148
SourceXtractor::OnnxModel::getOutputShape
const std::vector< std::int64_t > & getOutputShape() const
Definition
OnnxModel.h:120
SourceXtractor::OnnxModel::getOutputName
std::string getOutputName() const
Definition
OnnxModel.h:136
SourceXtractor::OnnxModel::m_output_type
ONNXTensorElementDataType m_output_type
Output type.
Definition
OnnxModel.h:158
SourceXtractor::OnnxModel::m_input_names
std::vector< std::string > m_input_names
Input tensor name.
Definition
OnnxModel.h:155
SourceXtractor::OnnxModel::OnnxModel
OnnxModel(const std::string &model_path)
Definition
OnnxModel.cpp:17
SourceXtractor::OnnxModel::getInputName
std::string getInputName() const
Definition
OnnxModel.h:132
SourceXtractor::OnnxModel::m_output_shape
std::vector< std::int64_t > m_output_shape
Output tensor shape.
Definition
OnnxModel.h:160
SourceXtractor::OnnxModel::m_graph_name
std::string m_graph_name
graph name
Definition
OnnxModel.h:154
SourceXtractor::OnnxModel::runMultiInput
void runMultiInput(std::map< std::string, std::vector< T > > &input_data, std::vector< U > &output_data) const
Definition
OnnxModel.h:61
SourceXtractor::OnnxModel::m_domain_name
std::string m_domain_name
domain name
Definition
OnnxModel.h:153
SourceXtractor::OnnxModel::getInputShape
const std::vector< std::int64_t > & getInputShape() const
Definition
OnnxModel.h:116
SourceXtractor::OnnxModel::m_model_path
std::string m_model_path
Path to the ONNX model.
Definition
OnnxModel.h:161
SourceXtractor::OnnxModel::getInputNb
size_t getInputNb() const
Definition
OnnxModel.h:144
SourceXtractor::OnnxModel::m_input_shapes
std::vector< std::vector< std::int64_t > > m_input_shapes
Input tensor shape.
Definition
OnnxModel.h:159
SourceXtractor::OnnxModel::getModelPath
std::string getModelPath() const
Definition
OnnxModel.h:140
std::vector::data
T data(T... args)
std::vector::emplace_back
T emplace_back(T... args)
std::end
T end(T... args)
std::map
std::multiplies
SourceXtractor
Definition
Aperture.h:30
std::vector::size
T size(T... args)
std::unique_ptr
std::vector
Generated by
1.14.0