SourceXtractorPlusPlus 1.0.3
SourceXtractor++, the next generation SExtractor
Loading...
Searching...
No Matches
MLSegmentation.cpp
Go to the documentation of this file.
1
17
18#include <memory>
19#include <vector>
20#include <list>
21#include <iostream>
22
23#include <onnxruntime_cxx_api.h>
24
27
29
32
35
37
40
42
43
46
49
51
52
53
54namespace SourceXtractor {
55
56namespace {
57class LutzLabellingListener : public Lutz::LutzListener {
58public:
59 LutzLabellingListener(Segmentation::LabellingListener& listener, std::shared_ptr<SourceFactory> source_factory,
60 int window_size) :
61 m_listener(listener),
62 m_source_factory(source_factory),
63 m_window_size(window_size) {}
64
65 virtual ~LutzLabellingListener() = default;
66
67 void publishGroup(Lutz::PixelGroup& pixel_group) override {
68 auto source = m_source_factory->createSource();
69 source->setProperty<PixelCoordinateList>(pixel_group.pixel_list);
70 source->setProperty<SourceId>();
71 m_listener.publishSource(std::move(source));
72 }
73
74 void notifyProgress(int line, int total) override {
75 m_listener.notifyProgress(line, total);
76
77 if (m_window_size > 0 && line > m_window_size) {
78 m_listener.requestProcessing(
79 ProcessSourcesEvent(std::make_shared<LineSelectionCriteria>(line - m_window_size))
80 );
81 }
82 }
83
84private:
85 Segmentation::LabellingListener& m_listener;
86 std::shared_ptr<SourceFactory> m_source_factory;
87 int m_window_size;
88};
89
90
91}
92
95
97
98 auto input_shape = model.getInputShape();
99 auto output_shape = model.getOutputShape();
100
101 // TODO add sanity check
102
103 int tile_size = output_shape[1];
104 int data_planes = output_shape[3];
105 float average_rms = frame->getBackgroundMedianRms();
106 float detection_threshold = m_ml_threshold;
107
108 onnx_logger.info() << "Onnx tile size: " << tile_size << " Data planes: " << data_planes << " RMS: " << average_rms;
109
110 if (model.getInputType() != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
111 throw Elements::Exception() << "Only ONNX models with float input are supported";
112 }
113
114 if (model.getOutputType() != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
115 throw Elements::Exception() << "Only ONNX models with float output are supported";
116 }
117
118 if (model.getInputNb() != 1) {
119 throw Elements::Exception() << "Only ONNX models with a single input tensor are supported";
120 }
121
122 // allocate memory
123 std::vector<float> input_data(tile_size * tile_size);
124 std::vector<float> output_data(tile_size * tile_size * data_planes);
125
126 auto image = frame->getSubtractedImage();
127 ImageAccessor<SeFloat> image_acc(image);
128
131 for (int i=0; i < data_planes; i++) {
132 tmp_images.emplace_back(FitsWriter::newTemporaryImage<float>("_tmp_ml_seg%%%%%%.fits", image->getWidth(), image->getHeight()));
133 check_images.emplace_back(CheckImages::getInstance().getMLDetectionImage(i, frame->getHduIndex()));
134 }
135
136 Lutz lutz;
137 LutzLabellingListener lutz_listener(listener, m_source_factory, 0);
138
139 for (int ox = 0; ox + tile_size * 3 / 4 < image->getWidth(); ox += tile_size / 2) {
140 for (int oy = 0; oy + tile_size * 3 / 4 < image->getHeight(); oy += tile_size / 2) {
141
142 for (int x = 0; x < tile_size; x++) {
143 for (int y = 0; y < tile_size; y++) {
144 if (ox+x < image->getWidth() && oy+y < image->getHeight()) {
145 input_data[x+y*tile_size] = image_acc.getValue(ox+x, oy+y) / average_rms;
146 } else {
147 input_data[x+y*tile_size] = 0;
148 }
149 }
150 }
151
152 model.run<float, float>(input_data, output_data);
153
154 int start_x = (ox == 0) ? 0 : tile_size / 4;
155 int start_y = (oy == 0) ? 0 : tile_size / 4;
156
157 int end_x = (ox + tile_size * 5 / 4 < image->getWidth()) ? tile_size * 3 / 4 : tile_size ;
158 int end_y = (oy + tile_size * 5 / 4 < image->getHeight()) ? tile_size * 3 / 4 : tile_size;
159
160 for (int x = start_x; x < end_x; x++) {
161 for (int y = start_y; y < end_y; y++) {
162 if (ox+x < image->getWidth() && oy+y < image->getHeight()) {
163 for (int i=0; i<data_planes; i++) {
164 tmp_images[i]->setValue(ox + x, oy + y, output_data[(x+y*tile_size) * data_planes + i] - detection_threshold);
165 if (check_images[i] != nullptr) {
166 check_images[i]->setValue(ox+x, oy+y, output_data[(x+y*tile_size) * data_planes + i]);
167 }
168 }
169 }
170 }
171 }
172 }
173 }
174 for (int i=0; i<data_planes; i++) {
175 lutz.labelImage(lutz_listener, *tmp_images[i]);
176 }
177}
178
179}
static Logging getLogger(const std::string &name="")
static CheckImages & getInstance()
static std::shared_ptr< WriteableImage< T > > newTemporaryImage(const std::string &pattern, int width, int height)
Definition FitsWriter.h:77
Implements a Segmentation based on the Lutz algorithm.
Definition Lutz.h:37
void labelImage(LutzListener &listener, const DetectionImage &image, PixelCoordinate offset=PixelCoordinate(0, 0))
Definition Lutz.cpp:59
void labelImage(Segmentation::LabellingListener &listener, std::shared_ptr< const DetectionImageFrame > frame) override
std::shared_ptr< SourceFactory > m_source_factory
void run(std::vector< T > &input_data, std::vector< U > &output_data) const
Definition OnnxModel.h:29
ONNXTensorElementDataType getInputType() const
Definition OnnxModel.h:108
ONNXTensorElementDataType getOutputType() const
Definition OnnxModel.h:112
const std::vector< std::int64_t > & getOutputShape() const
Definition OnnxModel.h:120
const std::vector< std::int64_t > & getInputShape() const
Definition OnnxModel.h:116
size_t getInputNb() const
Definition OnnxModel.h:144
T emplace_back(T... args)
T make_shared(T... args)
T move(T... args)
Elements::Logging onnx_logger
Logger for the ONNX plugin.