SourceXtractorPlusPlus 1.0.3
SourceXtractor++, the next generation SExtractor
Loading...
Searching...
No Matches
OnnxTaskFactory.cpp
Go to the documentation of this file.
1
17
18#include <onnxruntime_cxx_api.h>
19
20#include <AlexandriaKernel/memory_tools.h>
21#include <NdArray/NdArray.h>
22
24
29
31
32namespace SourceXtractor {
33
38 std::stringstream prop_name;
39
40 std::string domain = model.getDomain();
41 if (!domain.empty()) {
42 prop_name << domain << '.';
43 }
44
45 std::string graph_name = model.getGraphName();
46 if (!graph_name.empty()) {
47 prop_name << graph_name << '.';
48 }
49
50 prop_name << model.getOutputName();
51
52 return prop_name.str();
53}
54
56
58 if (property_id == PropertyId::create<OnnxProperty>()) {
60 }
61 return nullptr;
62}
63
67
69 const auto& onnx_config = manager.getConfiguration<OnnxConfig>();
70 const auto& models = onnx_config.getModels();
71
72 for (auto model_path : models) {
73 auto model = std::make_shared<OnnxModel>(model_path);
74
75 if (model->getInputType() != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
76 throw Elements::Exception() << "Only ONNX models with float input are supported";
77 }
78
79 if (model->getInputShape().size() != 4) {
80 throw Elements::Exception() << "Expected 4 axes for the input layer, got " << model->getInputShape().size();
81 }
82
83 auto prop_name = generatePropertyName(*model);
84 onnx_logger.info() << "Output name will be " << prop_name;
85
86 m_model_infos.emplace_back(OnnxSourceTask::OnnxModelInfo {model, prop_name});
87
88 }
89}
90
91template<typename T>
92static void registerColumnConverter(OutputRegistry& registry, const OnnxSourceTask::OnnxModelInfo& model_info) {
93 auto key = model_info.prop_name;
94
96 model_info.prop_name, [key](const OnnxProperty& prop) {
97 return prop.getData<T>(key);
98 }, "", model_info.model->getModelPath()
99 );
100}
101
103 for (const auto& model_info : m_model_infos) {
104 switch (model_info.model->getOutputType()) {
105 case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
106 registerColumnConverter<float>(registry, model_info);
107 break;
108 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
109 registerColumnConverter<int32_t>(registry, model_info);
110 break;
111 default:
112 throw Elements::Exception() << "Unsupported output type: " << model_info.model->getOutputType();
113 }
114 }
115}
116
117} // end of namespace SourceXtractor
const std::vector< std::string > & getModels() const
Definition OnnxConfig.h:44
std::string getGraphName() const
Definition OnnxModel.h:128
std::string getDomain() const
Definition OnnxModel.h:124
std::string getOutputName() const
Definition OnnxModel.h:136
void reportConfigDependencies(Euclid::Configuration::ConfigManager &manager) const override
Registers all the Configuration dependencies.
std::shared_ptr< Task > createTask(const PropertyId &property_id) const override
Returns a Task producing a Property corresponding to the given PropertyId.
void registerPropertyInstances(OutputRegistry &registry) override
std::vector< OnnxSourceTask::OnnxModelInfo > m_model_infos
void configure(Euclid::Configuration::ConfigManager &manager) override
Method which should initialize the object.
void registerColumnConverter(std::string column_name, ColumnConverter< PropertyType, OutType > converter, std::string column_unit="", std::string column_description="")
Identifier used to set and retrieve properties.
Definition PropertyId.h:40
static PropertyId create(unsigned int index=0)
Definition PropertyId.h:45
T empty(T... args)
T make_shared(T... args)
static void registerColumnConverter(OutputRegistry &registry, const OnnxSourceTask::OnnxModelInfo &model_info)
static std::string generatePropertyName(const OnnxModel &model)
Elements::Logging onnx_logger
Logger for the ONNX plugin.
T str(T... args)