SourceXtractorPlusPlus
1.0.3
SourceXtractor++, the next generation SExtractor
Loading...
Searching...
No Matches
SEImplementation
src
lib
Plugin
Onnx
OnnxTaskFactory.cpp
Go to the documentation of this file.
1
17
18
#include <onnxruntime_cxx_api.h>
19
20
#include <AlexandriaKernel/memory_tools.h>
21
#include <NdArray/NdArray.h>
22
23
#include "
SEImplementation/Common/OnnxCommon.h
"
24
25
#include "
SEImplementation/Plugin/Onnx/OnnxPlugin.h
"
26
#include "
SEImplementation/Plugin/Onnx/OnnxSourceTask.h
"
27
#include "
SEImplementation/Plugin/Onnx/OnnxProperty.h
"
28
#include "
SEImplementation/Plugin/Onnx/OnnxConfig.h
"
29
30
#include "
SEImplementation/Plugin/Onnx/OnnxTaskFactory.h
"
31
32
namespace
SourceXtractor
{
33
37
static
std::string
generatePropertyName
(
const
OnnxModel
& model) {
38
std::stringstream
prop_name;
39
40
std::string
domain = model.
getDomain
();
41
if
(!domain.
empty
()) {
42
prop_name << domain <<
'.'
;
43
}
44
45
std::string
graph_name = model.
getGraphName
();
46
if
(!graph_name.
empty
()) {
47
prop_name << graph_name <<
'.'
;
48
}
49
50
prop_name << model.
getOutputName
();
51
52
return
prop_name.
str
();
53
}
54
55
OnnxTaskFactory::OnnxTaskFactory
() {}
56
57
std::shared_ptr<Task>
OnnxTaskFactory::createTask
(
const
PropertyId
& property_id)
const
{
58
if
(property_id ==
PropertyId::create<OnnxProperty>
()) {
59
return
std::make_shared<OnnxSourceTask>
(
m_model_infos
);
60
}
61
return
nullptr
;
62
}
63
64
void
OnnxTaskFactory::reportConfigDependencies
(
Euclid::Configuration::ConfigManager
& manager)
const
{
65
manager.
registerConfiguration
<
OnnxConfig
>();
66
}
67
68
void
OnnxTaskFactory::configure
(
Euclid::Configuration::ConfigManager
& manager) {
69
const
auto
& onnx_config = manager.
getConfiguration
<
OnnxConfig
>();
70
const
auto
& models = onnx_config.
getModels
();
71
72
for
(
auto
model_path : models) {
73
auto
model =
std::make_shared<OnnxModel>
(model_path);
74
75
if
(model->getInputType() != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
76
throw
Elements::Exception
() <<
"Only ONNX models with float input are supported"
;
77
}
78
79
if
(model->getInputShape().size() != 4) {
80
throw
Elements::Exception
() <<
"Expected 4 axes for the input layer, got "
<< model->getInputShape().size();
81
}
82
83
auto
prop_name =
generatePropertyName
(*model);
84
onnx_logger
.info() <<
"Output name will be "
<< prop_name;
85
86
m_model_infos
.emplace_back(
OnnxSourceTask::OnnxModelInfo
{model, prop_name});
87
88
}
89
}
90
91
template
<
typename
T>
92
static
void
registerColumnConverter
(
OutputRegistry
& registry,
const
OnnxSourceTask::OnnxModelInfo
& model_info) {
93
auto
key = model_info.
prop_name
;
94
95
registry.
registerColumnConverter
<
OnnxProperty
,
Euclid::NdArray::NdArray<T>
>(
96
model_info.
prop_name
, [key](
const
OnnxProperty
& prop) {
97
return
prop.getData<T>(key);
98
},
""
, model_info.
model
->getModelPath()
99
);
100
}
101
102
void
OnnxTaskFactory::registerPropertyInstances
(
OutputRegistry
& registry) {
103
for
(
const
auto
& model_info :
m_model_infos
) {
104
switch
(model_info.model->getOutputType()) {
105
case
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
106
registerColumnConverter<float>
(registry, model_info);
107
break
;
108
case
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
109
registerColumnConverter<int32_t>
(registry, model_info);
110
break
;
111
default
:
112
throw
Elements::Exception
() <<
"Unsupported output type: "
<< model_info.model->getOutputType();
113
}
114
}
115
}
116
117
}
// end of namespace SourceXtractor
OnnxCommon.h
OnnxConfig.h
OnnxPlugin.h
OnnxProperty.h
OnnxSourceTask.h
OnnxTaskFactory.h
std::string
std::stringstream
Elements::Exception
Euclid::Configuration::ConfigManager
Euclid::Configuration::ConfigManager::registerConfiguration
void registerConfiguration()
Euclid::Configuration::ConfigManager::getConfiguration
T & getConfiguration()
Euclid::NdArray::NdArray
SourceXtractor::OnnxConfig
Definition
OnnxConfig.h:28
SourceXtractor::OnnxConfig::getModels
const std::vector< std::string > & getModels() const
Definition
OnnxConfig.h:44
SourceXtractor::OnnxModel
Definition
OnnxModel.h:23
SourceXtractor::OnnxModel::getGraphName
std::string getGraphName() const
Definition
OnnxModel.h:128
SourceXtractor::OnnxModel::getDomain
std::string getDomain() const
Definition
OnnxModel.h:124
SourceXtractor::OnnxModel::getOutputName
std::string getOutputName() const
Definition
OnnxModel.h:136
SourceXtractor::OnnxProperty
Definition
OnnxProperty.h:30
SourceXtractor::OnnxTaskFactory::reportConfigDependencies
void reportConfigDependencies(Euclid::Configuration::ConfigManager &manager) const override
Registers all the Configuration dependencies.
Definition
OnnxTaskFactory.cpp:64
SourceXtractor::OnnxTaskFactory::OnnxTaskFactory
OnnxTaskFactory()
Definition
OnnxTaskFactory.cpp:55
SourceXtractor::OnnxTaskFactory::createTask
std::shared_ptr< Task > createTask(const PropertyId &property_id) const override
Returns a Task producing a Property corresponding to the given PropertyId.
Definition
OnnxTaskFactory.cpp:57
SourceXtractor::OnnxTaskFactory::registerPropertyInstances
void registerPropertyInstances(OutputRegistry ®istry) override
Definition
OnnxTaskFactory.cpp:102
SourceXtractor::OnnxTaskFactory::m_model_infos
std::vector< OnnxSourceTask::OnnxModelInfo > m_model_infos
Definition
OnnxTaskFactory.h:49
SourceXtractor::OnnxTaskFactory::configure
void configure(Euclid::Configuration::ConfigManager &manager) override
Method which should initialize the object.
Definition
OnnxTaskFactory.cpp:68
SourceXtractor::OutputRegistry
Definition
OutputRegistry.h:37
SourceXtractor::OutputRegistry::registerColumnConverter
void registerColumnConverter(std::string column_name, ColumnConverter< PropertyType, OutType > converter, std::string column_unit="", std::string column_description="")
Definition
OutputRegistry.h:47
SourceXtractor::PropertyId
Identifier used to set and retrieve properties.
Definition
PropertyId.h:40
SourceXtractor::PropertyId::create
static PropertyId create(unsigned int index=0)
Definition
PropertyId.h:45
std::string::empty
T empty(T... args)
std::make_shared
T make_shared(T... args)
SourceXtractor
Definition
Aperture.h:30
SourceXtractor::registerColumnConverter
static void registerColumnConverter(OutputRegistry ®istry, const OnnxSourceTask::OnnxModelInfo &model_info)
Definition
OnnxTaskFactory.cpp:92
SourceXtractor::generatePropertyName
static std::string generatePropertyName(const OnnxModel &model)
Definition
OnnxTaskFactory.cpp:37
SourceXtractor::onnx_logger
Elements::Logging onnx_logger
Logger for the ONNX plugin.
Definition
OnnxPlugin.cpp:26
std::shared_ptr
std::stringstream::str
T str(T... args)
SourceXtractor::OnnxSourceTask::OnnxModelInfo
Definition
OnnxSourceTask.h:31
SourceXtractor::OnnxSourceTask::OnnxModelInfo::model
std::shared_ptr< OnnxModel > model
Definition
OnnxSourceTask.h:32
SourceXtractor::OnnxSourceTask::OnnxModelInfo::prop_name
std::string prop_name
Definition
OnnxSourceTask.h:33
Generated by
1.14.0